[
  {
    "path": ".gitignore",
    "content": "*.pyc\n\nmodels/\ndata/\noutput/\nwandb/\n"
  },
  {
    "path": "README.md",
    "content": "# PhysDreamer: Physics-Based Interaction with 3D Objects via Video Generation [[website](https://physdreamer.github.io/)]\n\n![teaser-figure](figures/figure_teaser.png)\n\n## Useage\n\n### Setup enviroment\n\nInstall diff-gaussian-rasterization at: https://github.com/graphdeco-inria/diff-gaussian-rasterization\n   \n```bash\nconda create -n physdreamer python\nconda activate physdreamer\n\npip install -r requirements.txt\n\npython setup.py install\n```\n\n### Download the scenes and optimized models from Hugging Face\n\nDownload the scenes and optimized velocity and material fields from: https://huggingface.co/datasets/YunjinZhang/PhysDreamer/tree/main\n\nPut folders of these scenes to `data/physics_dreamer/xxx`, e.g. `data/physics_dreamer/carnations`\n\nPut pretrained models to `./models`. \n\nSee `dataset_dir` and `model_list` in  `inference/configs/carnation.py` to match the path of dataset and pretrained models. \n\n\n### Run inference\n\n```bash\ncd projects/inference\nbash run.sh\n```\n\n\n## Acknowledgement\nThis codebase used lots of source code from: \n1. https://github.com/graphdeco-inria/gaussian-splatting\n2. https://github.com/zeshunzong/warp-mpm\n3. https://github.com/PingchuanMa/NCLaw\n\nWe thank the authors of these projects.\n\n\n## Citations\n```\n@article{zhang2024physdreamer,\n    title={{PhysDreamer}: Physics-Based Interaction with 3D Objects via Video Generation},\n    author={Tianyuan Zhang and Hong-Xing Yu and Rundi Wu and\n            Brandon Y. Feng and Changxi Zheng and Noah Snavely and Jiajun Wu and William T. Freeman},\n    journal={arxiv},\n    year={2024}\n}\n```\n"
  },
  {
    "path": "physdreamer/field_components/encoding.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Optional, Sequence, Tuple, List\nfrom physdreamer.losses.smoothness_loss import (\n    compute_plane_smoothness,\n    compute_plane_tv,\n)\n\n\nclass TemporalKplanesEncoding(nn.Module):\n    \"\"\"\n\n    Args:\n        resolutions (Sequence[int]): xyzt resolutions.\n    \"\"\"\n\n    def __init__(\n        self,\n        resolutions: Sequence[int],\n        feat_dim: int = 32,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce=\"sum\",  # Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n    ):\n        super().__init__()\n\n        self.resolutions = resolutions\n\n        if reduce == \"cat\":\n            feat_dim = feat_dim // 3\n        self.feat_dim = feat_dim\n\n        self.reduce = reduce\n\n        self.in_dim = 4\n\n        self.plane_coefs = nn.ParameterList()\n\n        self.coo_combs = [[0, 3], [1, 3], [2, 3]]\n        # [(x, t), (y, t), (z, t)]\n        for coo_comb in self.coo_combs:\n            # [feat_dim, time_resolution, spatial_resolution]\n            new_plane_coef = nn.Parameter(\n                torch.empty(\n                    [\n                        self.feat_dim,\n                        resolutions[coo_comb[1]],\n                        resolutions[coo_comb[0]],  # flip?\n                    ]\n                )\n            )\n\n            # when init to ones?\n\n            nn.init.uniform_(new_plane_coef, a=init_a, b=init_b)\n            self.plane_coefs.append(new_plane_coef)\n\n    def forward(self, inp: Float[Tensor, \"*bs 4\"]):\n        output = 1.0 if self.reduce == \"product\" else 0.0\n        if self.reduce == \"cat\":\n            output = []\n        for ci, coo_comb in enumerate(self.coo_combs):\n            grid = self.plane_coefs[ci].unsqueeze(0)  # [1, feature_dim, reso1, reso2]\n            coords = inp[..., coo_comb].view(1, 1, -1, 2)  # [1, 1, flattened_bs, 2]\n\n            interp = F.grid_sample(\n                grid, coords, align_corners=True, padding_mode=\"border\"\n            )  # [1, output_dim, 1, flattened_bs]\n            interp = interp.view(self.feat_dim, -1).T  # [flattened_bs, output_dim]\n\n            if self.reduce == \"product\":\n                output = output * interp\n            elif self.reduce == \"sum\":\n                output = output + interp\n            elif self.reduce == \"cat\":\n                output.append(interp)\n\n        if self.reduce == \"cat\":\n            # [flattened_bs, output_dim * 3]\n            output = torch.cat(output, dim=-1)\n\n        return output\n\n    def compute_temporal_smoothness(\n        self,\n    ):\n        ret_loss = 0.0\n\n        for plane_coef in self.plane_coefs:\n            ret_loss += compute_plane_smoothness(plane_coef)\n\n        return ret_loss\n\n    def compute_plane_tv(\n        self,\n    ):\n        ret_loss = 0.0\n\n        for plane_coef in self.plane_coefs:\n            ret_loss += compute_plane_tv(plane_coef)\n\n        return ret_loss\n\n    def visualize(\n        self,\n    ) -> Tuple[Float[Tensor, \"3 H W\"]]:\n        \"\"\"Visualize the encoding as a RGB images\n\n        Returns:\n            Tuple[Float[Tensor, \"3 H W\"]]\n        \"\"\"\n        pass\n\n    @staticmethod\n    def functional_forward(\n        plane_coefs: List[Float[Tensor, \"feat_dim H W\"]],\n        inp: Float[Tensor, \"*bs 4\"],\n        reduce: str = \"sum\",\n        coo_combs: Optional[List[List[int]]] = [[0, 3], [1, 3], [2, 3]],\n    ):\n        assert reduce in [\"sum\", \"product\", \"cat\"]\n        output = 1.0 if reduce == \"product\" else 0.0\n\n        if reduce == \"cat\":\n            output = []\n        for ci, coo_comb in enumerate(coo_combs):\n            grid = plane_coefs[ci].unsqueeze(0)  # [1, feature_dim, reso1, reso2]\n            feat_dim = grid.shape[1]\n            coords = inp[..., coo_comb].view(1, 1, -1, 2)  # [1, 1, flattened_bs, 2]\n\n            interp = F.grid_sample(\n                grid, coords, align_corners=True, padding_mode=\"border\"\n            )  # [1, output_dim, 1, flattened_bs]\n            interp = interp.view(feat_dim, -1).T  # [flattened_bs, output_dim]\n\n            if reduce == \"product\":\n                output = output * interp\n            elif reduce == \"sum\":\n                output = output + interp\n            elif reduce == \"cat\":\n                output.append(interp)\n\n        if reduce == \"cat\":\n            # [flattened_bs, output_dim * 3]\n            output = torch.cat(output, dim=-1)\n\n        return output\n\n\nclass TriplanesEncoding(nn.Module):\n    \"\"\"\n\n    Args:\n        resolutions (Sequence[int]): xyz resolutions.\n    \"\"\"\n\n    def __init__(\n        self,\n        resolutions: Sequence[int],\n        feat_dim: int = 32,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce=\"sum\",  # Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n    ):\n        super().__init__()\n\n        self.resolutions = resolutions\n\n        if reduce == \"cat\":\n            feat_dim = feat_dim  #  // 3\n        self.feat_dim = feat_dim\n\n        self.reduce = reduce\n\n        self.in_dim = 3\n\n        self.plane_coefs = nn.ParameterList()\n\n        self.coo_combs = [[0, 1], [0, 2], [1, 2]]\n        # [(x, t), (y, t), (z, t)]\n        for coo_comb in self.coo_combs:\n            new_plane_coef = nn.Parameter(\n                torch.empty(\n                    [\n                        self.feat_dim,\n                        resolutions[coo_comb[1]],\n                        resolutions[coo_comb[0]],\n                    ]\n                )\n            )\n\n            # when init to ones?\n\n            nn.init.uniform_(new_plane_coef, a=init_a, b=init_b)\n            self.plane_coefs.append(new_plane_coef)\n\n    def forward(self, inp: Float[Tensor, \"*bs 3\"]):\n        output = 1.0 if self.reduce == \"product\" else 0.0\n        if self.reduce == \"cat\":\n            output = []\n        for ci, coo_comb in enumerate(self.coo_combs):\n            grid = self.plane_coefs[ci].unsqueeze(0)  # [1, feature_dim, reso1, reso2]\n            coords = inp[..., coo_comb].view(1, 1, -1, 2)  # [1, 1, flattened_bs, 2]\n\n            interp = F.grid_sample(\n                grid, coords, align_corners=True, padding_mode=\"border\"\n            )  # [1, output_dim, 1, flattened_bs]\n            interp = interp.view(self.feat_dim, -1).T  # [flattened_bs, output_dim]\n\n            if self.reduce == \"product\":\n                output = output * interp\n            elif self.reduce == \"sum\":\n                output = output + interp\n            elif self.reduce == \"cat\":\n                output.append(interp)\n\n        if self.reduce == \"cat\":\n            # [flattened_bs, output_dim * 3]\n            output = torch.cat(output, dim=-1)\n\n        return output\n\n    def compute_plane_tv(\n        self,\n    ):\n        ret_loss = 0.0\n\n        for plane_coef in self.plane_coefs:\n            ret_loss += compute_plane_tv(plane_coef)\n\n        return ret_loss\n\n\nclass PlaneEncoding(nn.Module):\n    \"\"\"\n\n    Args:\n        resolutions (Sequence[int]): xyz resolutions.\n    \"\"\"\n\n    def __init__(\n        self,\n        resolutions: Sequence[int],  # [y_res, x_res]\n        feat_dim: int = 32,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n    ):\n        super().__init__()\n\n        self.resolutions = resolutions\n\n        self.feat_dim = feat_dim\n        self.in_dim = 2\n\n        self.plane_coefs = nn.ParameterList()\n\n        self.coo_combs = [[0, 1]]\n        for coo_comb in self.coo_combs:\n            new_plane_coef = nn.Parameter(\n                torch.empty(\n                    [\n                        self.feat_dim,\n                        resolutions[coo_comb[1]],\n                        resolutions[coo_comb[0]],\n                    ]\n                )\n            )\n\n            # when init to ones?\n\n            nn.init.uniform_(new_plane_coef, a=init_a, b=init_b)\n            self.plane_coefs.append(new_plane_coef)\n\n    def forward(self, inp: Float[Tensor, \"*bs 2\"]):\n\n        for ci, coo_comb in enumerate(self.coo_combs):\n            grid = self.plane_coefs[ci].unsqueeze(0)  # [1, feature_dim, reso1, reso2]\n            coords = inp[..., coo_comb].view(1, 1, -1, 2)  # [1, 1, flattened_bs, 2]\n\n            interp = F.grid_sample(\n                grid, coords, align_corners=True, padding_mode=\"border\"\n            )  # [1, output_dim, 1, flattened_bs]\n            interp = interp.view(self.feat_dim, -1).T  # [flattened_bs, output_dim]\n\n            output = interp\n\n        return output\n\n    def compute_plane_tv(\n        self,\n    ):\n        ret_loss = 0.0\n\n        for plane_coef in self.plane_coefs:\n            ret_loss += compute_plane_tv(plane_coef)\n\n        return ret_loss\n\n\nclass TemporalNeRFEncoding(nn.Module):\n    def __init__(\n        self,\n        in_dim,  # : int,\n        num_frequencies: int,\n        min_freq_exp: float,\n        max_freq_exp: float,\n        log_scale: bool = False,\n        include_input: bool = False,\n    ) -> None:\n        super().__init__()\n        self.in_dim = in_dim\n        self.num_frequencies = num_frequencies\n        self.min_freq = min_freq_exp\n        self.max_freq = max_freq_exp\n        self.log_scale = log_scale\n        self.include_input = include_input\n\n    def get_out_dim(self) -> int:\n        if self.in_dim is None:\n            raise ValueError(\"Input dimension has not been set\")\n        out_dim = self.in_dim * self.num_frequencies * 2\n        if self.include_input:\n            out_dim += self.in_dim\n        return out_dim\n\n    def forward(\n        self,\n        in_tensor: Float[Tensor, \"*bs input_dim\"],\n    ) -> Float[Tensor, \"*bs output_dim\"]:\n        \"\"\"Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed\n            in mip-NeRF.\n\n        Args:\n            in_tensor: For best performance, the input tensor should be between 0 and 1.\n            covs: Covariances of input points.\n        Returns:\n            Output values will be between -1 and 1\n        \"\"\"\n        scaled_in_tensor = 2 * torch.pi * in_tensor  # scale to [0, 2pi]\n\n        # freqs = 2 ** torch.linspace(\n        freqs = torch.linspace(\n            self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device\n        )\n        if self.log_scale:\n            freqs = 2**freqs\n        scaled_inputs = (\n            scaled_in_tensor[..., None] * freqs\n        )  # [..., \"input_dim\", \"num_scales\"]\n        scaled_inputs = scaled_inputs.view(\n            *scaled_inputs.shape[:-2], -1\n        )  # [..., \"input_dim\" * \"num_scales\"]\n\n        encoded_inputs = torch.sin(\n            torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1)\n        )\n        return encoded_inputs\n"
  },
  {
    "path": "physdreamer/field_components/mlp.py",
    "content": "\"\"\"\nMostly from nerfstudio: https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/field_components/mlp.py\n\"\"\"\nfrom typing import Optional, Set, Tuple, Union\n\nimport torch\nfrom jaxtyping import Float\nfrom torch import Tensor, nn\n\n\nclass MLP(nn.Module):\n    def __init__(\n        self,\n        in_dim: int,\n        num_layers: int,\n        layer_width: int,\n        out_dim: Optional[int] = None,\n        skip_connections: Optional[Tuple[int]] = None,\n        activation: Optional[nn.Module] = nn.ReLU(),\n        out_activation: Optional[nn.Module] = None,\n        zero_init = False,\n    ) -> None:\n        super().__init__()\n        self.in_dim = in_dim\n        assert self.in_dim > 0\n        self.out_dim = out_dim if out_dim is not None else layer_width\n        self.num_layers = num_layers\n        self.layer_width = layer_width\n        self.skip_connections = skip_connections\n        self._skip_connections: Set[int] = (\n            set(skip_connections) if skip_connections else set()\n        )\n        self.activation = activation\n        self.out_activation = out_activation\n        self.net = None\n        self.zero_init = zero_init\n\n        self.build_nn_modules()\n\n    def build_nn_modules(self) -> None:\n        \"\"\"Initialize multi-layer perceptron.\"\"\"\n        layers = []\n        if self.num_layers == 1:\n            layers.append(nn.Linear(self.in_dim, self.out_dim))\n        else:\n            for i in range(self.num_layers - 1):\n                if i == 0:\n                    assert (\n                        i not in self._skip_connections\n                    ), \"Skip connection at layer 0 doesn't make sense.\"\n                    layers.append(nn.Linear(self.in_dim, self.layer_width))\n                elif i in self._skip_connections:\n                    layers.append(\n                        nn.Linear(self.layer_width + self.in_dim, self.layer_width)\n                    )\n                else:\n                    layers.append(nn.Linear(self.layer_width, self.layer_width))\n            layers.append(nn.Linear(self.layer_width, self.out_dim))\n        self.layers = nn.ModuleList(layers)\n\n        if self.zero_init:\n            torch.nn.init.zeros_(self.layers[-1].weight)\n            torch.nn.init.zeros_(self.layers[-1].bias)\n\n    def pytorch_fwd(\n        self, in_tensor: Float[Tensor, \"*bs in_dim\"]\n    ) -> Float[Tensor, \"*bs out_dim\"]:\n        \"\"\"Process input with a multilayer perceptron.\n\n        Args:\n            in_tensor: Network input\n\n        Returns:\n            MLP network output\n        \"\"\"\n        x = in_tensor\n        for i, layer in enumerate(self.layers):\n            # as checked in `build_nn_modules`, 0 should not be in `_skip_connections`\n            if i in self._skip_connections:\n                x = torch.cat([in_tensor, x], -1)\n            x = layer(x)\n            if self.activation is not None and i < len(self.layers) - 1:\n                x = self.activation(x)\n        if self.out_activation is not None:\n            x = self.out_activation(x)\n        return x\n\n    def forward(\n        self, in_tensor: Float[Tensor, \"*bs in_dim\"]\n    ) -> Float[Tensor, \"*bs out_dim\"]:\n        return self.pytorch_fwd(in_tensor)\n"
  },
  {
    "path": "physdreamer/fields/mul_offset_field.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Literal, Optional, Sequence, Tuple, List\nfrom physdreamer.field_components.encoding import (\n    TemporalKplanesEncoding,\n    TriplanesEncoding,\n)\nfrom physdreamer.field_components.mlp import MLP\nfrom physdreamer.operators.rotation import rotation_6d_to_matrix, quaternion_to_matrix\nfrom physdreamer.data.scene_box import SceneBox\n\n\nclass MulTemporalKplanesOffsetfields(nn.Module):\n    \"\"\"Multiple Temporal Kplanes SE(3) fields.\n\n        Decoder is shared, but plane coefs are different.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z ,t].\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions_list: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        add_spatial_triplane: bool = True,\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        self.output_dim = 3\n\n        self.temporal_kplanes_encoding_list = nn.ModuleList(\n            [\n                TemporalKplanesEncoding(resolutions, feat_dim, init_a, init_b, reduce)\n                for resolutions in resolutions_list\n            ]\n        )\n\n        self.add_spatial_triplane = add_spatial_triplane\n        if add_spatial_triplane:\n            self.spatial_kplanes_encoding_list = nn.ModuleList(\n                [\n                    TriplanesEncoding(\n                        resolutions[:-1], feat_dim, init_a, init_b, reduce\n                    )\n                    for resolutions in resolutions_list\n                ]\n            )\n            feat_dim = feat_dim * 2\n\n        self.decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=self.output_dim,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n        )\n\n    def forward(\n        self, inp: Float[Tensor, \"*bs 4\"], dataset_indx: Int[Tensor, \"1\"]\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"*bs 3\"]]:\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inpx, self.aabb) * 2.0 - 1.0\n\n        inpt = inpt * 2.0 - 1.0\n\n        inp = torch.cat([inpx, inpt], dim=-1)\n\n        # for loop in batch dimension\n\n        output = self.temporal_kplanes_encoding_list[dataset_indx](inp)\n\n        if self.add_spatial_triplane:\n            spatial_output = self.spatial_kplanes_encoding_list[dataset_indx](inp)\n            output = torch.cat([output, spatial_output], dim=-1)\n\n        output = self.decoder(output)\n\n        return output\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        temporal_smoothness_loss = 0.0\n        for temporal_kplanes_encoding in self.temporal_kplanes_encoding_list:\n            temporal_smoothness_loss += (\n                temporal_kplanes_encoding.compute_temporal_smoothness()\n            )\n\n        smothness_loss = 0.0\n        for temporal_kplanes_encoding in self.temporal_kplanes_encoding_list:\n            smothness_loss += temporal_kplanes_encoding.compute_plane_tv()\n\n        if self.add_spatial_triplane:\n            for spatial_kplanes_encoding in self.spatial_kplanes_encoding_list:\n                smothness_loss += spatial_kplanes_encoding.compute_plane_tv()\n\n        return smothness_loss, temporal_smoothness_loss\n\n    def compute_loss(\n        self,\n        inp: Float[Tensor, \"*bs 4\"],\n        trajectory: Float[Tensor, \"*bs 3\"],\n        loss_func,\n    ):\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        output = self(inp)\n\n        rec_traj = inpx + output\n\n        rec_loss = loss_func(rec_traj, trajectory)\n\n        return rec_loss\n\n    def arap_loss(self, inp):\n        pass\n"
  },
  {
    "path": "physdreamer/fields/mul_se3_field.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Literal, Optional, Sequence, Tuple\nfrom physdreamer.field_components.encoding import (\n    TemporalKplanesEncoding,\n    TriplanesEncoding,\n)\nfrom physdreamer.field_components.mlp import MLP\nfrom physdreamer.operators.rotation import rotation_6d_to_matrix, quaternion_to_matrix\nfrom physdreamer.data.scene_box import SceneBox\n\n\nclass MulTemporalKplanesSE3fields(nn.Module):\n    \"\"\"Multiple Temporal Kplanes SE(3) fields.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z ,t].\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions_list: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        rotation_type: Literal[\"quaternion\", \"6d\"] = \"6d\",\n        add_spatial_triplane: bool = True,\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        output_dim_dict = {\"quaternion\": 4 + 3, \"6d\": 6 + 3}\n        self.output_dim = output_dim_dict[rotation_type]\n        self.rotation_type = rotation_type\n\n        self.temporal_kplanes_encoding_list = nn.ModuleList(\n            [\n                TemporalKplanesEncoding(resolutions, feat_dim, init_a, init_b, reduce)\n                for resolutions in resolutions_list\n            ]\n        )\n\n        self.add_spatial_triplane = add_spatial_triplane\n        if add_spatial_triplane:\n            self.spatial_kplanes_encoding_list = nn.ModuleList(\n                [\n                    TriplanesEncoding(\n                        resolutions[:-1], feat_dim, init_a, init_b, reduce\n                    )\n                    for resolutions in resolutions_list\n                ]\n            )\n            feat_dim = feat_dim * 2\n\n        self.decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=self.output_dim,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n        )\n\n    def forward(\n        self, inp: Float[Tensor, \"*bs 4\"], dataset_indx: Int[Tensor, \"1\"]\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"*bs 3\"]]:\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inpx, self.aabb) * 2.0 - 1.0\n\n        inpt = inpt * 2.0 - 1.0\n\n        inp = torch.cat([inpx, inpt], dim=-1)\n\n        # for loop in batch dimension\n\n        output = self.temporal_kplanes_encoding_list[dataset_indx](inp)\n\n        if self.add_spatial_triplane:\n            spatial_output = self.spatial_kplanes_encoding_list[dataset_indx](inp)\n            output = torch.cat([output, spatial_output], dim=-1)\n\n        output = self.decoder(output)\n\n        if self.rotation_type == \"6d\":\n            rotation_6d, translation = output[:, :6], output[:, 6:]\n            R_mat = rotation_6d_to_matrix(rotation_6d)\n\n        elif self.rotation_type == \"quaternion\":\n            quat, translation = output[:, :4], output[:, 4:]\n\n            # tanh and normalize\n            quat = torch.tanh(quat)\n\n            R_mat = quaternion_to_matrix(quat)\n\n        return R_mat, translation\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        temporal_smoothness_loss = 0.0\n        for temporal_kplanes_encoding in self.temporal_kplanes_encoding_list:\n            temporal_smoothness_loss += (\n                temporal_kplanes_encoding.compute_temporal_smoothness()\n            )\n\n        smothness_loss = 0.0\n        for temporal_kplanes_encoding in self.temporal_kplanes_encoding_list:\n            smothness_loss += temporal_kplanes_encoding.compute_plane_tv()\n\n        if self.add_spatial_triplane:\n            for spatial_kplanes_encoding in self.spatial_kplanes_encoding_list:\n                smothness_loss += spatial_kplanes_encoding.compute_plane_tv()\n\n        return smothness_loss, temporal_smoothness_loss\n\n    def compute_loss(\n        self,\n        inp: Float[Tensor, \"*bs 4\"],\n        trajectory: Float[Tensor, \"*bs 3\"],\n        loss_func,\n    ):\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        R, t = self(inp)\n\n        rec_traj = torch.bmm(R, inpx.unsqueeze(-1)).squeeze(-1) + t\n\n        rec_loss = loss_func(rec_traj, trajectory)\n\n        return rec_loss\n"
  },
  {
    "path": "physdreamer/fields/offset_field.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Literal, Optional, Sequence, Tuple, List\nfrom physdreamer.field_components.encoding import (\n    TemporalKplanesEncoding,\n    TriplanesEncoding,\n)\nfrom physdreamer.field_components.mlp import MLP\nfrom physdreamer.operators.rotation import rotation_6d_to_matrix, quaternion_to_matrix\nfrom physdreamer.data.scene_box import SceneBox\n\n\nclass TemporalKplanesOffsetfields(nn.Module):\n    \"\"\"Temporal Offsets fields.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z ,t].\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        add_spatial_triplane: bool = True,\n        zero_init: bool = True,\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        self.output_dim = 3\n\n        self.temporal_kplanes_encoding = TemporalKplanesEncoding(\n            resolutions, feat_dim, init_a, init_b, reduce\n        )\n\n        self.add_spatial_triplane = add_spatial_triplane\n        if add_spatial_triplane:\n            self.spatial_kplanes_encoding = TriplanesEncoding(\n                resolutions[:-1], feat_dim, init_a, init_b, reduce\n            )\n            feat_dim = feat_dim * 2\n\n        self.decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=self.output_dim,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n            zero_init=zero_init,\n        )\n\n    def forward(\n        self, inp: Float[Tensor, \"*bs 4\"]\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"*bs 3\"]]:\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inpx, self.aabb) * 2.0 - 1.0\n\n        inpt = inpt * 2.0 - 1.0\n\n        inp = torch.cat([inpx, inpt], dim=-1)\n        output = self.temporal_kplanes_encoding(inp)\n\n        if self.add_spatial_triplane:\n            spatial_output = self.spatial_kplanes_encoding(inpx)\n            output = torch.cat([output, spatial_output], dim=-1)\n\n        output = self.decoder(output)\n\n        return output\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        smothness_loss = self.temporal_kplanes_encoding.compute_plane_tv()\n        temporal_smoothness_loss = (\n            self.temporal_kplanes_encoding.compute_temporal_smoothness()\n        )\n\n        if self.add_spatial_triplane:\n            smothness_loss += self.spatial_kplanes_encoding.compute_plane_tv()\n\n        return smothness_loss + temporal_smoothness_loss\n\n    def compute_loss(\n        self,\n        inp: Float[Tensor, \"*bs 4\"],\n        trajectory: Float[Tensor, \"*bs 3\"],\n        loss_func,\n    ):\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        output = self(inp)\n\n        rec_traj = inpx + output\n\n        rec_loss = loss_func(rec_traj, trajectory)\n\n        return rec_loss\n\n    def arap_loss(self, inp):\n        pass\n\n    def forward_with_plane_coefs(\n        self,\n        plane_coefs: List[Float[Tensor, \"feat_dim H W\"]],\n        inp: Float[Tensor, \"*bs 4\"],\n    ):\n        \"\"\"\n        Args:\n            pass\n        \"\"\"\n\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inpx, self.aabb) * 2.0 - 1.0\n\n        inpt = inpt * 2.0 - 1.0\n\n        inp = torch.cat([inpx, inpt], dim=-1)\n        output = self.temporal_kplanes_encoding.functional_forward(\n            plane_coefs, inp, reduce=self.temporal_kplanes_encoding.reduce\n        )\n\n        if self.add_spatial_triplane:\n            spatial_output = self.spatial_kplanes_encoding(inpx)\n            output = torch.cat([output, spatial_output], dim=-1)\n\n        output = self.decoder(output)\n\n        return output\n"
  },
  {
    "path": "physdreamer/fields/se3_field.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Literal, Optional, Sequence, Tuple\nfrom physdreamer.field_components.encoding import (\n    TemporalKplanesEncoding,\n    TriplanesEncoding,\n)\nfrom physdreamer.field_components.mlp import MLP\nfrom physdreamer.operators.rotation import rotation_6d_to_matrix, quaternion_to_matrix\nfrom physdreamer.data.scene_box import SceneBox\n\n\nclass TemporalKplanesSE3fields(nn.Module):\n    \"\"\"Temporal Kplanes SE(3) fields.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z ,t].\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        rotation_type: Literal[\"quaternion\", \"6d\"] = \"6d\",\n        add_spatial_triplane: bool = True,\n        zero_init: bool = True,\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        output_dim_dict = {\"quaternion\": 4 + 3, \"6d\": 6 + 3}\n        self.output_dim = output_dim_dict[rotation_type]\n        self.rotation_type = rotation_type\n\n        self.temporal_kplanes_encoding = TemporalKplanesEncoding(\n            resolutions, feat_dim, init_a, init_b, reduce\n        )\n\n        self.add_spatial_triplane = add_spatial_triplane\n        if add_spatial_triplane:\n            self.spatial_kplanes_encoding = TriplanesEncoding(\n                resolutions[:-1], feat_dim, init_a, init_b, reduce\n            )\n            feat_dim = feat_dim * 2\n\n        self.decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=self.output_dim,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n            zero_init=zero_init,\n        )\n\n    def forward(\n        self,\n        inp: Float[Tensor, \"*bs 4\"],\n        compute_smoothess_loss: bool = False,\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"*bs 3\"]]:\n        if compute_smoothess_loss:\n            smothness_loss, temporal_smoothness_loss = self.compute_smoothess_loss()\n            return smothness_loss + temporal_smoothness_loss\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inpx, self.aabb) * 2.0 - 1.0\n\n        inpt = inpt * 2.0 - 1.0\n\n        inp = torch.cat([inpx, inpt], dim=-1)\n        output = self.temporal_kplanes_encoding(inp)\n\n        if self.add_spatial_triplane:\n            spatial_output = self.spatial_kplanes_encoding(inpx)\n            output = torch.cat([output, spatial_output], dim=-1)\n\n        output = self.decoder(output)\n\n        if self.rotation_type == \"6d\":\n            rotation_6d, translation = output[:, :6], output[:, 6:]\n            R_mat = rotation_6d_to_matrix(rotation_6d)\n\n        elif self.rotation_type == \"quaternion\":\n            quat, translation = output[:, :4], output[:, 4:]\n\n            # tanh and normalize\n            quat = torch.tanh(quat)\n\n            R_mat = quaternion_to_matrix(quat)\n\n            # --------------- remove below --------------- #\n            # add normalization\n            # r = quat\n            # norm = torch.sqrt(\n            #     r[:, 0] * r[:, 0]\n            #     + r[:, 1] * r[:, 1]\n            #     + r[:, 2] * r[:, 2]\n            #     + r[:, 3] * r[:, 3]\n            # )\n            # q = r / norm[:, None]\n            # R_mat = q\n            # --------------- remove above --------------- #\n\n        return R_mat, translation\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        smothness_loss = self.temporal_kplanes_encoding.compute_plane_tv()\n        temporal_smoothness_loss = (\n            self.temporal_kplanes_encoding.compute_temporal_smoothness()\n        )\n\n        if self.add_spatial_triplane:\n            smothness_loss += self.spatial_kplanes_encoding.compute_plane_tv()\n\n        return smothness_loss, temporal_smoothness_loss\n\n    def compute_loss(\n        self,\n        inp: Float[Tensor, \"*bs 4\"],\n        trajectory: Float[Tensor, \"*bs 3\"],\n        loss_func,\n    ):\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        R, t = self(inp)\n\n        rec_traj = torch.bmm(R, inpx.unsqueeze(-1)).squeeze(-1) + t\n\n        rec_loss = loss_func(rec_traj, trajectory)\n\n        return rec_loss\n"
  },
  {
    "path": "physdreamer/fields/triplane_field.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Optional, Sequence, Tuple, List\nfrom physdreamer.field_components.encoding import TriplanesEncoding\nfrom physdreamer.field_components.mlp import MLP\nfrom physdreamer.data.scene_box import SceneBox\n\n\nclass TriplaneFields(nn.Module):\n    \"\"\"Temporal Kplanes SE(3) fields.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z]\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce=\"sum\",  #: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        output_dim: int = 96,\n        zero_init: bool = False,\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        self.output_dim = output_dim\n\n        self.kplanes_encoding = TriplanesEncoding(\n            resolutions, feat_dim, init_a, init_b, reduce\n        )\n\n        if reduce == \"cat\":\n            feat_dim = feat_dim * 3\n        self.decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=self.output_dim,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n            zero_init=zero_init,\n        )\n\n    def forward(\n        self, inp: Float[Tensor, \"*bs 3\"]\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"*bs 3\"]]:\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inp, self.aabb) * 2.0 - 1.0\n\n        output = self.kplanes_encoding(inpx)\n\n        output = self.decoder(output)\n\n        # split_size = output.shape[-1] // 3\n        # output = torch.stack(torch.split(output, split_size, dim=-1), dim=-1)\n\n        return output\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        smothness_loss = self.kplanes_encoding.compute_plane_tv()\n\n        return smothness_loss\n\n\ndef compute_entropy(p):\n    return -torch.sum(\n        p * torch.log(p + 1e-5), dim=1\n    ).mean()  # Adding a small constant to prevent log(0)\n\n\nclass TriplaneFieldsWithEntropy(nn.Module):\n    \"\"\"Temporal Kplanes SE(3) fields.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z]\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce=\"sum\",  #: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        output_dim: int = 96,\n        zero_init: bool = False,\n        num_cls: int = 3,\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        self.output_dim = output_dim\n        self.num_cls = num_cls\n\n        self.kplanes_encoding = TriplanesEncoding(\n            resolutions, feat_dim, init_a, init_b, reduce\n        )\n\n        self.decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=self.num_cls,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n            zero_init=zero_init,\n        )\n\n        self.cls_embedding = torch.nn.Embedding(num_cls, output_dim)\n\n    def forward(\n        self, inp: Float[Tensor, \"*bs 3\"]\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"1\"]]:\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inp, self.aabb) * 2.0 - 1.0\n\n        output = self.kplanes_encoding(inpx)\n\n        output = self.decoder(output)\n\n        prob = F.softmax(output, dim=-1)\n\n        entropy = compute_entropy(prob)\n\n        cls_index = torch.tensor([0, 1, 2]).to(inp.device)\n        cls_emb = self.cls_embedding(cls_index)\n\n        output = torch.matmul(prob, cls_emb)\n\n        return output, entropy\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        smothness_loss = self.kplanes_encoding.compute_plane_tv()\n\n        return smothness_loss\n"
  },
  {
    "path": "physdreamer/gaussian_3d/README.md",
    "content": "This folder is mainly a copy paste from https://github.com/graphdeco-inria/gaussian-splatting\n\nWe add some function to render the applied external force. "
  },
  {
    "path": "physdreamer/gaussian_3d/arguments/__init__.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom argparse import ArgumentParser, Namespace\nimport sys\nimport os\n\nclass GroupParams:\n    pass\n\nclass ParamGroup:\n    def __init__(self, parser: ArgumentParser, name : str, fill_none = False):\n        group = parser.add_argument_group(name)\n        for key, value in vars(self).items():\n            shorthand = False\n            if key.startswith(\"_\"):\n                shorthand = True\n                key = key[1:]\n            t = type(value)\n            value = value if not fill_none else None \n            if shorthand:\n                if t == bool:\n                    group.add_argument(\"--\" + key, (\"-\" + key[0:1]), default=value, action=\"store_true\")\n                else:\n                    group.add_argument(\"--\" + key, (\"-\" + key[0:1]), default=value, type=t)\n            else:\n                if t == bool:\n                    group.add_argument(\"--\" + key, default=value, action=\"store_true\")\n                else:\n                    group.add_argument(\"--\" + key, default=value, type=t)\n\n    def extract(self, args):\n        group = GroupParams()\n        for arg in vars(args).items():\n            if arg[0] in vars(self) or (\"_\" + arg[0]) in vars(self):\n                setattr(group, arg[0], arg[1])\n        return group\n\nclass ModelParams(ParamGroup): \n    def __init__(self, parser, sentinel=False):\n        self.sh_degree = 3\n        self._source_path = \"\"\n        self._model_path = \"\"\n        self._images = \"images\"\n        self._resolution = -1\n        self._white_background = False\n        self.data_device = \"cuda\"\n        self.eval = False\n        super().__init__(parser, \"Loading Parameters\", sentinel)\n\n    def extract(self, args):\n        g = super().extract(args)\n        g.source_path = os.path.abspath(g.source_path)\n        return g\n\nclass PipelineParams(ParamGroup):\n    def __init__(self, parser):\n        self.convert_SHs_python = False\n        self.compute_cov3D_python = False\n        self.debug = False\n        super().__init__(parser, \"Pipeline Parameters\")\n\nclass OptimizationParams(ParamGroup):\n    def __init__(self, parser):\n        self.iterations = 30_000\n        self.position_lr_init = 0.00016\n        self.position_lr_final = 0.0000016\n        self.position_lr_delay_mult = 0.01\n        self.position_lr_max_steps = 30_000\n        self.feature_lr = 0.0025\n        self.opacity_lr = 0.05\n        self.scaling_lr = 0.005\n        self.rotation_lr = 0.001\n        self.percent_dense = 0.01\n        self.lambda_dssim = 0.2\n        self.densification_interval = 100\n        self.opacity_reset_interval = 3000\n        self.densify_from_iter = 500\n        self.densify_until_iter = 15_000\n        self.densify_grad_threshold = 0.0002\n        super().__init__(parser, \"Optimization Parameters\")\n\ndef get_combined_args(parser : ArgumentParser):\n    cmdlne_string = sys.argv[1:]\n    cfgfile_string = \"Namespace()\"\n    args_cmdline = parser.parse_args(cmdlne_string)\n\n    try:\n        cfgfilepath = os.path.join(args_cmdline.model_path, \"cfg_args\")\n        print(\"Looking for config file in\", cfgfilepath)\n        with open(cfgfilepath) as cfg_file:\n            print(\"Config file found: {}\".format(cfgfilepath))\n            cfgfile_string = cfg_file.read()\n    except TypeError:\n        print(\"Config file not found at\")\n        pass\n    args_cfgfile = eval(cfgfile_string)\n\n    merged_dict = vars(args_cfgfile).copy()\n    for k,v in vars(args_cmdline).items():\n        if v != None:\n            merged_dict[k] = v\n    return Namespace(**merged_dict)\n"
  },
  {
    "path": "physdreamer/gaussian_3d/gaussian_renderer/__init__.py",
    "content": ""
  },
  {
    "path": "physdreamer/gaussian_3d/gaussian_renderer/depth_uv_render.py",
    "content": "import torch\nfrom physdreamer.gaussian_3d.scene.gaussian_model import GaussianModel\nimport math\n\nfrom diff_gaussian_rasterization import (\n    GaussianRasterizationSettings,\n    GaussianRasterizer,\n)\nfrom typing import Callable\n\n\ndef render_uv_depth_w_gaussian(\n    viewpoint_camera,\n    pc: GaussianModel,\n    pipe,\n    bg_color: torch.Tensor,\n    scaling_modifier=1.0,\n):\n    \"\"\"\n    Render the scene.\n\n    Background tensor (bg_color) must be on GPU!\n\n    Args:\n        point_disp: [N, 3]\n    \"\"\"\n\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = (\n        torch.zeros_like(\n            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\"\n        )\n        + 0\n    )\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug,\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if pipe.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n\n    shs = None\n    colors_precomp = None\n\n    # project point motion to 2D using camera:\n    w2c = viewpoint_camera.world_view_transform.transpose(0, 1)\n    cam_plane_2_img = viewpoint_camera.cam_plane_2_img  # [2, 2]\n\n    R = w2c[:3, :3].unsqueeze(0)  # [1, 3, 3]\n    t = w2c[:3, 3].unsqueeze(0)  # [1, 3]\n\n    # [N, 3, 1]\n    pts = torch.cat([pc._xyz, torch.ones_like(pc._xyz[:, 0:1])], dim=-1)\n    pts_cam = w2c.unsqueeze(0) @ pts.unsqueeze(-1)  # [N, 4, 1]\n    # pts_cam = R @ (pc._xyz.unsqueeze(-1)) + t[:, None]\n    depth = pts_cam[:, 2, 0]  # [N]\n    # print(\"depth\", depth.shape, depth.max(), depth.mean(), depth.min())\n\n    # [N, 2]\n    pts_cam_xy = pts_cam[:, :2, 0] / depth.unsqueeze(-1)\n\n    pts_cam_xy_pixel = cam_plane_2_img.unsqueeze(0) @ pts_cam_xy.unsqueeze(\n        -1\n    )  # [N, 2, 1]\n    pts_cam_xy_pixel = pts_cam_xy_pixel.squeeze(-1)  # [N, 2]\n\n    colors_precomp = torch.cat(\n        [pts_cam_xy_pixel, depth.unsqueeze(dim=-1)], dim=-1\n    )  # [N, 3]\n\n    # print(\"converted 2D motion precompute: \", colors_precomp.shape, shs, colors_precomp.max(), colors_precomp.min(), colors_precomp.mean())\n    # Rasterize visible Gaussians to image, obtain their radii (on screen).\n    rendered_image, radii = rasterizer(\n        means3D=means3D,\n        means2D=means2D,\n        shs=shs,\n        colors_precomp=colors_precomp,\n        opacities=opacity,\n        scales=scales,\n        rotations=rotations,\n        cov3D_precomp=cov3D_precomp,\n    )\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n\n    return {\n        \"render\": rendered_image,\n        \"visibility_filter\": radii > 0,\n        \"radii\": radii,\n        \"pts_depth\": depth,\n        \"pts_cam_xy_pixel\": pts_cam_xy_pixel,\n    }\n"
  },
  {
    "path": "physdreamer/gaussian_3d/gaussian_renderer/feat_render.py",
    "content": "import torch\nfrom physdreamer.gaussian_3d.scene.gaussian_model import GaussianModel\nimport math\n\nfrom diff_gaussian_rasterization import (\n    GaussianRasterizationSettings,\n    GaussianRasterizer,\n)\nfrom typing import Callable\n\n\ndef render_feat_gaussian(\n    viewpoint_camera,\n    pc: GaussianModel,\n    pipe,\n    bg_color: torch.Tensor,\n    points_feat: torch.Tensor,\n    scaling_modifier=1.0,\n):\n    \"\"\"\n    Render the scene.\n\n    Background tensor (bg_color) must be on GPU!\n\n    Args:\n        point_disp: [N, 3]\n    \"\"\"\n\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = (\n        torch.zeros_like(\n            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\"\n        )\n        + 0\n    )\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug,\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if pipe.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n\n    shs = None\n    colors_precomp = points_feat\n    assert (points_feat.shape[1] == 3) and (points_feat.shape[0] == means3D.shape[0])\n\n    # print(\"converted 2D motion precompute: \", colors_precomp.shape, shs, colors_precomp.max(), colors_precomp.min(), colors_precomp.mean())\n    # Rasterize visible Gaussians to image, obtain their radii (on screen).\n    rendered_image, radii = rasterizer(\n        means3D=means3D,\n        means2D=means2D,\n        shs=shs,\n        colors_precomp=colors_precomp,\n        opacities=opacity,\n        scales=scales,\n        rotations=rotations,\n        cov3D_precomp=cov3D_precomp,\n    )\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n\n    return {\n        \"render\": rendered_image,\n        \"visibility_filter\": radii > 0,\n        \"radii\": radii,\n    }\n"
  },
  {
    "path": "physdreamer/gaussian_3d/gaussian_renderer/flow_depth_render.py",
    "content": "import torch\nfrom physdreamer.gaussian_3d.scene.gaussian_model import GaussianModel\nimport math\n\nfrom diff_gaussian_rasterization import (\n    GaussianRasterizationSettings,\n    GaussianRasterizer,\n)\nfrom typing import Callable\n\n\ndef render_flow_depth_w_gaussian(\n    viewpoint_camera,\n    pc: GaussianModel,\n    pipe,\n    point_disp: torch.Tensor,\n    bg_color: torch.Tensor,\n    scaling_modifier=1.0,\n):\n    \"\"\"\n    Render the scene.\n\n    Background tensor (bg_color) must be on GPU!\n\n    Args:\n        point_disp: [N, 3]\n    \"\"\"\n\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = (\n        torch.zeros_like(\n            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\"\n        )\n        + 0\n    )\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug,\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if pipe.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    shs = None\n    colors_precomp = None\n\n    # project point motion to 2D using camera:\n    w2c = viewpoint_camera.world_view_transform.transpose(0, 1)\n    cam_plane_2_img = viewpoint_camera.cam_plane_2_img  # [2, 2]\n\n    R = w2c[:3, :3].unsqueeze(0)  # [1, 3, 3]\n    t = w2c[:3, 3].unsqueeze(0)  # [1, 3]\n\n    # [N, 3, 1]\n    pts = torch.cat([pc._xyz, torch.ones_like(pc._xyz[:, 0:1])], dim=-1)\n    pts_cam = w2c.unsqueeze(0) @ pts.unsqueeze(-1)  # [N, 4, 1]\n    # pts_cam = R @ (pc._xyz.unsqueeze(-1)) + t[:, None]\n    depth = pts_cam[:, 2, 0]  # [N]\n    # print(\"depth\", depth.shape, depth.max(), depth.mean(), depth.min())\n\n    point_disp_pad = torch.cat(\n        [point_disp, torch.zeros_like(point_disp[:, 0:1])], dim=-1\n    )  # [N, 4]\n\n    pts_motion = w2c.unsqueeze(0) @ point_disp_pad.unsqueeze(-1)  # [N, 4, 1]\n\n    # [N, 2]\n    pts_motion_xy = pts_motion[:, :2, 0] / depth.unsqueeze(-1)\n\n    pts_motion_xy_pixel = cam_plane_2_img.unsqueeze(0) @ pts_motion_xy.unsqueeze(\n        -1\n    )  # [N, 2, 1]\n    pts_motion_xy_pixel = pts_motion_xy_pixel.squeeze(-1)  # [N, 2]\n\n    colors_precomp = torch.cat(\n        [pts_motion_xy_pixel, depth.unsqueeze(dim=-1)], dim=-1\n    )  # [N, 3]\n\n    # print(\"converted 2D motion precompute: \", colors_precomp.shape, shs, colors_precomp.max(), colors_precomp.min(), colors_precomp.mean())\n    # Rasterize visible Gaussians to image, obtain their radii (on screen).\n    rendered_image, radii = rasterizer(\n        means3D=means3D,\n        means2D=means2D,\n        shs=shs,\n        colors_precomp=colors_precomp,\n        opacities=opacity,\n        scales=scales,\n        rotations=rotations,\n        cov3D_precomp=cov3D_precomp,\n    )\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n\n    # return {\n    #     \"render\": rendered_image,\n    #     \"viewspace_points\": screenspace_points,\n    #     \"visibility_filter\": radii > 0,\n    #     \"radii\": radii,\n    # }\n\n    return {\"render\": rendered_image}\n"
  },
  {
    "path": "physdreamer/gaussian_3d/gaussian_renderer/render.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport math\nfrom diff_gaussian_rasterization import (\n    GaussianRasterizationSettings,\n    GaussianRasterizer,\n)\nfrom physdreamer.gaussian_3d.scene.gaussian_model import GaussianModel\n\n\ndef render_gaussian(\n    viewpoint_camera,\n    pc: GaussianModel,\n    pipe,\n    bg_color: torch.Tensor,\n    scaling_modifier=1.0,\n    override_color=None,\n    cov3D_precomp=None,\n):\n    \"\"\"\n    Render the scene.\n\n    Background tensor (bg_color) must be on GPU!\n    \"\"\"\n\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = (\n        torch.zeros_like(\n            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\"\n        )\n        + 0\n    )\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug,\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n\n    if pipe.compute_cov3D_python or cov3D_precomp is None:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    elif cov3D_precomp is None:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    shs = None\n    colors_precomp = None\n    if override_color is None:\n        if pipe.convert_SHs_python:\n            shs_view = pc.get_features.transpose(1, 2).view(\n                -1, 3, (pc.max_sh_degree + 1) ** 2\n            )\n            dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(\n                pc.get_features.shape[0], 1\n            )\n            dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)\n            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)\n            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)\n        else:\n            shs = pc.get_features\n    else:\n        colors_precomp = override_color\n\n    # Rasterize visible Gaussians to image, obtain their radii (on screen).\n    rendered_image, radii = rasterizer(\n        means3D=means3D,\n        means2D=means2D,\n        shs=shs,\n        colors_precomp=colors_precomp,\n        opacities=opacity,\n        scales=scales,\n        rotations=rotations,\n        cov3D_precomp=cov3D_precomp,\n    )\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n    return {\n        \"render\": rendered_image,\n        \"viewspace_points\": screenspace_points,\n        \"visibility_filter\": radii > 0,\n        \"radii\": radii,\n    }\n    # return {\"render\": rendered_image}\n\n\ndef gaussian_intrin_scale(x_or_y: torch.Tensor, w_or_h: float):\n\n    ret = ((x_or_y + 1.0) * w_or_h - 1.0) * 0.5\n\n    return ret\n\n\ndef render_arrow_in_screen(viewpoint_camera, points_3d):\n\n    # project point motion to 2D using camera:\n    w2c = viewpoint_camera.world_view_transform.transpose(0, 1)\n    cam_plane_2_img = viewpoint_camera.cam_plane_2_img  # [2, 2]\n    cam_plane_2_img = viewpoint_camera.projection_matrix.transpose(0, 1)  # [4, 4]\n\n    full_proj_mat = viewpoint_camera.full_proj_transform\n\n    # [N, 4]\n    pts = torch.cat([points_3d, torch.ones_like(points_3d[:, 0:1])], dim=-1)\n    # [N, 1, 4] <-  [N, 1, 4] @ [1, 4, 4]\n    pts_cam = pts.unsqueeze(-2) @ full_proj_mat.unsqueeze(0)  # [N, 1, 4]\n\n    # start here\n\n    # pts: [N, 4]\n    # [1, 4, 4] @ [N, 4, 1] -> [N, 4, 1]\n    # from IPython import embed\n\n    # embed()\n    # pts_cam = torch.bmm(\n    #     full_proj_mat.T.unsqueeze(0), pts.unsqueeze(-1)\n    # )  # K*[R,T]*[x,y,z,1]^T to get 2D projection of Gaussians\n    # end here\n    pts_cam = full_proj_mat.T.unsqueeze(0) @ pts.unsqueeze(-1)\n\n    # print(pts_cam.shape)\n\n    pts_cam = pts_cam.squeeze(-1)  # [N, 4]\n    pts_cam = pts_cam[:, :3] / pts_cam[:, 3:]  # [N, 1, 3]\n\n    # print(pts_cam, \"after proj\")\n\n    pts_cam_yx_pixel = pts_cam[:, :2]\n    #  [N, 2] yx => xy\n    # pts_cam_xy_pixel = torch.cat(\n    #     [pts_cam_xy_pixel[:, [1]], pts_cam_xy_pixel[:, [0]]], dim=-1\n    # )\n\n    pts_cam_x, pts_cam_y = pts_cam_yx_pixel[:, 0], pts_cam_yx_pixel[:, 1]\n\n    w, h = viewpoint_camera.image_width, viewpoint_camera.image_height\n\n    pts_cam_x = gaussian_intrin_scale(pts_cam_x, w)\n    pts_cam_y = gaussian_intrin_scale(pts_cam_y, h)\n\n    ret_pts_cam_xy = torch.cat(\n        [pts_cam_x.unsqueeze(-1), pts_cam_y.unsqueeze(-1)], dim=-1\n    )\n\n    # print(ret_pts_cam_xy)\n\n    return ret_pts_cam_xy\n\n\ndef render_arrow_in_screen_back(viewpoint_camera, points_3d):\n\n    # project point motion to 2D using camera:\n    w2c = viewpoint_camera.world_view_transform.transpose(0, 1)\n    cam_plane_2_img = viewpoint_camera.cam_plane_2_img  # [2, 2]\n    cam_plane_2_img = viewpoint_camera.projection_matrix.transpose(0, 1)\n\n    from IPython import embed\n\n    embed()\n\n    R = w2c[:3, :3].unsqueeze(0)  # [1, 3, 3]\n    t = w2c[:3, 3].unsqueeze(0)  # [1, 3]\n\n    # [N, 3, 1]\n    pts = torch.cat([points_3d, torch.ones_like(points_3d[:, 0:1])], dim=-1)\n    pts_cam = w2c.unsqueeze(0) @ pts.unsqueeze(-1)  # [N, 4, 1]\n    # pts_cam = R @ (pc._xyz.unsqueeze(-1)) + t[:, None]\n    depth = pts_cam[:, 2, 0]  # [N]\n    # print(\"depth\", depth.shape, depth.max(), depth.mean(), depth.min())\n\n    # [N, 2]\n    pts_cam_xy = pts_cam[:, :2, 0] / depth.unsqueeze(-1)\n\n    pts_cam_xy_pixel = cam_plane_2_img.unsqueeze(0) @ pts_cam_xy.unsqueeze(\n        -1\n    )  # [N, 2, 1]\n    pts_cam_xy_pixel = pts_cam_xy_pixel.squeeze(-1)  # [N, 2]\n\n    #  [N, 2] yx => xy\n    pts_cam_xy_pixel = torch.cat(\n        [pts_cam_xy_pixel[:, [1]], pts_cam_xy_pixel[:, [0]]], dim=-1\n    )\n\n    return pts_cam_xy_pixel\n\n\n# for spherecal harmonics\n\n\nC0 = 0.28209479177387814\nC1 = 0.4886025119029199\nC2 = [\n    1.0925484305920792,\n    -1.0925484305920792,\n    0.31539156525252005,\n    -1.0925484305920792,\n    0.5462742152960396,\n]\nC3 = [\n    -0.5900435899266435,\n    2.890611442640554,\n    -0.4570457994644658,\n    0.3731763325901154,\n    -0.4570457994644658,\n    1.445305721320277,\n    -0.5900435899266435,\n]\nC4 = [\n    2.5033429417967046,\n    -1.7701307697799304,\n    0.9461746957575601,\n    -0.6690465435572892,\n    0.10578554691520431,\n    -0.6690465435572892,\n    0.47308734787878004,\n    -1.7701307697799304,\n    0.6258357354491761,\n]\n\n\ndef eval_sh(deg, sh, dirs):\n    \"\"\"\n    Evaluate spherical harmonics at unit directions\n    using hardcoded SH polynomials.\n    Works with torch/np/jnp.\n    ... Can be 0 or more batch dimensions.\n    Args:\n        deg: int SH deg. Currently, 0-3 supported\n        sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]\n        dirs: jnp.ndarray unit directions [..., 3]\n    Returns:\n        [..., C]\n    \"\"\"\n    assert deg <= 4 and deg >= 0\n    coeff = (deg + 1) ** 2\n    assert sh.shape[-1] >= coeff\n\n    result = C0 * sh[..., 0]\n    if deg > 0:\n        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]\n        result = (\n            result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3]\n        )\n\n        if deg > 1:\n            xx, yy, zz = x * x, y * y, z * z\n            xy, yz, xz = x * y, y * z, x * z\n            result = (\n                result\n                + C2[0] * xy * sh[..., 4]\n                + C2[1] * yz * sh[..., 5]\n                + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6]\n                + C2[3] * xz * sh[..., 7]\n                + C2[4] * (xx - yy) * sh[..., 8]\n            )\n\n            if deg > 2:\n                result = (\n                    result\n                    + C3[0] * y * (3 * xx - yy) * sh[..., 9]\n                    + C3[1] * xy * z * sh[..., 10]\n                    + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11]\n                    + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12]\n                    + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13]\n                    + C3[5] * z * (xx - yy) * sh[..., 14]\n                    + C3[6] * x * (xx - 3 * yy) * sh[..., 15]\n                )\n\n                if deg > 3:\n                    result = (\n                        result\n                        + C4[0] * xy * (xx - yy) * sh[..., 16]\n                        + C4[1] * yz * (3 * xx - yy) * sh[..., 17]\n                        + C4[2] * xy * (7 * zz - 1) * sh[..., 18]\n                        + C4[3] * yz * (7 * zz - 3) * sh[..., 19]\n                        + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20]\n                        + C4[5] * xz * (7 * zz - 3) * sh[..., 21]\n                        + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22]\n                        + C4[7] * xz * (xx - 3 * yy) * sh[..., 23]\n                        + C4[8]\n                        * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))\n                        * sh[..., 24]\n                    )\n    return result\n\n\ndef RGB2SH(rgb):\n    return (rgb - 0.5) / C0\n\n\ndef SH2RGB(sh):\n    return sh * C0 + 0.5\n"
  },
  {
    "path": "physdreamer/gaussian_3d/scene/__init__.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport random\nimport numpy as np\nimport json\nfrom physdreamer.gaussian_3d.utils.system_utils import searchForMaxIteration\nfrom physdreamer.gaussian_3d.scene.dataset_readers import sceneLoadTypeCallbacks\nfrom physdreamer.gaussian_3d.scene.gaussian_model import GaussianModel\nfrom physdreamer.gaussian_3d.arguments import ModelParams\nfrom physdreamer.gaussian_3d.utils.camera_utils import (\n    cameraList_from_camInfos,\n    camera_to_JSON,\n)\n\n\nclass Scene:\n    gaussians: GaussianModel\n\n    def __init__(\n        self,\n        args: ModelParams,\n        gaussians: GaussianModel,\n        load_iteration=None,\n        shuffle=True,\n        resolution_scales=[1.0],\n    ):\n        \"\"\"b\n        :param path: Path to colmap scene main folder.\n        \"\"\"\n        self.model_path = args.model_path\n        self.loaded_iter = None\n        self.gaussians = gaussians\n\n        if load_iteration:\n            if load_iteration == -1:\n                self.loaded_iter = searchForMaxIteration(\n                    os.path.join(self.model_path, \"point_cloud\")\n                )\n            else:\n                self.loaded_iter = load_iteration\n            print(\"Loading trained model at iteration {}\".format(self.loaded_iter))\n\n        self.train_cameras = {}\n        self.test_cameras = {}\n\n        if os.path.exists(os.path.join(args.source_path, \"sparse\")):\n            scene_info = sceneLoadTypeCallbacks[\"Colmap\"](\n                args.source_path, args.images, args.eval\n            )\n        elif os.path.exists(os.path.join(args.source_path, \"transforms_train.json\")):\n            print(\"Found transforms_train.json file, assuming Blender data set!\")\n            scene_info = sceneLoadTypeCallbacks[\"Blender\"](\n                args.source_path, args.white_background, args.eval\n            )\n        else:\n            assert False, \"Could not recognize scene type!\"\n\n        if not self.loaded_iter:\n            with open(scene_info.ply_path, \"rb\") as src_file, open(\n                os.path.join(self.model_path, \"input.ply\"), \"wb\"\n            ) as dest_file:\n                dest_file.write(src_file.read())\n            json_cams = []\n            camlist = []\n            if scene_info.test_cameras:\n                camlist.extend(scene_info.test_cameras)\n            if scene_info.train_cameras:\n                camlist.extend(scene_info.train_cameras)\n            for id, cam in enumerate(camlist):\n                json_cams.append(camera_to_JSON(id, cam))\n            with open(os.path.join(self.model_path, \"cameras.json\"), \"w\") as file:\n                json.dump(json_cams, file)\n\n        if shuffle:\n            random.shuffle(\n                scene_info.train_cameras\n            )  # Multi-res consistent random shuffling\n            random.shuffle(\n                scene_info.test_cameras\n            )  # Multi-res consistent random shuffling\n\n        self.cameras_extent = scene_info.nerf_normalization[\"radius\"]\n\n        for resolution_scale in resolution_scales:\n            print(\"Loading Training Cameras\")\n            self.train_cameras[resolution_scale] = cameraList_from_camInfos(\n                scene_info.train_cameras, resolution_scale, args\n            )\n            print(\"Loading Test Cameras\")\n            self.test_cameras[resolution_scale] = cameraList_from_camInfos(\n                scene_info.test_cameras, resolution_scale, args\n            )\n\n        if self.loaded_iter:\n            self.gaussians.load_ply(\n                os.path.join(\n                    self.model_path,\n                    \"point_cloud\",\n                    \"iteration_\" + str(self.loaded_iter),\n                    \"point_cloud.ply\",\n                )\n            )\n        else:\n            self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)\n\n    def save(self, iteration):\n        point_cloud_path = os.path.join(\n            self.model_path, \"point_cloud/iteration_{}\".format(iteration)\n        )\n        self.gaussians.save_ply(os.path.join(point_cloud_path, \"point_cloud.ply\"))\n\n    def getTrainCameras(self, scale=1.0):\n        return self.train_cameras[scale]\n\n    def getTestCameras(self, scale=1.0):\n        return self.test_cameras[scale]\n"
  },
  {
    "path": "physdreamer/gaussian_3d/scene/cameras.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nfrom torch import nn\nimport numpy as np\nfrom physdreamer.gaussian_3d.utils.graphics_utils import (\n    getWorld2View2,\n    getProjectionMatrix,\n)\n\n\nclass Camera(nn.Module):\n    def __init__(\n        self,\n        colmap_id,\n        R,\n        T,\n        FoVx,\n        FoVy,\n        image,\n        gt_alpha_mask,\n        image_name,\n        uid,\n        trans=np.array([0.0, 0.0, 0.0]),\n        scale=1.0,\n        data_device=\"cuda\",\n    ):\n        super(Camera, self).__init__()\n\n        self.uid = uid\n        self.colmap_id = colmap_id\n        self.R = R\n        self.T = T\n        self.FoVx = FoVx\n        self.FoVy = FoVy\n        self.image_name = image_name\n\n        try:\n            self.data_device = torch.device(data_device)\n        except Exception as e:\n            print(e)\n            print(\n                f\"[Warning] Custom device {data_device} failed, fallback to default cuda device\"\n            )\n            self.data_device = torch.device(\"cuda\")\n\n        self.original_image = image.clamp(0.0, 1.0).to(self.data_device)\n        self.image_width = self.original_image.shape[2]\n        self.image_height = self.original_image.shape[1]\n\n        if gt_alpha_mask is not None:\n            self.original_image *= gt_alpha_mask.to(self.data_device)\n        else:\n            self.original_image *= torch.ones(\n                (1, self.image_height, self.image_width), device=self.data_device\n            )\n\n        self.zfar = 100.0\n        self.znear = 0.01\n\n        self.trans = trans\n        self.scale = scale\n\n        self.world_view_transform = (\n            torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()\n        )\n        self.projection_matrix = (\n            getProjectionMatrix(\n                znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy\n            )\n            .transpose(0, 1)\n            .cuda()\n        )\n        self.full_proj_transform = (\n            self.world_view_transform.unsqueeze(0).bmm(\n                self.projection_matrix.unsqueeze(0)\n            )\n        ).squeeze(0)\n        self.camera_center = self.world_view_transform.inverse()[3, :3]\n\n\nclass MiniCam:\n    def __init__(\n        self,\n        width,\n        height,\n        fovy,\n        fovx,\n        znear,\n        zfar,\n        world_view_transform,\n        full_proj_transform,\n    ):\n        self.image_width = width\n        self.image_height = height\n        self.FoVy = fovy\n        self.FoVx = fovx\n        self.znear = znear\n        self.zfar = zfar\n        self.world_view_transform = world_view_transform\n        self.full_proj_transform = full_proj_transform\n        view_inv = torch.inverse(self.world_view_transform)\n        self.camera_center = view_inv[3][:3]\n"
  },
  {
    "path": "physdreamer/gaussian_3d/scene/colmap_loader.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport numpy as np\nimport collections\nimport struct\n\nCameraModel = collections.namedtuple(\n    \"CameraModel\", [\"model_id\", \"model_name\", \"num_params\"])\nCamera = collections.namedtuple(\n    \"Camera\", [\"id\", \"model\", \"width\", \"height\", \"params\"])\nBaseImage = collections.namedtuple(\n    \"Image\", [\"id\", \"qvec\", \"tvec\", \"camera_id\", \"name\", \"xys\", \"point3D_ids\"])\nPoint3D = collections.namedtuple(\n    \"Point3D\", [\"id\", \"xyz\", \"rgb\", \"error\", \"image_ids\", \"point2D_idxs\"])\nCAMERA_MODELS = {\n    CameraModel(model_id=0, model_name=\"SIMPLE_PINHOLE\", num_params=3),\n    CameraModel(model_id=1, model_name=\"PINHOLE\", num_params=4),\n    CameraModel(model_id=2, model_name=\"SIMPLE_RADIAL\", num_params=4),\n    CameraModel(model_id=3, model_name=\"RADIAL\", num_params=5),\n    CameraModel(model_id=4, model_name=\"OPENCV\", num_params=8),\n    CameraModel(model_id=5, model_name=\"OPENCV_FISHEYE\", num_params=8),\n    CameraModel(model_id=6, model_name=\"FULL_OPENCV\", num_params=12),\n    CameraModel(model_id=7, model_name=\"FOV\", num_params=5),\n    CameraModel(model_id=8, model_name=\"SIMPLE_RADIAL_FISHEYE\", num_params=4),\n    CameraModel(model_id=9, model_name=\"RADIAL_FISHEYE\", num_params=5),\n    CameraModel(model_id=10, model_name=\"THIN_PRISM_FISHEYE\", num_params=12)\n}\nCAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)\n                         for camera_model in CAMERA_MODELS])\nCAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)\n                           for camera_model in CAMERA_MODELS])\n\n\ndef qvec2rotmat(qvec):\n    return np.array([\n        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,\n         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],\n         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],\n        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],\n         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,\n         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],\n        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],\n         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],\n         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])\n\ndef rotmat2qvec(R):\n    Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat\n    K = np.array([\n        [Rxx - Ryy - Rzz, 0, 0, 0],\n        [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],\n        [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],\n        [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0\n    eigvals, eigvecs = np.linalg.eigh(K)\n    qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]\n    if qvec[0] < 0:\n        qvec *= -1\n    return qvec\n\nclass Image(BaseImage):\n    def qvec2rotmat(self):\n        return qvec2rotmat(self.qvec)\n\ndef read_next_bytes(fid, num_bytes, format_char_sequence, endian_character=\"<\"):\n    \"\"\"Read and unpack the next bytes from a binary file.\n    :param fid:\n    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.\n    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.\n    :param endian_character: Any of {@, =, <, >, !}\n    :return: Tuple of read and unpacked values.\n    \"\"\"\n    data = fid.read(num_bytes)\n    return struct.unpack(endian_character + format_char_sequence, data)\n\ndef read_points3D_text(path):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DText(const std::string& path)\n        void Reconstruction::WritePoints3DText(const std::string& path)\n    \"\"\"\n    xyzs = None\n    rgbs = None\n    errors = None\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                xyz = np.array(tuple(map(float, elems[1:4])))\n                rgb = np.array(tuple(map(int, elems[4:7])))\n                error = np.array(float(elems[7]))\n                if xyzs is None:\n                    xyzs = xyz[None, ...]\n                    rgbs = rgb[None, ...]\n                    errors = error[None, ...]\n                else:\n                    xyzs = np.append(xyzs, xyz[None, ...], axis=0)\n                    rgbs = np.append(rgbs, rgb[None, ...], axis=0)\n                    errors = np.append(errors, error[None, ...], axis=0)\n    return xyzs, rgbs, errors\n\ndef read_points3D_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DBinary(const std::string& path)\n        void Reconstruction::WritePoints3DBinary(const std::string& path)\n    \"\"\"\n\n\n    with open(path_to_model_file, \"rb\") as fid:\n        num_points = read_next_bytes(fid, 8, \"Q\")[0]\n\n        xyzs = np.empty((num_points, 3))\n        rgbs = np.empty((num_points, 3))\n        errors = np.empty((num_points, 1))\n\n        for p_id in range(num_points):\n            binary_point_line_properties = read_next_bytes(\n                fid, num_bytes=43, format_char_sequence=\"QdddBBBd\")\n            xyz = np.array(binary_point_line_properties[1:4])\n            rgb = np.array(binary_point_line_properties[4:7])\n            error = np.array(binary_point_line_properties[7])\n            track_length = read_next_bytes(\n                fid, num_bytes=8, format_char_sequence=\"Q\")[0]\n            track_elems = read_next_bytes(\n                fid, num_bytes=8*track_length,\n                format_char_sequence=\"ii\"*track_length)\n            xyzs[p_id] = xyz\n            rgbs[p_id] = rgb\n            errors[p_id] = error\n    return xyzs, rgbs, errors\n\ndef read_intrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    cameras = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                camera_id = int(elems[0])\n                model = elems[1]\n                assert model == \"PINHOLE\", \"While the loader support other types, the rest of the code assumes PINHOLE\"\n                width = int(elems[2])\n                height = int(elems[3])\n                params = np.array(tuple(map(float, elems[4:])))\n                cameras[camera_id] = Camera(id=camera_id, model=model,\n                                            width=width, height=height,\n                                            params=params)\n    return cameras\n\ndef read_extrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadImagesBinary(const std::string& path)\n        void Reconstruction::WriteImagesBinary(const std::string& path)\n    \"\"\"\n    images = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_reg_images = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_reg_images):\n            binary_image_properties = read_next_bytes(\n                fid, num_bytes=64, format_char_sequence=\"idddddddi\")\n            image_id = binary_image_properties[0]\n            qvec = np.array(binary_image_properties[1:5])\n            tvec = np.array(binary_image_properties[5:8])\n            camera_id = binary_image_properties[8]\n            image_name = \"\"\n            current_char = read_next_bytes(fid, 1, \"c\")[0]\n            while current_char != b\"\\x00\":   # look for the ASCII 0 entry\n                image_name += current_char.decode(\"utf-8\")\n                current_char = read_next_bytes(fid, 1, \"c\")[0]\n            num_points2D = read_next_bytes(fid, num_bytes=8,\n                                           format_char_sequence=\"Q\")[0]\n            x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,\n                                       format_char_sequence=\"ddq\"*num_points2D)\n            xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),\n                                   tuple(map(float, x_y_id_s[1::3]))])\n            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))\n            images[image_id] = Image(\n                id=image_id, qvec=qvec, tvec=tvec,\n                camera_id=camera_id, name=image_name,\n                xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_intrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::WriteCamerasBinary(const std::string& path)\n        void Reconstruction::ReadCamerasBinary(const std::string& path)\n    \"\"\"\n    cameras = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_cameras = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_cameras):\n            camera_properties = read_next_bytes(\n                fid, num_bytes=24, format_char_sequence=\"iiQQ\")\n            camera_id = camera_properties[0]\n            model_id = camera_properties[1]\n            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name\n            width = camera_properties[2]\n            height = camera_properties[3]\n            num_params = CAMERA_MODEL_IDS[model_id].num_params\n            params = read_next_bytes(fid, num_bytes=8*num_params,\n                                     format_char_sequence=\"d\"*num_params)\n            cameras[camera_id] = Camera(id=camera_id,\n                                        model=model_name,\n                                        width=width,\n                                        height=height,\n                                        params=np.array(params))\n        assert len(cameras) == num_cameras\n    return cameras\n\n\ndef read_extrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    images = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                image_id = int(elems[0])\n                qvec = np.array(tuple(map(float, elems[1:5])))\n                tvec = np.array(tuple(map(float, elems[5:8])))\n                camera_id = int(elems[8])\n                image_name = elems[9]\n                elems = fid.readline().split()\n                xys = np.column_stack([tuple(map(float, elems[0::3])),\n                                       tuple(map(float, elems[1::3]))])\n                point3D_ids = np.array(tuple(map(int, elems[2::3])))\n                images[image_id] = Image(\n                    id=image_id, qvec=qvec, tvec=tvec,\n                    camera_id=camera_id, name=image_name,\n                    xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_colmap_bin_array(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py\n\n    :param path: path to the colmap binary file.\n    :return: nd array with the floating point values in the value\n    \"\"\"\n    with open(path, \"rb\") as fid:\n        width, height, channels = np.genfromtxt(fid, delimiter=\"&\", max_rows=1,\n                                                usecols=(0, 1, 2), dtype=int)\n        fid.seek(0)\n        num_delimiter = 0\n        byte = fid.read(1)\n        while True:\n            if byte == b\"&\":\n                num_delimiter += 1\n                if num_delimiter >= 3:\n                    break\n            byte = fid.read(1)\n        array = np.fromfile(fid, np.float32)\n    array = array.reshape((width, height, channels), order=\"F\")\n    return np.transpose(array, (1, 0, 2)).squeeze()\n"
  },
  {
    "path": "physdreamer/gaussian_3d/scene/dataset_readers.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport sys\nfrom PIL import Image\nfrom typing import NamedTuple\nfrom physdreamer.gaussian_3d.scene.colmap_loader import (\n    read_extrinsics_text,\n    read_intrinsics_text,\n    qvec2rotmat,\n    read_extrinsics_binary,\n    read_intrinsics_binary,\n    read_points3D_binary,\n    read_points3D_text,\n)\nfrom physdreamer.gaussian_3d.utils.graphics_utils import (\n    getWorld2View2,\n    focal2fov,\n    fov2focal,\n)\nimport numpy as np\nimport math\nimport json\nfrom pathlib import Path\nfrom plyfile import PlyData, PlyElement\nfrom physdreamer.gaussian_3d.utils.sh_utils import SH2RGB\nfrom physdreamer.gaussian_3d.scene.gaussian_model import BasicPointCloud\nimport torch\nimport torch.nn as nn\nfrom physdreamer.gaussian_3d.utils.graphics_utils import (\n    getWorld2View2,\n    getProjectionMatrix,\n)\n\n\nclass CameraInfo(NamedTuple):\n    uid: int\n    R: np.array\n    T: np.array\n    FovY: np.array\n    FovX: np.array\n    image: np.array\n    image_path: str\n    image_name: str\n    width: int\n    height: int\n\n\nclass SceneInfo(NamedTuple):\n    point_cloud: BasicPointCloud\n    train_cameras: list\n    test_cameras: list\n    nerf_normalization: dict\n    ply_path: str\n\n\ndef getNerfppNorm(cam_info):\n    def get_center_and_diag(cam_centers):\n        cam_centers = np.hstack(cam_centers)\n        avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)\n        center = avg_cam_center\n        dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)\n        diagonal = np.max(dist)\n        return center.flatten(), diagonal\n\n    cam_centers = []\n\n    for cam in cam_info:\n        W2C = getWorld2View2(cam.R, cam.T)\n        C2W = np.linalg.inv(W2C)\n        cam_centers.append(C2W[:3, 3:4])\n\n    center, diagonal = get_center_and_diag(cam_centers)\n    radius = diagonal * 1.1\n\n    translate = -center\n\n    return {\"translate\": translate, \"radius\": radius}\n\n\ndef readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):\n    cam_infos = []\n    for idx, key in enumerate(cam_extrinsics):\n        sys.stdout.write(\"\\r\")\n        # the exact output you're looking for:\n        sys.stdout.write(\"Reading camera {}/{}\".format(idx + 1, len(cam_extrinsics)))\n        sys.stdout.flush()\n\n        extr = cam_extrinsics[key]\n        intr = cam_intrinsics[extr.camera_id]\n        height = intr.height\n        width = intr.width\n\n        uid = intr.id\n        R = np.transpose(qvec2rotmat(extr.qvec))\n        T = np.array(extr.tvec)\n\n        if intr.model == \"SIMPLE_PINHOLE\":\n            focal_length_x = intr.params[0]\n            FovY = focal2fov(focal_length_x, height)\n            FovX = focal2fov(focal_length_x, width)\n        elif intr.model == \"PINHOLE\":\n            focal_length_x = intr.params[0]\n            focal_length_y = intr.params[1]\n            FovY = focal2fov(focal_length_y, height)\n            FovX = focal2fov(focal_length_x, width)\n        else:\n            assert (\n                False\n            ), \"Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!\"\n\n        image_path = os.path.join(images_folder, os.path.basename(extr.name))\n        image_name = os.path.basename(image_path).split(\".\")[0]\n        image = Image.open(image_path)\n\n        cam_info = CameraInfo(\n            uid=uid,\n            R=R,\n            T=T,\n            FovY=FovY,\n            FovX=FovX,\n            image=image,\n            image_path=image_path,\n            image_name=image_name,\n            width=width,\n            height=height,\n        )\n        cam_infos.append(cam_info)\n    sys.stdout.write(\"\\n\")\n    return cam_infos\n\n\ndef fetchPly(path):\n    plydata = PlyData.read(path)\n    vertices = plydata[\"vertex\"]\n    positions = np.vstack([vertices[\"x\"], vertices[\"y\"], vertices[\"z\"]]).T\n    colors = np.vstack([vertices[\"red\"], vertices[\"green\"], vertices[\"blue\"]]).T / 255.0\n    normals = np.vstack([vertices[\"nx\"], vertices[\"ny\"], vertices[\"nz\"]]).T\n    return BasicPointCloud(points=positions, colors=colors, normals=normals)\n\n\ndef storePly(path, xyz, rgb):\n    # Define the dtype for the structured array\n    dtype = [\n        (\"x\", \"f4\"),\n        (\"y\", \"f4\"),\n        (\"z\", \"f4\"),\n        (\"nx\", \"f4\"),\n        (\"ny\", \"f4\"),\n        (\"nz\", \"f4\"),\n        (\"red\", \"u1\"),\n        (\"green\", \"u1\"),\n        (\"blue\", \"u1\"),\n    ]\n\n    normals = np.zeros_like(xyz)\n\n    elements = np.empty(xyz.shape[0], dtype=dtype)\n    attributes = np.concatenate((xyz, normals, rgb), axis=1)\n    elements[:] = list(map(tuple, attributes))\n\n    # Create the PlyData object and write to file\n    vertex_element = PlyElement.describe(elements, \"vertex\")\n    ply_data = PlyData([vertex_element])\n    ply_data.write(path)\n\n\ndef readColmapSceneInfo(path, images, eval, llffhold=8):\n    try:\n        cameras_extrinsic_file = os.path.join(path, \"sparse/0\", \"images.bin\")\n        cameras_intrinsic_file = os.path.join(path, \"sparse/0\", \"cameras.bin\")\n        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)\n        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)\n    except:\n        cameras_extrinsic_file = os.path.join(path, \"sparse/0\", \"images.txt\")\n        cameras_intrinsic_file = os.path.join(path, \"sparse/0\", \"cameras.txt\")\n        cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)\n        cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)\n\n    reading_dir = \"images\" if images == None else images\n    cam_infos_unsorted = readColmapCameras(\n        cam_extrinsics=cam_extrinsics,\n        cam_intrinsics=cam_intrinsics,\n        images_folder=os.path.join(path, reading_dir),\n    )\n    cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name)\n\n    if eval:\n        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]\n        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]\n    else:\n        train_cam_infos = cam_infos\n        test_cam_infos = []\n\n    nerf_normalization = getNerfppNorm(train_cam_infos)\n\n    ply_path = os.path.join(path, \"sparse/0/points3D.ply\")\n    bin_path = os.path.join(path, \"sparse/0/points3D.bin\")\n    txt_path = os.path.join(path, \"sparse/0/points3D.txt\")\n    if not os.path.exists(ply_path):\n        print(\n            \"Converting point3d.bin to .ply, will happen only the first time you open the scene.\"\n        )\n        try:\n            xyz, rgb, _ = read_points3D_binary(bin_path)\n        except:\n            xyz, rgb, _ = read_points3D_text(txt_path)\n        storePly(ply_path, xyz, rgb)\n    try:\n        pcd = fetchPly(ply_path)\n    except:\n        pcd = None\n\n    scene_info = SceneInfo(\n        point_cloud=pcd,\n        train_cameras=train_cam_infos,\n        test_cameras=test_cam_infos,\n        nerf_normalization=nerf_normalization,\n        ply_path=ply_path,\n    )\n    return scene_info\n\n\ndef readCamerasFromTransforms(path, transformsfile, white_background, extension=\".png\"):\n    cam_infos = []\n\n    with open(os.path.join(path, transformsfile)) as json_file:\n        contents = json.load(json_file)\n\n        # camera_angle_x is the horizontal field of view\n        # frames.file_path is the image name\n        # frame.transform_matrix is the camera-to-world transform\n\n        fovx = contents[\"camera_angle_x\"]\n\n        frames = contents[\"frames\"]\n        for idx, frame in enumerate(frames):\n            cam_name = os.path.join(path, frame[\"file_path\"] + extension)\n\n            # NeRF 'transform_matrix' is a camera-to-world transform\n            c2w = np.array(frame[\"transform_matrix\"])\n            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)\n            c2w[:3, 1:3] *= -1\n\n            # get the world-to-camera transform and set R, T\n            w2c = np.linalg.inv(c2w)\n            R = np.transpose(\n                w2c[:3, :3]\n            )  # R is stored transposed due to 'glm' in CUDA code\n            T = w2c[:3, 3]\n\n            image_path = os.path.join(path, cam_name)\n            image_name = Path(cam_name).stem\n            image = Image.open(image_path)\n\n            im_data = np.array(image.convert(\"RGBA\"))\n\n            bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0])\n\n            norm_data = im_data / 255.0\n            arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * (\n                1 - norm_data[:, :, 3:4]\n            )\n            image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), \"RGB\")\n\n            fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])\n            FovY = fovy\n            FovX = fovx\n\n            cam_infos.append(\n                CameraInfo(\n                    uid=idx,\n                    R=R,\n                    T=T,\n                    FovY=FovY,\n                    FovX=FovX,\n                    image=image,\n                    image_path=image_path,\n                    image_name=image_name,\n                    width=image.size[0],\n                    height=image.size[1],\n                )\n            )\n\n    return cam_infos\n\n\ndef readNerfSyntheticInfo(path, white_background, eval, extension=\".png\"):\n    print(\"Reading Training Transforms\")\n    train_cam_infos = readCamerasFromTransforms(\n        path, \"transforms_train.json\", white_background, extension\n    )\n    print(\"Reading Test Transforms\")\n    test_cam_infos = readCamerasFromTransforms(\n        path, \"transforms_test.json\", white_background, extension\n    )\n\n    if not eval:\n        train_cam_infos.extend(test_cam_infos)\n        test_cam_infos = []\n\n    nerf_normalization = getNerfppNorm(train_cam_infos)\n\n    ply_path = os.path.join(path, \"points3d.ply\")\n    if not os.path.exists(ply_path):\n        # Since this data set has no colmap data, we start with random points\n        num_pts = 100_000\n        print(f\"Generating random point cloud ({num_pts})...\")\n\n        # We create random points inside the bounds of the synthetic Blender scenes\n        xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3\n        shs = np.random.random((num_pts, 3)) / 255.0\n        pcd = BasicPointCloud(\n            points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))\n        )\n\n        storePly(ply_path, xyz, SH2RGB(shs) * 255)\n    try:\n        pcd = fetchPly(ply_path)\n    except:\n        pcd = None\n\n    scene_info = SceneInfo(\n        point_cloud=pcd,\n        train_cameras=train_cam_infos,\n        test_cameras=test_cam_infos,\n        nerf_normalization=nerf_normalization,\n        ply_path=ply_path,\n    )\n    return scene_info\n\n\nsceneLoadTypeCallbacks = {\n    \"Colmap\": readColmapSceneInfo,\n    \"Blender\": readNerfSyntheticInfo,\n}\n\n\n# below used for easy rendering\nclass NoImageCamera(nn.Module):\n    def __init__(\n        self,\n        colmap_id,\n        R,\n        T,\n        FoVx,\n        FoVy,\n        width,\n        height,\n        uid,\n        trans=np.array([0.0, 0.0, 0.0]),\n        scale=1.0,\n        data_device=\"cuda\",\n        img_path=None,  # not needed\n    ):\n        super(NoImageCamera, self).__init__()\n\n        self.uid = uid\n        self.colmap_id = colmap_id\n        self.R = R\n        self.T = T\n        self.FoVx = FoVx\n        self.FoVy = FoVy\n        self.img_path = img_path\n\n        try:\n            self.data_device = torch.device(data_device)\n        except Exception as e:\n            print(e)\n            print(\n                f\"[Warning] Custom device {data_device} failed, fallback to default cuda device\"\n            )\n            self.data_device = torch.device(\"cuda\")\n\n        self.image_width = width\n        self.image_height = height\n\n        self.zfar = 100.0\n        self.znear = 0.01\n\n        self.trans = trans\n        self.scale = scale\n\n        # world to camera, then transpose.  # [4, 4]\n        #  w2c.transpose\n        self.world_view_transform = (\n            torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()\n        )\n\n        # [4, 4]\n        self.projection_matrix = (\n            getProjectionMatrix(\n                znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy\n            )\n            .transpose(0, 1)\n            .cuda()\n        )\n\n        # # [4, 4].  points @ full_proj_transform => screen space.\n        self.full_proj_transform = (\n            self.world_view_transform.unsqueeze(0).bmm(\n                self.projection_matrix.unsqueeze(0)\n            )\n        ).squeeze(0)\n        self.camera_center = self.world_view_transform.inverse()[3, :3]\n\n        # [2, 2].\n        #  (w2c @ p) / depth => cam_plane\n        #  (p_in_cam / depth)[:2] @  cam_plane_2_img => [pixel_x, pixel_y]    cam_plane => img_plane\n        self.cam_plane_2_img = torch.tensor(\n            [\n                [0.5 * width / math.tan(self.FoVx / 2.0), 0.0],\n                [0.0, 0.5 * height / math.tan(self.FoVy / 2.0)],\n            ]\n        ).cuda()\n\n\ndef fast_read_cameras_from_transform_file(file_path, width=1080, height=720):\n    cam_infos = []\n\n    dir_name = os.path.dirname(file_path)\n\n    with open(file_path) as json_file:\n        contents = json.load(json_file)\n\n        # camera_angle_x is the horizontal field of view\n        # frames.file_path is the image name\n        # frame.transform_matrix is the camera-to-world transform\n\n        fovx = contents[\"camera_angle_x\"]\n\n        frames = contents[\"frames\"]\n        for idx, frame in enumerate(frames):\n            # NeRF 'transform_matrix' is a camera-to-world transform\n            c2w = np.array(frame[\"transform_matrix\"])\n            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)\n            c2w[:3, 1:3] *= -1\n\n            # get the world-to-camera transform and set R, T\n            w2c = np.linalg.inv(c2w)\n            R = np.transpose(\n                w2c[:3, :3]\n            )  # R is stored transposed due to 'glm' in CUDA code\n            T = w2c[:3, 3]\n\n            fovy = focal2fov(fov2focal(fovx, width), height)\n            FovY = fovy\n            FovX = fovx\n\n            img_path = os.path.join(dir_name, frame[\"file_path\"] + \".png\")\n            cam_ = NoImageCamera(\n                colmap_id=idx,\n                R=R,\n                T=T,\n                FoVx=FovX,\n                FoVy=FovY,\n                width=width,\n                height=height,\n                uid=id,\n                data_device=\"cuda\",\n                img_path=img_path,\n            )\n\n            cam_infos.append(cam_)\n\n    return cam_infos\n"
  },
  {
    "path": "physdreamer/gaussian_3d/scene/gaussian_model.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport numpy as np\nfrom physdreamer.gaussian_3d.utils.general_utils import (\n    inverse_sigmoid,\n    get_expon_lr_func,\n    build_rotation,\n)\nfrom torch import nn\nimport os\nfrom physdreamer.gaussian_3d.utils.system_utils import mkdir_p\nfrom plyfile import PlyData, PlyElement\nfrom physdreamer.gaussian_3d.utils.sh_utils import RGB2SH\nfrom simple_knn._C import distCUDA2\nfrom physdreamer.gaussian_3d.utils.graphics_utils import BasicPointCloud\nfrom physdreamer.gaussian_3d.utils.general_utils import (\n    strip_symmetric,\n    build_scaling_rotation,\n)\nfrom physdreamer.gaussian_3d.utils.rigid_body_utils import (\n    get_rigid_transform,\n    matrix_to_quaternion,\n    quaternion_multiply,\n)\n\n\nclass GaussianModel:\n    def setup_functions(self):\n        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):\n            L = build_scaling_rotation(scaling_modifier * scaling, rotation)\n            actual_covariance = L @ L.transpose(1, 2)\n            symm = strip_symmetric(actual_covariance)\n            return symm\n\n        self.scaling_activation = torch.exp\n        self.scaling_inverse_activation = torch.log\n\n        self.covariance_activation = build_covariance_from_scaling_rotation\n\n        self.opacity_activation = torch.sigmoid\n        self.inverse_opacity_activation = inverse_sigmoid\n\n        self.rotation_activation = torch.nn.functional.normalize\n\n    def __init__(self, sh_degree: int = 3):\n        self.active_sh_degree = 0\n        self.max_sh_degree = sh_degree\n        self._xyz = torch.empty(0)\n        self._features_dc = torch.empty(0)\n        self._features_rest = torch.empty(0)\n        self._scaling = torch.empty(0)\n        self._rotation = torch.empty(0)\n        self._opacity = torch.empty(0)\n        self.max_radii2D = torch.empty(0)\n        self.xyz_gradient_accum = torch.empty(0)\n        self.denom = torch.empty(0)\n        self.optimizer = None\n        self.percent_dense = 0\n        self.spatial_lr_scale = 0\n        self.setup_functions()\n\n        self.matched_inds = None\n\n    def capture(self):\n        if self.optimizer is None:\n            optim_state = None\n        else:\n            optim_state = self.optimizer.state_dict()\n\n        return (\n            self.active_sh_degree,\n            self._xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            self._rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            optim_state,\n            self.spatial_lr_scale,\n        )\n\n    def restore(self, model_args, training_args):\n        (\n            self.active_sh_degree,\n            self._xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            self._rotation,\n            self._opacity,\n            self.max_radii2D,\n            xyz_gradient_accum,\n            denom,\n            opt_dict,\n            self.spatial_lr_scale,\n        ) = model_args\n\n        if training_args is not None:\n            self.training_setup(training_args)\n        self.xyz_gradient_accum = xyz_gradient_accum\n        self.denom = denom\n        if opt_dict is not None:\n            self.optimizer.load_state_dict(opt_dict)\n\n    def capture_training_args(\n        self,\n    ):\n        pass\n\n    @property\n    def get_scaling(self):\n        return self.scaling_activation(self._scaling)\n\n    @property\n    def get_rotation(self):\n        return self.rotation_activation(self._rotation)\n\n    @property\n    def get_xyz(self):\n        return self._xyz\n\n    @property\n    def get_features(self):\n        features_dc = self._features_dc\n        features_rest = self._features_rest\n        return torch.cat((features_dc, features_rest), dim=1)\n\n    @property\n    def get_opacity(self):\n        return self.opacity_activation(self._opacity)\n\n    def get_covariance(self, scaling_modifier=1):\n        return self.covariance_activation(\n            self.get_scaling, scaling_modifier, self._rotation\n        )\n\n    def oneupSHdegree(self):\n        if self.active_sh_degree < self.max_sh_degree:\n            self.active_sh_degree += 1\n\n    def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):\n        self.spatial_lr_scale = spatial_lr_scale\n        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()\n        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())\n        features = (\n            torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2))\n            .float()\n            .cuda()\n        )\n        features[:, :3, 0] = fused_color\n        # typo here?\n        features[:, 3:, 1:] = 0.0\n\n        print(\"Number of points at initialisation : \", fused_point_cloud.shape[0])\n\n        dist2 = torch.clamp_min(\n            distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()),\n            0.0000001,\n        )\n        scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)\n        rots = torch.zeros((fused_point_cloud.shape[0], 4), device=\"cuda\")\n        rots[:, 0] = 1\n\n        opacities = inverse_sigmoid(\n            0.1\n            * torch.ones(\n                (fused_point_cloud.shape[0], 1), dtype=torch.float, device=\"cuda\"\n            )\n        )\n\n        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))\n        self._features_dc = nn.Parameter(\n            features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)\n        )\n        self._features_rest = nn.Parameter(\n            features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)\n        )\n        self._scaling = nn.Parameter(scales.requires_grad_(True))\n        self._rotation = nn.Parameter(rots.requires_grad_(True))\n        self._opacity = nn.Parameter(opacities.requires_grad_(True))\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n\n    def training_setup(self, training_args):\n        self.percent_dense = training_args.percent_dense\n        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n\n        l = [\n            {\n                \"params\": [self._xyz],\n                \"lr\": training_args.position_lr_init * self.spatial_lr_scale,\n                \"name\": \"xyz\",\n            },\n            {\n                \"params\": [self._features_dc],\n                \"lr\": training_args.feature_lr,\n                \"name\": \"f_dc\",\n            },\n            {\n                \"params\": [self._features_rest],\n                \"lr\": training_args.feature_lr / 20.0,\n                \"name\": \"f_rest\",\n            },\n            {\n                \"params\": [self._opacity],\n                \"lr\": training_args.opacity_lr,\n                \"name\": \"opacity\",\n            },\n            {\n                \"params\": [self._scaling],\n                \"lr\": training_args.scaling_lr,\n                \"name\": \"scaling\",\n            },\n            {\n                \"params\": [self._rotation],\n                \"lr\": training_args.rotation_lr,\n                \"name\": \"rotation\",\n            },\n        ]\n\n        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)\n        self.xyz_scheduler_args = get_expon_lr_func(\n            lr_init=training_args.position_lr_init * self.spatial_lr_scale,\n            lr_final=training_args.position_lr_final * self.spatial_lr_scale,\n            lr_delay_mult=training_args.position_lr_delay_mult,\n            max_steps=training_args.position_lr_max_steps,\n        )\n\n    def update_learning_rate(self, iteration):\n        \"\"\"Learning rate scheduling per step\"\"\"\n        for param_group in self.optimizer.param_groups:\n            if param_group[\"name\"] == \"xyz\":\n                lr = self.xyz_scheduler_args(iteration)\n                param_group[\"lr\"] = lr\n                return lr\n\n    def construct_list_of_attributes(self):\n        l = [\"x\", \"y\", \"z\", \"nx\", \"ny\", \"nz\"]\n        # All channels except the 3 DC\n        for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):\n            l.append(\"f_dc_{}\".format(i))\n        for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]):\n            l.append(\"f_rest_{}\".format(i))\n        l.append(\"opacity\")\n        for i in range(self._scaling.shape[1]):\n            l.append(\"scale_{}\".format(i))\n        for i in range(self._rotation.shape[1]):\n            l.append(\"rot_{}\".format(i))\n        return l\n\n    def save_ply(self, path):\n        mkdir_p(os.path.dirname(path))\n\n        xyz = self._xyz.detach().cpu().numpy()\n        normals = np.zeros_like(xyz)\n        f_dc = (\n            self._features_dc.detach()\n            .transpose(1, 2)\n            .flatten(start_dim=1)\n            .contiguous()\n            .cpu()\n            .numpy()\n        )\n        f_rest = (\n            self._features_rest.detach()\n            .transpose(1, 2)\n            .flatten(start_dim=1)\n            .contiguous()\n            .cpu()\n            .numpy()\n        )\n        opacities = self._opacity.detach().cpu().numpy()\n        scale = self._scaling.detach().cpu().numpy()\n        rotation = self._rotation.detach().cpu().numpy()\n\n        dtype_full = [\n            (attribute, \"f4\") for attribute in self.construct_list_of_attributes()\n        ]\n\n        elements = np.empty(xyz.shape[0], dtype=dtype_full)\n        attributes = np.concatenate(\n            (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1\n        )\n        elements[:] = list(map(tuple, attributes))\n        el = PlyElement.describe(elements, \"vertex\")\n        PlyData([el]).write(path)\n\n    def reset_opacity(self):\n        opacities_new = inverse_sigmoid(\n            torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01)\n        )\n        optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, \"opacity\")\n        self._opacity = optimizable_tensors[\"opacity\"]\n\n    def load_ply(self, path):\n        plydata = PlyData.read(path)\n\n        xyz = np.stack(\n            (\n                np.asarray(plydata.elements[0][\"x\"]),\n                np.asarray(plydata.elements[0][\"y\"]),\n                np.asarray(plydata.elements[0][\"z\"]),\n            ),\n            axis=1,\n        )\n        opacities = np.asarray(plydata.elements[0][\"opacity\"])[..., np.newaxis]\n\n        features_dc = np.zeros((xyz.shape[0], 3, 1))\n        features_dc[:, 0, 0] = np.asarray(plydata.elements[0][\"f_dc_0\"])\n        features_dc[:, 1, 0] = np.asarray(plydata.elements[0][\"f_dc_1\"])\n        features_dc[:, 2, 0] = np.asarray(plydata.elements[0][\"f_dc_2\"])\n\n        extra_f_names = [\n            p.name\n            for p in plydata.elements[0].properties\n            if p.name.startswith(\"f_rest_\")\n        ]\n        extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split(\"_\")[-1]))\n        assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3\n        features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))\n        for idx, attr_name in enumerate(extra_f_names):\n            features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])\n        # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)\n        features_extra = features_extra.reshape(\n            (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)\n        )\n\n        scale_names = [\n            p.name\n            for p in plydata.elements[0].properties\n            if p.name.startswith(\"scale_\")\n        ]\n        scale_names = sorted(scale_names, key=lambda x: int(x.split(\"_\")[-1]))\n        scales = np.zeros((xyz.shape[0], len(scale_names)))\n        for idx, attr_name in enumerate(scale_names):\n            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])\n\n        rot_names = [\n            p.name for p in plydata.elements[0].properties if p.name.startswith(\"rot\")\n        ]\n        rot_names = sorted(rot_names, key=lambda x: int(x.split(\"_\")[-1]))\n        rots = np.zeros((xyz.shape[0], len(rot_names)))\n        for idx, attr_name in enumerate(rot_names):\n            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])\n\n        self._xyz = nn.Parameter(\n            torch.tensor(xyz, dtype=torch.float, device=\"cuda\").requires_grad_(True)\n        )\n        self._features_dc = nn.Parameter(\n            torch.tensor(features_dc, dtype=torch.float, device=\"cuda\")\n            .transpose(1, 2)\n            .contiguous()\n            .requires_grad_(True)\n        )\n        self._features_rest = nn.Parameter(\n            torch.tensor(features_extra, dtype=torch.float, device=\"cuda\")\n            .transpose(1, 2)\n            .contiguous()\n            .requires_grad_(True)\n        )\n        self._opacity = nn.Parameter(\n            torch.tensor(opacities, dtype=torch.float, device=\"cuda\").requires_grad_(\n                True\n            )\n        )\n        self._scaling = nn.Parameter(\n            torch.tensor(scales, dtype=torch.float, device=\"cuda\").requires_grad_(True)\n        )\n        self._rotation = nn.Parameter(\n            torch.tensor(rots, dtype=torch.float, device=\"cuda\").requires_grad_(True)\n        )\n\n        self.active_sh_degree = self.max_sh_degree\n\n    def replace_tensor_to_optimizer(self, tensor, name):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            if group[\"name\"] == name:\n                stored_state = self.optimizer.state.get(group[\"params\"][0], None)\n                stored_state[\"exp_avg\"] = torch.zeros_like(tensor)\n                stored_state[\"exp_avg_sq\"] = torch.zeros_like(tensor)\n\n                del self.optimizer.state[group[\"params\"][0]]\n                group[\"params\"][0] = nn.Parameter(tensor.requires_grad_(True))\n                self.optimizer.state[group[\"params\"][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n        return optimizable_tensors\n\n    def _prune_optimizer(self, mask):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            stored_state = self.optimizer.state.get(group[\"params\"][0], None)\n            if stored_state is not None:\n                stored_state[\"exp_avg\"] = stored_state[\"exp_avg\"][mask]\n                stored_state[\"exp_avg_sq\"] = stored_state[\"exp_avg_sq\"][mask]\n\n                del self.optimizer.state[group[\"params\"][0]]\n                group[\"params\"][0] = nn.Parameter(\n                    (group[\"params\"][0][mask].requires_grad_(True))\n                )\n                self.optimizer.state[group[\"params\"][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n            else:\n                group[\"params\"][0] = nn.Parameter(\n                    group[\"params\"][0][mask].requires_grad_(True)\n                )\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n        return optimizable_tensors\n\n    def prune_points(self, mask):\n        valid_points_mask = ~mask\n        optimizable_tensors = self._prune_optimizer(valid_points_mask)\n\n        self._xyz = optimizable_tensors[\"xyz\"]\n        self._features_dc = optimizable_tensors[\"f_dc\"]\n        self._features_rest = optimizable_tensors[\"f_rest\"]\n        self._opacity = optimizable_tensors[\"opacity\"]\n        self._scaling = optimizable_tensors[\"scaling\"]\n        self._rotation = optimizable_tensors[\"rotation\"]\n\n        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]\n\n        self.denom = self.denom[valid_points_mask]\n        self.max_radii2D = self.max_radii2D[valid_points_mask]\n\n    def cat_tensors_to_optimizer(self, tensors_dict):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            assert len(group[\"params\"]) == 1\n            extension_tensor = tensors_dict[group[\"name\"]]\n            stored_state = self.optimizer.state.get(group[\"params\"][0], None)\n            if stored_state is not None:\n                stored_state[\"exp_avg\"] = torch.cat(\n                    (stored_state[\"exp_avg\"], torch.zeros_like(extension_tensor)), dim=0\n                )\n                stored_state[\"exp_avg_sq\"] = torch.cat(\n                    (stored_state[\"exp_avg_sq\"], torch.zeros_like(extension_tensor)),\n                    dim=0,\n                )\n\n                del self.optimizer.state[group[\"params\"][0]]\n                group[\"params\"][0] = nn.Parameter(\n                    torch.cat(\n                        (group[\"params\"][0], extension_tensor), dim=0\n                    ).requires_grad_(True)\n                )\n                self.optimizer.state[group[\"params\"][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n            else:\n                group[\"params\"][0] = nn.Parameter(\n                    torch.cat(\n                        (group[\"params\"][0], extension_tensor), dim=0\n                    ).requires_grad_(True)\n                )\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n\n        return optimizable_tensors\n\n    def densification_postfix(\n        self,\n        new_xyz,\n        new_features_dc,\n        new_features_rest,\n        new_opacities,\n        new_scaling,\n        new_rotation,\n    ):\n        d = {\n            \"xyz\": new_xyz,\n            \"f_dc\": new_features_dc,\n            \"f_rest\": new_features_rest,\n            \"opacity\": new_opacities,\n            \"scaling\": new_scaling,\n            \"rotation\": new_rotation,\n        }\n\n        optimizable_tensors = self.cat_tensors_to_optimizer(d)\n        self._xyz = optimizable_tensors[\"xyz\"]\n        self._features_dc = optimizable_tensors[\"f_dc\"]\n        self._features_rest = optimizable_tensors[\"f_rest\"]\n        self._opacity = optimizable_tensors[\"opacity\"]\n        self._scaling = optimizable_tensors[\"scaling\"]\n        self._rotation = optimizable_tensors[\"rotation\"]\n\n        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n\n    def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):\n        n_init_points = self.get_xyz.shape[0]\n        # Extract points that satisfy the gradient condition\n        padded_grad = torch.zeros((n_init_points), device=\"cuda\")\n        padded_grad[: grads.shape[0]] = grads.squeeze()\n        selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)\n        selected_pts_mask = torch.logical_and(\n            selected_pts_mask,\n            torch.max(self.get_scaling, dim=1).values\n            > self.percent_dense * scene_extent,\n        )\n\n        stds = self.get_scaling[selected_pts_mask].repeat(N, 1)\n        means = torch.zeros((stds.size(0), 3), device=\"cuda\")\n        samples = torch.normal(mean=means, std=stds)\n        rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1)\n        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[\n            selected_pts_mask\n        ].repeat(N, 1)\n        new_scaling = self.scaling_inverse_activation(\n            self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N)\n        )\n        new_rotation = self._rotation[selected_pts_mask].repeat(N, 1)\n        new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1)\n        new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1)\n        new_opacity = self._opacity[selected_pts_mask].repeat(N, 1)\n\n        self.densification_postfix(\n            new_xyz,\n            new_features_dc,\n            new_features_rest,\n            new_opacity,\n            new_scaling,\n            new_rotation,\n        )\n\n        prune_filter = torch.cat(\n            (\n                selected_pts_mask,\n                torch.zeros(N * selected_pts_mask.sum(), device=\"cuda\", dtype=bool),\n            )\n        )\n        self.prune_points(prune_filter)\n\n    def densify_and_clone(self, grads, grad_threshold, scene_extent):\n        # Extract points that satisfy the gradient condition\n        selected_pts_mask = torch.where(\n            torch.norm(grads, dim=-1) >= grad_threshold, True, False\n        )\n        selected_pts_mask = torch.logical_and(\n            selected_pts_mask,\n            torch.max(self.get_scaling, dim=1).values\n            <= self.percent_dense * scene_extent,\n        )\n\n        new_xyz = self._xyz[selected_pts_mask]\n        new_features_dc = self._features_dc[selected_pts_mask]\n        new_features_rest = self._features_rest[selected_pts_mask]\n        new_opacities = self._opacity[selected_pts_mask]\n        new_scaling = self._scaling[selected_pts_mask]\n        new_rotation = self._rotation[selected_pts_mask]\n\n        self.densification_postfix(\n            new_xyz,\n            new_features_dc,\n            new_features_rest,\n            new_opacities,\n            new_scaling,\n            new_rotation,\n        )\n\n    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):\n        grads = self.xyz_gradient_accum / self.denom\n        grads[grads.isnan()] = 0.0\n\n        self.densify_and_clone(grads, max_grad, extent)\n        self.densify_and_split(grads, max_grad, extent)\n\n        prune_mask = (self.get_opacity < min_opacity).squeeze()\n        if max_screen_size:\n            big_points_vs = self.max_radii2D > max_screen_size\n            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent\n            prune_mask = torch.logical_or(\n                torch.logical_or(prune_mask, big_points_vs), big_points_ws\n            )\n        self.prune_points(prune_mask)\n\n        torch.cuda.empty_cache()\n\n    def add_densification_stats(self, viewspace_point_tensor, update_filter):\n        self.xyz_gradient_accum[update_filter] += torch.norm(\n            viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True\n        )\n        self.denom[update_filter] += 1\n\n    def apply_discrete_offset_filds(self, origin_points, offsets):\n        \"\"\"\n        Args:\n            origin_points: (N_r, 3)\n            offsets: (N_r, 3)\n        \"\"\"\n\n        # since origin points and self._xyz might not be matched, we need to first\n        #   compute the distance between origin points and self._xyz\n        #   then find the nearest point in self._xyz for each origin point\n\n        # compute the distance between origin points and self._xyz\n        # [N_r, num_points]\n        dist = torch.cdist(origin_points, self._xyz)\n        # find the nearest point in self._xyz for each origin point\n        _, idx = torch.min(dist, dim=0)\n\n        # apply offsets\n\n        new_xyz = self._xyz + offsets[idx]\n\n        if self.optimizer is None:\n            optim_state = None\n        else:\n            optim_state = self.optimizer.state_dict()\n\n        new_model_args = (\n            self.active_sh_degree,\n            new_xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            self._rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            optim_state,\n            self.spatial_lr_scale,\n        )\n\n        ret_gaussian = GaussianModel(self.max_sh_degree)\n        ret_gaussian.restore(new_model_args, None)\n\n        return ret_gaussian\n\n    def apply_discrete_offset_filds_with_R(self, origin_points, offsets, topk=6):\n        \"\"\"\n        Args:\n            origin_points: (N_r, 3)\n            offsets: (N_r, 3)\n        \"\"\"\n\n        # since origin points and self._xyz might not be matched, we need to first\n        #   compute the distance between origin points and self._xyz\n        #   then find the nearest point in self._xyz for each origin point\n\n        if self.matched_inds is None:\n            # compute the distance between origin points and self._xyz\n            # [N_r, num_points]\n            dist = torch.cdist(origin_points, self._xyz) * -1.0\n            # find the nearest point in self._xyz for each origin point\n\n            # idxs: [topk, num_points]\n            print(dist.shape, topk, dist[0])\n            _, idxs = torch.topk(dist, topk, dim=0)\n\n            self.matched_inds = idxs\n        else:\n            idxs = self.matched_inds\n\n        # [topk, num_points, 3] => [num_points, topk, 3]\n        matched_topk_offsets = offsets[idxs].transpose(0, 1)\n        source_points = origin_points[idxs].transpose(0, 1)\n\n        # [num_points, 3, 3/1]\n        R, t = get_rigid_transform(source_points, source_points + matched_topk_offsets)\n\n        # new_xyz = R @ self._xyz.unsqueeze(dim=-1) + t\n        # new_xyz = new_xyz.squeeze(dim=-1)\n\n        avg_offsets = matched_topk_offsets.mean(dim=1)\n        new_xyz = self._xyz + avg_offsets  # offset directly\n\n        new_rotation = quaternion_multiply(matrix_to_quaternion(R), self._rotation)\n\n        if self.optimizer is None:\n            optim_state = None\n        else:\n            optim_state = self.optimizer.state_dict()\n\n        new_model_args = (\n            self.active_sh_degree,\n            new_xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            new_rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            optim_state,\n            self.spatial_lr_scale,\n        )\n\n        ret_gaussian = GaussianModel(self.max_sh_degree)\n        ret_gaussian.restore(new_model_args, None)\n\n        return ret_gaussian\n\n    def apply_se3_fields(\n        self,\n        se3_model,\n        timestamp: float,\n        freeze_mask=None,\n    ):\n        \"\"\"\n        Args:\n            se3_model: SE3Model\n            timestamp: float.  in range [0, 1]\n            freeze_mask: [N]\n        \"\"\"\n\n        inp_time = torch.ones_like(self._xyz[:, 0:1]) * timestamp\n        inp = torch.cat([self._xyz, inp_time], dim=-1)\n\n        if freeze_mask is not None:\n            moving_mask = torch.logical_not(freeze_mask)\n            inp = inp[moving_mask, ...]\n        # [bs, 3, 3]. [bs, 3]\n        R, t = se3_model(inp)\n\n        # print(\"abs t mean\", torch.abs(t).mean(dim=0))\n        # new_xyz = (R @ self._xyz.unsqueeze(dim=-1)).squeeze(dim=-1) + t\n\n        if freeze_mask is None:\n            new_xyz = self._xyz + t\n            new_rotation = quaternion_multiply(matrix_to_quaternion(R), self._rotation)\n        else:\n            new_xyz = self._xyz.clone()\n            new_xyz[moving_mask, ...] += t\n            new_rotation = self._rotation.clone()\n            new_rotation[moving_mask, ...] = quaternion_multiply(\n                matrix_to_quaternion(R), self._rotation[moving_mask, ...]\n            )\n\n        if self.optimizer is None:\n            optim_state = None\n        else:\n            optim_state = self.optimizer.state_dict()\n\n        new_model_args = (\n            self.active_sh_degree,\n            new_xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            new_rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            optim_state,\n            self.spatial_lr_scale,\n        )\n\n        ret_gaussian = GaussianModel(self.max_sh_degree)\n        ret_gaussian.restore(new_model_args, None)\n\n        return ret_gaussian\n\n    def apply_offset_fields(self, offset_field, timestamp: float):\n        \"\"\"\n        Args:\n            se3_model: SE3Model\n            timestamp: float.  in range [0, 1]\n        \"\"\"\n\n        inp_time = torch.ones_like(self._xyz[:, 0:1]) * timestamp\n        inp = torch.cat([self._xyz, inp_time], dim=-1)\n        # [bs, 3, 3]. [bs, 3]\n        offsets = offset_field(inp)\n\n        # print(\"abs t mean\", torch.abs(t).mean(dim=0))\n        new_xyz = self._xyz + offsets\n\n        if self.optimizer is None:\n            optim_state = None\n        else:\n            optim_state = self.optimizer.state_dict()\n\n        new_model_args = (\n            self.active_sh_degree,\n            new_xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            self._rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            optim_state,\n            self.spatial_lr_scale,\n        )\n\n        ret_gaussian = GaussianModel(self.max_sh_degree)\n        ret_gaussian.restore(new_model_args, None)\n\n        return ret_gaussian\n\n    def apply_offset_fields_with_R(self, offset_field, timestamp: float, eps=1e-2):\n        \"\"\"\n        Args:\n            se3_model: SE3Model\n            timestamp: float.  in range [0, 1]\n        \"\"\"\n\n        # [4, 3]\n        inp_perterb = (\n            torch.tensor(\n                [\n                    [0.0, 0.0, 0.0],  # add this will coplanar?\n                    [+eps, -eps, -eps],\n                    [-eps, -eps, +eps],\n                    [-eps, +eps, -eps],\n                    [+eps, +eps, +eps],\n                ],\n            )\n            .to(self._xyz.device)\n            .float()\n        )\n        #  => [N, 4, 3]\n        source_points = self._xyz.unsqueeze(dim=1) + inp_perterb.unsqueeze(dim=0)\n        num_points = source_points.shape[0]\n\n        inpx = source_points.flatten(end_dim=1)\n        inp_time = torch.ones_like(inpx[:, 0:1]) * timestamp\n\n        inp = torch.cat([inpx, inp_time], dim=-1)\n\n        sampled_offsets = offset_field(inp).reshape((num_points, -1, 3))\n\n        R, t = get_rigid_transform(source_points, source_points + sampled_offsets)\n\n        # new_xyz = R @ self._xyz.unsqueeze(dim=-1) + t\n        # new_xyz = new_xyz.squeeze(dim=-1)\n\n        avg_offsets = sampled_offsets.mean(dim=1)\n        new_xyz = self._xyz + avg_offsets  # offset directly\n\n        new_rotation = quaternion_multiply(matrix_to_quaternion(R), self._rotation)\n\n        if self.optimizer is None:\n            optim_state = None\n        else:\n            optim_state = self.optimizer.state_dict()\n\n        new_model_args = (\n            self.active_sh_degree,\n            new_xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            new_rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            optim_state,\n            self.spatial_lr_scale,\n        )\n\n        ret_gaussian = GaussianModel(self.max_sh_degree)\n        ret_gaussian.restore(new_model_args, None)\n\n        return ret_gaussian\n\n    def init_from_mesh(\n        self,\n        mesh_path: str,\n        num_gaussians: int = 10000,\n    ):\n        import point_cloud_utils as pcu\n\n        mesh = pcu.load_triangle_mesh(mesh_path)\n\n        v, f = mesh.v, mesh.f\n\n        v_n = pcu.estimate_mesh_normals(v, f)\n        vert_colors = mesh.vertex_data.colors\n\n        fid, bc = pcu.sample_mesh_random(v, f, num_gaussians)\n\n        # Interpolate the vertex positions and normals using the returned barycentric coordinates\n        # to get sample positions and normals\n        rand_positions = pcu.interpolate_barycentric_coords(f, fid, bc, v)\n        rand_normals = pcu.interpolate_barycentric_coords(f, fid, bc, v_n)\n        rand_colors = pcu.interpolate_barycentric_coords(f, fid, bc, vert_colors)[:, :3]\n\n        # copy original pointcloud init functions\n\n        fused_point_cloud = torch.tensor(np.asarray(rand_positions)).float().cuda()\n        fused_color = RGB2SH(torch.tensor(np.asarray(rand_colors)).float().cuda())\n        features = (\n            torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2))\n            .float()\n            .cuda()\n        )\n        features[:, :3, 0] = fused_color\n        # typo here?\n        features[:, 3:, 1:] = 0.0\n\n        print(\"Number of points at initialisation : \", fused_point_cloud.shape[0])\n\n        dist2 = torch.clamp_min(\n            distCUDA2(torch.from_numpy(np.asarray(rand_positions)).float().cuda()),\n            0.0000001,\n        )\n        scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)\n        rots = torch.zeros((fused_point_cloud.shape[0], 4), device=\"cuda\")\n        rots[:, 0] = 1\n\n        opacities = inverse_sigmoid(\n            0.1\n            * torch.ones(\n                (fused_point_cloud.shape[0], 1), dtype=torch.float, device=\"cuda\"\n            )\n        )\n\n        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))\n        self._features_dc = nn.Parameter(\n            features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)\n        )\n        self._features_rest = nn.Parameter(\n            features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)\n        )\n        self._scaling = nn.Parameter(scales.requires_grad_(True))\n        self._rotation = nn.Parameter(rots.requires_grad_(True))\n        self._opacity = nn.Parameter(opacities.requires_grad_(True))\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n\n    def detach_grad(\n        self,\n    ):\n        self._xyz.requires_grad = False\n        self._features_dc.requires_grad = False\n        self._features_rest.requires_grad = False\n        self._scaling.requires_grad = False\n        self._rotation.requires_grad = False\n        self._opacity.requires_grad = False\n\n    def apply_mask(self, mask):\n        new_xyz = self._xyz[mask]\n        if self.xyz_gradient_accum.shape == self._xyz.shape:\n            new_xyz_gradient_accum = self.xyz_gradient_accum[mask]\n            new_denom = self.denom[mask]\n        else:\n            new_xyz_gradient_accum = self.xyz_gradient_accum\n            new_denom = self.denom\n        new_model_args = (\n            self.active_sh_degree,\n            new_xyz,\n            self._features_dc[mask],\n            self._features_rest[mask],\n            self._scaling[mask],\n            self._rotation[mask],\n            self._opacity[mask],\n            self.max_radii2D,\n            new_xyz_gradient_accum,\n            new_denom,\n            None,\n            self.spatial_lr_scale,\n        )\n\n        ret_gaussian = GaussianModel(self.max_sh_degree)\n        ret_gaussian.restore(new_model_args, None)\n\n        return ret_gaussian\n\n    @torch.no_grad()\n    def extract_fields(self, resolution=128, num_blocks=16, relax_ratio=1.5):\n        # resolution: resolution of field\n\n        block_size = 2 / num_blocks\n\n        assert resolution % block_size == 0\n        split_size = resolution // num_blocks\n\n        opacities = self.get_opacity\n\n        # pre-filter low opacity gaussians to save computation\n        mask = (opacities > 0.005).squeeze(1)\n\n        opacities = opacities[mask]\n        xyzs = self.get_xyz[mask]\n        stds = self.get_scaling[mask]\n\n        # normalize to ~ [-1, 1]\n        mn, mx = xyzs.amin(0), xyzs.amax(0)\n        self.center = (mn + mx) / 2\n        self.scale = 1.0 / (mx - mn).amax().item()\n\n        print(\"gaussian center, scale\", self.center, self.scale)\n        xyzs = (xyzs - self.center) * self.scale\n        stds = stds * self.scale\n\n        covs = self.covariance_activation(stds, 1, self._rotation[mask])\n\n        # tile\n        device = opacities.device\n        occ = torch.zeros([resolution] * 3, dtype=torch.float32, device=device)\n\n        X = torch.linspace(-1, 1, resolution).split(split_size)\n        Y = torch.linspace(-1, 1, resolution).split(split_size)\n        Z = torch.linspace(-1, 1, resolution).split(split_size)\n\n        # loop blocks (assume max size of gaussian is small than relax_ratio * block_size !!!)\n        for xi, xs in enumerate(X):\n            for yi, ys in enumerate(Y):\n                for zi, zs in enumerate(Z):\n                    xx, yy, zz = torch.meshgrid(xs, ys, zs)\n                    # sample points [M, 3]\n                    pts = torch.cat(\n                        [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],\n                        dim=-1,\n                    ).to(device)\n                    # in-tile gaussians mask\n                    vmin, vmax = pts.amin(0), pts.amax(0)\n                    vmin -= block_size * relax_ratio\n                    vmax += block_size * relax_ratio\n                    mask = (xyzs < vmax).all(-1) & (xyzs > vmin).all(-1)\n                    # if hit no gaussian, continue to next block\n                    if not mask.any():\n                        continue\n                    mask_xyzs = xyzs[mask]  # [L, 3]\n                    mask_covs = covs[mask]  # [L, 6]\n                    mask_opas = opacities[mask].view(1, -1)  # [L, 1] --> [1, L]\n\n                    # query per point-gaussian pair.\n                    g_pts = pts.unsqueeze(1).repeat(\n                        1, mask_covs.shape[0], 1\n                    ) - mask_xyzs.unsqueeze(\n                        0\n                    )  # [M, L, 3]\n                    g_covs = mask_covs.unsqueeze(0).repeat(\n                        pts.shape[0], 1, 1\n                    )  # [M, L, 6]\n\n                    # batch on gaussian to avoid OOM\n                    batch_g = 1024\n                    val = 0\n                    for start in range(0, g_covs.shape[1], batch_g):\n                        end = min(start + batch_g, g_covs.shape[1])\n                        w = gaussian_3d_coeff(\n                            g_pts[:, start:end].reshape(-1, 3),\n                            g_covs[:, start:end].reshape(-1, 6),\n                        ).reshape(\n                            pts.shape[0], -1\n                        )  # [M, l]\n                        val += (mask_opas[:, start:end] * w).sum(-1)\n\n                    # kiui.lo(val, mask_opas, w)\n\n                    occ[\n                        xi * split_size : xi * split_size + len(xs),\n                        yi * split_size : yi * split_size + len(ys),\n                        zi * split_size : zi * split_size + len(zs),\n                    ] = val.reshape(len(xs), len(ys), len(zs))\n\n        return occ\n\n    def extract_mesh(self, path, density_thresh=1, resolution=128, decimate_target=1e5):\n        os.makedirs(os.path.dirname(path), exist_ok=True)\n\n        from physdreamer.gaussian_3d.scene.mesh import Mesh\n        from physdreamer.gaussian_3d.scene.mesh_utils import decimate_mesh, clean_mesh\n\n        occ = self.extract_fields(resolution).detach().cpu().numpy()\n\n        print(occ.shape, occ.min(), occ.max(), occ.mean(), \"occ stats\")\n        print(np.percentile(occ, [0, 1, 5, 10, 50, 90, 95, 99, 100]), \"occ percentiles\")\n        import mcubes\n\n        vertices, triangles = mcubes.marching_cubes(occ, density_thresh)\n        vertices = vertices / (resolution - 1.0) * 2 - 1\n\n        # transform back to the original space\n        vertices = vertices / self.scale + self.center.detach().cpu().numpy()\n\n        vertices, triangles = clean_mesh(\n            vertices, triangles, remesh=True, remesh_size=0.015\n        )\n        if decimate_target > 0 and triangles.shape[0] > decimate_target:\n            vertices, triangles = decimate_mesh(vertices, triangles, decimate_target)\n\n        v = torch.from_numpy(vertices.astype(np.float32)).contiguous().cuda()\n        f = torch.from_numpy(triangles.astype(np.int32)).contiguous().cuda()\n\n        print(\n            f\"[INFO] marching cubes result: {v.shape} ({v.min().item()}-{v.max().item()}), {f.shape}\"\n        )\n\n        mesh = Mesh(v=v, f=f, device=\"cuda\")\n\n        return mesh\n\n\ndef gaussian_3d_coeff(xyzs, covs):\n    # xyzs: [N, 3]\n    # covs: [N, 6]\n    x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]\n    a, b, c, d, e, f = (\n        covs[:, 0],\n        covs[:, 1],\n        covs[:, 2],\n        covs[:, 3],\n        covs[:, 4],\n        covs[:, 5],\n    )\n\n    # eps must be small enough !!!\n    inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)\n    inv_a = (d * f - e**2) * inv_det\n    inv_b = (e * c - b * f) * inv_det\n    inv_c = (e * b - c * d) * inv_det\n    inv_d = (a * f - c**2) * inv_det\n    inv_e = (b * c - e * a) * inv_det\n    inv_f = (a * d - b**2) * inv_det\n\n    power = (\n        -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f)\n        - x * y * inv_b\n        - x * z * inv_c\n        - y * z * inv_e\n    )\n\n    power[power > 0] = -1e10  # abnormal values... make weights 0\n\n    return torch.exp(power)\n"
  },
  {
    "path": "physdreamer/gaussian_3d/scene/mesh.py",
    "content": "import os\nimport cv2\nimport torch\nimport trimesh\nimport numpy as np\n\n\ndef dot(x, y):\n    return torch.sum(x * y, -1, keepdim=True)\n\n\ndef length(x, eps=1e-20):\n    return torch.sqrt(torch.clamp(dot(x, x), min=eps))\n\n\ndef safe_normalize(x, eps=1e-20):\n    return x / length(x, eps)\n\n\nclass Mesh:\n    def __init__(\n        self,\n        v=None,\n        f=None,\n        vn=None,\n        fn=None,\n        vt=None,\n        ft=None,\n        albedo=None,\n        vc=None,  # vertex color\n        device=None,\n    ):\n        self.device = device\n        self.v = v\n        self.vn = vn\n        self.vt = vt\n        self.f = f\n        self.fn = fn\n        self.ft = ft\n        # only support a single albedo\n        self.albedo = albedo\n        # support vertex color is no albedo\n        self.vc = vc\n\n        self.ori_center = 0\n        self.ori_scale = 1\n\n    @classmethod\n    def load(\n        cls,\n        path=None,\n        resize=True,\n        renormal=True,\n        retex=False,\n        front_dir=\"+z\",\n        **kwargs,\n    ):\n        # assume init with kwargs\n        if path is None:\n            mesh = cls(**kwargs)\n        # obj supports face uv\n        elif path.endswith(\".obj\"):\n            mesh = cls.load_obj(path, **kwargs)\n        # trimesh only supports vertex uv, but can load more formats\n        else:\n            mesh = cls.load_trimesh(path, **kwargs)\n\n        print(f\"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}\")\n        # auto-normalize\n        if resize:\n            mesh.auto_size()\n        # auto-fix normal\n        if renormal or mesh.vn is None:\n            mesh.auto_normal()\n            print(f\"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}\")\n        # auto-fix texcoords\n        if retex or (mesh.albedo is not None and mesh.vt is None):\n            mesh.auto_uv(cache_path=path)\n            print(f\"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}\")\n\n        # rotate front dir to +z\n        if front_dir != \"+z\":\n            # axis switch\n            if \"-z\" in front_dir:\n                T = torch.tensor(\n                    [[1, 0, 0], [0, 1, 0], [0, 0, -1]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            elif \"+x\" in front_dir:\n                T = torch.tensor(\n                    [[0, 0, 1], [0, 1, 0], [1, 0, 0]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            elif \"-x\" in front_dir:\n                T = torch.tensor(\n                    [[0, 0, -1], [0, 1, 0], [1, 0, 0]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            elif \"+y\" in front_dir:\n                T = torch.tensor(\n                    [[1, 0, 0], [0, 0, 1], [0, 1, 0]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            elif \"-y\" in front_dir:\n                T = torch.tensor(\n                    [[1, 0, 0], [0, 0, -1], [0, 1, 0]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            else:\n                T = torch.tensor(\n                    [[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            # rotation (how many 90 degrees)\n            if \"1\" in front_dir:\n                T @= torch.tensor(\n                    [[0, -1, 0], [1, 0, 0], [0, 0, 1]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            elif \"2\" in front_dir:\n                T @= torch.tensor(\n                    [[1, 0, 0], [0, -1, 0], [0, 0, 1]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            elif \"3\" in front_dir:\n                T @= torch.tensor(\n                    [[0, 1, 0], [-1, 0, 0], [0, 0, 1]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            mesh.v @= T\n            mesh.vn @= T\n\n        return mesh\n\n    # load from obj file\n    @classmethod\n    def load_obj(cls, path, albedo_path=None, device=None):\n        assert os.path.splitext(path)[-1] == \".obj\"\n\n        mesh = cls()\n\n        # device\n        if device is None:\n            device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n        mesh.device = device\n\n        # load obj\n        with open(path, \"r\") as f:\n            lines = f.readlines()\n\n        def parse_f_v(fv):\n            # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)\n            # supported forms:\n            # f v1 v2 v3\n            # f v1/vt1 v2/vt2 v3/vt3\n            # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3\n            # f v1//vn1 v2//vn2 v3//vn3\n            xs = [int(x) - 1 if x != \"\" else -1 for x in fv.split(\"/\")]\n            xs.extend([-1] * (3 - len(xs)))\n            return xs[0], xs[1], xs[2]\n\n        # NOTE: we ignore usemtl, and assume the mesh ONLY uses one material (first in mtl)\n        vertices, texcoords, normals = [], [], []\n        faces, tfaces, nfaces = [], [], []\n        mtl_path = None\n\n        for line in lines:\n            split_line = line.split()\n            # empty line\n            if len(split_line) == 0:\n                continue\n            prefix = split_line[0].lower()\n            # mtllib\n            if prefix == \"mtllib\":\n                mtl_path = split_line[1]\n            # usemtl\n            elif prefix == \"usemtl\":\n                pass  # ignored\n            # v/vn/vt\n            elif prefix == \"v\":\n                vertices.append([float(v) for v in split_line[1:]])\n            elif prefix == \"vn\":\n                normals.append([float(v) for v in split_line[1:]])\n            elif prefix == \"vt\":\n                val = [float(v) for v in split_line[1:]]\n                texcoords.append([val[0], 1.0 - val[1]])\n            elif prefix == \"f\":\n                vs = split_line[1:]\n                nv = len(vs)\n                v0, t0, n0 = parse_f_v(vs[0])\n                for i in range(nv - 2):  # triangulate (assume vertices are ordered)\n                    v1, t1, n1 = parse_f_v(vs[i + 1])\n                    v2, t2, n2 = parse_f_v(vs[i + 2])\n                    faces.append([v0, v1, v2])\n                    tfaces.append([t0, t1, t2])\n                    nfaces.append([n0, n1, n2])\n\n        mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)\n        mesh.vt = (\n            torch.tensor(texcoords, dtype=torch.float32, device=device)\n            if len(texcoords) > 0\n            else None\n        )\n        mesh.vn = (\n            torch.tensor(normals, dtype=torch.float32, device=device)\n            if len(normals) > 0\n            else None\n        )\n\n        mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)\n        mesh.ft = (\n            torch.tensor(tfaces, dtype=torch.int32, device=device)\n            if len(texcoords) > 0\n            else None\n        )\n        mesh.fn = (\n            torch.tensor(nfaces, dtype=torch.int32, device=device)\n            if len(normals) > 0\n            else None\n        )\n\n        # see if there is vertex color\n        use_vertex_color = False\n        if mesh.v.shape[1] == 6:\n            use_vertex_color = True\n            mesh.vc = mesh.v[:, 3:]\n            mesh.v = mesh.v[:, :3]\n            print(f\"[load_obj] use vertex color: {mesh.vc.shape}\")\n\n        # try to load texture image\n        if not use_vertex_color:\n            # try to retrieve mtl file\n            mtl_path_candidates = []\n            if mtl_path is not None:\n                mtl_path_candidates.append(mtl_path)\n                mtl_path_candidates.append(\n                    os.path.join(os.path.dirname(path), mtl_path)\n                )\n            mtl_path_candidates.append(path.replace(\".obj\", \".mtl\"))\n\n            mtl_path = None\n            for candidate in mtl_path_candidates:\n                if os.path.exists(candidate):\n                    mtl_path = candidate\n                    break\n\n            # if albedo_path is not provided, try retrieve it from mtl\n            if mtl_path is not None and albedo_path is None:\n                with open(mtl_path, \"r\") as f:\n                    lines = f.readlines()\n                for line in lines:\n                    split_line = line.split()\n                    # empty line\n                    if len(split_line) == 0:\n                        continue\n                    prefix = split_line[0]\n                    # NOTE: simply use the first map_Kd as albedo!\n                    if \"map_Kd\" in prefix:\n                        albedo_path = os.path.join(os.path.dirname(path), split_line[1])\n                        print(f\"[load_obj] use texture from: {albedo_path}\")\n                        break\n\n            # still not found albedo_path, or the path doesn't exist\n            if albedo_path is None or not os.path.exists(albedo_path):\n                # init an empty texture\n                print(f\"[load_obj] init empty albedo!\")\n                # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)\n                albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array(\n                    [0.5, 0.5, 0.5]\n                )  # default color\n            else:\n                albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)\n                albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)\n                albedo = albedo.astype(np.float32) / 255\n                print(f\"[load_obj] load texture: {albedo.shape}\")\n\n                # import matplotlib.pyplot as plt\n                # plt.imshow(albedo)\n                # plt.show()\n\n            mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)\n\n        return mesh\n\n    @classmethod\n    def load_trimesh(cls, path, device=None):\n        mesh = cls()\n\n        # device\n        if device is None:\n            device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n        mesh.device = device\n\n        # use trimesh to load ply/glb, assume only has one single RootMesh...\n        _data = trimesh.load(path)\n        if isinstance(_data, trimesh.Scene):\n            if len(_data.geometry) == 1:\n                _mesh = list(_data.geometry.values())[0]\n            else:\n                # manual concat, will lose texture\n                _concat = []\n                for g in _data.geometry.values():\n                    if isinstance(g, trimesh.Trimesh):\n                        _concat.append(g)\n                _mesh = trimesh.util.concatenate(_concat)\n        else:\n            _mesh = _data\n\n        if _mesh.visual.kind == \"vertex\":\n            vertex_colors = _mesh.visual.vertex_colors\n            vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255\n            mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)\n            print(f\"[load_trimesh] use vertex color: {mesh.vc.shape}\")\n        elif _mesh.visual.kind == \"texture\":\n            _material = _mesh.visual.material\n            if isinstance(_material, trimesh.visual.material.PBRMaterial):\n                texture = np.array(_material.baseColorTexture).astype(np.float32) / 255\n            elif isinstance(_material, trimesh.visual.material.SimpleMaterial):\n                texture = (\n                    np.array(_material.to_pbr().baseColorTexture).astype(np.float32)\n                    / 255\n                )\n            else:\n                raise NotImplementedError(\n                    f\"material type {type(_material)} not supported!\"\n                )\n            mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)\n            print(f\"[load_trimesh] load texture: {texture.shape}\")\n        else:\n            texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array(\n                [0.5, 0.5, 0.5]\n            )\n            mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)\n            print(f\"[load_trimesh] failed to load texture.\")\n\n        vertices = _mesh.vertices\n\n        try:\n            texcoords = _mesh.visual.uv\n            texcoords[:, 1] = 1 - texcoords[:, 1]\n        except Exception as e:\n            texcoords = None\n\n        try:\n            normals = _mesh.vertex_normals\n        except Exception as e:\n            normals = None\n\n        # trimesh only support vertex uv...\n        faces = tfaces = nfaces = _mesh.faces\n\n        mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)\n        mesh.vt = (\n            torch.tensor(texcoords, dtype=torch.float32, device=device)\n            if texcoords is not None\n            else None\n        )\n        mesh.vn = (\n            torch.tensor(normals, dtype=torch.float32, device=device)\n            if normals is not None\n            else None\n        )\n\n        mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)\n        mesh.ft = (\n            torch.tensor(tfaces, dtype=torch.int32, device=device)\n            if texcoords is not None\n            else None\n        )\n        mesh.fn = (\n            torch.tensor(nfaces, dtype=torch.int32, device=device)\n            if normals is not None\n            else None\n        )\n\n        return mesh\n\n    # aabb\n    def aabb(self):\n        return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values\n\n    # unit size\n    @torch.no_grad()\n    def auto_size(self):\n        vmin, vmax = self.aabb()\n        self.ori_center = (vmax + vmin) / 2\n        self.ori_scale = 1.2 / torch.max(vmax - vmin).item()\n        self.v = (self.v - self.ori_center) * self.ori_scale\n\n    def auto_normal(self):\n        i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()\n        v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]\n\n        face_normals = torch.cross(v1 - v0, v2 - v0)\n\n        # Splat face normals to vertices\n        vn = torch.zeros_like(self.v)\n        vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)\n        vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)\n        vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)\n\n        # Normalize, replace zero (degenerated) normals with some default value\n        vn = torch.where(\n            dot(vn, vn) > 1e-20,\n            vn,\n            torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),\n        )\n        vn = safe_normalize(vn)\n\n        self.vn = vn\n        self.fn = self.f\n\n    def auto_uv(self, cache_path=None, vmap=True):\n        # try to load cache\n        if cache_path is not None:\n            cache_path = os.path.splitext(cache_path)[0] + \"_uv.npz\"\n        if cache_path is not None and os.path.exists(cache_path):\n            data = np.load(cache_path)\n            vt_np, ft_np, vmapping = data[\"vt\"], data[\"ft\"], data[\"vmapping\"]\n        else:\n            import xatlas\n\n            v_np = self.v.detach().cpu().numpy()\n            f_np = self.f.detach().int().cpu().numpy()\n            atlas = xatlas.Atlas()\n            atlas.add_mesh(v_np, f_np)\n            chart_options = xatlas.ChartOptions()\n            # chart_options.max_iterations = 4\n            atlas.generate(chart_options=chart_options)\n            vmapping, ft_np, vt_np = atlas[0]  # [N], [M, 3], [N, 2]\n\n            # save to cache\n            if cache_path is not None:\n                np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)\n\n        vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)\n        ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)\n        self.vt = vt\n        self.ft = ft\n\n        if vmap:\n            # remap v/f to vt/ft, so each v correspond to a unique vt. (necessary for gltf)\n            vmapping = (\n                torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)\n            )\n            self.align_v_to_vt(vmapping)\n\n    def align_v_to_vt(self, vmapping=None):\n        # remap v/f and vn/vn to vt/ft.\n        if vmapping is None:\n            ft = self.ft.view(-1).long()\n            f = self.f.view(-1).long()\n            vmapping = torch.zeros(\n                self.vt.shape[0], dtype=torch.long, device=self.device\n            )\n            vmapping[ft] = f  # scatter, randomly choose one if index is not unique\n\n        self.v = self.v[vmapping]\n        self.f = self.ft\n        # assume fn == f\n        if self.vn is not None:\n            self.vn = self.vn[vmapping]\n            self.fn = self.ft\n\n    def to(self, device):\n        self.device = device\n        for name in [\"v\", \"f\", \"vn\", \"fn\", \"vt\", \"ft\", \"albedo\"]:\n            tensor = getattr(self, name)\n            if tensor is not None:\n                setattr(self, name, tensor.to(device))\n        return self\n\n    def write(self, path):\n        if path.endswith(\".ply\"):\n            self.write_ply(path)\n        elif path.endswith(\".obj\"):\n            self.write_obj(path)\n        elif path.endswith(\".glb\") or path.endswith(\".gltf\"):\n            self.write_glb(path)\n        else:\n            raise NotImplementedError(f\"format {path} not supported!\")\n\n    # write to ply file (only geom)\n    def write_ply(self, path):\n        v_np = self.v.detach().cpu().numpy()\n        f_np = self.f.detach().cpu().numpy()\n\n        _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)\n        _mesh.export(path)\n\n    # write to gltf/glb file (geom + texture)\n    def write_glb(self, path):\n        assert (\n            self.vn is not None and self.vt is not None\n        )  # should be improved to support export without texture...\n\n        # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]\n        if self.v.shape[0] != self.vt.shape[0]:\n            self.align_v_to_vt()\n\n        # assume f == fn == ft\n\n        import pygltflib\n\n        f_np = self.f.detach().cpu().numpy().astype(np.uint32)\n        v_np = self.v.detach().cpu().numpy().astype(np.float32)\n        # vn_np = self.vn.detach().cpu().numpy().astype(np.float32)\n        vt_np = self.vt.detach().cpu().numpy().astype(np.float32)\n\n        albedo = self.albedo.detach().cpu().numpy()\n        albedo = (albedo * 255).astype(np.uint8)\n        albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)\n\n        f_np_blob = f_np.flatten().tobytes()\n        v_np_blob = v_np.tobytes()\n        # vn_np_blob = vn_np.tobytes()\n        vt_np_blob = vt_np.tobytes()\n        albedo_blob = cv2.imencode(\".png\", albedo)[1].tobytes()\n\n        gltf = pygltflib.GLTF2(\n            scene=0,\n            scenes=[pygltflib.Scene(nodes=[0])],\n            nodes=[pygltflib.Node(mesh=0)],\n            meshes=[\n                pygltflib.Mesh(\n                    primitives=[\n                        pygltflib.Primitive(\n                            # indices to accessors (0 is triangles)\n                            attributes=pygltflib.Attributes(\n                                POSITION=1,\n                                TEXCOORD_0=2,\n                            ),\n                            indices=0,\n                            material=0,\n                        )\n                    ]\n                )\n            ],\n            materials=[\n                pygltflib.Material(\n                    pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(\n                        baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),\n                        metallicFactor=0.0,\n                        roughnessFactor=1.0,\n                    ),\n                    alphaCutoff=0,\n                    doubleSided=True,\n                )\n            ],\n            textures=[\n                pygltflib.Texture(sampler=0, source=0),\n            ],\n            samplers=[\n                pygltflib.Sampler(\n                    magFilter=pygltflib.LINEAR,\n                    minFilter=pygltflib.LINEAR_MIPMAP_LINEAR,\n                    wrapS=pygltflib.REPEAT,\n                    wrapT=pygltflib.REPEAT,\n                ),\n            ],\n            images=[\n                # use embedded (buffer) image\n                pygltflib.Image(bufferView=3, mimeType=\"image/png\"),\n            ],\n            buffers=[\n                pygltflib.Buffer(\n                    byteLength=len(f_np_blob)\n                    + len(v_np_blob)\n                    + len(vt_np_blob)\n                    + len(albedo_blob)\n                )\n            ],\n            # buffer view (based on dtype)\n            bufferViews=[\n                # triangles; as flatten (element) array\n                pygltflib.BufferView(\n                    buffer=0,\n                    byteLength=len(f_np_blob),\n                    target=pygltflib.ELEMENT_ARRAY_BUFFER,  # GL_ELEMENT_ARRAY_BUFFER (34963)\n                ),\n                # positions; as vec3 array\n                pygltflib.BufferView(\n                    buffer=0,\n                    byteOffset=len(f_np_blob),\n                    byteLength=len(v_np_blob),\n                    byteStride=12,  # vec3\n                    target=pygltflib.ARRAY_BUFFER,  # GL_ARRAY_BUFFER (34962)\n                ),\n                # texcoords; as vec2 array\n                pygltflib.BufferView(\n                    buffer=0,\n                    byteOffset=len(f_np_blob) + len(v_np_blob),\n                    byteLength=len(vt_np_blob),\n                    byteStride=8,  # vec2\n                    target=pygltflib.ARRAY_BUFFER,\n                ),\n                # texture; as none target\n                pygltflib.BufferView(\n                    buffer=0,\n                    byteOffset=len(f_np_blob) + len(v_np_blob) + len(vt_np_blob),\n                    byteLength=len(albedo_blob),\n                ),\n            ],\n            accessors=[\n                # 0 = triangles\n                pygltflib.Accessor(\n                    bufferView=0,\n                    componentType=pygltflib.UNSIGNED_INT,  # GL_UNSIGNED_INT (5125)\n                    count=f_np.size,\n                    type=pygltflib.SCALAR,\n                    max=[int(f_np.max())],\n                    min=[int(f_np.min())],\n                ),\n                # 1 = positions\n                pygltflib.Accessor(\n                    bufferView=1,\n                    componentType=pygltflib.FLOAT,  # GL_FLOAT (5126)\n                    count=len(v_np),\n                    type=pygltflib.VEC3,\n                    max=v_np.max(axis=0).tolist(),\n                    min=v_np.min(axis=0).tolist(),\n                ),\n                # 2 = texcoords\n                pygltflib.Accessor(\n                    bufferView=2,\n                    componentType=pygltflib.FLOAT,\n                    count=len(vt_np),\n                    type=pygltflib.VEC2,\n                    max=vt_np.max(axis=0).tolist(),\n                    min=vt_np.min(axis=0).tolist(),\n                ),\n            ],\n        )\n\n        # set actual data\n        gltf.set_binary_blob(f_np_blob + v_np_blob + vt_np_blob + albedo_blob)\n\n        # glb = b\"\".join(gltf.save_to_bytes())\n        gltf.save(path)\n\n    # write to obj file (geom + texture)\n    def write_obj(self, path):\n        mtl_path = path.replace(\".obj\", \".mtl\")\n        albedo_path = path.replace(\".obj\", \"_albedo.png\")\n\n        v_np = self.v.detach().cpu().numpy()\n        vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None\n        vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None\n        f_np = self.f.detach().cpu().numpy()\n        ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None\n        fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None\n\n        with open(path, \"w\") as fp:\n            fp.write(f\"mtllib {os.path.basename(mtl_path)} \\n\")\n\n            for v in v_np:\n                fp.write(f\"v {v[0]} {v[1]} {v[2]} \\n\")\n\n            if vt_np is not None:\n                for v in vt_np:\n                    fp.write(f\"vt {v[0]} {1 - v[1]} \\n\")\n\n            if vn_np is not None:\n                for v in vn_np:\n                    fp.write(f\"vn {v[0]} {v[1]} {v[2]} \\n\")\n\n            fp.write(f\"usemtl defaultMat \\n\")\n            for i in range(len(f_np)):\n                fp.write(\n                    f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else \"\"}/{fn_np[i, 0] + 1 if fn_np is not None else \"\"} \\\n                             {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else \"\"}/{fn_np[i, 1] + 1 if fn_np is not None else \"\"} \\\n                             {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else \"\"}/{fn_np[i, 2] + 1 if fn_np is not None else \"\"} \\n'\n                )\n\n        with open(mtl_path, \"w\") as fp:\n            fp.write(f\"newmtl defaultMat \\n\")\n            fp.write(f\"Ka 1 1 1 \\n\")\n            fp.write(f\"Kd 1 1 1 \\n\")\n            fp.write(f\"Ks 0 0 0 \\n\")\n            fp.write(f\"Tr 1 \\n\")\n            fp.write(f\"illum 1 \\n\")\n            fp.write(f\"Ns 0 \\n\")\n            fp.write(f\"map_Kd {os.path.basename(albedo_path)} \\n\")\n\n        if not (False or self.albedo is None):\n            albedo = self.albedo.detach().cpu().numpy()\n            albedo = (albedo * 255).astype(np.uint8)\n            cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))\n"
  },
  {
    "path": "physdreamer/gaussian_3d/scene/mesh_utils.py",
    "content": "import numpy as np\nimport pymeshlab as pml\n\n\ndef poisson_mesh_reconstruction(points, normals=None):\n    # points/normals: [N, 3] np.ndarray\n\n    import open3d as o3d\n\n    pcd = o3d.geometry.PointCloud()\n    pcd.points = o3d.utility.Vector3dVector(points)\n\n    # outlier removal\n    pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)\n\n    # normals\n    if normals is None:\n        pcd.estimate_normals()\n    else:\n        pcd.normals = o3d.utility.Vector3dVector(normals[ind])\n\n    # visualize\n    o3d.visualization.draw_geometries([pcd], point_show_normal=False)\n\n    mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(\n        pcd, depth=9\n    )\n    vertices_to_remove = densities < np.quantile(densities, 0.1)\n    mesh.remove_vertices_by_mask(vertices_to_remove)\n\n    # visualize\n    o3d.visualization.draw_geometries([mesh])\n\n    vertices = np.asarray(mesh.vertices)\n    triangles = np.asarray(mesh.triangles)\n\n    print(\n        f\"[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}\"\n    )\n\n    return vertices, triangles\n\n\ndef decimate_mesh(\n    verts, faces, target, backend=\"pymeshlab\", remesh=False, optimalplacement=True\n):\n    # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.\n\n    _ori_vert_shape = verts.shape\n    _ori_face_shape = faces.shape\n\n    if backend == \"pyfqmr\":\n        import pyfqmr\n\n        solver = pyfqmr.Simplify()\n        solver.setMesh(verts, faces)\n        solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)\n        verts, faces, normals = solver.getMesh()\n    else:\n        m = pml.Mesh(verts, faces)\n        ms = pml.MeshSet()\n        ms.add_mesh(m, \"mesh\")  # will copy!\n\n        # filters\n        # ms.meshing_decimation_clustering(threshold=pml.PercentageValue(1))\n        ms.meshing_decimation_quadric_edge_collapse(\n            targetfacenum=int(target), optimalplacement=optimalplacement\n        )\n\n        if remesh:\n            # ms.apply_coord_taubin_smoothing()\n            ms.meshing_isotropic_explicit_remeshing(\n                iterations=3, targetlen=pml.PercentageValue(1)\n            )\n\n        # extract mesh\n        m = ms.current_mesh()\n        verts = m.vertex_matrix()\n        faces = m.face_matrix()\n\n    print(\n        f\"[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}\"\n    )\n\n    return verts, faces\n\n\ndef clean_mesh(\n    verts,\n    faces,\n    v_pct=1,\n    min_f=64,\n    min_d=20,\n    repair=True,\n    remesh=True,\n    remesh_size=0.01,\n):\n    # verts: [N, 3]\n    # faces: [N, 3]\n\n    _ori_vert_shape = verts.shape\n    _ori_face_shape = faces.shape\n\n    m = pml.Mesh(verts, faces)\n    ms = pml.MeshSet()\n    ms.add_mesh(m, \"mesh\")  # will copy!\n\n    # filters\n    ms.meshing_remove_unreferenced_vertices()  # verts not refed by any faces\n\n    if v_pct > 0:\n        ms.meshing_merge_close_vertices(\n            threshold=pml.PercentageValue(v_pct)\n        )  # 1/10000 of bounding box diagonal\n\n    ms.meshing_remove_duplicate_faces()  # faces defined by the same verts\n    ms.meshing_remove_null_faces()  # faces with area == 0\n\n    if min_d > 0:\n        ms.meshing_remove_connected_component_by_diameter(\n            mincomponentdiag=pml.PercentageValue(min_d)\n        )\n\n    if min_f > 0:\n        ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)\n\n    if repair:\n        # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)\n        ms.meshing_repair_non_manifold_edges(method=0)\n        ms.meshing_repair_non_manifold_vertices(vertdispratio=0)\n\n    if remesh:\n        # ms.apply_coord_taubin_smoothing()\n        ms.meshing_isotropic_explicit_remeshing(\n            iterations=3, targetlen=pml.PureValue(remesh_size)\n        )\n\n    # extract mesh\n    m = ms.current_mesh()\n    verts = m.vertex_matrix()\n    faces = m.face_matrix()\n\n    print(\n        f\"[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}\"\n    )\n\n    return verts, faces\n"
  },
  {
    "path": "physdreamer/gaussian_3d/utils/camera_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom physdreamer.gaussian_3d.scene.cameras import Camera\nimport numpy as np\nfrom physdreamer.gaussian_3d.utils.general_utils import PILtoTorch\nfrom physdreamer.gaussian_3d.utils.graphics_utils import fov2focal\nimport torch\n\nWARNED = False\n\n\ndef loadCam(args, id, cam_info, resolution_scale):\n    orig_w, orig_h = cam_info.image.size\n\n    if args.resolution in [1, 2, 4, 8]:\n        resolution = round(orig_w / (resolution_scale * args.resolution)), round(\n            orig_h / (resolution_scale * args.resolution)\n        )\n    else:  # should be a type that converts to float\n        if args.resolution == -1:\n            if orig_w > 1600:\n                global WARNED\n                if not WARNED:\n                    print(\n                        \"[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\\n \"\n                        \"If this is not desired, please explicitly specify '--resolution/-r' as 1\"\n                    )\n                    WARNED = True\n                global_down = orig_w / 1600\n            else:\n                global_down = 1\n        else:\n            global_down = orig_w / args.resolution\n\n        scale = float(global_down) * float(resolution_scale)\n        resolution = (int(orig_w / scale), int(orig_h / scale))\n\n    resized_image_rgb = PILtoTorch(cam_info.image, resolution)\n\n    gt_image = resized_image_rgb[:3, ...]\n    loaded_mask = None\n\n    if resized_image_rgb.shape[1] == 4:\n        loaded_mask = resized_image_rgb[3:4, ...]\n\n    return Camera(\n        colmap_id=cam_info.uid,\n        R=cam_info.R,\n        T=cam_info.T,\n        FoVx=cam_info.FovX,\n        FoVy=cam_info.FovY,\n        image=gt_image,\n        gt_alpha_mask=loaded_mask,\n        image_name=cam_info.image_name,\n        uid=id,\n        data_device=args.data_device,\n    )\n\n\ndef cameraList_from_camInfos(cam_infos, resolution_scale, args):\n    camera_list = []\n\n    for id, c in enumerate(cam_infos):\n        camera_list.append(loadCam(args, id, c, resolution_scale))\n\n    return camera_list\n\n\ndef camera_to_JSON(id, camera: Camera):\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = camera.R.transpose()\n    Rt[:3, 3] = camera.T\n    Rt[3, 3] = 1.0\n\n    W2C = np.linalg.inv(Rt)\n    pos = W2C[:3, 3]\n    rot = W2C[:3, :3]\n    serializable_array_2d = [x.tolist() for x in rot]\n    camera_entry = {\n        \"id\": id,\n        \"img_name\": camera.image_name,\n        \"width\": camera.width,\n        \"height\": camera.height,\n        \"position\": pos.tolist(),\n        \"rotation\": serializable_array_2d,\n        \"fy\": fov2focal(camera.FovY, camera.height),\n        \"fx\": fov2focal(camera.FovX, camera.width),\n    }\n    return camera_entry\n\n\ndef look_at(from_point, to_point, up_vector=(0, 1, 0)):\n    \"\"\"\n    Compute the look-at matrix for a camera.\n\n    :param from_point: The position of the camera.\n    :param to_point: The point the camera is looking at.\n    :param up_vector: The up direction of the camera.\n    :return: The 4x4 look-at matrix.\n    \"\"\"\n\n    # minus z for opengl. z for colmap\n    forward = np.array(to_point) - np.array(from_point)\n    forward = forward / (np.linalg.norm(forward) + 1e-5)\n\n    # x-axis\n    # Right direction is the cross product of the forward vector and the up vector\n    right = np.cross(up_vector, forward)\n    right = right / (np.linalg.norm(right) + 1e-5)\n\n    # y axis\n    # True up direction is the cross product of the right vector and the forward vector\n    true_up = np.cross(forward, right)\n    true_up = true_up / (np.linalg.norm(true_up) + 1e-5)\n\n    # camera to world\n    rotation = np.array(\n        [\n            [right[0], true_up[0], forward[0]],\n            [right[1], true_up[1], forward[1]],\n            [right[2], true_up[2], forward[2]],\n        ]\n    )\n\n    # Construct the translation matrix\n    translation = np.array(\n        [\n            [-from_point[0]],\n            [-from_point[1]],\n            [-from_point[2]],\n        ]\n    )\n\n    # Combine the rotation and translation to get the look-at matrix\n    T = 1.0 * rotation.transpose() @ translation\n\n    return rotation.transpose(), T\n\n\ndef create_cameras_around_sphere(\n    radius=6,\n    elevation=0,\n    fovx=35,\n    resolutions=(720, 1080),\n    num_cams=60,\n    center=(0, 0, 0),\n):\n    \"\"\"\n    Create cameras around a sphere.\n\n    :param radius: The radius of the circle on which cameras are placed.\n    :param elevation: The elevation angle in degrees.\n    :param fovx: The horizontal field of view of the cameras.\n    :param resolutions: The resolution of the cameras.\n    :param num_cams: The number of cameras.\n    :param center: The center of the sphere.\n    :return: A list of camera extrinsics (world2camera transformations).\n    \"\"\"\n    extrinsics = []\n\n    # Convert elevation to radians\n    elevation_rad = np.radians(elevation)\n\n    # Compute the y-coordinate of the cameras based on the elevation\n    z = radius * np.sin(elevation_rad)\n\n    # Compute the radius of the circle at the given elevation\n    circle_radius = radius * np.cos(elevation_rad)\n\n    for i in range(num_cams):\n        # Compute the angle for the current camera\n        angle = 2 * np.pi * i / num_cams\n\n        # Compute the x and z coordinates of the camera\n        x = circle_radius * np.cos(angle) + center[0]\n        y = circle_radius * np.sin(angle) + center[1]\n\n        # Create the look-at matrix for the camera\n        R, T = look_at((x, y, z + center[2]), center)\n        extrinsics.append([R, T.squeeze(axis=-1)])\n\n    cam_list = []\n    dummy_image = torch.tensor(\n        np.zeros((3, resolutions[0], resolutions[1]), dtype=np.uint8)\n    )\n    for i in range(num_cams):\n        R, T = extrinsics[i]\n\n        # R is stored transposed due to 'glm' in CUDA code\n        R = R.transpose()\n        cam = Camera(\n            colmap_id=i,\n            R=R,\n            T=T,\n            FoVx=fovx,\n            FoVy=fovx * resolutions[1] / resolutions[0],\n            image_name=\"\",\n            uid=i,\n            data_device=\"cuda\",\n            image=dummy_image,\n            gt_alpha_mask=None,\n        )\n\n        cam_list.append(cam)\n\n    return cam_list\n"
  },
  {
    "path": "physdreamer/gaussian_3d/utils/general_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport sys\nfrom datetime import datetime\nimport numpy as np\nimport random\n\ndef inverse_sigmoid(x):\n    return torch.log(x/(1-x))\n\ndef PILtoTorch(pil_image, resolution):\n    resized_image_PIL = pil_image.resize(resolution)\n    resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0\n    if len(resized_image.shape) == 3:\n        return resized_image.permute(2, 0, 1)\n    else:\n        return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)\n\ndef get_expon_lr_func(\n    lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000\n):\n    \"\"\"\n    Copied from Plenoxels\n\n    Continuous learning rate decay function. Adapted from JaxNeRF\n    The returned rate is lr_init when step=0 and lr_final when step=max_steps, and\n    is log-linearly interpolated elsewhere (equivalent to exponential decay).\n    If lr_delay_steps>0 then the learning rate will be scaled by some smooth\n    function of lr_delay_mult, such that the initial learning rate is\n    lr_init*lr_delay_mult at the beginning of optimization but will be eased back\n    to the normal learning rate when steps>lr_delay_steps.\n    :param conf: config subtree 'lr' or similar\n    :param max_steps: int, the number of steps during optimization.\n    :return HoF which takes step as input\n    \"\"\"\n\n    def helper(step):\n        if step < 0 or (lr_init == 0.0 and lr_final == 0.0):\n            # Disable this parameter\n            return 0.0\n        if lr_delay_steps > 0:\n            # A kind of reverse cosine decay.\n            delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(\n                0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)\n            )\n        else:\n            delay_rate = 1.0\n        t = np.clip(step / max_steps, 0, 1)\n        log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)\n        return delay_rate * log_lerp\n\n    return helper\n\ndef strip_lowerdiag(L):\n    uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=\"cuda\")\n\n    uncertainty[:, 0] = L[:, 0, 0]\n    uncertainty[:, 1] = L[:, 0, 1]\n    uncertainty[:, 2] = L[:, 0, 2]\n    uncertainty[:, 3] = L[:, 1, 1]\n    uncertainty[:, 4] = L[:, 1, 2]\n    uncertainty[:, 5] = L[:, 2, 2]\n    return uncertainty\n\ndef strip_symmetric(sym):\n    return strip_lowerdiag(sym)\n\ndef build_rotation(r):\n    norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])\n\n    q = r / norm[:, None]\n\n    R = torch.zeros((q.size(0), 3, 3), device='cuda')\n\n    r = q[:, 0]\n    x = q[:, 1]\n    y = q[:, 2]\n    z = q[:, 3]\n\n    R[:, 0, 0] = 1 - 2 * (y*y + z*z)\n    R[:, 0, 1] = 2 * (x*y - r*z)\n    R[:, 0, 2] = 2 * (x*z + r*y)\n    R[:, 1, 0] = 2 * (x*y + r*z)\n    R[:, 1, 1] = 1 - 2 * (x*x + z*z)\n    R[:, 1, 2] = 2 * (y*z - r*x)\n    R[:, 2, 0] = 2 * (x*z - r*y)\n    R[:, 2, 1] = 2 * (y*z + r*x)\n    R[:, 2, 2] = 1 - 2 * (x*x + y*y)\n    return R\n\ndef build_scaling_rotation(s, r):\n    L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=\"cuda\")\n    R = build_rotation(r)\n\n    L[:,0,0] = s[:,0]\n    L[:,1,1] = s[:,1]\n    L[:,2,2] = s[:,2]\n\n    L = R @ L\n    return L\n\ndef safe_state(silent):\n    old_f = sys.stdout\n    class F:\n        def __init__(self, silent):\n            self.silent = silent\n\n        def write(self, x):\n            if not self.silent:\n                if x.endswith(\"\\n\"):\n                    old_f.write(x.replace(\"\\n\", \" [{}]\\n\".format(str(datetime.now().strftime(\"%d/%m %H:%M:%S\")))))\n                else:\n                    old_f.write(x)\n\n        def flush(self):\n            old_f.flush()\n\n    sys.stdout = F(silent)\n\n    random.seed(0)\n    np.random.seed(0)\n    torch.manual_seed(0)\n    torch.cuda.set_device(torch.device(\"cuda:0\"))\n"
  },
  {
    "path": "physdreamer/gaussian_3d/utils/graphics_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport math\nimport numpy as np\nfrom typing import NamedTuple\n\nclass BasicPointCloud(NamedTuple):\n    points : np.array\n    colors : np.array\n    normals : np.array\n\ndef geom_transform_points(points, transf_matrix):\n    P, _ = points.shape\n    ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)\n    points_hom = torch.cat([points, ones], dim=1)\n    points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))\n\n    denom = points_out[..., 3:] + 0.0000001\n    return (points_out[..., :3] / denom).squeeze(dim=0)\n\ndef getWorld2View(R, t):\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = R.transpose()\n    Rt[:3, 3] = t\n    Rt[3, 3] = 1.0\n    return np.float32(Rt)\n\ndef getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = R.transpose()\n    Rt[:3, 3] = t\n    Rt[3, 3] = 1.0\n\n    C2W = np.linalg.inv(Rt)\n    cam_center = C2W[:3, 3]\n    cam_center = (cam_center + translate) * scale\n    C2W[:3, 3] = cam_center\n    Rt = np.linalg.inv(C2W)\n    return np.float32(Rt)\n\ndef getProjectionMatrix(znear, zfar, fovX, fovY):\n    tanHalfFovY = math.tan((fovY / 2))\n    tanHalfFovX = math.tan((fovX / 2))\n\n    top = tanHalfFovY * znear\n    bottom = -top\n    right = tanHalfFovX * znear\n    left = -right\n\n    P = torch.zeros(4, 4)\n\n    z_sign = 1.0\n\n    P[0, 0] = 2.0 * znear / (right - left)\n    P[1, 1] = 2.0 * znear / (top - bottom)\n    P[0, 2] = (right + left) / (right - left)\n    P[1, 2] = (top + bottom) / (top - bottom)\n    P[3, 2] = z_sign\n    P[2, 2] = z_sign * zfar / (zfar - znear)\n    P[2, 3] = -(zfar * znear) / (zfar - znear)\n    return P\n\ndef fov2focal(fov, pixels):\n    return pixels / (2 * math.tan(fov / 2))\n\ndef focal2fov(focal, pixels):\n    return 2*math.atan(pixels/(2*focal))"
  },
  {
    "path": "physdreamer/gaussian_3d/utils/image_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\n\ndef mse(img1, img2):\n    return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)\n\ndef psnr(img1, img2):\n    mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)\n    return 20 * torch.log10(1.0 / torch.sqrt(mse))\n"
  },
  {
    "path": "physdreamer/gaussian_3d/utils/loss_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nfrom math import exp\n\ndef l1_loss(network_output, gt):\n    return torch.abs((network_output - gt)).mean()\n\ndef l2_loss(network_output, gt):\n    return ((network_output - gt) ** 2).mean()\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])\n    return gauss / gauss.sum()\n\ndef create_window(window_size, channel):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())\n    return window\n\ndef ssim(img1, img2, window_size=11, size_average=True):\n    channel = img1.size(-3)\n    window = create_window(window_size, channel)\n\n    if img1.is_cuda:\n        window = window.cuda(img1.get_device())\n    window = window.type_as(img1)\n\n    return _ssim(img1, img2, window, window_size, channel, size_average)\n\ndef _ssim(img1, img2, window, window_size, channel, size_average=True):\n    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)\n    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq\n    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq\n    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2\n\n    C1 = 0.01 ** 2\n    C2 = 0.03 ** 2\n\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))\n\n    if size_average:\n        return ssim_map.mean()\n    else:\n        return ssim_map.mean(1).mean(1).mean(1)\n\n"
  },
  {
    "path": "physdreamer/gaussian_3d/utils/rigid_body_utils.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n\ndef get_rigid_transform(A, B):\n    \"\"\"\n    Estimate the rigid body transformation between two sets of 3D points.\n    A and B are Nx3 matrices where each row is a 3D point.\n    Returns a rotation matrix R and translation vector t.\n    Args:\n        A, B: [batch, N, 3] matrix of 3D points\n    Outputs:\n        R, t: [batch, 3, 3/1]\n        target = R @ source (source shape [3, 1]) + t\n    \"\"\"\n    assert A.shape == B.shape, \"Input matrices must have the same shape\"\n    assert A.shape[-1] == 3, \"Input matrices must have 3 columns (x, y, z coordinates)\"\n\n    # Compute centroids. [..., 1, 3]\n    centroid_A = torch.mean(A, dim=-2, keepdim=True)\n    centroid_B = torch.mean(B, dim=-2, keepdim=True)\n\n    # Center the point sets\n    A_centered = A - centroid_A\n    B_centered = B - centroid_B\n\n    # Compute the cross-covariance matrix. [..., 3, 3]\n    H = A_centered.transpose(-2, -1) @ B_centered\n\n    # Compute the Singular Value Decomposition. Along last two dimensions\n    U, S, Vt = torch.linalg.svd(H)\n\n    # Compute the rotation matrix\n    R = Vt.transpose(-2, -1) @ U.transpose(-2, -1)\n\n    # Ensure a right-handed coordinate system\n    flip_mask = (torch.det(R) < 0) * -2.0 + 1.0\n    # Vt[:, 2, :] *= flip_mask[..., None]\n\n    # [N] => [N, 3]\n    pad_flip_mask = torch.stack(\n        [torch.ones_like(flip_mask), torch.ones_like(flip_mask), flip_mask], dim=-1\n    )\n    Vt = Vt * pad_flip_mask[..., None]\n\n    # Compute the rotation matrix\n    R = Vt.transpose(-2, -1) @ U.transpose(-2, -1)\n\n    # print(R.shape, centroid_A.shape, centroid_B.shape, flip_mask.shape)\n    # Compute the translation\n    t = centroid_B - (R @ centroid_A.transpose(-2, -1)).transpose(-2, -1)\n    t = t.transpose(-2, -1)\n    return R, t\n\n\ndef _test_rigid_transform():\n    # Example usage:\n    A = torch.tensor([[1, 2, 3], [4, 5, 6], [9, 8, 10], [10, -5, 1]]) * 1.0\n\n    R_synthesized = torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) * 1.0\n    # init a random rotation matrix:\n\n    B = (R_synthesized @ A.T).T + 2.0  # Just an example offset\n\n    R, t = get_rigid_transform(A[None, ...], B[None, ...])\n    print(\"Rotation matrix R:\")\n    print(R)\n    print(\"\\nTranslation vector t:\")\n    print(t)\n\n\ndef _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Returns torch.sqrt(torch.max(0, x))\n    but with a zero subgradient where x is 0.\n    \"\"\"\n    ret = torch.zeros_like(x)\n    positive_mask = x > 0\n    ret[positive_mask] = torch.sqrt(x[positive_mask])\n    return ret\n\n\ndef matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    from pytorch3d. Based on trace_method like: https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L205\n    Convert rotations given as rotation matrices to quaternions.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n    \"\"\"\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix shape {matrix.shape}.\")\n\n    batch_dim = matrix.shape[:-2]\n    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(\n        matrix.reshape(batch_dim + (9,)), dim=-1\n    )\n\n    q_abs = _sqrt_positive_part(\n        torch.stack(\n            [\n                1.0 + m00 + m11 + m22,\n                1.0 + m00 - m11 - m22,\n                1.0 - m00 + m11 - m22,\n                1.0 - m00 - m11 + m22,\n            ],\n            dim=-1,\n        )\n    )\n\n    # we produce the desired quaternion multiplied by each of r, i, j, k\n    quat_by_rijk = torch.stack(\n        [\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),\n        ],\n        dim=-2,\n    )\n\n    # We floor here at 0.1 but the exact level is not important; if q_abs is small,\n    # the candidate won't be picked.\n    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)\n    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))\n\n    # if not for numerical problems, quat_candidates[i] should be same (up to a sign),\n    # forall i; we pick the best-conditioned one (with the largest denominator)\n\n    return quat_candidates[\n        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :\n    ].reshape(batch_dim + (4,))\n\n\ndef quternion_to_matrix(r):\n    norm = torch.sqrt(\n        r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]\n    )\n\n    q = r / norm[:, None]\n\n    R = torch.zeros((q.size(0), 3, 3), device=\"cuda\")\n\n    r = q[:, 0]\n    x = q[:, 1]\n    y = q[:, 2]\n    z = q[:, 3]\n\n    R[:, 0, 0] = 1 - 2 * (y * y + z * z)\n    R[:, 0, 1] = 2 * (x * y - r * z)\n    R[:, 0, 2] = 2 * (x * z + r * y)\n    R[:, 1, 0] = 2 * (x * y + r * z)\n    R[:, 1, 1] = 1 - 2 * (x * x + z * z)\n    R[:, 1, 2] = 2 * (y * z - r * x)\n    R[:, 2, 0] = 2 * (x * z - r * y)\n    R[:, 2, 1] = 2 * (y * z + r * x)\n    R[:, 2, 2] = 1 - 2 * (x * x + y * y)\n    return R\n\n\ndef standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    from Pytorch3d\n    Convert a unit quaternion to a standard form: one in which the real\n    part is non negative.\n\n    Args:\n        quaternions: Quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Standardized quaternions as tensor of shape (..., 4).\n    \"\"\"\n    return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)\n\n\ndef quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    From pytorch3d\n    Multiply two quaternions.\n    Usual torch rules for broadcasting apply.\n\n    Args:\n        a: Quaternions as tensor of shape (..., 4), real part first.\n        b: Quaternions as tensor of shape (..., 4), real part first.\n\n    Returns:\n        The product of a and b, a tensor of quaternions shape (..., 4).\n    \"\"\"\n    aw, ax, ay, az = torch.unbind(a, -1)\n    bw, bx, by, bz = torch.unbind(b, -1)\n    ow = aw * bw - ax * bx - ay * by - az * bz\n    ox = aw * bx + ax * bw + ay * bz - az * by\n    oy = aw * by - ax * bz + ay * bw + az * bx\n    oz = aw * bz + ax * by - ay * bx + az * bw\n    ret = torch.stack((ow, ox, oy, oz), -1)\n    ret = standardize_quaternion(ret)\n    return ret\n\n\ndef _test_matrix_to_quaternion():\n    # init a random batch of quaternion\n    r = torch.randn((10, 4)).cuda()\n\n    norm = torch.sqrt(\n        r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]\n    )\n\n    q = r / norm[:, None]\n\n    q = standardize_quaternion(q)\n\n    R = quternion_to_matrix(q)\n\n    I_rec = R @ R.transpose(-2, -1)\n    I_rec_error = torch.abs(I_rec - torch.eye(3, device=\"cuda\")[None, ...]).max()\n\n    q_recovered = matrix_to_quaternion(R)\n    norm_ = torch.linalg.norm(q_recovered, dim=-1)\n    q_recovered = q_recovered / norm_[..., None]\n    q_recovered = standardize_quaternion(q_recovered)\n\n    print(q_recovered.shape, q.shape, R.shape)\n\n    rec = (q - q_recovered).abs().max()\n\n    print(\"rotation to I error:\", I_rec_error, \"quant rec error: \", rec)\n\n\ndef _test_matrix_to_quaternion_2():\n    R = (\n        torch.tensor(\n            [[[1, 0, 0], [0, -1, 0], [0, 0, -1]], [[1, 0, 0], [0, 0, 1], [0, -1, 0]]]\n        )\n        * 1.0\n    )\n\n    q_rec = matrix_to_quaternion(R.transpose(-2, -1))\n\n    R_rec = quternion_to_matrix(q_rec)\n\n    print(R_rec)\n\n\nif __name__ == \"__main__\":\n    # _test_rigid_transform()\n    _test_matrix_to_quaternion()\n\n    _test_matrix_to_quaternion_2()\n"
  },
  {
    "path": "physdreamer/gaussian_3d/utils/sh_utils.py",
    "content": "#  Copyright 2021 The PlenOctree Authors.\n#  Redistribution and use in source and binary forms, with or without\n#  modification, are permitted provided that the following conditions are met:\n#\n#  1. Redistributions of source code must retain the above copyright notice,\n#  this list of conditions and the following disclaimer.\n#\n#  2. Redistributions in binary form must reproduce the above copyright notice,\n#  this list of conditions and the following disclaimer in the documentation\n#  and/or other materials provided with the distribution.\n#\n#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE\n#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n#  POSSIBILITY OF SUCH DAMAGE.\n\nimport torch\n\nC0 = 0.28209479177387814\nC1 = 0.4886025119029199\nC2 = [\n    1.0925484305920792,\n    -1.0925484305920792,\n    0.31539156525252005,\n    -1.0925484305920792,\n    0.5462742152960396\n]\nC3 = [\n    -0.5900435899266435,\n    2.890611442640554,\n    -0.4570457994644658,\n    0.3731763325901154,\n    -0.4570457994644658,\n    1.445305721320277,\n    -0.5900435899266435\n]\nC4 = [\n    2.5033429417967046,\n    -1.7701307697799304,\n    0.9461746957575601,\n    -0.6690465435572892,\n    0.10578554691520431,\n    -0.6690465435572892,\n    0.47308734787878004,\n    -1.7701307697799304,\n    0.6258357354491761,\n]   \n\n\ndef eval_sh(deg, sh, dirs):\n    \"\"\"\n    Evaluate spherical harmonics at unit directions\n    using hardcoded SH polynomials.\n    Works with torch/np/jnp.\n    ... Can be 0 or more batch dimensions.\n    Args:\n        deg: int SH deg. Currently, 0-3 supported\n        sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]\n        dirs: jnp.ndarray unit directions [..., 3]\n    Returns:\n        [..., C]\n    \"\"\"\n    assert deg <= 4 and deg >= 0\n    coeff = (deg + 1) ** 2\n    assert sh.shape[-1] >= coeff\n\n    result = C0 * sh[..., 0]\n    if deg > 0:\n        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]\n        result = (result -\n                C1 * y * sh[..., 1] +\n                C1 * z * sh[..., 2] -\n                C1 * x * sh[..., 3])\n\n        if deg > 1:\n            xx, yy, zz = x * x, y * y, z * z\n            xy, yz, xz = x * y, y * z, x * z\n            result = (result +\n                    C2[0] * xy * sh[..., 4] +\n                    C2[1] * yz * sh[..., 5] +\n                    C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +\n                    C2[3] * xz * sh[..., 7] +\n                    C2[4] * (xx - yy) * sh[..., 8])\n\n            if deg > 2:\n                result = (result +\n                C3[0] * y * (3 * xx - yy) * sh[..., 9] +\n                C3[1] * xy * z * sh[..., 10] +\n                C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +\n                C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +\n                C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +\n                C3[5] * z * (xx - yy) * sh[..., 14] +\n                C3[6] * x * (xx - 3 * yy) * sh[..., 15])\n\n                if deg > 3:\n                    result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +\n                            C4[1] * yz * (3 * xx - yy) * sh[..., 17] +\n                            C4[2] * xy * (7 * zz - 1) * sh[..., 18] +\n                            C4[3] * yz * (7 * zz - 3) * sh[..., 19] +\n                            C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +\n                            C4[5] * xz * (7 * zz - 3) * sh[..., 21] +\n                            C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +\n                            C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +\n                            C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])\n    return result\n\ndef RGB2SH(rgb):\n    return (rgb - 0.5) / C0\n\ndef SH2RGB(sh):\n    return sh * C0 + 0.5"
  },
  {
    "path": "physdreamer/gaussian_3d/utils/system_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom errno import EEXIST\nfrom os import makedirs, path\nimport os\n\ndef mkdir_p(folder_path):\n    # Creates a directory. equivalent to using mkdir -p on the command line\n    try:\n        makedirs(folder_path)\n    except OSError as exc: # Python >2.5\n        if exc.errno == EEXIST and path.isdir(folder_path):\n            pass\n        else:\n            raise\n\ndef searchForMaxIteration(folder):\n    saved_iters = [int(fname.split(\"_\")[-1]) for fname in os.listdir(folder)]\n    return max(saved_iters)\n"
  },
  {
    "path": "physdreamer/losses/smoothness_loss.py",
    "content": "import torch\nfrom typing import Tuple\n\n\ndef compute_plane_tv(t: torch.Tensor, only_w: bool = False) -> float:\n    \"\"\"Computes total variance across a plane.\n    From nerf-studio\n\n    Args:\n        t: Plane tensor\n        only_w: Whether to only compute total variance across w dimension\n\n    Returns:\n        Total variance\n    \"\"\"\n    _, h, w = t.shape\n    w_tv = torch.square(t[..., :, 1:] - t[..., :, : w - 1]).mean()\n\n    if only_w:\n        return w_tv\n\n    h_tv = torch.square(t[..., 1:, :] - t[..., : h - 1, :]).mean()\n    return h_tv + w_tv\n\n\ndef compute_plane_smoothness(t: torch.Tensor) -> float:\n    \"\"\"Computes smoothness across the temporal axis of a plane\n    From nerf-studio\n    Args:\n        t: Plane tensor\n\n    Returns:\n        Time smoothness\n    \"\"\"\n    _, h, _ = t.shape\n    # Convolve with a second derivative filter, in the time dimension which is dimension 2\n    first_difference = t[..., 1:, :] - t[..., : h - 1, :]  # [c, h-1, w]\n    second_difference = (\n        first_difference[..., 1:, :] - first_difference[..., : h - 2, :]\n    )  # [c, h-2, w]\n    # Take the L2 norm of the result\n    return torch.square(second_difference).mean()\n"
  },
  {
    "path": "physdreamer/operators/dct.py",
    "content": "\"\"\"\nCode from https://github.com/zh217/torch-dct/blob/master/torch_dct/_dct.py\n\"\"\"\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\n\nimport torch.fft\n\n\ndef dct1_rfft_impl(x):\n    return torch.view_as_real(torch.fft.rfft(x, dim=1))\n\n\ndef dct_fft_impl(v):\n    return torch.view_as_real(torch.fft.fft(v, dim=1))\n\n\ndef idct_irfft_impl(V):\n    return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)\n\n\ndef dct(x, norm=None):\n    \"\"\"\n    Discrete Cosine Transform, Type II (a.k.a. the DCT)\n\n    For the meaning of the parameter `norm`, see:\n    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html\n\n    if norm is None:\n              N-1\n    y[k] = 2* sum x[n]*cos(pi*k*(2n+1)/(2*N)), 0 <= k < N.\n              n=0\n\n    :param x: the input signal\n    :param norm: the normalization, None or 'ortho'\n    :return: the DCT-II of the signal over the last dimension\n    \"\"\"\n    x_shape = x.shape\n    N = x_shape[-1]\n    x = x.contiguous().view(-1, N)\n\n    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)\n\n    Vc = dct_fft_impl(v)\n\n    k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)\n    W_r = torch.cos(k)\n    W_i = torch.sin(k)\n\n    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i\n\n    if norm == \"ortho\":\n        V[:, 0] /= np.sqrt(N) * 2\n        V[:, 1:] /= np.sqrt(N / 2) * 2\n\n    V = 2 * V.view(*x_shape)\n\n    return V\n\n\ndef idct(X, norm=None):\n    \"\"\"\n    The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III\n\n    Our definition of idct is that idct(dct(x)) == x\n\n    For the meaning of the parameter `norm`, see:\n    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html\n\n    :param X: the input signal\n    :param norm: the normalization, None or 'ortho'\n    :return: the inverse DCT-II of the signal over the last dimension\n    \"\"\"\n\n    x_shape = X.shape\n    N = x_shape[-1]\n\n    X_v = X.contiguous().view(-1, x_shape[-1]) / 2\n\n    if norm == \"ortho\":\n        X_v[:, 0] *= np.sqrt(N) * 2\n        X_v[:, 1:] *= np.sqrt(N / 2) * 2\n\n    k = (\n        torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :]\n        * np.pi\n        / (2 * N)\n    )\n    W_r = torch.cos(k)\n    W_i = torch.sin(k)\n\n    V_t_r = X_v\n    V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)\n\n    V_r = V_t_r * W_r - V_t_i * W_i\n    V_i = V_t_r * W_i + V_t_i * W_r\n\n    V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)\n\n    v = idct_irfft_impl(V)\n    x = v.new_zeros(v.shape)\n    x[:, ::2] += v[:, : N - (N // 2)]\n    x[:, 1::2] += v.flip([1])[:, : N // 2]\n\n    return x.view(*x_shape)\n\n\ndef dct_3d(x, norm=None):\n    \"\"\"\n    3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)\n\n    For the meaning of the parameter `norm`, see:\n    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html\n\n    :param x: the input signal\n    :param norm: the normalization, None or 'ortho'\n    :return: the DCT-II of the signal over the last 3 dimensions\n    \"\"\"\n    X1 = dct(x, norm=norm)\n    X2 = dct(X1.transpose(-1, -2), norm=norm)\n    X3 = dct(X2.transpose(-1, -3), norm=norm)\n    return X3.transpose(-1, -3).transpose(-1, -2)\n\n\ndef idct_3d(X, norm=None):\n    \"\"\"\n    The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III\n\n    Our definition of idct is that idct_3d(dct_3d(x)) == x\n\n    For the meaning of the parameter `norm`, see:\n    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html\n\n    :param X: the input signal\n    :param norm: the normalization, None or 'ortho'\n    :return: the DCT-II of the signal over the last 3 dimensions\n    \"\"\"\n    x1 = idct(X, norm=norm)\n    x2 = idct(x1.transpose(-1, -2), norm=norm)\n    x3 = idct(x2.transpose(-1, -3), norm=norm)\n    return x3.transpose(-1, -3).transpose(-1, -2)\n\n\ndef code_test_dct3d():\n    # init a tensor of shape [100, 20, 3]\n    x = torch.rand(100, 20, 3)\n\n    dct_coef = dct_3d(x, norm=\"ortho\")\n    print(\"inp signal shape: \", x.shape, \"  dct coef shape: \", dct_coef.shape)\n\n    x_recon = idct_3d(dct_coef, norm=\"ortho\")\n    print(\"inp signal shape: \", x.shape, \"  recon signal shape: \", x_recon.shape)\n\n    print(\"max error: \", torch.max(torch.abs(x - x_recon)))\n\n    dct_coef[:, 0, :] = 0\n\n    x_recon = idct_3d(dct_coef, norm=\"ortho\")\n    print(\"max error after removing first order: \", torch.max(torch.abs(x - x_recon)))\n\n\nif __name__ == \"__main__\":\n    code_test_dct3d()\n"
  },
  {
    "path": "physdreamer/operators/np_operators.py",
    "content": "import torch\nimport numpy as np\nfrom sklearn.decomposition import PCA\nimport matplotlib.pyplot as plt\n\n\ndef feature_map_to_rgb_pca(feature_map):\n    \"\"\"\n    Args:\n        feature_map: (C, H, W) feature map.\n    Outputs:\n        rgb_image: (H, W, 3) image.\n    \"\"\"\n    # Move feature map to CPU and convert to numpy\n    if isinstance(feature_map, torch.Tensor):\n        feature_map = feature_map.detach().cpu().numpy()\n\n    H, W = feature_map.shape[1:]\n    # Flatten spatial dimensions  # [N, C]\n    flattened_map = feature_map.reshape(feature_map.shape[0], -1).T\n\n    # Apply PCA and reduce channel dimension to 3\n    pca = PCA(n_components=3)\n    pca_result = pca.fit_transform(flattened_map)\n\n    # Reshape back to (H, W, 3)\n    rgb_image = pca_result.reshape(H, W, 3)\n\n    # Normalize to [0, 1]\n    rgb_image = (rgb_image - rgb_image.min()) / (\n        rgb_image.max() - rgb_image.min() + 1e-3\n    )\n\n    return rgb_image\n"
  },
  {
    "path": "physdreamer/operators/rotation.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\n\n\ndef rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Converts 6D rotation representation by Zhou et al. [1] to rotation matrix\n    using Gram--Schmidt orthogonalization per Section B of [1].\n    Args:\n        d6: 6D rotation representation, of size (*, 6)\n\n    Returns:\n        batch of rotation matrices of size (*, 3, 3)\n\n    [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.\n    On the Continuity of Rotation Representations in Neural Networks.\n    IEEE Conference on Computer Vision and Pattern Recognition, 2019.\n    Retrieved from http://arxiv.org/abs/1812.07035\n    \"\"\"\n\n    a1, a2 = d6[..., :3], d6[..., 3:]\n    b1 = F.normalize(a1, dim=-1)\n    b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1\n    b2 = F.normalize(b2, dim=-1)\n    b3 = torch.cross(b1, b2, dim=-1)\n    return torch.stack((b1, b2, b3), dim=-2)\n\n\ndef matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Converts rotation matrices to 6D rotation representation by Zhou et al. [1]\n    by dropping the last row. Note that 6D representation is not unique.\n    Args:\n        matrix: batch of rotation matrices of size (*, 3, 3)\n\n    Returns:\n        6D rotation representation, of size (*, 6)\n\n    [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.\n    On the Continuity of Rotation Representations in Neural Networks.\n    IEEE Conference on Computer Vision and Pattern Recognition, 2019.\n    Retrieved from http://arxiv.org/abs/1812.07035\n    \"\"\"\n    batch_dim = matrix.size()[:-2]\n    return matrix[..., :2, :].clone().reshape(batch_dim + (6,))\n\n\ndef quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert rotations given as quaternions to rotation matrices.\n\n    Args:\n        quaternions: quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n    \"\"\"\n    r, i, j, k = torch.unbind(quaternions, -1)\n    # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.\n    two_s = 2.0 / (quaternions * quaternions).sum(-1)\n\n    o = torch.stack(\n        (\n            1 - two_s * (j * j + k * k),\n            two_s * (i * j - k * r),\n            two_s * (i * k + j * r),\n            two_s * (i * j + k * r),\n            1 - two_s * (i * i + k * k),\n            two_s * (j * k - i * r),\n            two_s * (i * k - j * r),\n            two_s * (j * k + i * r),\n            1 - two_s * (i * i + j * j),\n        ),\n        -1,\n    )\n    return o.reshape(quaternions.shape[:-1] + (3, 3))\n\n\ndef _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Returns torch.sqrt(torch.max(0, x))\n    but with a zero subgradient where x is 0.\n    \"\"\"\n    ret = torch.zeros_like(x)\n    positive_mask = x > 0\n    ret[positive_mask] = torch.sqrt(x[positive_mask])\n    return ret\n\n\ndef matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert rotations given as rotation matrices to quaternions.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n    \"\"\"\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix shape {matrix.shape}.\")\n\n    batch_dim = matrix.shape[:-2]\n    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(\n        matrix.reshape(batch_dim + (9,)), dim=-1\n    )\n\n    q_abs = _sqrt_positive_part(\n        torch.stack(\n            [\n                1.0 + m00 + m11 + m22,\n                1.0 + m00 - m11 - m22,\n                1.0 - m00 + m11 - m22,\n                1.0 - m00 - m11 + m22,\n            ],\n            dim=-1,\n        )\n    )\n\n    # we produce the desired quaternion multiplied by each of r, i, j, k\n    quat_by_rijk = torch.stack(\n        [\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),\n        ],\n        dim=-2,\n    )\n\n    # We floor here at 0.1 but the exact level is not important; if q_abs is small,\n    # the candidate won't be picked.\n    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)\n    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))\n\n    # if not for numerical problems, quat_candidates[i] should be same (up to a sign),\n    # forall i; we pick the best-conditioned one (with the largest denominator)\n\n    return quat_candidates[\n        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :\n    ].reshape(batch_dim + (4,))\n"
  },
  {
    "path": "physdreamer/utils/camera_utils.py",
    "content": "import numpy as np\n\n\ndef normalize(x: np.ndarray) -> np.ndarray:\n    \"\"\"Normalization helper function.\"\"\"\n    return x / np.linalg.norm(x)\n\n\ndef viewmatrix(lookdir: np.ndarray, up: np.ndarray, position: np.ndarray) -> np.ndarray:\n    \"\"\"Construct lookat view matrix.\"\"\"\n    vec2 = normalize(lookdir)\n    vec0 = normalize(np.cross(up, vec2))\n    vec1 = normalize(np.cross(vec2, vec0))\n    m = np.stack([vec0, vec1, vec2, position], axis=1)\n    return m\n\n\ndef generate_spiral_path(\n    pose: np.ndarray,\n    radius: float,\n    lookat_pt: np.ndarray = np.array([0, 0, 0]),\n    up: np.ndarray = np.array([0, 0, 1]),\n    n_frames: int = 60,\n    n_rots: int = 1,\n    y_scale: float = 1.0,\n) -> np.ndarray:\n    \"\"\"Calculates a forward facing spiral path for rendering.\"\"\"\n    x_axis = pose[:3, 0]\n    y_axis = pose[:3, 1]\n    campos = pose[:3, 3]\n\n    render_poses = []\n    for theta in np.linspace(0.0, 2 * np.pi * n_rots, n_frames, endpoint=False):\n        t = (np.cos(theta) * x_axis + y_scale * np.sin(theta) * y_axis) * radius\n        position = campos + t\n        z_axis = position - lookat_pt\n        new_pose = np.eye(4)\n        new_pose[:3] = viewmatrix(z_axis, up, position)\n        render_poses.append(new_pose)\n    render_poses = np.stack(render_poses, axis=0)\n    return render_poses\n"
  },
  {
    "path": "physdreamer/utils/colmap_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport numpy as np\nimport collections\nimport struct\n\nCameraModel = collections.namedtuple(\n    \"CameraModel\", [\"model_id\", \"model_name\", \"num_params\"])\nCamera = collections.namedtuple(\n    \"Camera\", [\"id\", \"model\", \"width\", \"height\", \"params\"])\nBaseImage = collections.namedtuple(\n    \"Image\", [\"id\", \"qvec\", \"tvec\", \"camera_id\", \"name\", \"xys\", \"point3D_ids\"])\nPoint3D = collections.namedtuple(\n    \"Point3D\", [\"id\", \"xyz\", \"rgb\", \"error\", \"image_ids\", \"point2D_idxs\"])\nCAMERA_MODELS = {\n    CameraModel(model_id=0, model_name=\"SIMPLE_PINHOLE\", num_params=3),\n    CameraModel(model_id=1, model_name=\"PINHOLE\", num_params=4),\n    CameraModel(model_id=2, model_name=\"SIMPLE_RADIAL\", num_params=4),\n    CameraModel(model_id=3, model_name=\"RADIAL\", num_params=5),\n    CameraModel(model_id=4, model_name=\"OPENCV\", num_params=8),\n    CameraModel(model_id=5, model_name=\"OPENCV_FISHEYE\", num_params=8),\n    CameraModel(model_id=6, model_name=\"FULL_OPENCV\", num_params=12),\n    CameraModel(model_id=7, model_name=\"FOV\", num_params=5),\n    CameraModel(model_id=8, model_name=\"SIMPLE_RADIAL_FISHEYE\", num_params=4),\n    CameraModel(model_id=9, model_name=\"RADIAL_FISHEYE\", num_params=5),\n    CameraModel(model_id=10, model_name=\"THIN_PRISM_FISHEYE\", num_params=12)\n}\nCAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)\n                         for camera_model in CAMERA_MODELS])\nCAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)\n                           for camera_model in CAMERA_MODELS])\n\n\ndef qvec2rotmat(qvec):\n    return np.array([\n        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,\n         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],\n         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],\n        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],\n         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,\n         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],\n        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],\n         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],\n         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])\n\ndef rotmat2qvec(R):\n    Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat\n    K = np.array([\n        [Rxx - Ryy - Rzz, 0, 0, 0],\n        [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],\n        [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],\n        [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0\n    eigvals, eigvecs = np.linalg.eigh(K)\n    qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]\n    if qvec[0] < 0:\n        qvec *= -1\n    return qvec\n\nclass Image(BaseImage):\n    def qvec2rotmat(self):\n        return qvec2rotmat(self.qvec)\n\ndef read_next_bytes(fid, num_bytes, format_char_sequence, endian_character=\"<\"):\n    \"\"\"Read and unpack the next bytes from a binary file.\n    :param fid:\n    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.\n    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.\n    :param endian_character: Any of {@, =, <, >, !}\n    :return: Tuple of read and unpacked values.\n    \"\"\"\n    data = fid.read(num_bytes)\n    return struct.unpack(endian_character + format_char_sequence, data)\n\ndef read_points3D_text(path):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DText(const std::string& path)\n        void Reconstruction::WritePoints3DText(const std::string& path)\n    \"\"\"\n    xyzs = None\n    rgbs = None\n    errors = None\n    num_points = 0\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                num_points += 1\n\n\n    xyzs = np.empty((num_points, 3))\n    rgbs = np.empty((num_points, 3))\n    errors = np.empty((num_points, 1))\n    count = 0\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                xyz = np.array(tuple(map(float, elems[1:4])))\n                rgb = np.array(tuple(map(int, elems[4:7])))\n                error = np.array(float(elems[7]))\n                xyzs[count] = xyz\n                rgbs[count] = rgb\n                errors[count] = error\n                count += 1\n\n    return xyzs, rgbs, errors\n\ndef read_points3D_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DBinary(const std::string& path)\n        void Reconstruction::WritePoints3DBinary(const std::string& path)\n    \"\"\"\n\n\n    with open(path_to_model_file, \"rb\") as fid:\n        num_points = read_next_bytes(fid, 8, \"Q\")[0]\n\n        xyzs = np.empty((num_points, 3))\n        rgbs = np.empty((num_points, 3))\n        errors = np.empty((num_points, 1))\n\n        for p_id in range(num_points):\n            binary_point_line_properties = read_next_bytes(\n                fid, num_bytes=43, format_char_sequence=\"QdddBBBd\")\n            xyz = np.array(binary_point_line_properties[1:4])\n            rgb = np.array(binary_point_line_properties[4:7])\n            error = np.array(binary_point_line_properties[7])\n            track_length = read_next_bytes(\n                fid, num_bytes=8, format_char_sequence=\"Q\")[0]\n            track_elems = read_next_bytes(\n                fid, num_bytes=8*track_length,\n                format_char_sequence=\"ii\"*track_length)\n            xyzs[p_id] = xyz\n            rgbs[p_id] = rgb\n            errors[p_id] = error\n    return xyzs, rgbs, errors\n\ndef read_intrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    cameras = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                camera_id = int(elems[0])\n                model = elems[1]\n                assert model == \"PINHOLE\", \"While the loader support other types, the rest of the code assumes PINHOLE\"\n                width = int(elems[2])\n                height = int(elems[3])\n                params = np.array(tuple(map(float, elems[4:])))\n                cameras[camera_id] = Camera(id=camera_id, model=model,\n                                            width=width, height=height,\n                                            params=params)\n    return cameras\n\ndef read_extrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadImagesBinary(const std::string& path)\n        void Reconstruction::WriteImagesBinary(const std::string& path)\n    \"\"\"\n    images = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_reg_images = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_reg_images):\n            binary_image_properties = read_next_bytes(\n                fid, num_bytes=64, format_char_sequence=\"idddddddi\")\n            image_id = binary_image_properties[0]\n            qvec = np.array(binary_image_properties[1:5])\n            tvec = np.array(binary_image_properties[5:8])\n            camera_id = binary_image_properties[8]\n            image_name = \"\"\n            current_char = read_next_bytes(fid, 1, \"c\")[0]\n            while current_char != b\"\\x00\":   # look for the ASCII 0 entry\n                image_name += current_char.decode(\"utf-8\")\n                current_char = read_next_bytes(fid, 1, \"c\")[0]\n            num_points2D = read_next_bytes(fid, num_bytes=8,\n                                           format_char_sequence=\"Q\")[0]\n            x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,\n                                       format_char_sequence=\"ddq\"*num_points2D)\n            xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),\n                                   tuple(map(float, x_y_id_s[1::3]))])\n            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))\n            images[image_id] = Image(\n                id=image_id, qvec=qvec, tvec=tvec,\n                camera_id=camera_id, name=image_name,\n                xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_intrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::WriteCamerasBinary(const std::string& path)\n        void Reconstruction::ReadCamerasBinary(const std::string& path)\n    \"\"\"\n    cameras = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_cameras = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_cameras):\n            camera_properties = read_next_bytes(\n                fid, num_bytes=24, format_char_sequence=\"iiQQ\")\n            camera_id = camera_properties[0]\n            model_id = camera_properties[1]\n            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name\n            width = camera_properties[2]\n            height = camera_properties[3]\n            num_params = CAMERA_MODEL_IDS[model_id].num_params\n            params = read_next_bytes(fid, num_bytes=8*num_params,\n                                     format_char_sequence=\"d\"*num_params)\n            cameras[camera_id] = Camera(id=camera_id,\n                                        model=model_name,\n                                        width=width,\n                                        height=height,\n                                        params=np.array(params))\n        assert len(cameras) == num_cameras\n    return cameras\n\n\ndef read_extrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    images = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                image_id = int(elems[0])\n                qvec = np.array(tuple(map(float, elems[1:5])))\n                tvec = np.array(tuple(map(float, elems[5:8])))\n                camera_id = int(elems[8])\n                image_name = elems[9]\n                elems = fid.readline().split()\n                xys = np.column_stack([tuple(map(float, elems[0::3])),\n                                       tuple(map(float, elems[1::3]))])\n                point3D_ids = np.array(tuple(map(int, elems[2::3])))\n                images[image_id] = Image(\n                    id=image_id, qvec=qvec, tvec=tvec,\n                    camera_id=camera_id, name=image_name,\n                    xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_colmap_bin_array(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py\n\n    :param path: path to the colmap binary file.\n    :return: nd array with the floating point values in the value\n    \"\"\"\n    with open(path, \"rb\") as fid:\n        width, height, channels = np.genfromtxt(fid, delimiter=\"&\", max_rows=1,\n                                                usecols=(0, 1, 2), dtype=int)\n        fid.seek(0)\n        num_delimiter = 0\n        byte = fid.read(1)\n        while True:\n            if byte == b\"&\":\n                num_delimiter += 1\n                if num_delimiter >= 3:\n                    break\n            byte = fid.read(1)\n        array = np.fromfile(fid, np.float32)\n    array = array.reshape((width, height, channels), order=\"F\")\n    return np.transpose(array, (1, 0, 2)).squeeze()\n"
  },
  {
    "path": "physdreamer/utils/config.py",
    "content": "from omegaconf import OmegaConf\n\n\ndef load_config_with_merge(config_path: str):\n    cfg = OmegaConf.load(config_path)\n\n    path_ = cfg.get(\"_base\", None)\n\n    if path_ is not None:\n        print(f\"Merging base config from {path_}\")\n        cfg = OmegaConf.merge(load_config_with_merge(path_), cfg)\n    else:\n        return cfg\n    return cfg\n\n\ndef merge_without_none(base_cfg, override_cfg):\n    for key, value in override_cfg.items():\n        if value is not None:\n            base_cfg[key] = value\n        elif not (key in base_cfg):\n            base_cfg[key] = None\n    return base_cfg\n\n\ndef create_config(config_path, args, cli_args: list = []):\n    \"\"\"\n    Args:\n        config_path: path to config file\n        args: argparse object with known variables\n        cli_args: list of cli args in the format of\n            [\"lr=0.1\", \"model.name=alexnet\"]\n    \"\"\"\n    # recursively merge base config\n    cfg = load_config_with_merge(config_path)\n\n    # parse cli args, and merge them into cfg\n    cli_conf = OmegaConf.from_cli(cli_args)\n    arg_cfg = OmegaConf.create(vars(args))\n\n    # drop None in arg_cfg\n\n    arg_cfg = OmegaConf.merge(arg_cfg, cli_conf)\n\n    # cfg = OmegaConf.merge(cfg, arg_cfg, cli_conf)\n    cfg = merge_without_none(cfg, arg_cfg)\n\n    return cfg\n"
  },
  {
    "path": "physdreamer/utils/img_utils.py",
    "content": "import torch\nimport torchvision\nimport cv2\nimport numpy as np\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nfrom math import exp\n\n\ndef make_grid(imgs: torch.Tensor, scale=0.5):\n    \"\"\"\n    Args:\n        imgs: [B, C, H, W] in [0, 1]\n    Output:\n        x row of images, and 3 x column of images\n        which means 3 x ^ 2 <= B\n\n        img_grid: np.ndarray, [H', W', C]\n    \"\"\"\n\n    B, C, H, W = imgs.shape\n\n    num_row = int(np.sqrt(B / 3))\n    if num_row < 1:\n        num_row = 1\n    num_col = int(np.ceil(B / num_row))\n\n    img_grid = torchvision.utils.make_grid(imgs, nrow=num_col, padding=0)\n\n    img_grid = img_grid.permute(1, 2, 0).cpu().numpy()\n\n    # resize by scale\n    img_grid = cv2.resize(img_grid, None, fx=scale, fy=scale)\n    return img_grid\n\n\ndef compute_psnr(img1, img2, mask=None):\n    \"\"\"\n    Args:\n        img1: [B, C, H, W]\n        img2: [B, C, H, W]\n        mask: [B, 1, H, W] or [1, 1, H, W] or None\n    Outs:\n        psnr: [B]\n    \"\"\"\n    # batch dim is preserved\n    if mask is None:\n        mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)\n    else:\n        if mask.shape[0] != img1.shape[0]:\n            mask = mask.repeat(img1.shape[0], 1, 1, 1)\n        if mask.shape[1] != img1.shape[1]:\n            mask = mask.repeat(1, img1.shape[1], 1, 1)\n\n        diff = ((img1 - img2)) ** 2\n        diff = diff * mask\n        mse = diff.view(img1.shape[0], -1).sum(1, keepdim=True) / (\n            mask.view(img1.shape[0], -1).sum(1, keepdim=True) + 1e-8\n        )\n\n    return 20 * torch.log10(1.0 / torch.sqrt(mse))\n\n\ndef torch_rgb_to_gray(image):\n    # image is [B, C, H, W]\n    gray_image = (\n        0.299 * image[:, 0, :, :]\n        + 0.587 * image[:, 1, :, :]\n        + 0.114 * image[:, 2, :, :]\n    )\n    gray_image = gray_image.unsqueeze(1)\n\n    return gray_image\n\n\ndef compute_gradient_loss(pred, gt, mask=None):\n    \"\"\"\n    Args:\n        pred: [B, C, H, W]\n        gt: [B, C, H, W]\n        mask: [B, 1, H, W] or None\n    \"\"\"\n    assert pred.shape == gt.shape, \"a and b must have the same shape\"\n\n    pred = torch_rgb_to_gray(pred)\n    gt = torch_rgb_to_gray(gt)\n\n    sobel_kernel_x = torch.tensor(\n        [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=pred.dtype, device=pred.device\n    )\n    sobel_kernel_y = torch.tensor(\n        [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=pred.dtype, device=pred.device\n    )\n\n    gradient_a_x = (\n        torch.nn.functional.conv2d(\n            pred.repeat(1, 3, 1, 1),\n            sobel_kernel_x.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1),\n            padding=1,\n        )\n        / 3\n    )\n    gradient_a_y = (\n        torch.nn.functional.conv2d(\n            pred.repeat(1, 3, 1, 1),\n            sobel_kernel_y.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1),\n            padding=1,\n        )\n        / 3\n    )\n    # gradient_a_magnitude = torch.sqrt(gradient_a_x ** 2 + gradient_a_y ** 2)\n\n    gradient_b_x = (\n        torch.nn.functional.conv2d(\n            gt.repeat(1, 3, 1, 1),\n            sobel_kernel_x.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1),\n            padding=1,\n        )\n        / 3\n    )\n    gradient_b_y = (\n        torch.nn.functional.conv2d(\n            gt.repeat(1, 3, 1, 1),\n            sobel_kernel_y.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1),\n            padding=1,\n        )\n        / 3\n    )\n    # gradient_b_magnitude = torch.sqrt(gradient_b_x ** 2 + gradient_b_y ** 2)\n\n    pred_grad = torch.cat([gradient_a_x, gradient_a_y], dim=1)\n    gt_grad = torch.cat([gradient_b_x, gradient_b_y], dim=1)\n\n    if mask is None:\n        gradient_difference = torch.abs(pred_grad - gt_grad).mean()\n    else:\n        gradient_difference = torch.abs(pred_grad - gt_grad).mean(dim=1, keepdim=True)[\n            mask\n        ].sum() / (mask.sum() + 1e-8)\n\n    return gradient_difference\n\n\ndef mark_image_with_red_squares(img):\n    # img, torch.Tensor of shape [B, H, W, C]\n\n    mark_color = torch.tensor([1.0, 0, 0], dtype=torch.float32)\n\n    for x_offset in range(4):\n        for y_offset in range(4):\n            img[:, x_offset::16, y_offset::16, :] = mark_color\n\n    return img\n\n\n# below for compute batched SSIM\ndef gaussian(window_size, sigma):\n\n    gauss = torch.Tensor(\n        [\n            exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))\n            for x in range(window_size)\n        ]\n    )\n    return gauss / gauss.sum()\n\n\ndef create_window(window_size, channel):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = Variable(\n        _2D_window.expand(channel, 1, window_size, window_size).contiguous()\n    )\n    return window\n\n\ndef compute_ssim(img1, img2, window_size=11, size_average=True):\n    channel = img1.size(-3)\n    window = create_window(window_size, channel)\n\n    if img1.is_cuda:\n        window = window.cuda(img1.get_device())\n    window = window.type_as(img1)\n\n    return _ssim(img1, img2, window, window_size, channel, size_average)\n\n\ndef _ssim(img1, img2, window, window_size, channel, size_average=True):\n    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)\n    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = (\n        F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq\n    )\n    sigma2_sq = (\n        F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq\n    )\n    sigma12 = (\n        F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)\n        - mu1_mu2\n    )\n\n    C1 = 0.01**2\n    C2 = 0.03**2\n\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (\n        (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)\n    )\n\n    if size_average:\n        return ssim_map.mean()\n    else:\n        return ssim_map.mean(1).mean(1).mean(1)\n\n\n# above for compute batched SSIM\n\n\ndef compute_low_res_psnr(img1, img2, scale_factor):\n    \"\"\"\n    Args:\n        img1: [B, C, H, W]\n        img2: [B, C, H, W]\n        scale_factor: int\n    \"\"\"\n    img1 = F.interpolate(\n        img1, scale_factor=1 / scale_factor, mode=\"bilinear\", align_corners=False\n    )\n    img2 = F.interpolate(\n        img2, scale_factor=1 / scale_factor, mode=\"bilinear\", align_corners=False\n    )\n    return compute_psnr(img1, img2)\n\n\ndef compute_low_res_mse(img1, img2, scale_factor):\n    \"\"\"\n    Args:\n        img1: [B, C, H, W]\n        img2: [B, C, H, W]\n        scale_factor: int\n    \"\"\"\n    img1 = F.interpolate(\n        img1, scale_factor=1 / scale_factor, mode=\"bilinear\", align_corners=False\n    )\n    img2 = F.interpolate(\n        img2, scale_factor=1 / scale_factor, mode=\"bilinear\", align_corners=False\n    )\n    loss_mse = F.mse_loss(img1, img2, reduction=\"mean\")\n    return loss_mse\n"
  },
  {
    "path": "physdreamer/utils/io_utils.py",
    "content": "import cv2\nimport imageio\nimport numpy as np\nimport mediapy\nimport os\nimport PIL\n\n\ndef read_video_cv2(video_path, rgb=True):\n    \"\"\"Read video using cv2, return [T, 3, H, W] array, fps\"\"\"\n\n    # BGR\n    cap = cv2.VideoCapture(video_path)\n    fps = cap.get(cv2.CAP_PROP_FPS)\n    num_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n    ret_list = []\n    for i in range(num_frame):\n        ret, frame = cap.read()\n        if ret:\n            if rgb:\n                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n            frame = np.transpose(frame, [2, 0, 1])  # [3, H, W]\n            ret_list.append(frame[np.newaxis, ...])\n        else:\n            break\n    cap.release()\n    ret_array = np.concatenate(ret_list, axis=0)  # [T, 3, H, W]\n    return ret_array, fps\n\n\ndef save_video_cv2(video_path, img_list, fps):\n    # BGR\n\n    if len(img_list) == 0:\n        return\n    h, w = img_list[0].shape[:2]\n    fourcc = cv2.VideoWriter_fourcc(\n        *\"mp4v\"\n    )  # cv2.VideoWriter_fourcc('m', 'p', '4', 'v')\n    writer = cv2.VideoWriter(video_path, fourcc, fps, (w, h))\n\n    for frame in img_list:\n        writer.write(frame)\n    writer.release()\n\n\ndef save_video_imageio(video_path, img_list, fps):\n    \"\"\"\n    Img_list: [[H, W, 3]]\n    \"\"\"\n    if len(img_list) == 0:\n        return\n    writer = imageio.get_writer(video_path, fps=fps)\n    for frame in img_list:\n        writer.append_data(frame)\n\n    writer.close()\n\n\ndef save_gif_imageio(video_path, img_list, fps):\n    \"\"\"\n    Img_list: [[H, W, 3]]\n    \"\"\"\n    if len(img_list) == 0:\n        return\n    assert video_path.endswith(\".gif\")\n\n    imageio.mimsave(video_path, img_list, format=\"GIF\", fps=fps)\n\n\ndef save_video_mediapy(video_frames, output_video_path: str = None, fps: int = 14):\n    # video_frames: [N, H, W, 3]\n    if isinstance(video_frames[0], PIL.Image.Image):\n        video_frames = [np.array(frame) for frame in video_frames]\n    os.makedirs(os.path.dirname(output_video_path), exist_ok=True)\n    mediapy.write_video(output_video_path, video_frames, fps=fps, qp=18)\n"
  },
  {
    "path": "physdreamer/utils/optimizer.py",
    "content": "import torch\nfrom torch.optim.lr_scheduler import LambdaLR\n\n\ndef get_linear_schedule_with_warmup(\n    optimizer, num_warmup_steps, num_training_steps, last_epoch=-1\n):\n    \"\"\"\n    From diffusers.optimization\n    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after\n    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.\n\n    Args:\n        optimizer ([`~torch.optim.Optimizer`]):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (`int`):\n            The total number of training steps.\n        last_epoch (`int`, *optional*, defaults to -1):\n            The index of the last epoch when resuming training.\n\n    Return:\n        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n\n    def lr_lambda(current_step: int):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        return max(\n            0.0,\n            float(num_training_steps - current_step)\n            / float(max(1, num_training_steps - num_warmup_steps)),\n        )\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n"
  },
  {
    "path": "physdreamer/utils/print_utils.py",
    "content": "import torch.distributed as dist\n\n\ndef print_if_zero_rank(s):\n    if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0):\n        print(\"### \" + s)\n"
  },
  {
    "path": "physdreamer/utils/pytorch_mssim.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom math import exp\nimport numpy as np\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])\n    return gauss/gauss.sum()\n\n\ndef create_window(window_size, channel=1):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)\n    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()\n    return window\n\ndef create_window_3d(window_size, channel=1):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t())\n    _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())\n    window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)\n    return window\n\n\ndef ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):\n    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).\n    if val_range is None:\n        if torch.max(img1) > 128:\n            max_val = 255\n        else:\n            max_val = 1\n\n        if torch.min(img1) < -0.5:\n            min_val = -1\n        else:\n            min_val = 0\n        L = max_val - min_val\n    else:\n        L = val_range\n\n    padd = 0\n    (_, channel, height, width) = img1.size()\n    if window is None:\n        real_size = min(window_size, height, width)\n        window = create_window(real_size, channel=channel).to(img1.device)\n    \n    # mu1 = F.conv2d(img1, window, padding=padd, groups=channel)\n    # mu2 = F.conv2d(img2, window, padding=padd, groups=channel)\n    mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)\n    mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq\n    sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq\n    sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2\n\n    C1 = (0.01 * L) ** 2\n    C2 = (0.03 * L) ** 2\n\n    v1 = 2.0 * sigma12 + C2\n    v2 = sigma1_sq + sigma2_sq + C2\n    cs = torch.mean(v1 / v2)  # contrast sensitivity\n\n    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)\n\n    if size_average:\n        ret = ssim_map.mean()\n    else:\n        ret = ssim_map.mean(1).mean(1).mean(1)\n\n    if full:\n        return ret, cs\n    return ret\n\n\ndef ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):\n    \"\"\"\n    Args:\n        img1, img2: (N, C, H, W)\n    \"\"\"\n    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).\n    if val_range is None:\n        if torch.max(img1) > 128:\n            max_val = 255\n        else:\n            max_val = 1\n\n        if torch.min(img1) < -0.5:\n            min_val = -1\n        else:\n            min_val = 0\n        L = max_val - min_val\n    else:\n        L = val_range\n\n    padd = 0\n    (_, _, height, width) = img1.size()\n    if window is None:\n        real_size = min(window_size, height, width)\n        window = create_window_3d(real_size, channel=1).to(img1.device)\n        # Channel is set to 1 since we consider color images as volumetric images\n\n    img1 = img1.unsqueeze(1)\n    img2 = img2.unsqueeze(1)\n\n    mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)\n    mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq\n    sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq\n    sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2\n\n    C1 = (0.01 * L) ** 2\n    C2 = (0.03 * L) ** 2\n\n    v1 = 2.0 * sigma12 + C2\n    v2 = sigma1_sq + sigma2_sq + C2\n    cs = torch.mean(v1 / v2)  # contrast sensitivity\n\n    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)\n\n    if size_average:\n        ret = ssim_map.mean()\n    else:\n        ret = ssim_map.mean(1).mean(1).mean(1)\n\n    if full:\n        return ret, cs\n    return ret\n\n\ndef msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):\n    device = img1.device\n    weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)\n    levels = weights.size()[0]\n    mssim = []\n    mcs = []\n    for _ in range(levels):\n        sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)\n        mssim.append(sim)\n        mcs.append(cs)\n\n        img1 = F.avg_pool2d(img1, (2, 2))\n        img2 = F.avg_pool2d(img2, (2, 2))\n\n    mssim = torch.stack(mssim)\n    mcs = torch.stack(mcs)\n\n    # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)\n    if normalize:\n        mssim = (mssim + 1) / 2\n        mcs = (mcs + 1) / 2\n\n    pow1 = mcs ** weights\n    pow2 = mssim ** weights\n    # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/\n    output = torch.prod(pow1[:-1] * pow2[-1])\n    return output\n\n\n# Classes to re-use window\nclass SSIM(torch.nn.Module):\n    def __init__(self, window_size=11, size_average=True, val_range=None):\n        super(SSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.val_range = val_range\n\n        # Assume 3 channel for SSIM\n        self.channel = 3\n        self.window = create_window(window_size, channel=self.channel)\n\n    def forward(self, img1, img2):\n        (_, channel, _, _) = img1.size()\n\n        if channel == self.channel and self.window.dtype == img1.dtype:\n            window = self.window\n        else:\n            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)\n            self.window = window\n            self.channel = channel\n\n        _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)\n        dssim = (1 - _ssim) / 2\n        return dssim\n\nclass MSSSIM(torch.nn.Module):\n    def __init__(self, window_size=11, size_average=True, channel=3):\n        super(MSSSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.channel = channel\n\n    def forward(self, img1, img2):\n        return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)"
  },
  {
    "path": "physdreamer/utils/svd_helpper.py",
    "content": "from glob import glob\nfrom sys import version\nfrom typing import Dict, List, Optional, Tuple, Union\nimport numpy as np\nimport torch\nimport os\n\nfrom omegaconf import ListConfig, OmegaConf\nfrom safetensors.torch import load_file as load_safetensors\n\nfrom sgm.inference.helpers import embed_watermark\nfrom sgm.modules.diffusionmodules.guiders import LinearPredictionGuider, VanillaCFG\nfrom sgm.util import append_dims, default, instantiate_from_config\nimport math\nfrom einops import repeat\n\n\ndef init_st(version_dict, load_ckpt=True, load_filter=True):\n    state = dict()\n    if not \"model\" in state:\n        config = version_dict[\"config\"]\n        ckpt = version_dict[\"ckpt\"]\n\n        config = OmegaConf.load(config)\n        model, msg = load_model_from_config(config, ckpt if load_ckpt else None)\n\n        state[\"msg\"] = msg\n        state[\"model\"] = model\n        state[\"ckpt\"] = ckpt if load_ckpt else None\n        state[\"config\"] = config\n        if load_filter:\n            return state\n            # from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering\n            state[\"filter\"] = DeepFloydDataFiltering(verbose=False)\n    return state\n\n\ndef load_model_from_config(config, ckpt=None, verbose=True):\n    model = instantiate_from_config(config.model)\n\n    if ckpt is not None:\n        print(f\"Loading model from {ckpt}\")\n        if ckpt.endswith(\"ckpt\"):\n            pl_sd = torch.load(ckpt, map_location=\"cpu\")\n            if \"global_step\" in pl_sd:\n                global_step = pl_sd[\"global_step\"]\n                print(f\"Global Step: {pl_sd['global_step']}\")\n            sd = pl_sd[\"state_dict\"]\n        elif ckpt.endswith(\"safetensors\"):\n            sd = load_safetensors(ckpt)\n        else:\n            raise NotImplementedError\n\n        msg = None\n\n        m, u = model.load_state_dict(sd, strict=False)\n\n        if len(m) > 0 and verbose:\n            print(\"missing keys:\")\n            print(m)\n        if len(u) > 0 and verbose:\n            print(\"unexpected keys:\")\n            print(u)\n    else:\n        msg = None\n\n    model = initial_model_load(model)\n    # model.eval()  # ?\n    return model, msg\n\n\ndef load_model(model):\n    model.cuda()\n\n\nlowvram_mode = False\n\n\ndef set_lowvram_mode(mode):\n    global lowvram_mode\n    lowvram_mode = mode\n\n\ndef initial_model_load(model):\n    global lowvram_mode\n    if lowvram_mode:\n        model.model.half()\n    else:\n        model.cuda()\n    return model\n\n\ndef unload_model(model):\n    global lowvram_mode\n    if lowvram_mode:\n        model.cpu()\n        torch.cuda.empty_cache()\n\n\ndef get_unique_embedder_keys_from_conditioner(conditioner):\n    return list(set([x.input_key for x in conditioner.embedders]))\n\n\ndef get_batch(keys, value_dict, N, T, device):\n    batch = {}\n    batch_uc = {}\n\n    for key in keys:\n        if key == \"fps_id\":\n            batch[key] = (\n                torch.tensor([value_dict[\"fps_id\"]])\n                .to(device)\n                .repeat(int(math.prod(N)))\n            )\n        elif key == \"motion_bucket_id\":\n            batch[key] = (\n                torch.tensor([value_dict[\"motion_bucket_id\"]])\n                .to(device)\n                .repeat(int(math.prod(N)))\n            )\n        elif key == \"cond_aug\":\n            batch[key] = repeat(\n                torch.tensor([value_dict[\"cond_aug\"]]).to(device),\n                \"1 -> b\",\n                b=math.prod(N),\n            )\n        elif key == \"cond_frames\":\n            batch[key] = repeat(value_dict[\"cond_frames\"], \"1 ... -> b ...\", b=N[0])\n        elif key == \"cond_frames_without_noise\":\n            batch[key] = repeat(\n                value_dict[\"cond_frames_without_noise\"], \"1 ... -> b ...\", b=N[0]\n            )\n        else:\n            batch[key] = value_dict[key]\n\n    if T is not None:\n        batch[\"num_video_frames\"] = T\n\n    for key in batch.keys():\n        if key not in batch_uc and isinstance(batch[key], torch.Tensor):\n            batch_uc[key] = torch.clone(batch[key])\n    return batch, batch_uc\n\n\nif __name__ == \"__main__\":\n    pass\n"
  },
  {
    "path": "physdreamer/utils/torch_utils.py",
    "content": "import torch\nimport time\n\n\ndef get_sync_time():\n    if torch.cuda.is_available():\n        torch.cuda.synchronize()\n    return time.time()\n"
  },
  {
    "path": "physdreamer/warp_mpm/README.md",
    "content": "This folder is mainly copy paste from  https://github.com/zeshunzong/warp-mpm\n\nThe biggest change is to make some operations during simulation **non-inplace**, and save the intermediate state during simulation, otherwise gradient computed by warp would be wrong. \n"
  },
  {
    "path": "physdreamer/warp_mpm/gaussian_sim_utils.py",
    "content": "import numpy as np\n\n\ndef get_volume(xyzs: np.ndarray, resolution=128) -> np.ndarray:\n\n    # set a grid in the range of [-1, 1], with resolution\n    voxel_counts = np.zeros((resolution, resolution, resolution))\n\n    points_xyzindex = ((xyzs + 1) / 2 * (resolution - 1)).astype(np.uint32)\n    cell_volume = (2.0 / (resolution - 1)) ** 3\n\n    for x, y, z in points_xyzindex:\n        voxel_counts[x, y, z] += 1\n\n    points_number_in_corresponding_voxel = voxel_counts[\n        points_xyzindex[:, 0], points_xyzindex[:, 1], points_xyzindex[:, 2]\n    ]\n\n    points_volume = cell_volume / points_number_in_corresponding_voxel\n\n    points_volume = points_volume.astype(np.float32)\n\n    # some statistics\n    num_non_empyt_voxels = np.sum(voxel_counts > 0)\n    max_points_in_voxel = np.max(voxel_counts)\n    min_points_in_voxel = np.min(voxel_counts)\n    avg_points_in_voxel = np.sum(voxel_counts) / num_non_empyt_voxels\n    print(\"Number of non-empty voxels: \", num_non_empyt_voxels)\n    print(\"Max points in voxel: \", max_points_in_voxel)\n    print(\"Min points in voxel: \", min_points_in_voxel)\n    print(\"Avg points in voxel: \", avg_points_in_voxel)\n\n    return points_volume\n"
  },
  {
    "path": "physdreamer/warp_mpm/mpm_data_structure.py",
    "content": "import warp as wp\nimport warp.torch\nimport torch\nfrom typing import Optional, Union, Sequence, Any\nfrom torch import Tensor\nimport os\nimport sys\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)))\nfrom warp_utils import from_torch_safe\n\n\n@wp.struct\nclass MPMStateStruct(object):\n    ###### essential #####\n    # particle\n    particle_x: wp.array(dtype=wp.vec3)  # current position\n    particle_v: wp.array(dtype=wp.vec3)  # particle velocity\n    particle_F: wp.array(dtype=wp.mat33)  # particle elastic deformation gradient\n    particle_cov: wp.array(dtype=float)  # current covariance matrix\n    particle_F_trial: wp.array(\n        dtype=wp.mat33\n    )  # apply return mapping on this to obtain elastic def grad\n    particle_stress: wp.array(dtype=wp.mat33)  # Kirchoff stress, elastic stress\n    particle_C: wp.array(dtype=wp.mat33)\n    particle_vol: wp.array(dtype=float)  # current volume\n    particle_mass: wp.array(dtype=float)  # mass\n    particle_density: wp.array(dtype=float)  # density\n\n    particle_selection: wp.array(\n        dtype=int\n    )  # only particle_selection[p] = 0 will be simulated\n\n    # grid\n    grid_m: wp.array(dtype=float, ndim=3)\n    grid_v_in: wp.array(dtype=wp.vec3, ndim=3)  # grid node momentum/velocity\n    grid_v_out: wp.array(\n        dtype=wp.vec3, ndim=3\n    )  # grid node momentum/velocity, after grid update\n\n    def init(\n        self,\n        shape: Union[Sequence[int], int],\n        device: wp.context.Devicelike = None,\n        requires_grad=False,\n    ) -> None:\n        # shape default is int. number of particles\n        self.particle_x = wp.zeros(\n            shape, dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.particle_v = wp.zeros(\n            shape, dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.particle_F = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_cov = wp.zeros(\n            shape * 6, dtype=float, device=device, requires_grad=False\n        )\n\n        self.particle_F_trial = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_stress = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_C = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_vol = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=False\n        )\n        self.particle_mass = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=False\n        )\n        self.particle_density = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=False\n        )\n\n        self.particle_selection = wp.zeros(\n            shape, dtype=int, device=device, requires_grad=False\n        )\n\n        # grid: will init later\n        self.grid_m = wp.zeros(\n            (10, 10, 10), dtype=float, device=device, requires_grad=requires_grad\n        )\n        self.grid_v_in = wp.zeros(\n            (10, 10, 10), dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.grid_v_out = wp.zeros(\n            (10, 10, 10), dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n\n    def init_grid(\n        self, grid_res: int, device: wp.context.Devicelike = None, requires_grad=False\n    ):\n        self.grid_m = wp.zeros(\n            (grid_res, grid_res, grid_res),\n            dtype=float,\n            device=device,\n            requires_grad=False,\n        )\n        self.grid_v_in = wp.zeros(\n            (grid_res, grid_res, grid_res),\n            dtype=wp.vec3,\n            device=device,\n            requires_grad=requires_grad,\n        )\n        self.grid_v_out = wp.zeros(\n            (grid_res, grid_res, grid_res),\n            dtype=wp.vec3,\n            device=device,\n            requires_grad=requires_grad,\n        )\n\n    def from_torch(\n        self,\n        tensor_x: Tensor,\n        tensor_volume: Tensor,\n        tensor_cov: Optional[Tensor] = None,\n        tensor_velocity: Optional[Tensor] = None,\n        n_grid: int = 100,\n        grid_lim=1.0,\n        device=\"cuda:0\",\n        requires_grad=True,\n    ):\n        num_dim, n_particles = tensor_x.shape[1], tensor_x.shape[0]\n        assert tensor_x.shape[0] == tensor_volume.shape[0]\n        # assert tensor_x.shape[0] == tensor_cov.reshape(-1, 6).shape[0]\n        self.init_grid(grid_res=n_grid, device=device, requires_grad=requires_grad)\n\n        if tensor_x is not None:\n            self.particle_x = from_torch_safe(\n                tensor_x.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_volume is not None:\n            print(self.particle_vol.shape, tensor_volume.shape)\n            volume_numpy = tensor_volume.detach().cpu().numpy()\n            self.particle_vol = wp.from_numpy(\n                volume_numpy, dtype=float, device=device, requires_grad=False\n            )\n\n        if tensor_cov is not None:\n            cov_numpy = tensor_cov.reshape(-1).detach().clone().cpu().numpy()\n            self.particle_cov = wp.from_numpy(\n                cov_numpy, dtype=float, device=device, requires_grad=False\n            )\n\n        if tensor_velocity is not None:\n            self.particle_v = from_torch_safe(\n                tensor_velocity.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        # initial deformation gradient is set to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F_trial],\n            device=device,\n        )\n        # initial trial deformation gradient is set to identity\n\n        print(\"Particles initialized from torch data.\")\n        print(\"Total particles: \", n_particles)\n\n    def reset_state(\n        self,\n        tensor_x: Tensor,\n        tensor_cov: Optional[Tensor] = None,\n        tensor_velocity: Optional[Tensor] = None,\n        tensor_density: Optional[Tensor] = None,\n        selection_mask: Optional[Tensor] = None,\n        device=\"cuda:0\",\n        requires_grad=True,\n    ):\n        # reset p_c, p_v, p_C, p_F_trial\n        num_dim, n_particles = tensor_x.shape[1], tensor_x.shape[0]\n\n        # assert tensor_x.shape[0] == tensor_cov.reshape(-1, 6).shape[0]\n\n        if tensor_x is not None:\n            self.particle_x = from_torch_safe(\n                tensor_x.contiguous().detach(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_cov is not None:\n            cov_numpy = tensor_cov.reshape(-1).detach().clone().cpu().numpy()\n            self.particle_cov = wp.from_numpy(\n                cov_numpy, dtype=float, device=device, requires_grad=False\n            )\n\n        if tensor_velocity is not None:\n            self.particle_v = from_torch_safe(\n                tensor_velocity.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_density is not None and selection_mask is not None:\n            wp_density = from_torch_safe(\n                tensor_density.contiguous().detach().clone(),\n                dtype=wp.float32,\n                requires_grad=False,\n            )\n            # 1 indicate we need to simulate this particle\n            wp_selection_mask = from_torch_safe(\n                selection_mask.contiguous().detach().clone().type(torch.int),\n                dtype=wp.int32,\n                requires_grad=False,\n            )\n\n            wp.launch(\n                kernel=set_float_vec_to_vec_wmask,\n                dim=n_particles,\n                inputs=[self.particle_density, wp_density, wp_selection_mask],\n                device=device,\n            )\n\n        # initial deformation gradient is set to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F_trial],\n            device=device,\n        )\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=set_mat33_to_zero,\n            dim=n_particles,\n            inputs=[self.particle_C],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=set_mat33_to_zero,\n            dim=n_particles,\n            inputs=[self.particle_stress],\n            device=device,\n        )\n\n    def continue_from_torch(\n        self,\n        tensor_x: Tensor,\n        tensor_velocity: Optional[Tensor] = None,\n        tensor_F: Optional[Tensor] = None,\n        tensor_C: Optional[Tensor] = None,\n        device=\"cuda:0\",\n        requires_grad=True,\n    ):\n        if tensor_x is not None:\n            self.particle_x = from_torch_safe(\n                tensor_x.contiguous().detach(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_velocity is not None:\n            self.particle_v = from_torch_safe(\n                tensor_velocity.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_F is not None:\n            self.particle_F_trial = from_torch_safe(\n                tensor_F.contiguous().detach().clone(),\n                dtype=wp.mat33,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_C is not None:\n            self.particle_C = from_torch_safe(\n                tensor_C.contiguous().detach().clone(),\n                dtype=wp.mat33,\n                requires_grad=requires_grad,\n            )\n\n    def set_require_grad(self, requires_grad=True):\n        self.particle_x.requires_grad = requires_grad\n        self.particle_v.requires_grad = requires_grad\n        self.particle_F.requires_grad = requires_grad\n        self.particle_F_trial.requires_grad = requires_grad\n        self.particle_stress.requires_grad = requires_grad\n        self.particle_C.requires_grad = requires_grad\n\n        self.grid_v_out.requires_grad = requires_grad\n        self.grid_v_in.requires_grad = requires_grad\n\n    def reset_density(\n        self,\n        tensor_density: Tensor,\n        selection_mask: Optional[Tensor] = None,\n        device=\"cuda:0\",\n        requires_grad=True,\n        update_mass=False,\n    ):\n        n_particles = tensor_density.shape[0]\n        if tensor_density is not None:\n            wp_density = from_torch_safe(\n                tensor_density.contiguous().detach().clone(),\n                dtype=wp.float32,\n                requires_grad=False,\n            )\n        \n        if selection_mask is not None:\n            # 1 indicate we need to simulate this particle\n            wp_selection_mask = from_torch_safe(\n                selection_mask.contiguous().detach().clone().type(torch.int),\n                dtype=wp.int32,\n                requires_grad=False,\n            )\n\n            wp.launch(\n                kernel=set_float_vec_to_vec_wmask,\n                dim=n_particles,\n                inputs=[self.particle_density, wp_density, wp_selection_mask],\n                device=device,\n            )\n        else:\n            wp.launch(\n                kernel=set_float_vec_to_vec,\n                dim=n_particles,\n                inputs=[self.particle_density, wp_density],\n                device=device,\n            )\n\n        if update_mass:\n            num_particles = self.particle_x.shape[0]\n            wp.launch(\n                kernel=get_float_array_product,\n                dim=num_particles,\n                inputs=[\n                    self.particle_density,\n                    self.particle_vol,\n                    self.particle_mass,\n                ],\n                device=device,\n            )\n\n    def partial_clone(self, device=\"cuda:0\", requires_grad=True):\n        new_state = MPMStateStruct()\n        n_particles = self.particle_x.shape[0]\n        new_state.init(n_particles, device=device, requires_grad=requires_grad)\n\n        # clone section:\n        # new_state.particle_vol = wp.clone(self.particle_vol, requires_grad=False)\n        # new_state.particle_density = wp.clone(self.particle_density, requires_grad=False)\n        # new_state.particle_mass = wp.clone(self.particle_mass, requires_grad=False)\n\n        # new_state.particle_selection = wp.clone(self.particle_selection, requires_grad=False)\n\n        wp.copy(new_state.particle_vol, self.particle_vol)\n        wp.copy(new_state.particle_density, self.particle_density)\n        wp.copy(new_state.particle_mass, self.particle_mass)\n        wp.copy(new_state.particle_selection, self.particle_selection)\n\n        # init grid to zero with grid res.\n        new_state.init_grid(\n            grid_res=self.grid_v_in.shape[0], device=device, requires_grad=requires_grad\n        )\n\n        # init some matrix to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[new_state.particle_F_trial],\n            device=device,\n        )\n\n        new_state.set_require_grad(requires_grad=requires_grad)\n        return new_state\n\n\n@wp.struct\nclass MPMModelStruct(object):\n    ####### essential #######\n    grid_lim: float\n    n_particles: int\n    n_grid: int\n    dx: float\n    inv_dx: float\n    grid_dim_x: int\n    grid_dim_y: int\n    grid_dim_z: int\n    mu: wp.array(dtype=float)\n    lam: wp.array(dtype=float)\n    E: wp.array(dtype=float)\n    nu: wp.array(dtype=float)\n    material: int\n\n    ######## for plasticity ####\n    yield_stress: wp.array(dtype=float)\n    friction_angle: float\n    alpha: float\n    gravitational_accelaration: wp.vec3\n    hardening: float\n    xi: float\n    plastic_viscosity: float\n    softening: float\n\n    ####### for damping\n    rpic_damping: float\n    grid_v_damping_scale: float\n\n    ####### for PhysGaussian: covariance\n    update_cov_with_F: int\n\n    def init(\n        self,\n        shape: Union[Sequence[int], int],\n        device: wp.context.Devicelike = None,\n        requires_grad=False,\n    ) -> None:\n        self.E = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )  # young's modulus\n        self.nu = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )  # poisson's ratio\n\n        self.mu = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n        self.lam = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n        self.yield_stress = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n    def finalize_mu_lam(self, n_particles, device=\"cuda:0\"):\n        wp.launch(\n            kernel=compute_mu_lam_from_E_nu_clean,\n            dim=n_particles,\n            inputs=[self.mu, self.lam, self.E, self.nu],\n            device=device,\n        )\n\n    def init_other_params(self, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.grid_lim = grid_lim\n        self.n_grid = n_grid\n        self.grid_dim_x = n_grid\n        self.grid_dim_y = n_grid\n        self.grid_dim_z = n_grid\n        (\n            self.dx,\n            self.inv_dx,\n        ) = self.grid_lim / self.n_grid, float(\n            n_grid / grid_lim\n        )  # [0-1]?\n\n        self.update_cov_with_F = False\n\n        # material is used to switch between different elastoplastic models. 0 is jelly\n        self.material = 0\n\n        self.plastic_viscosity = 0.0\n        self.softening = 0.1\n        self.friction_angle = 25.0\n        sin_phi = wp.sin(self.friction_angle / 180.0 * 3.14159265)\n        self.alpha = wp.sqrt(2.0 / 3.0) * 2.0 * sin_phi / (3.0 - sin_phi)\n\n        self.gravitational_accelaration = wp.vec3(0.0, 0.0, 0.0)\n\n        self.rpic_damping = 0.0  # 0.0 if no damping (apic). -1 if pic\n\n        self.grid_v_damping_scale = 1.1  # globally applied\n\n    def from_torch(\n        self, tensor_E: Tensor, tensor_nu: Tensor, device=\"cuda:0\", requires_grad=False\n    ):\n        self.E = wp.from_torch(tensor_E.contiguous(), requires_grad=requires_grad)\n        self.nu = wp.from_torch(tensor_nu.contiguous(), requires_grad=requires_grad)\n        n_particles = tensor_E.shape[0]\n        self.finalize_mu_lam(n_particles=n_particles, device=device)\n\n    def set_require_grad(self, requires_grad=True):\n        self.E.requires_grad = requires_grad\n        self.nu.requires_grad = requires_grad\n        self.mu.requires_grad = requires_grad\n        self.lam.requires_grad = requires_grad\n\n\n# for various boundary conditions\n@wp.struct\nclass Dirichlet_collider:\n    point: wp.vec3\n    normal: wp.vec3\n    direction: wp.vec3\n\n    start_time: float\n    end_time: float\n\n    friction: float\n    surface_type: int\n\n    velocity: wp.vec3\n\n    threshold: float\n    reset: int\n    index: int\n\n    x_unit: wp.vec3\n    y_unit: wp.vec3\n    radius: float\n    v_scale: float\n    width: float\n    height: float\n    length: float\n    R: float\n\n    size: wp.vec3\n\n    horizontal_axis_1: wp.vec3\n    horizontal_axis_2: wp.vec3\n    half_height_and_radius: wp.vec2\n\n\n@wp.struct\nclass GridCollider:\n    point: wp.vec3\n    normal: wp.vec3\n    direction: wp.vec3\n\n    start_time: float\n    end_time: float\n    mask: wp.array(dtype=int, ndim=3)\n\n\n@wp.struct\nclass Impulse_modifier:\n    # this needs to be changed for each different BC!\n    point: wp.vec3\n    normal: wp.vec3\n    start_time: float\n    end_time: float\n    force: wp.vec3\n    forceTimesDt: wp.vec3\n    numsteps: int\n\n    point: wp.vec3\n    size: wp.vec3\n    mask: wp.array(dtype=int)\n\n\n@wp.struct\nclass MPMtailoredStruct:\n    # this needs to be changed for each different BC!\n    point: wp.vec3\n    normal: wp.vec3\n    start_time: float\n    end_time: float\n    friction: float\n    surface_type: int\n    velocity: wp.vec3\n    threshold: float\n    reset: int\n\n    point_rotate: wp.vec3\n    normal_rotate: wp.vec3\n    x_unit: wp.vec3\n    y_unit: wp.vec3\n    radius: float\n    v_scale: float\n    width: float\n    point_plane: wp.vec3\n    normal_plane: wp.vec3\n    velocity_plane: wp.vec3\n    threshold_plane: float\n\n\n@wp.struct\nclass MaterialParamsModifier:\n    point: wp.vec3\n    size: wp.vec3\n    E: float\n    nu: float\n    density: float\n\n\n@wp.struct\nclass ParticleVelocityModifier:\n    point: wp.vec3\n    normal: wp.vec3\n    half_height_and_radius: wp.vec2\n    rotation_scale: float\n    translation_scale: float\n\n    size: wp.vec3\n\n    horizontal_axis_1: wp.vec3\n    horizontal_axis_2: wp.vec3\n\n    start_time: float\n\n    end_time: float\n\n    velocity: wp.vec3\n\n    mask: wp.array(dtype=int)\n\n\n@wp.kernel\ndef compute_mu_lam_from_E_nu_clean(\n    mu: wp.array(dtype=float),\n    lam: wp.array(dtype=float),\n    E: wp.array(dtype=float),\n    nu: wp.array(dtype=float),\n):\n    p = wp.tid()\n    mu[p] = E[p] / (2.0 * (1.0 + nu[p]))\n    lam[p] = E[p] * nu[p] / ((1.0 + nu[p]) * (1.0 - 2.0 * nu[p]))\n\n\n@wp.kernel\ndef set_vec3_to_zero(target_array: wp.array(dtype=wp.vec3)):\n    tid = wp.tid()\n    target_array[tid] = wp.vec3(0.0, 0.0, 0.0)\n\n\n@wp.kernel\ndef set_vec3_to_vec3(\n    source_array: wp.array(dtype=wp.vec3), target_array: wp.array(dtype=wp.vec3)\n):\n    tid = wp.tid()\n    source_array[tid] = target_array[tid]\n\n\n@wp.kernel\ndef set_float_vec_to_vec_wmask(\n    source_array: wp.array(dtype=float),\n    target_array: wp.array(dtype=float),\n    selection_mask: wp.array(dtype=int),\n):\n    tid = wp.tid()\n    if selection_mask[tid] == 1:\n        source_array[tid] = target_array[tid]\n\n\n@wp.kernel\ndef set_float_vec_to_vec(\n    source_array: wp.array(dtype=float), target_array: wp.array(dtype=float)\n):\n    tid = wp.tid()\n    source_array[tid] = target_array[tid]\n\n\n@wp.kernel\ndef set_mat33_to_identity(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n\n\n@wp.kernel\ndef set_mat33_to_zero(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n\n@wp.kernel\ndef add_identity_to_mat33(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.add(\n        target_array[tid], wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    )\n\n\n@wp.kernel\ndef subtract_identity_to_mat33(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.sub(\n        target_array[tid], wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    )\n\n\n@wp.kernel\ndef add_vec3_to_vec3(\n    first_array: wp.array(dtype=wp.vec3), second_array: wp.array(dtype=wp.vec3)\n):\n    tid = wp.tid()\n    first_array[tid] = wp.add(first_array[tid], second_array[tid])\n\n\n@wp.kernel\ndef set_value_to_float_array(target_array: wp.array(dtype=float), value: float):\n    tid = wp.tid()\n    target_array[tid] = value\n\n\n@wp.kernel\ndef set_warpvalue_to_float_array(\n    target_array: wp.array(dtype=float), value: warp.types.float32\n):\n    tid = wp.tid()\n    target_array[tid] = value\n\n\n@wp.kernel\ndef get_float_array_product(\n    arrayA: wp.array(dtype=float),\n    arrayB: wp.array(dtype=float),\n    arrayC: wp.array(dtype=float),\n):\n    tid = wp.tid()\n    arrayC[tid] = arrayA[tid] * arrayB[tid]\n\n\ndef torch2warp_quat(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 4\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.quat,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_float(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=warp.types.float32,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_vec3(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 3\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.vec3,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_mat33(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 3\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.mat33,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n"
  },
  {
    "path": "physdreamer/warp_mpm/mpm_solver_diff.py",
    "content": "import sys\nimport os\n\nimport warp as wp\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)))\nfrom mpm_data_structure import *\nfrom mpm_utils import *\nfrom typing import Optional, Union, Sequence, Any, Tuple\nfrom jaxtyping import Float, Int, Shaped\n\n\nclass MPMWARPDiff(object):\n    # def __init__(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n    #     self.initialize(n_particles, n_grid, grid_lim, device=device)\n    #     self.time_profile = {}\n\n    def __init__(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.initialize(n_particles, n_grid, grid_lim, device=device)\n        self.time_profile = {}\n\n    def initialize(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.n_particles = n_particles\n\n        self.time = 0.0\n\n        self.grid_postprocess = []\n        self.collider_params = []\n        self.modify_bc = []\n\n        self.tailored_struct_for_bc = MPMtailoredStruct()\n        self.pre_p2g_operations = []\n        self.impulse_params = []\n\n        self.particle_velocity_modifiers = []\n        self.particle_velocity_modifier_params = []\n\n    # must give density. mass will be updated as density * volume\n    def set_parameters(self, device=\"cuda:0\", **kwargs):\n        self.set_parameters_dict(device, kwargs)\n\n    def set_parameters_dict(self, mpm_model, mpm_state, kwargs={}, device=\"cuda:0\"):\n        if \"material\" in kwargs:\n            if kwargs[\"material\"] == \"jelly\":\n                mpm_model.material = 0\n            elif kwargs[\"material\"] == \"metal\":\n                mpm_model.material = 1\n            elif kwargs[\"material\"] == \"sand\":\n                mpm_model.material = 2\n            elif kwargs[\"material\"] == \"foam\":\n                mpm_model.material = 3\n            elif kwargs[\"material\"] == \"snow\":\n                mpm_model.material = 4\n            elif kwargs[\"material\"] == \"plasticine\":\n                mpm_model.material = 5\n            elif kwargs[\"material\"] == \"neo-hookean\":\n                mpm_model.material = 6\n            else:\n                raise TypeError(\"Undefined material type\")\n\n        if \"yield_stress\" in kwargs:\n            val = kwargs[\"yield_stress\"]\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_model.yield_stress, val],\n                device=device,\n            )\n        if \"hardening\" in kwargs:\n            mpm_model.hardening = kwargs[\"hardening\"]\n        if \"xi\" in kwargs:\n            mpm_model.xi = kwargs[\"xi\"]\n        if \"friction_angle\" in kwargs:\n            mpm_model.friction_angle = kwargs[\"friction_angle\"]\n            sin_phi = wp.sin(mpm_model.friction_angle / 180.0 * 3.14159265)\n            mpm_model.alpha = wp.sqrt(2.0 / 3.0) * 2.0 * sin_phi / (3.0 - sin_phi)\n\n        if \"g\" in kwargs:\n            mpm_model.gravitational_accelaration = wp.vec3(\n                kwargs[\"g\"][0], kwargs[\"g\"][1], kwargs[\"g\"][2]\n            )\n\n        if \"density\" in kwargs:\n            density_value = kwargs[\"density\"]\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_state.particle_density, density_value],\n                device=device,\n            )\n            wp.launch(\n                kernel=get_float_array_product,\n                dim=self.n_particles,\n                inputs=[\n                    mpm_state.particle_density,\n                    mpm_state.particle_vol,\n                    mpm_state.particle_mass,\n                ],\n                device=device,\n            )\n        if \"rpic_damping\" in kwargs:\n            mpm_model.rpic_damping = kwargs[\"rpic_damping\"]\n        if \"plastic_viscosity\" in kwargs:\n            mpm_model.plastic_viscosity = kwargs[\"plastic_viscosity\"]\n        if \"softening\" in kwargs:\n            mpm_model.softening = kwargs[\"softening\"]\n        if \"grid_v_damping_scale\" in kwargs:\n            mpm_model.grid_v_damping_scale = kwargs[\"grid_v_damping_scale\"]\n\n    def set_E_nu(self, mpm_model, E: float, nu: float, device=\"cuda:0\"):\n        if isinstance(E, float):\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_model.E, E],\n                device=device,\n            )\n        else:  # E is warp array\n            wp.launch(\n                kernel=set_float_vec_to_vec,\n                dim=self.n_particles,\n                inputs=[mpm_model.E, E],\n                device=device,\n            )\n\n        if isinstance(nu, float):\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_model.nu, nu],\n                device=device,\n            )\n        else:\n            wp.launch(\n                kernel=set_float_vec_to_vec,\n                dim=self.n_particles,\n                inputs=[mpm_model.nu, nu],\n                device=device,\n            )\n\n    def set_E_nu_from_torch(\n        self,\n        mpm_model,\n        E: Float[Tensor, \"n\"] | Float[Tensor, \"1\"],\n        nu: Float[Tensor, \"n\"] | Float[Tensor, \"1\"],\n        device=\"cuda:0\",\n    ):\n        if E.ndim == 0:\n            E_inp = E.item()  # float\n        else:\n            E_inp = from_torch_safe(E, dtype=wp.float32, requires_grad=True)\n\n        if nu.ndim == 0:\n            nu_inp = nu.item()  # float\n        else:\n            nu_inp = from_torch_safe(nu, dtype=wp.float32, requires_grad=True)\n\n        self.set_E_nu(mpm_model, E_inp, nu_inp, device=device)\n\n    def prepare_mu_lam(self, mpm_model, mpm_state, device=\"cuda:0\"):\n        # compute mu and lam from E and nu\n        wp.launch(\n            kernel=compute_mu_lam_from_E_nu,\n            dim=self.n_particles,\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n    def p2g2p_differentiable(\n        self, mpm_model, mpm_state, next_state, dt, device=\"cuda:0\"\n    ):\n        \"\"\"\n        Some boundary conditions, might not give gradient,\n        see kernels in\n            self.pre_p2g_operations,    Usually None.\n            self.particle_velocity_modifiers.   Mostly used to freeze points\n            self.grid_postprocess,      Should apply BC here\n        \"\"\"\n        grid_size = (\n            mpm_model.grid_dim_x,\n            mpm_model.grid_dim_y,\n            mpm_model.grid_dim_z,\n        )\n        wp.launch(\n            kernel=zero_grid,  # gradient might gone\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        # apply pre-p2g operations on particles\n        # apply impulse force on particles..\n        for k in range(len(self.pre_p2g_operations)):\n            wp.launch(\n                kernel=self.pre_p2g_operations[k],\n                dim=self.n_particles,\n                inputs=[self.time, dt, mpm_state, self.impulse_params[k]],\n                device=device,\n            )\n\n        # apply dirichlet particle v modifier\n        for k in range(len(self.particle_velocity_modifiers)):\n            wp.launch(\n                kernel=self.particle_velocity_modifiers[k],\n                dim=self.n_particles,\n                inputs=[\n                    self.time,\n                    mpm_state,\n                    self.particle_velocity_modifier_params[k],\n                ],\n                device=device,\n            )\n\n        # compute stress = stress(returnMap(F_trial))\n        # F_trail => F                    # TODO: this is overite..\n        # F, SVD(F), lam, mu => Stress.   # TODO: this is overite..\n\n        with wp.ScopedTimer(\n            \"compute_stress_from_F_trial\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=compute_stress_from_F_trial,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # F and stress are updated\n\n        # p2g\n        with wp.ScopedTimer(\n            \"p2g\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=p2g_apic_with_stress,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # apply p2g'\n\n        # grid update\n        with wp.ScopedTimer(\n            \"grid_update\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=grid_normalization_and_gravity,\n                dim=(grid_size),\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )\n\n        if mpm_model.grid_v_damping_scale < 1.0:\n            wp.launch(\n                kernel=add_damping_via_grid,\n                dim=(grid_size),\n                inputs=[mpm_state, mpm_model.grid_v_damping_scale],\n                device=device,\n            )\n\n        # apply BC on grid, collide\n        with wp.ScopedTimer(\n            \"apply_BC_on_grid\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            for k in range(len(self.grid_postprocess)):\n                wp.launch(\n                    kernel=self.grid_postprocess[k],\n                    dim=grid_size,\n                    inputs=[\n                        self.time,\n                        dt,\n                        mpm_state,\n                        mpm_model,\n                        self.collider_params[k],\n                    ],\n                    device=device,\n                )\n                if self.modify_bc[k] is not None:\n                    self.modify_bc[k](self.time, dt, self.collider_params[k])\n\n        # g2p\n        with wp.ScopedTimer(\n            \"g2p\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=g2p_differentiable,\n                dim=self.n_particles,\n                inputs=[mpm_state, next_state, mpm_model, dt],\n                device=device,\n            )  # x, v, C, F_trial are updated\n\n        self.time = self.time + dt\n\n    def p2g2p(self, mpm_model, mpm_state, step, dt, device=\"cuda:0\"):\n        grid_size = (\n            mpm_model.grid_dim_x,\n            mpm_model.grid_dim_y,\n            mpm_model.grid_dim_z,\n        )\n\n        wp.launch(\n            kernel=zero_grid,  # gradient might gone\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        # apply pre-p2g operations on particles\n        # apply impulse force on particles..\n        for k in range(len(self.pre_p2g_operations)):\n            wp.launch(\n                kernel=self.pre_p2g_operations[k],\n                dim=self.n_particles,\n                inputs=[self.time, dt, mpm_state, self.impulse_params[k]],\n                device=device,\n            )\n\n        # apply dirichlet particle v modifier\n        for k in range(len(self.particle_velocity_modifiers)):\n            wp.launch(\n                kernel=self.particle_velocity_modifiers[k],\n                dim=self.n_particles,\n                inputs=[\n                    self.time,\n                    mpm_state,\n                    self.particle_velocity_modifier_params[k],\n                ],\n                device=device,\n            )\n\n        # compute stress = stress(returnMap(F_trial))\n        # F_trail => F                    # TODO: this is overite..\n        # F, SVD(F), lam, mu => Stress.   # TODO: this is overite..\n\n        with wp.ScopedTimer(\n            \"compute_stress_from_F_trial\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=compute_stress_from_F_trial,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # F and stress are updated\n\n        # p2g\n        with wp.ScopedTimer(\n            \"p2g\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=p2g_apic_with_stress,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # apply p2g'\n\n        # grid update\n        with wp.ScopedTimer(\n            \"grid_update\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=grid_normalization_and_gravity,\n                dim=(grid_size),\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )\n\n        if mpm_model.grid_v_damping_scale < 1.0:\n            wp.launch(\n                kernel=add_damping_via_grid,\n                dim=(grid_size),\n                inputs=[mpm_state, mpm_model.grid_v_damping_scale],\n                device=device,\n            )\n\n        # apply BC on grid, collide\n        with wp.ScopedTimer(\n            \"apply_BC_on_grid\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            for k in range(len(self.grid_postprocess)):\n                wp.launch(\n                    kernel=self.grid_postprocess[k],\n                    dim=grid_size,\n                    inputs=[\n                        self.time,\n                        dt,\n                        mpm_state,\n                        mpm_model,\n                        self.collider_params[k],\n                    ],\n                    device=device,\n                )\n                if self.modify_bc[k] is not None:\n                    self.modify_bc[k](self.time, dt, self.collider_params[k])\n\n        # g2p\n        with wp.ScopedTimer(\n            \"g2p\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=g2p,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # x, v, C, F_trial are updated\n\n        #### CFL check ####\n        # particle_v = self.mpm_state.particle_v.numpy()\n        # if np.max(np.abs(particle_v)) > self.mpm_model.dx / dt:\n        #     print(\"max particle v: \", np.max(np.abs(particle_v)))\n        #     print(\"max allowed  v: \", self.mpm_model.dx / dt)\n        #     print(\"does not allow v*dt>dx\")\n        #     input()\n        #### CFL check ####\n        with wp.ScopedTimer(\n            \"clip_particle_x\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=clip_particle_x,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model],\n                device=device,\n            )\n\n        self.time = self.time + dt\n\n    def print_time_profile(self):\n        print(\"MPM Time profile:\")\n        for key, value in self.time_profile.items():\n            print(key, sum(value))\n\n    # a surface specified by a point and the normal vector\n    def add_surface_collider(\n        self,\n        point,\n        normal,\n        surface=\"sticky\",\n        friction=0.0,\n        start_time=0.0,\n        end_time=999.0,\n    ):\n        point = list(point)\n        # Normalize normal\n        normal_scale = 1.0 / wp.sqrt(float(sum(x**2 for x in normal)))\n        normal = list(normal_scale * x for x in normal)\n\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n\n        collider_param.point = wp.vec3(point[0], point[1], point[2])\n        collider_param.normal = wp.vec3(normal[0], normal[1], normal[2])\n\n        if surface == \"sticky\" and friction != 0:\n            raise ValueError(\"friction must be 0 on sticky surfaces.\")\n        if surface == \"sticky\":\n            collider_param.surface_type = 0\n        elif surface == \"slip\":\n            collider_param.surface_type = 1\n        elif surface == \"cut\":\n            collider_param.surface_type = 11\n        else:\n            collider_param.surface_type = 2\n        # frictional\n        collider_param.friction = friction\n\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                offset = wp.vec3(\n                    float(grid_x) * model.dx - param.point[0],\n                    float(grid_y) * model.dx - param.point[1],\n                    float(grid_z) * model.dx - param.point[2],\n                )\n                n = wp.vec3(param.normal[0], param.normal[1], param.normal[2])\n                dotproduct = wp.dot(offset, n)\n\n                if dotproduct < 0.0:\n                    if param.surface_type == 0:\n                        state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                            0.0, 0.0, 0.0\n                        )\n                    elif param.surface_type == 11:\n                        if (\n                            float(grid_z) * model.dx < 0.4\n                            or float(grid_z) * model.dx > 0.53\n                        ):\n                            state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                                0.0, 0.0, 0.0\n                            )\n                        else:\n                            v_in = state.grid_v_out[grid_x, grid_y, grid_z]\n                            state.grid_v_out[grid_x, grid_y, grid_z] = (\n                                wp.vec3(v_in[0], 0.0, v_in[2]) * 0.3\n                            )\n                    else:\n                        v = state.grid_v_out[grid_x, grid_y, grid_z]\n                        normal_component = wp.dot(v, n)\n                        if param.surface_type == 1:\n                            v = (\n                                v - normal_component * n\n                            )  # Project out all normal component\n                        else:\n                            v = (\n                                v - wp.min(normal_component, 0.0) * n\n                            )  # Project out only inward normal component\n                        if normal_component < 0.0 and wp.length(v) > 1e-20:\n                            v = wp.max(\n                                0.0, wp.length(v) + normal_component * param.friction\n                            ) * wp.normalize(\n                                v\n                            )  # apply friction here\n                        state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                            0.0, 0.0, 0.0\n                        )\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(None)\n\n    # a cubiod is a rectangular cube'\n    # centered at `point`\n    # dimension is x: point[0]±size[0]\n    #              y: point[1]±size[1]\n    #              z: point[2]±size[2]\n    # all grid nodes lie within the cubiod will have their speed set to velocity\n    # the cuboid itself is also moving with const speed = velocity\n    # set the speed to zero to fix BC\n    def set_velocity_on_cuboid(\n        self,\n        point,\n        size,\n        velocity,\n        start_time=0.0,\n        end_time=999.0,\n        reset=0,\n    ):\n        point = list(point)\n\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n        collider_param.point = wp.vec3(point[0], point[1], point[2])\n        collider_param.size = size\n        collider_param.velocity = wp.vec3(velocity[0], velocity[1], velocity[2])\n        # collider_param.threshold = threshold\n        collider_param.reset = reset\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                offset = wp.vec3(\n                    float(grid_x) * model.dx - param.point[0],\n                    float(grid_y) * model.dx - param.point[1],\n                    float(grid_z) * model.dx - param.point[2],\n                )\n                if (\n                    wp.abs(offset[0]) < param.size[0]\n                    and wp.abs(offset[1]) < param.size[1]\n                    and wp.abs(offset[2]) < param.size[2]\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = param.velocity\n            elif param.reset == 1:\n                if time < param.end_time + 15.0 * dt:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n\n        def modify(time, dt, param: Dirichlet_collider):\n            if time >= param.start_time and time < param.end_time:\n                param.point = wp.vec3(\n                    param.point[0] + dt * param.velocity[0],\n                    param.point[1] + dt * param.velocity[1],\n                    param.point[2] + dt * param.velocity[2],\n                )  # param.point + dt * param.velocity\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(modify)\n\n    def add_bounding_box(self, start_time=0.0, end_time=999.0):\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            padding = 3\n            if time >= param.start_time and time < param.end_time:\n                if grid_x < padding and state.grid_v_out[grid_x, grid_y, grid_z][0] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n                if (\n                    grid_x >= model.grid_dim_x - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][0] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n\n                if grid_y < padding and state.grid_v_out[grid_x, grid_y, grid_z][1] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n                if (\n                    grid_y >= model.grid_dim_y - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][1] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n\n                if grid_z < padding and state.grid_v_out[grid_x, grid_y, grid_z][2] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        0.0,\n                    )\n                if (\n                    grid_z >= model.grid_dim_z - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][2] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        0.0,\n                    )\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(None)\n\n    # particle_v += force/particle_mass * dt\n    # this is applied from start_dt, ends after num_dt p2g2p's\n    # particle velocity is changed before p2g at each timestep\n    def add_impulse_on_particles(\n        self,\n        mpm_state,\n        force,\n        dt,\n        point=[1, 1, 1],\n        size=[1, 1, 1],\n        num_dt=1,\n        start_time=0.0,\n        device=\"cuda:0\",\n    ):\n        impulse_param = Impulse_modifier()\n        impulse_param.start_time = start_time\n        impulse_param.end_time = start_time + dt * num_dt\n\n        impulse_param.point = wp.vec3(point[0], point[1], point[2])\n        impulse_param.size = wp.vec3(size[0], size[1], size[2])\n        impulse_param.mask = wp.zeros(shape=self.n_particles, dtype=int, device=device)\n\n        impulse_param.force = wp.vec3(\n            force[0],\n            force[1],\n            force[2],\n        )\n\n        wp.launch(\n            kernel=selection_add_impulse_on_particles,\n            dim=self.n_particles,\n            inputs=[mpm_state, impulse_param],\n            device=device,\n        )\n\n        self.impulse_params.append(impulse_param)\n\n        @wp.kernel\n        def apply_force(\n            time: float, dt: float, state: MPMStateStruct, param: Impulse_modifier\n        ):\n            p = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                if param.mask[p] == 1:\n                    impulse = wp.vec3(\n                        param.force[0] / state.particle_mass[p],\n                        param.force[1] / state.particle_mass[p],\n                        param.force[2] / state.particle_mass[p],\n                    )\n                    state.particle_v[p] = state.particle_v[p] + impulse * dt\n\n        self.pre_p2g_operations.append(apply_force)\n\n    def enforce_particle_velocity_translation(\n        self, mpm_state, point, size, velocity, start_time, end_time, device=\"cuda:0\"\n    ):\n        # first select certain particles based on position\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.point = wp.vec3(point[0], point[1], point[2])\n        velocity_modifier_params.size = wp.vec3(size[0], size[1], size[2])\n\n        velocity_modifier_params.velocity = wp.vec3(\n            velocity[0], velocity[1], velocity[2]\n        )\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.zeros(\n            shape=self.n_particles, dtype=int, device=device\n        )\n\n        wp.launch(\n            kernel=selection_enforce_particle_velocity_translation,\n            dim=self.n_particles,\n            inputs=[mpm_state, velocity_modifier_params],\n            device=device,\n        )\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    state.particle_v[p] = velocity_modifier_params.velocity\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    # define a cylinder with center point, half_height, radius, normal\n    # particles within the cylinder are rotating along the normal direction\n    # may also have a translational velocity along the normal direction\n    def enforce_particle_velocity_rotation(\n        self,\n        mpm_state,\n        point,\n        normal,\n        half_height_and_radius,\n        rotation_scale,\n        translation_scale,\n        start_time,\n        end_time,\n        device=\"cuda:0\",\n    ):\n        normal_scale = 1.0 / wp.sqrt(\n            float(normal[0] ** 2 + normal[1] ** 2 + normal[2] ** 2)\n        )\n        normal = list(normal_scale * x for x in normal)\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.point = wp.vec3(point[0], point[1], point[2])\n        velocity_modifier_params.half_height_and_radius = wp.vec2(\n            half_height_and_radius[0], half_height_and_radius[1]\n        )\n        velocity_modifier_params.normal = wp.vec3(normal[0], normal[1], normal[2])\n\n        horizontal_1 = wp.vec3(1.0, 1.0, 1.0)\n        if wp.abs(wp.dot(velocity_modifier_params.normal, horizontal_1)) < 0.01:\n            horizontal_1 = wp.vec3(0.72, 0.37, -0.67)\n        horizontal_1 = (\n            horizontal_1\n            - wp.dot(horizontal_1, velocity_modifier_params.normal)\n            * velocity_modifier_params.normal\n        )\n        horizontal_1 = horizontal_1 * (1.0 / wp.length(horizontal_1))\n        horizontal_2 = wp.cross(horizontal_1, velocity_modifier_params.normal)\n\n        velocity_modifier_params.horizontal_axis_1 = horizontal_1\n        velocity_modifier_params.horizontal_axis_2 = horizontal_2\n\n        velocity_modifier_params.rotation_scale = rotation_scale\n        velocity_modifier_params.translation_scale = translation_scale\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.zeros(\n            shape=self.n_particles, dtype=int, device=device\n        )\n\n        wp.launch(\n            kernel=selection_enforce_particle_velocity_cylinder,\n            dim=self.n_particles,\n            inputs=[mpm_state, velocity_modifier_params],\n            device=device,\n        )\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    offset = state.particle_x[p] - velocity_modifier_params.point\n                    horizontal_distance = wp.length(\n                        offset\n                        - wp.dot(offset, velocity_modifier_params.normal)\n                        * velocity_modifier_params.normal\n                    )\n                    cosine = (\n                        wp.dot(offset, velocity_modifier_params.horizontal_axis_1)\n                        / horizontal_distance\n                    )\n                    theta = wp.acos(cosine)\n                    if wp.dot(offset, velocity_modifier_params.horizontal_axis_2) > 0:\n                        theta = theta\n                    else:\n                        theta = -theta\n                    axis1_scale = (\n                        -horizontal_distance\n                        * wp.sin(theta)\n                        * velocity_modifier_params.rotation_scale\n                    )\n                    axis2_scale = (\n                        horizontal_distance\n                        * wp.cos(theta)\n                        * velocity_modifier_params.rotation_scale\n                    )\n                    axis_vertical_scale = translation_scale\n                    state.particle_v[p] = (\n                        axis1_scale * velocity_modifier_params.horizontal_axis_1\n                        + axis2_scale * velocity_modifier_params.horizontal_axis_2\n                        + axis_vertical_scale * velocity_modifier_params.normal\n                    )\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    # given normal direction, say [0,0,1]\n    # gradually release grid velocities from start position to end position\n    def release_particles_sequentially(\n        self, normal, start_position, end_position, num_layers, start_time, end_time\n    ):\n        num_layers = 50\n        point = [0, 0, 0]\n        size = [0, 0, 0]\n        axis = -1\n        for i in range(3):\n            if normal[i] == 0:\n                point[i] = 1\n                size[i] = 1\n            else:\n                axis = i\n                point[i] = end_position\n\n        half_length_portion = wp.abs(start_position - end_position) / num_layers\n        end_time_portion = end_time / num_layers\n        for i in range(num_layers):\n            size[axis] = half_length_portion * (num_layers - i)\n            self.enforce_particle_velocity_translation(\n                point=point,\n                size=size,\n                velocity=[0, 0, 0],\n                start_time=start_time,\n                end_time=end_time_portion * (i + 1),\n            )\n\n    def enforce_particle_velocity_by_mask(\n        self,\n        mpm_state,\n        selection_mask: torch.Tensor,\n        velocity,\n        start_time,\n        end_time,\n    ):\n        # first select certain particles based on position\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.velocity = wp.vec3(\n            velocity[0],\n            velocity[1],\n            velocity[2],\n        )\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.from_torch(selection_mask)\n\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    state.particle_v[p] = velocity_modifier_params.velocity\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    def restart_and_compute_F_C(self, mpm_model, mpm_state, target_pos, device):\n        grid_size = (\n            mpm_model.grid_dim_x,\n            mpm_model.grid_dim_y,\n            mpm_model.grid_dim_z,\n        )\n\n        wp.launch(\n            kernel=zero_grid,  # gradient might gone\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        wp.launch(\n            set_F_C_p2g,\n            dim=self.n_particles,\n            inputs=[mpm_state, mpm_model, target_pos],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=grid_normalization_and_gravity,\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model, 0],\n            device=device,\n        )\n\n        wp.launch(\n            set_F_C_g2p,\n            dim=self.n_particles,\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=zero_grid,  # gradient might gone\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        # set position to target_pos\n        wp.launch(\n            kernel=set_vec3_to_vec3,\n            dim=self.n_particles,\n            inputs=[mpm_state.particle_x, target_pos],\n            device=device,\n        )\n\n    def enforce_grid_velocity_by_mask(\n        self,\n        selection_mask: torch.Tensor,  # should be int\n    ):\n\n        grid_modifier_params = GridCollider()\n\n        grid_modifier_params.mask = wp.from_torch(selection_mask)\n\n        self.collider_params.append(grid_modifier_params)\n\n        @wp.kernel\n        def modify_grid_v_before_g2p(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            grid_modifier_params: GridCollider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n\n            if grid_modifier_params.mask[grid_x, grid_y, grid_z] >= 1:\n                state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n\n        self.grid_postprocess.append(modify_grid_v_before_g2p)\n        self.modify_bc.append(None)\n\n    # particle_v += force/particle_mass * dt\n    # this is applied from start_dt, ends after num_dt p2g2p's\n    # particle velocity is changed before p2g at each timestep\n    def add_impulse_on_particles_with_mask(\n        self,\n        mpm_state,\n        force,\n        dt,\n        particle_mask,  # 1 for selected particles, 0 for others\n        point=[1, 1, 1],\n        size=[1, 1, 1],\n        end_time=1,\n        start_time=0.0,\n        device=\"cuda:0\",\n    ):\n        assert (\n            len(particle_mask) == self.n_particles\n        ), \"mask should have n_particles elements\"\n        impulse_param = Impulse_modifier()\n        impulse_param.start_time = start_time\n        impulse_param.end_time = end_time\n        impulse_param.mask = wp.from_torch(particle_mask)\n\n        impulse_param.point = wp.vec3(point[0], point[1], point[2])\n        impulse_param.size = wp.vec3(size[0], size[1], size[2])\n\n        impulse_param.force = wp.vec3(\n            force[0],\n            force[1],\n            force[2],\n        )\n\n        wp.launch(\n            kernel=selection_add_impulse_on_particles,\n            dim=self.n_particles,\n            inputs=[mpm_state, impulse_param],\n            device=device,\n        )\n\n        self.impulse_params.append(impulse_param)\n\n        @wp.kernel\n        def apply_force(\n            time: float, dt: float, state: MPMStateStruct, param: Impulse_modifier\n        ):\n            p = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                if param.mask[p] >= 1:\n                    # impulse = wp.vec3(\n                    #     param.force[0] / state.particle_mass[p],\n                    #     param.force[1] / state.particle_mass[p],\n                    #     param.force[2] / state.particle_mass[p],\n                    # )\n                    impulse = wp.vec3(\n                        param.force[0],\n                        param.force[1],\n                        param.force[2],\n                    )\n                    state.particle_v[p] = state.particle_v[p] + impulse * dt\n\n        self.pre_p2g_operations.append(apply_force)\n"
  },
  {
    "path": "physdreamer/warp_mpm/mpm_utils.py",
    "content": "import warp as wp\nfrom mpm_data_structure import *\nimport numpy as np\nimport math\n\n\n# compute stress from F\n@wp.func\ndef kirchoff_stress_FCR(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, mu: float, lam: float\n):\n    # compute kirchoff stress for FCR model (remember tau = P F^T)\n    R = U * wp.transpose(V)\n    id = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    return 2.0 * mu * (F - R) * wp.transpose(F) + id * lam * J * (J - 1.0)\n\n\n@wp.func\ndef kirchoff_stress_neoHookean(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, sig: wp.vec3, mu: float, lam: float\n):\n    \"\"\"\n    B = F * wp.transpose(F)\n    dev(B) = B - (1/3) * tr(B) * I\n\n    For a compressible Rivlin neo-Hookean materia, the cauchy stress is given by:\n    mu * J^(-2/3) * dev(B) + lam * J (J - 1) * I\n    see: https://en.wikipedia.org/wiki/Neo-Hookean_solid\n    \"\"\"\n\n    # compute kirchoff stress for FCR model (remember tau = P F^T)\n    b = wp.vec3(sig[0] * sig[0], sig[1] * sig[1], sig[2] * sig[2])\n    b_hat = b - wp.vec3(\n        (b[0] + b[1] + b[2]) / 3.0,\n        (b[0] + b[1] + b[2]) / 3.0,\n        (b[0] + b[1] + b[2]) / 3.0,\n    )\n    tau = mu * J ** (-2.0 / 3.0) * b_hat + lam / 2.0 * (J * J - 1.0) * wp.vec3(\n        1.0, 1.0, 1.0\n    )\n\n    return (\n        U\n        * wp.mat33(tau[0], 0.0, 0.0, 0.0, tau[1], 0.0, 0.0, 0.0, tau[2])\n        * wp.transpose(V)\n        * wp.transpose(F)\n    )\n\n\n@wp.func\ndef kirchoff_stress_StVK(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, sig: wp.vec3, mu: float, lam: float\n):\n    sig = wp.vec3(\n        wp.max(sig[0], 0.01), wp.max(sig[1], 0.01), wp.max(sig[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    log_sig_sum = wp.log(sig[0]) + wp.log(sig[1]) + wp.log(sig[2])\n    ONE = wp.vec3(1.0, 1.0, 1.0)\n    tau = 2.0 * mu * epsilon + lam * log_sig_sum * ONE\n    return (\n        U\n        * wp.mat33(tau[0], 0.0, 0.0, 0.0, tau[1], 0.0, 0.0, 0.0, tau[2])\n        * wp.transpose(V)\n        * wp.transpose(F)\n    )\n\n\n@wp.func\ndef kirchoff_stress_drucker_prager(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, sig: wp.vec3, mu: float, lam: float\n):\n    log_sig_sum = wp.log(sig[0]) + wp.log(sig[1]) + wp.log(sig[2])\n    center00 = 2.0 * mu * wp.log(sig[0]) * (1.0 / sig[0]) + lam * log_sig_sum * (\n        1.0 / sig[0]\n    )\n    center11 = 2.0 * mu * wp.log(sig[1]) * (1.0 / sig[1]) + lam * log_sig_sum * (\n        1.0 / sig[1]\n    )\n    center22 = 2.0 * mu * wp.log(sig[2]) * (1.0 / sig[2]) + lam * log_sig_sum * (\n        1.0 / sig[2]\n    )\n    center = wp.mat33(center00, 0.0, 0.0, 0.0, center11, 0.0, 0.0, 0.0, center22)\n    return U * center * wp.transpose(V) * wp.transpose(F)\n\n\n@wp.func\ndef von_mises_return_mapping(F_trial: wp.mat33, model: MPMModelStruct, p: int):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig_old = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig_old, V)\n\n    sig = wp.vec3(\n        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    temp = (epsilon[0] + epsilon[1] + epsilon[2]) / 3.0\n\n    tau = 2.0 * model.mu[p] * epsilon + model.lam[p] * (\n        epsilon[0] + epsilon[1] + epsilon[2]\n    ) * wp.vec3(1.0, 1.0, 1.0)\n    sum_tau = tau[0] + tau[1] + tau[2]\n    cond = wp.vec3(\n        tau[0] - sum_tau / 3.0, tau[1] - sum_tau / 3.0, tau[2] - sum_tau / 3.0\n    )\n    if wp.length(cond) > model.yield_stress[p]:\n        epsilon_hat = epsilon - wp.vec3(temp, temp, temp)\n        epsilon_hat_norm = wp.length(epsilon_hat) + 1e-6\n        delta_gamma = epsilon_hat_norm - model.yield_stress[p] / (2.0 * model.mu[p])\n        epsilon = epsilon - (delta_gamma / epsilon_hat_norm) * epsilon_hat\n        sig_elastic = wp.mat33(\n            wp.exp(epsilon[0]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[1]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[2]),\n        )\n        F_elastic = U * sig_elastic * wp.transpose(V)\n        if model.hardening == 1:\n            model.yield_stress[p] = (\n                model.yield_stress[p] + 2.0 * model.mu[p] * model.xi * delta_gamma\n            )\n        return F_elastic\n    else:\n        return F_trial\n\n\n@wp.func\ndef von_mises_return_mapping_with_damage(\n    F_trial: wp.mat33, model: MPMModelStruct, p: int\n):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig_old = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig_old, V)\n\n    sig = wp.vec3(\n        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    temp = (epsilon[0] + epsilon[1] + epsilon[2]) / 3.0\n\n    tau = 2.0 * model.mu[p] * epsilon + model.lam[p] * (\n        epsilon[0] + epsilon[1] + epsilon[2]\n    ) * wp.vec3(1.0, 1.0, 1.0)\n    sum_tau = tau[0] + tau[1] + tau[2]\n    cond = wp.vec3(\n        tau[0] - sum_tau / 3.0, tau[1] - sum_tau / 3.0, tau[2] - sum_tau / 3.0\n    )\n    if wp.length(cond) > model.yield_stress[p]:\n        if model.yield_stress[p] <= 0:\n            return F_trial\n        epsilon_hat = epsilon - wp.vec3(temp, temp, temp)\n        epsilon_hat_norm = wp.length(epsilon_hat) + 1e-6\n        delta_gamma = epsilon_hat_norm - model.yield_stress[p] / (2.0 * model.mu[p])\n        epsilon = epsilon - (delta_gamma / epsilon_hat_norm) * epsilon_hat\n        model.yield_stress[p] = model.yield_stress[p] - model.softening * wp.length(\n            (delta_gamma / epsilon_hat_norm) * epsilon_hat\n        )\n        if model.yield_stress[p] <= 0:\n            model.mu[p] = 0.0\n            model.lam[p] = 0.0\n        sig_elastic = wp.mat33(\n            wp.exp(epsilon[0]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[1]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[2]),\n        )\n        F_elastic = U * sig_elastic * wp.transpose(V)\n        if model.hardening == 1:\n            model.yield_stress[p] = (\n                model.yield_stress[p] + 2.0 * model.mu[p] * model.xi * delta_gamma\n            )\n        return F_elastic\n    else:\n        return F_trial\n\n\n# for toothpaste\n@wp.func\ndef viscoplasticity_return_mapping_with_StVK(\n    F_trial: wp.mat33, model: MPMModelStruct, p: int, dt: float\n):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig_old = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig_old, V)\n\n    sig = wp.vec3(\n        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    b_trial = wp.vec3(sig[0] * sig[0], sig[1] * sig[1], sig[2] * sig[2])\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    trace_epsilon = epsilon[0] + epsilon[1] + epsilon[2]\n    epsilon_hat = epsilon - wp.vec3(\n        trace_epsilon / 3.0, trace_epsilon / 3.0, trace_epsilon / 3.0\n    )\n    s_trial = 2.0 * model.mu[p] * epsilon_hat\n    s_trial_norm = wp.length(s_trial)\n    y = s_trial_norm - wp.sqrt(2.0 / 3.0) * model.yield_stress[p]\n    if y > 0:\n        mu_hat = model.mu[p] * (b_trial[0] + b_trial[1] + b_trial[2]) / 3.0\n        s_new_norm = s_trial_norm - y / (\n            1.0 + model.plastic_viscosity / (2.0 * mu_hat * dt)\n        )\n        s_new = (s_new_norm / s_trial_norm) * s_trial\n        epsilon_new = 1.0 / (2.0 * model.mu[p]) * s_new + wp.vec3(\n            trace_epsilon / 3.0, trace_epsilon / 3.0, trace_epsilon / 3.0\n        )\n        sig_elastic = wp.mat33(\n            wp.exp(epsilon_new[0]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon_new[1]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon_new[2]),\n        )\n        F_elastic = U * sig_elastic * wp.transpose(V)\n        return F_elastic\n    else:\n        return F_trial\n\n\n@wp.func\ndef sand_return_mapping(\n    F_trial: wp.mat33, state: MPMStateStruct, model: MPMModelStruct, p: int\n):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig, V)\n\n    epsilon = wp.vec3(\n        wp.log(wp.max(wp.abs(sig[0]), 1e-14)),\n        wp.log(wp.max(wp.abs(sig[1]), 1e-14)),\n        wp.log(wp.max(wp.abs(sig[2]), 1e-14)),\n    )\n    sigma_out = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    tr = epsilon[0] + epsilon[1] + epsilon[2]  # + state.particle_Jp[p]\n    epsilon_hat = epsilon - wp.vec3(tr / 3.0, tr / 3.0, tr / 3.0)\n    epsilon_hat_norm = wp.length(epsilon_hat)\n    delta_gamma = (\n        epsilon_hat_norm\n        + (3.0 * model.lam[p] + 2.0 * model.mu[p])\n        / (2.0 * model.mu[p])\n        * tr\n        * model.alpha\n    )\n\n    if delta_gamma <= 0:\n        F_elastic = F_trial\n\n    if delta_gamma > 0 and tr > 0:\n        F_elastic = U * wp.transpose(V)\n\n    if delta_gamma > 0 and tr <= 0:\n        H = epsilon - epsilon_hat * (delta_gamma / epsilon_hat_norm)\n        s_new = wp.vec3(wp.exp(H[0]), wp.exp(H[1]), wp.exp(H[2]))\n\n        F_elastic = U * wp.diag(s_new) * wp.transpose(V)\n    return F_elastic\n\n\n@wp.kernel\ndef compute_mu_lam_from_E_nu(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n    model.mu[p] = model.E[p] / (2.0 * (1.0 + model.nu[p]))\n    model.lam[p] = (\n        model.E[p] * model.nu[p] / ((1.0 + model.nu[p]) * (1.0 - 2.0 * model.nu[p]))\n    )\n\n\n@wp.kernel\ndef zero_grid(state: MPMStateStruct, model: MPMModelStruct):\n    grid_x, grid_y, grid_z = wp.tid()\n    state.grid_m[grid_x, grid_y, grid_z] = 0.0\n    state.grid_v_in[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n\n\n@wp.func\ndef compute_dweight(\n    model: MPMModelStruct, w: wp.mat33, dw: wp.mat33, i: int, j: int, k: int\n):\n    dweight = wp.vec3(\n        dw[0, i] * w[1, j] * w[2, k],\n        w[0, i] * dw[1, j] * w[2, k],\n        w[0, i] * w[1, j] * dw[2, k],\n    )\n    return dweight * model.inv_dx\n\n\n@wp.func\ndef update_cov(state: MPMStateStruct, p: int, grad_v: wp.mat33, dt: float):\n    cov_n = wp.mat33(0.0)\n    cov_n[0, 0] = state.particle_cov[p * 6]\n    cov_n[0, 1] = state.particle_cov[p * 6 + 1]\n    cov_n[0, 2] = state.particle_cov[p * 6 + 2]\n    cov_n[1, 0] = state.particle_cov[p * 6 + 1]\n    cov_n[1, 1] = state.particle_cov[p * 6 + 3]\n    cov_n[1, 2] = state.particle_cov[p * 6 + 4]\n    cov_n[2, 0] = state.particle_cov[p * 6 + 2]\n    cov_n[2, 1] = state.particle_cov[p * 6 + 4]\n    cov_n[2, 2] = state.particle_cov[p * 6 + 5]\n\n    cov_np1 = cov_n + dt * (grad_v * cov_n + cov_n * wp.transpose(grad_v))\n\n    state.particle_cov[p * 6] = cov_np1[0, 0]\n    state.particle_cov[p * 6 + 1] = cov_np1[0, 1]\n    state.particle_cov[p * 6 + 2] = cov_np1[0, 2]\n    state.particle_cov[p * 6 + 3] = cov_np1[1, 1]\n    state.particle_cov[p * 6 + 4] = cov_np1[1, 2]\n    state.particle_cov[p * 6 + 5] = cov_np1[2, 2]\n\n\n@wp.func\ndef update_cov_differentiable(\n    state: MPMStateStruct,\n    next_state: MPMStateStruct,\n    p: int,\n    grad_v: wp.mat33,\n    dt: float,\n):\n    cov_n = wp.mat33(0.0)\n    cov_n[0, 0] = state.particle_cov[p * 6]\n    cov_n[0, 1] = state.particle_cov[p * 6 + 1]\n    cov_n[0, 2] = state.particle_cov[p * 6 + 2]\n    cov_n[1, 0] = state.particle_cov[p * 6 + 1]\n    cov_n[1, 1] = state.particle_cov[p * 6 + 3]\n    cov_n[1, 2] = state.particle_cov[p * 6 + 4]\n    cov_n[2, 0] = state.particle_cov[p * 6 + 2]\n    cov_n[2, 1] = state.particle_cov[p * 6 + 4]\n    cov_n[2, 2] = state.particle_cov[p * 6 + 5]\n\n    cov_np1 = cov_n + dt * (grad_v * cov_n + cov_n * wp.transpose(grad_v))\n\n    next_state.particle_cov[p * 6] = cov_np1[0, 0]\n    next_state.particle_cov[p * 6 + 1] = cov_np1[0, 1]\n    next_state.particle_cov[p * 6 + 2] = cov_np1[0, 2]\n    next_state.particle_cov[p * 6 + 3] = cov_np1[1, 1]\n    next_state.particle_cov[p * 6 + 4] = cov_np1[1, 2]\n    next_state.particle_cov[p * 6 + 5] = cov_np1[2, 2]\n\n\n@wp.kernel\ndef p2g_apic_with_stress(state: MPMStateStruct, model: MPMModelStruct, dt: float):\n    # input given to p2g:   particle_stress\n    #                       particle_x\n    #                       particle_v\n    #                       particle_C\n    # output:               grid_v_in, grid_m\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        stress = state.particle_stress[p]\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    dpos = (\n                        wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    ) * model.dx\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n\n                    C = state.particle_C[p]\n                    # if model.rpic = 0, standard apic\n                    C = (1.0 - model.rpic_damping) * C + model.rpic_damping / 2.0 * (\n                        C - wp.transpose(C)\n                    )\n\n                    # C = (1.0 - model.rpic_damping) * state.particle_C[\n                    #     p\n                    # ] + model.rpic_damping / 2.0 * (\n                    #     state.particle_C[p] - wp.transpose(state.particle_C[p])\n                    # )\n\n                    if model.rpic_damping < -0.001:\n                        # standard pic\n                        C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n                    elastic_force = -state.particle_vol[p] * stress * dweight\n                    v_in_add = (\n                        weight\n                        * state.particle_mass[p]\n                        * (state.particle_v[p] + C * dpos)\n                        + dt * elastic_force\n                    )\n                    wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)\n                    wp.atomic_add(\n                        state.grid_m, ix, iy, iz, weight * state.particle_mass[p]\n                    )\n\n\n# add gravity\n@wp.kernel\ndef grid_normalization_and_gravity(\n    state: MPMStateStruct, model: MPMModelStruct, dt: float\n):\n    grid_x, grid_y, grid_z = wp.tid()\n    if state.grid_m[grid_x, grid_y, grid_z] > 1e-15:\n        v_out = state.grid_v_in[grid_x, grid_y, grid_z] * (\n            1.0 / state.grid_m[grid_x, grid_y, grid_z]\n        )\n        # add gravity\n        v_out = v_out + dt * model.gravitational_accelaration\n        state.grid_v_out[grid_x, grid_y, grid_z] = v_out\n\n\n@wp.kernel\ndef g2p(state: MPMStateStruct, model: MPMModelStruct, dt: float):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n        new_v = wp.vec3(0.0, 0.0, 0.0)\n        new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    grid_v = state.grid_v_out[ix, iy, iz]\n                    new_v = new_v + grid_v * weight\n                    new_C = new_C + wp.outer(grid_v, dpos) * (\n                        weight * model.inv_dx * 4.0\n                    )\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n                    new_F = new_F + wp.outer(grid_v, dweight)\n\n        state.particle_v[p] = new_v\n        # state.particle_x[p] = state.particle_x[p] + dt * new_v\n        # state.particle_x[p] = state.particle_x[p] + dt * state.particle_v[p]\n\n        # wp.atomic_add(state.particle_x, p, dt * state.particle_v[p]) # old one is this..\n        wp.atomic_add(state.particle_x, p, dt * new_v)  # debug\n        # new_x = state.particle_x[p] + dt * state.particle_v[p]\n        # state.particle_x[p] = new_x\n\n        state.particle_C[p] = new_C\n\n        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n        F_tmp = (I33 + new_F * dt) * state.particle_F[p]\n        state.particle_F_trial[p] = F_tmp\n        # debug for jelly\n        # wp.atomic_add(state.particle_F_trial, p, new_F * dt * state.particle_F[p])\n\n        if model.update_cov_with_F:\n            update_cov(state, p, new_F, dt)\n\n\n@wp.kernel\ndef g2p_differentiable(\n    state: MPMStateStruct, next_state: MPMStateStruct, model: MPMModelStruct, dt: float\n):\n    \"\"\"\n    Compute:\n        next_state.particle_v, next_state.particle_x, next_state.particle_C, next_state.particle_F_trial\n    \"\"\"\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n        new_v = wp.vec3(0.0, 0.0, 0.0)\n        # new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        new_C = wp.mat33(new_v, new_v, new_v)\n        \n        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    grid_v = state.grid_v_out[ix, iy, iz]\n                    new_v = (\n                        new_v + grid_v * weight\n                    )  # TODO, check gradient from static loop\n                    new_C = new_C + wp.outer(grid_v, dpos) * (\n                        weight * model.inv_dx * 4.0\n                    )\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n                    new_F = new_F + wp.outer(grid_v, dweight)\n\n        next_state.particle_v[p] = new_v\n\n        # add clip here:\n        new_x = state.particle_x[p] + dt * new_v\n        dx = 1.0 / model.inv_dx\n        a_min = dx * 2.0\n        a_max = model.grid_lim - dx * 2.0\n\n        new_x_clamped = wp.vec3(\n            wp.clamp(new_x[0], a_min, a_max),\n            wp.clamp(new_x[1], a_min, a_max),\n            wp.clamp(new_x[2], a_min, a_max),\n        )\n        next_state.particle_x[p] = new_x_clamped\n\n        # next_state.particle_x[p] = new_x\n\n        next_state.particle_C[p] = new_C\n\n        I33_1 = wp.vec3(1.0, 0.0, 0.0)\n        I33_2 = wp.vec3(0.0, 1.0, 0.0)\n        I33_3 = wp.vec3(0.0, 0.0, 1.0)\n        I33 = wp.mat33(I33_1, I33_2, I33_3)\n        F_tmp = (I33 + new_F * dt) * state.particle_F[p]\n        next_state.particle_F_trial[p] = F_tmp\n\n        if 0:\n            update_cov_differentiable(state, next_state, p, new_F, dt)\n\n\n@wp.kernel\ndef clip_particle_x(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n\n    posx = state.particle_x[p]\n    if state.particle_selection[p] == 0:\n        dx = 1.0 / model.inv_dx\n        a_min = dx * 2.0\n        a_max = model.grid_lim - dx * 2.0\n        new_x = wp.vec3(\n            wp.clamp(posx[0], a_min, a_max),\n            wp.clamp(posx[1], a_min, a_max),\n            wp.clamp(posx[2], a_min, a_max),\n        )\n\n        state.particle_x[\n            p\n        ] = new_x  # Warn: this gives wrong gradient, don't use this for backward\n\n\n# compute (Kirchhoff) stress = stress(returnMap(F_trial))\n@wp.kernel\ndef compute_stress_from_F_trial(\n    state: MPMStateStruct, model: MPMModelStruct, dt: float\n):\n    \"\"\"\n    state.particle_F_trial => state.particle_F   # return mapping\n    state.particle_F => state.particle_stress    # stress-strain\n\n    TODO: check the gradient of SVD!  is wp.svd3 differentiable? I guess so\n    \"\"\"\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        # apply return mapping\n        if model.material == 1:  # metal\n            state.particle_F[p] = von_mises_return_mapping(\n                state.particle_F_trial[p], model, p\n            )\n        elif model.material == 2:  # sand\n            state.particle_F[p] = sand_return_mapping(\n                state.particle_F_trial[p], state, model, p\n            )\n        elif model.material == 3:  # visplas, with StVk+VM, no thickening\n            state.particle_F[p] = viscoplasticity_return_mapping_with_StVK(\n                state.particle_F_trial[p], model, p, dt\n            )\n        elif model.material == 5:\n            state.particle_F[p] = von_mises_return_mapping_with_damage(\n                state.particle_F_trial[p], model, p\n            )\n        else:  # elastic, jelly, or neo-hookean\n            state.particle_F[p] = state.particle_F_trial[p]\n\n        # also compute stress here\n        J = wp.determinant(state.particle_F[p])\n        U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        sig = wp.vec3(0.0)\n        stress = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        wp.svd3(state.particle_F[p], U, sig, V)\n        if model.material == 0 or model.material == 5:\n            stress = kirchoff_stress_FCR(\n                state.particle_F[p], U, V, J, model.mu[p], model.lam[p]\n            )\n        if model.material == 1:\n            stress = kirchoff_stress_StVK(\n                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]\n            )\n        if model.material == 2:\n            stress = kirchoff_stress_drucker_prager(\n                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]\n            )\n        if model.material == 3:\n            # temporarily use stvk, subject to change\n            stress = kirchoff_stress_StVK(\n                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]\n            )\n\n        if model.material == 6:\n            stress = kirchoff_stress_neoHookean(\n                state.particle_F[p], U, V, J, sig, model.mu[p], model.lam[p]\n            )\n        # stress = (stress + wp.transpose(stress)) / 2.0  # enfore symmetry\n        state.particle_stress[p] = (stress + wp.transpose(stress)) / 2.0\n\n\n# @wp.kernel\n# def compute_cov_from_F(state: MPMStateStruct, model: MPMModelStruct):\n#     p = wp.tid()\n\n#     F = state.particle_F_trial[p]\n\n#     init_cov = wp.mat33(0.0)\n#     init_cov[0, 0] = state.particle_init_cov[p * 6]\n#     init_cov[0, 1] = state.particle_init_cov[p * 6 + 1]\n#     init_cov[0, 2] = state.particle_init_cov[p * 6 + 2]\n#     init_cov[1, 0] = state.particle_init_cov[p * 6 + 1]\n#     init_cov[1, 1] = state.particle_init_cov[p * 6 + 3]\n#     init_cov[1, 2] = state.particle_init_cov[p * 6 + 4]\n#     init_cov[2, 0] = state.particle_init_cov[p * 6 + 2]\n#     init_cov[2, 1] = state.particle_init_cov[p * 6 + 4]\n#     init_cov[2, 2] = state.particle_init_cov[p * 6 + 5]\n\n#     cov = F * init_cov * wp.transpose(F)\n\n#     state.particle_cov[p * 6] = cov[0, 0]\n#     state.particle_cov[p * 6 + 1] = cov[0, 1]\n#     state.particle_cov[p * 6 + 2] = cov[0, 2]\n#     state.particle_cov[p * 6 + 3] = cov[1, 1]\n#     state.particle_cov[p * 6 + 4] = cov[1, 2]\n#     state.particle_cov[p * 6 + 5] = cov[2, 2]\n\n\n# @wp.kernel\n# def compute_R_from_F(state: MPMStateStruct, model: MPMModelStruct):\n#     p = wp.tid()\n\n#     F = state.particle_F_trial[p]\n\n#     # polar svd decomposition\n#     U = wp.mat33(0.0)\n#     V = wp.mat33(0.0)\n#     sig = wp.vec3(0.0)\n#     wp.svd3(F, U, sig, V)\n\n#     if wp.determinant(U) < 0.0:\n#         U[0, 2] = -U[0, 2]\n#         U[1, 2] = -U[1, 2]\n#         U[2, 2] = -U[2, 2]\n\n#     if wp.determinant(V) < 0.0:\n#         V[0, 2] = -V[0, 2]\n#         V[1, 2] = -V[1, 2]\n#         V[2, 2] = -V[2, 2]\n\n#     # compute rotation matrix\n#     R = U * wp.transpose(V)\n#     state.particle_R[p] = wp.transpose(R) # particle R is removed\n\n\n@wp.kernel\ndef add_damping_via_grid(state: MPMStateStruct, scale: float):\n    grid_x, grid_y, grid_z = wp.tid()\n    # state.grid_v_out[grid_x, grid_y, grid_z] = (\n    #     state.grid_v_out[grid_x, grid_y, grid_z] * scale\n    # )\n    wp.atomic_sub(\n        state.grid_v_out,\n        grid_x,\n        grid_y,\n        grid_z,\n        (1.0 - scale) * state.grid_v_out[grid_x, grid_y, grid_z],\n    )\n\n\n@wp.kernel\ndef apply_additional_params(\n    state: MPMStateStruct,\n    model: MPMModelStruct,\n    params_modifier: MaterialParamsModifier,\n):\n    p = wp.tid()\n    pos = state.particle_x[p]\n    if (\n        pos[0] > params_modifier.point[0] - params_modifier.size[0]\n        and pos[0] < params_modifier.point[0] + params_modifier.size[0]\n        and pos[1] > params_modifier.point[1] - params_modifier.size[1]\n        and pos[1] < params_modifier.point[1] + params_modifier.size[1]\n        and pos[2] > params_modifier.point[2] - params_modifier.size[2]\n        and pos[2] < params_modifier.point[2] + params_modifier.size[2]\n    ):\n        model.E[p] = params_modifier.E\n        model.nu[p] = params_modifier.nu\n        state.particle_density[p] = params_modifier.density\n\n\n@wp.kernel\ndef selection_add_impulse_on_particles(\n    state: MPMStateStruct, impulse_modifier: Impulse_modifier\n):\n    p = wp.tid()\n    offset = state.particle_x[p] - impulse_modifier.point\n    if (\n        wp.abs(offset[0]) < impulse_modifier.size[0]\n        and wp.abs(offset[1]) < impulse_modifier.size[1]\n        and wp.abs(offset[2]) < impulse_modifier.size[2]\n    ):\n        impulse_modifier.mask[p] = 1\n    else:\n        impulse_modifier.mask[p] = 0\n\n\n@wp.kernel\ndef selection_enforce_particle_velocity_translation(\n    state: MPMStateStruct, velocity_modifier: ParticleVelocityModifier\n):\n    p = wp.tid()\n    offset = state.particle_x[p] - velocity_modifier.point\n    if (\n        wp.abs(offset[0]) < velocity_modifier.size[0]\n        and wp.abs(offset[1]) < velocity_modifier.size[1]\n        and wp.abs(offset[2]) < velocity_modifier.size[2]\n    ):\n        velocity_modifier.mask[p] = 1\n    else:\n        velocity_modifier.mask[p] = 0\n\n\n@wp.kernel\ndef selection_enforce_particle_velocity_cylinder(\n    state: MPMStateStruct, velocity_modifier: ParticleVelocityModifier\n):\n    p = wp.tid()\n    offset = state.particle_x[p] - velocity_modifier.point\n\n    vertical_distance = wp.abs(wp.dot(offset, velocity_modifier.normal))\n\n    horizontal_distance = wp.length(\n        offset - wp.dot(offset, velocity_modifier.normal) * velocity_modifier.normal\n    )\n    if (\n        vertical_distance < velocity_modifier.half_height_and_radius[0]\n        and horizontal_distance < velocity_modifier.half_height_and_radius[1]\n    ):\n        velocity_modifier.mask[p] = 1\n    else:\n        velocity_modifier.mask[p] = 0\n\n\n@wp.kernel\ndef compute_position_l2_loss(\n    mpm_state: MPMStateStruct,\n    gt_pos: wp.array(dtype=wp.vec3),\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    pos = mpm_state.particle_x[tid]\n    pos_gt = gt_pos[tid]\n\n    # l1_diff = wp.abs(pos - pos_gt)\n    l2 = wp.length(pos - pos_gt)\n\n    wp.atomic_add(loss, 0, l2)\n\n\n@wp.kernel\ndef aggregate_grad(x: wp.array(dtype=float), grad: wp.array(dtype=float)):\n    tid = wp.tid()\n\n    # gradient descent step\n    wp.atomic_add(x, 0, grad[tid])\n\n\n@wp.kernel\ndef set_F_C_p2g(\n    state: MPMStateStruct, model: MPMModelStruct, target_pos: wp.array(dtype=wp.vec3)\n):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        # p2g for displacement\n        particle_disp = target_pos[p] - state.particle_x[p]\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    v_in_add = weight * state.particle_mass[p] * particle_disp\n                    wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)\n                    wp.atomic_add(\n                        state.grid_m, ix, iy, iz, weight * state.particle_mass[p]\n                    )\n\n\n@wp.kernel\ndef set_F_C_g2p(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n        new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n        # g2p for C and F\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    grid_v = state.grid_v_out[ix, iy, iz]\n                    new_C = new_C + wp.outer(grid_v, dpos) * (\n                        weight * model.inv_dx * 4.0\n                    )\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n                    new_F = new_F + wp.outer(grid_v, dweight)\n\n        # C should still be zero..\n        # state.particle_C[p] = new_C\n        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n        F_tmp = I33 + new_F\n        state.particle_F_trial[p] = F_tmp\n\n        if model.update_cov_with_F:\n            update_cov(state, p, new_F, 1.0)\n\n\n@wp.kernel\ndef compute_posloss_with_grad(\n    mpm_state: MPMStateStruct,\n    gt_pos: wp.array(dtype=wp.vec3),\n    grad: wp.array(dtype=wp.vec3),\n    dt: float,\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    pos = mpm_state.particle_x[tid]\n    pos_gt = gt_pos[tid]\n\n    # l1_diff = wp.abs(pos - pos_gt)\n    # l2 = wp.length(pos - (pos_gt - grad[tid] * dt))\n    diff = pos - (pos_gt - grad[tid] * dt)\n    l2 = wp.dot(diff, diff)\n    wp.atomic_add(loss, 0, l2)\n\n\n@wp.kernel\ndef compute_veloloss_with_grad(\n    mpm_state: MPMStateStruct,\n    gt_pos: wp.array(dtype=wp.vec3),\n    grad: wp.array(dtype=wp.vec3),\n    dt: float,\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    pos = mpm_state.particle_v[tid]\n    pos_gt = gt_pos[tid]\n\n    # l1_diff = wp.abs(pos - pos_gt)\n    # l2 = wp.length(pos - (pos_gt - grad[tid] * dt))\n\n    diff = pos - (pos_gt - grad[tid] * dt)\n    l2 = wp.dot(diff, diff)\n    wp.atomic_add(loss, 0, l2)\n\n\n@wp.kernel\ndef compute_Floss_with_grad(\n    mpm_state: MPMStateStruct,\n    gt_mat: wp.array(dtype=wp.mat33),\n    grad: wp.array(dtype=wp.mat33),\n    dt: float,\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    mat_ = mpm_state.particle_F_trial[tid]\n    mat_gt = gt_mat[tid]\n\n    mat_gt = mat_gt - grad[tid] * dt\n    # l1_diff = wp.abs(pos - pos_gt)\n    mat_diff = mat_ - mat_gt\n\n    l2 = wp.ddot(mat_diff, mat_diff)\n    # l2 = wp.sqrt(\n    #     mat_diff[0, 0] ** 2.0\n    #     + mat_diff[0, 1] ** 2.0\n    #     + mat_diff[0, 2] ** 2.0\n    #     + mat_diff[1, 0] ** 2.0\n    #     + mat_diff[1, 1] ** 2.0\n    #     + mat_diff[1, 2] ** 2.0\n    #     + mat_diff[2, 0] ** 2.0\n    #     + mat_diff[2, 1] ** 2.0\n    #     + mat_diff[2, 2] ** 2.0\n    # )\n\n    wp.atomic_add(loss, 0, l2)\n\n\n@wp.kernel\ndef compute_Closs_with_grad(\n    mpm_state: MPMStateStruct,\n    gt_mat: wp.array(dtype=wp.mat33),\n    grad: wp.array(dtype=wp.mat33),\n    dt: float,\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    mat_ = mpm_state.particle_C[tid]\n    mat_gt = gt_mat[tid]\n\n    mat_gt = mat_gt - grad[tid] * dt\n    # l1_diff = wp.abs(pos - pos_gt)\n\n    mat_diff = mat_ - mat_gt\n    l2 = wp.ddot(mat_diff, mat_diff)\n\n    wp.atomic_add(loss, 0, l2)\n"
  },
  {
    "path": "physdreamer/warp_mpm/warp_utils.py",
    "content": "import warp as wp\nimport ctypes\nfrom typing import Optional\n\nfrom warp.torch import (\n    dtype_from_torch,\n    device_from_torch,\n    dtype_is_compatible,\n    from_torch,\n)\n\n\ndef from_torch_safe(t, dtype=None, requires_grad=None, grad=None):\n    \"\"\"Wrap a PyTorch tensor to a Warp array without copying the data.\n\n    Args:\n        t (torch.Tensor): The torch tensor to wrap.\n        dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.\n        requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.\n\n    Returns:\n        warp.array: The wrapped array.\n    \"\"\"\n    if dtype is None:\n        dtype = dtype_from_torch(t.dtype)\n    elif not dtype_is_compatible(t.dtype, dtype):\n        raise RuntimeError(f\"Incompatible data types: {t.dtype} and {dtype}\")\n\n    # get size of underlying data type to compute strides\n    ctype_size = ctypes.sizeof(dtype._type_)\n\n    shape = tuple(t.shape)\n    strides = tuple(s * ctype_size for s in t.stride())\n\n    # if target is a vector or matrix type\n    # then check if trailing dimensions match\n    # the target type and update the shape\n    if hasattr(dtype, \"_shape_\"):\n        dtype_shape = dtype._shape_\n        dtype_dims = len(dtype._shape_)\n        if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:\n            raise RuntimeError(\n                f\"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}\"\n            )\n\n        # ensure the inner strides are contiguous\n        stride = ctype_size\n        for i in range(dtype_dims):\n            if strides[-i - 1] != stride:\n                raise RuntimeError(\n                    f\"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous\"\n                )\n            stride *= dtype_shape[-i - 1]\n\n        shape = tuple(shape[:-dtype_dims]) or (1,)\n        strides = tuple(strides[:-dtype_dims]) or (ctype_size,)\n\n    requires_grad = t.requires_grad if requires_grad is None else requires_grad\n    if grad is not None:\n        if not isinstance(grad, wp.array):\n            import torch\n\n            if isinstance(grad, torch.Tensor):\n                grad = from_torch(grad, dtype=dtype)\n            else:\n                raise ValueError(f\"Invalid gradient type: {type(grad)}\")\n    elif requires_grad:\n        # wrap the tensor gradient, allocate if necessary\n        if t.grad is None:\n            # allocate a zero-filled gradient tensor if it doesn't exist\n            import torch\n\n            t.grad = torch.zeros_like(t, requires_grad=False)\n        grad = from_torch(t.grad, dtype=dtype)\n\n    a = wp.types.array(\n        ptr=t.data_ptr(),\n        dtype=dtype,\n        shape=shape,\n        strides=strides,\n        device=device_from_torch(t.device),\n        copy=False,\n        owner=False,\n        grad=grad,\n        requires_grad=requires_grad,\n    )\n\n    # save a reference to the source tensor, otherwise it will be deallocated\n    a._tensor = t\n    return a\n\n\nclass MyTape(wp.Tape):\n    # returns the adjoint of a kernel parameter\n    def get_adjoint(self, a):\n        if not wp.types.is_array(a) and not isinstance(a, wp.codegen.StructInstance):\n            # if input is a simple type (e.g.: float, vec3, etc) then\n            # no gradient needed (we only return gradients through arrays and structs)\n            return a\n\n        elif wp.types.is_array(a) and a.grad:\n            # keep track of all gradients used by the tape (for zeroing)\n            # ignore the scalar loss since we don't want to clear its grad\n            self.gradients[a] = a.grad\n            return a.grad\n\n        elif isinstance(a, wp.codegen.StructInstance):\n            adj = a._cls()\n            for name, _ in a._cls.ctype._fields_:\n                if name.startswith(\"_\"):\n                    continue\n                if isinstance(a._cls.vars[name].type, wp.array):\n                    arr = getattr(a, name)\n                    if arr is None:\n                        continue\n                    if arr.grad:\n                        grad = self.gradients[arr] = arr.grad\n                    else:\n                        grad = wp.zeros_like(arr)\n                    setattr(adj, name, grad)\n                else:\n                    setattr(adj, name, getattr(a, name))\n\n            self.gradients[a] = adj\n            return adj\n\n        return None\n\n\n# from https://github.com/PingchuanMa/NCLaw/blob/main/nclaw/warp/tape.py\nclass CondTape(object):\n    def __init__(self, tape: Optional[MyTape], cond: bool = True) -> None:\n        self.tape = tape\n        self.cond = cond\n\n    def __enter__(self):\n        if self.tape is not None and self.cond:\n            self.tape.__enter__()\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        if self.tape is not None and self.cond:\n            self.tape.__exit__(exc_type, exc_value, traceback)"
  },
  {
    "path": "projects/inference/README.md",
    "content": "## How to run\n\n**config file**\n\nThe config files for four scenes: carnation, aloacasia, hat, telephone is in `configs/` folder. Please check the path for `dataset_dir` and `model_list` is correct after you download all the models. \n\n**inference.py** \n\nPlease follow `run.sh` for common args. \n\nIf you encounter OOM error, it's very likely due to the Kmeans downsampling operations. See line ~260 of `inference.py`:\n\n``` python\n# WARNING: this is a GPU implementation, and will be OOM if the number of points is large\n# you might want to use a CPU implementation if the number of points is large\n# For CPU implementation: uncomment the following lines\n# from local_utils import downsample_with_kmeans\n# sim_xyzs = downsample_with_kmeans(sim_xyzs.detach().cpu().numpy(), num_cluster)\n# sim_xyzs = torch.from_numpy(sim_xyzs).float().to(device)\nsim_xyzs = downsample_with_kmeans_gpu(sim_xyzs, num_cluster)\n```\n\n"
  },
  {
    "path": "projects/inference/config_demo.py",
    "content": "import numpy as np\n\n# from model_config import (\n#     model_list,\n#     camera_cfg_list,\n#     points_list,\n#     force_directions,\n#     simulate_cfg,\n#     dataset_dir,\n#     result_dir,\n#     exp_name,\n# )\nimport importlib.util\nimport os\n\n\nclass DemoParams(object):\n    def __init__(self, scene_name):\n\n        self.scene_name = scene_name\n        base_dir = os.path.dirname(__file__)\n\n        # import_file_path = \".configs.\" + scene_name\n        import_file_path = os.path.join(base_dir, \"configs\", scene_name + \".py\")\n        print(\"loading scene params from: \", import_file_path)\n        spec = importlib.util.spec_from_file_location(scene_name, import_file_path)\n        if spec is None:\n            print(f\"Could not load the spec for: {import_file_path}\")\n        module = importlib.util.module_from_spec(spec)\n        spec.loader.exec_module(module)\n\n        self.model_list = module.model_list\n        self.camera_cfg_list = module.camera_cfg_list\n        self.points_list = module.points_list\n        self.force_directions = module.force_directions\n        self.simulate_cfg = module.simulate_cfg\n        self.dataset_dir = module.dataset_dir\n        self.result_dir = module.result_dir\n        self.exp_name = module.exp_name\n\n        substep = self.simulate_cfg[\"substep\"]\n        grid_size = self.simulate_cfg[\"grid_size\"]\n        self.init_youngs = self.simulate_cfg[\"init_young\"]\n        self.downsample_scale = self.simulate_cfg[\"downsample_scale\"]\n\n        self.demo_dict = {\n            \"baseline\": {\n                \"model_path\": self.model_list[0],\n                \"substep\": substep,\n                \"grid_size\": grid_size,\n                \"name\": \"baseline\",\n                \"camera_cfg\": self.camera_cfg_list[0],\n                \"cam_id\": 0,\n                \"init_youngs\": self.init_youngs,\n                \"downsample_scale\": self.downsample_scale,\n            }\n        }\n\n    def get_cfg(\n        self,\n        demo_name=None,\n        model_id: int = 0,\n        eval_ys: float = 1.0,\n        force_id: int = 0,\n        force_mag: float = 1.0,\n        velo_scaling: float = 3.0,\n        point_id: int = 0,\n        cam_id: int = 0,\n        apply_force: bool = False,\n    ):\n        if demo_name == \"None\":\n            demo_name = None\n        if (demo_name is not None) and (demo_name in self.demo_dict):\n            cfg = self.demo_dict[demo_name]\n        else:\n            cfg = {}\n            cfg[\"model_path\"] = self.model_list[model_id]\n            cfg[\"center_point\"] = self.points_list[point_id]\n            cfg[\"force\"] = self.force_directions[force_id] * force_mag\n            cfg[\"camera_cfg\"] = self.camera_cfg_list[cam_id]\n            cfg[\"cam_id\"] = cam_id\n            cfg[\"force_duration\"] = 0.75\n            cfg[\"force_radius\"] = 0.1\n            cfg[\"substep\"] = self.simulate_cfg[\"substep\"]\n            cfg[\"grid_size\"] = self.simulate_cfg[\"grid_size\"]\n            cfg[\"total_time\"] = 5\n            cfg[\"eval_ys\"] = eval_ys\n            cfg[\"velo_scaling\"] = velo_scaling\n\n            if demo_name is None:\n                name = \"\"\n            else:\n                name = demo_name + \"_\"\n            name = (\n                name\n                + f\"{self.scene_name}_sv_gres{cfg['grid_size']}_substep{cfg['substep']}\"\n            )\n            if eval_ys > 10:\n                name += f\"_eval_ys_{eval_ys}\"\n            else:\n                name += f\"_model_{model_id}\"\n\n            if apply_force:\n                name += f\"_force_{force_id}_mag_{force_mag}_point_{point_id}\"\n            else:\n                name += f\"_no_force_velo_{velo_scaling}\"\n            cfg[\"name\"] = name\n\n        cfg[\"dataset_dir\"] = self.dataset_dir\n        cfg[\"result_dir\"] = self.result_dir\n        cfg[\"init_youngs\"] = self.init_youngs\n        cfg[\"downsample_scale\"] = self.downsample_scale\n\n        return cfg\n"
  },
  {
    "path": "projects/inference/configs/alocasia.py",
    "content": "import numpy as np\n\ndataset_dir = \"../../data/physics_dreamer/alocasia/\"\nresult_dir = \"output/alocasia/results\"\nexp_name = \"alocasia\"\n\nmodel_list = [\"../../models/physdreamer/alocasia/model\"]\n\nfocus_point_list = [\n    np.array([-1.242875, -0.468537, -0.251450]),  # botton of the background\n]\n\ncamera_cfg_list = [\n    {\n        \"type\": \"spiral\",\n        \"focus_point\": focus_point_list[0],\n        \"radius\": 0.1,\n        \"up\": np.array([0, 0, 1]),\n    },\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00001.png\",\n        \"end_frame\": \"frame_00019.png\",\n    },\n    # real captured viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00236.png\",\n    },\n    # another viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00006.png\",\n    },\n    # another viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00095.png\",\n    },\n]\n\nsimulate_cfg = {\n    \"substep\": 768,\n    \"grid_size\": 64,\n    \"init_young\": 1e6,\n    \"downsample_scale\": 0.1,  # downsample the points to speed up the simulation\n}\n\n\npoints_list = [\n    np.array([-0.508607, -0.180955, -0.123896]),  # top of the big stem\n    np.array([-0.462227, -0.259485, -0.112966]),  # top of the second stem\n    np.array([-0.728061, -0.092306, -0.149104]),  # top of the third stem\n    np.array([-0.603330 - 0.204207 - 0.127469]),  # top of the 4th stem\n    np.array([-0.408097, -0.076293, -0.110391]),  # top of the big leaf\n    np.array([-0.391575, -0.224018, -0.052054]),  # top of the second leaf\n    np.array([-0.768167, -0.032502, -0.143995]),  # top of the third leaf\n    np.array([-0.633866, -0.170207, -0.103671]),  # top of the 4th leaf\n]\n\nforce_directions = [\n    np.array([1.0, 0.0, 0]),\n    np.array([0.0, 1.0, 0.0]),\n    np.array([1.0, 0.0, 1.0]),\n    np.array([1.0, 1.0, 0.0]),\n    np.array([1.0, 0.0, 1.0]),\n    np.array([0.0, 1.0, 1.0]),\n    np.array([1.0, 1.0, 1.0]),\n]\n\nforce_directions = np.array(force_directions)\nforce_directions = force_directions / np.linalg.norm(force_directions, axis=1)[:, None]\n"
  },
  {
    "path": "projects/inference/configs/carnation.py",
    "content": "import numpy as np\n\ndataset_dir = \"../../data/physics_dreamer/carnations/\"\nresult_dir = \"output/carnations/demos\"\nexp_name = \"carnations\"\n\n\nmodel_list = [\n    \"../../models/physdreamer/carnations/model\",\n]\n\nfocus_point_list = [\n    np.array([0.189558, 2.064228, -0.216089]),  # botton of the background\n]\n\ncamera_cfg_list = [\n    {\n        \"type\": \"spiral\",\n        \"focus_point\": focus_point_list[0],\n        \"radius\": 0.05,\n        \"up\": np.array([0, -0.5, 1]),\n    },\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00001.png\",\n        \"end_frame\": \"frame_00022.png\",\n    },\n    # real capture viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00219.png\",\n    },\n    # another render viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00106.png\",\n    },\n    # another render viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00011.png\",\n    },\n]\n\nsimulate_cfg = {\n    \"substep\": 768,\n    \"grid_size\": 64,\n    \"init_young\": 2140628.25,  # save the initialized young's modulus, since optimized\n    \"downsample_scale\": 0.1,  # downsample the points to speed up the simulation\n}\n\n\npoints_list = [\n    np.array([0.076272, 0.848310, 0.074134]),  # top of the flower\n    np.array([0.057208, 0.848147, -0.013685]),  # middle of the flower\n    np.array([0.134908, 0.912759, -0.023763]),  # top of the stem\n    np.array([0.169540, 0.968676, -0.095261]),  # middle of the stem\n    np.array([0.186664, 1.028284, -0.187793]),  # bottom of the stem\n]\n\nforce_directions = [\n    np.array([1.0, 0.0, 0]),\n    np.array([0.0, 1.0, 0.0]),\n    np.array([1.0, 0.0, 1.0]),\n    np.array([1.0, 1.0, 0.0]),\n    np.array([1.0, 0.0, 1.0]),\n    np.array([0.0, 1.0, 1.0]),\n    np.array([1.0, 1.0, 1.0]),\n]\n\nforce_directions = np.array(force_directions)\nforce_directions = force_directions / np.linalg.norm(force_directions, axis=1)[:, None]\n\nforce_directions_old_carnations = [\n    np.array([2.0, 1.0, 0]),  # horizontal to left\n    np.array([0.0, 1.0, 2.0]),  # vertical to top\n    np.array([1.0, 1.0, 1.0]),  # top right to bottom left\n    np.array([0.0, 1.0, 0.0]),  # orthgonal to the screen,\n]\n"
  },
  {
    "path": "projects/inference/configs/hat.py",
    "content": "import numpy as np\n\ndataset_dir = \"../../data/physics_dreamer/hat/\"\nresult_dir = \"output/hat/demo\"\nexp_name = \"hat\"\n\nmodel_list = [\n    \"../../models/physdreamer/hat/model/\",\n]\n\nfocus_point_list = [\n    np.array([-0.467188, 0.067178, 0.044333]),\n]\n\ncamera_cfg_list = [\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00001.png\",\n        \"end_frame\": \"frame_00187.png\",  # or 91\n    },\n    # real captured viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00217.png\",\n    },\n    # other selected viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00001.png\",\n    },\n    # other selected viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00001.png\",\n    },\n    # other selected viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00079.png\",\n    },\n]\n\nsimulate_cfg = {\n    \"substep\": 384,\n    \"grid_size\": 64,\n    \"init_young\": 1e5,\n    \"downsample_scale\": 0.04,\n}\n\n\npoints_list = [\n    np.array([-0.390069, 0.139051, -0.182607]),  # bottom of the hat\n    np.array([-0.404391, 0.184975, -0.001585]),  # middle of the hat\n    np.array([-0.289375, 0.034581, 0.062010]),  # left of the hat\n    np.array([-0.352060, 0.105737, 0.009359]),  # center of the hat\n]\n\nforce_directions = [\n    np.array([1.0, 0.0, 0]),\n    np.array([0.0, 1.0, 0.0]),\n    np.array([1.0, 0.0, 1.0]),\n    np.array([1.0, 1.0, 0.0]),\n    np.array([1.0, 0.0, 1.0]),\n    np.array([0.0, 1.0, 1.0]),\n    np.array([1.0, 1.0, 1.0]),\n]\n\nforce_directions = np.array(force_directions)\nforce_directions = force_directions / np.linalg.norm(force_directions, axis=1)[:, None]\n"
  },
  {
    "path": "projects/inference/configs/telephone.py",
    "content": "import numpy as np\n\nexp_name = \"telephone\"\ndataset_dir = \"../../data/physics_dreamer/telephone/\"\nresult_dir = \"output/telephone/results\"\n\nmodel_list = [\"../../models/physdreamer/telephone/model\"]\n\nfocus_point_list = [\n    np.array([-0.401468, 0.889287, -0.116852]),  # botton of the background\n]\n\ncamera_cfg_list = [\n    {\n        \"type\": \"spiral\",\n        \"focus_point\": focus_point_list[0],\n        \"radius\": 0.1,\n        \"up\": np.array([0, 0, 1]),\n    },\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00001.png\",\n        \"end_frame\": \"frame_00019.png\",\n    },\n    # real video viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00190.png\",\n    },\n    # other selected viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00037.png\",\n    },\n    # other selected viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00090.png\",\n    },\n]\n\nsimulate_cfg = {\n    \"substep\": 256,\n    \"grid_size\": 96,\n    \"init_young\": 1e5,\n    \"downsample_scale\": 0.1,  # downsample the points to speed up the simulation\n}\n\n\npoints_list = [\n    np.array([-0.417240, 0.907780, -0.379144]),  # bottom of the lines.\n    np.array([-0.374907, 0.796209, -0.178907]),  # middle of the right lines\n    np.array([-0.414156, 0.901207, -0.182275]),  # middle of the left lines\n]\n\nforce_directions = [\n    np.array([1.0, 0.0, 0]),\n    np.array([0.0, 1.0, 0.0]),\n    np.array([1.0, 0.0, 1.0]),\n    np.array([1.0, 1.0, 0.0]),\n    np.array([1.0, 0.0, 1.0]),\n    np.array([0.0, 1.0, 1.0]),\n    np.array([1.0, 1.0, 1.0]),\n]\n\nforce_directions = np.array(force_directions)\nforce_directions = force_directions / np.linalg.norm(force_directions, axis=1)[:, None]\n"
  },
  {
    "path": "projects/inference/demo.py",
    "content": "import argparse\nimport os\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\nimport point_cloud_utils as pcu\nfrom accelerate.utils import ProjectConfiguration\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import set_seed\nfrom accelerate import Accelerator, DistributedDataParallelKwargs\nimport numpy as np\nimport logging\nimport argparse\nimport torch\nimport os\nfrom physdreamer.utils.config import create_config\nimport numpy as np\n\nfrom physdreamer.gaussian_3d.scene import GaussianModel\n\nfrom physdreamer.data.datasets.multiview_dataset import MultiviewImageDataset\nfrom physdreamer.data.datasets.multiview_video_dataset import (\n    MultiviewVideoDataset,\n    camera_dataset_collate_fn,\n)\n\nfrom physdreamer.data.datasets.multiview_dataset import (\n    camera_dataset_collate_fn as camera_dataset_collate_fn_img,\n)\n\nfrom typing import NamedTuple\n\nfrom physdreamer.utils.img_utils import compute_psnr, compute_ssim\nfrom physdreamer.warp_mpm.mpm_data_structure import (\n    MPMStateStruct,\n    MPMModelStruct,\n)\nfrom physdreamer.warp_mpm.mpm_solver_diff import MPMWARPDiff\nfrom physdreamer.warp_mpm.gaussian_sim_utils import get_volume\nimport warp as wp\n\nfrom local_utils import (\n    cycle,\n    create_spatial_fields,\n    find_far_points,\n    apply_grid_bc_w_freeze_pts,\n    add_constant_force,\n    downsample_with_kmeans_gpu,\n    render_gaussian_seq_w_mask_with_disp,\n    render_gaussian_seq_w_mask_cam_seq_with_force_with_disp,\n    get_camera_trajectory,\n    render_gaussian_seq_w_mask_with_disp_for_figure,\n)\nfrom config_demo import DemoParams\nfrom physdreamer.utils.io_utils import save_video_mediapy\n\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef create_dataset(args):\n\n    res = [576, 1024]\n    video_dir_name = \"videos\"\n\n    # dataset = MultiviewVideoDataset(\n    #     args.dataset_dir,\n    #     use_white_background=False,\n    #     resolution=res,\n    #     scale_x_angle=1.0,\n    #     video_dir_name=video_dir_name,\n    # )\n    dataset = MultiviewImageDataset(\n        args.dataset_dir,\n        use_white_background=False,\n        resolution=res,\n        scale_x_angle=1.0,\n        load_imgs=False,\n    )\n\n    test_dataset = MultiviewImageDataset(\n        args.dataset_dir,\n        use_white_background=False,\n        resolution=res,\n        # use_index=[0],\n        scale_x_angle=1.0,\n        fitler_with_renderd=False,\n        load_imgs=False,\n    )\n    print(\"len of test dataset\", len(test_dataset))\n    return dataset, test_dataset\n\n\nclass Trainer:\n    def __init__(self, args):\n        self.args = args\n\n        logging_dir = os.path.join(args.output_dir, \"debug_demo\")\n        accelerator_project_config = ProjectConfiguration(logging_dir=logging_dir)\n        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n        accelerator = Accelerator(\n            mixed_precision=\"no\",\n            log_with=\"wandb\",\n            project_config=accelerator_project_config,\n            kwargs_handlers=[ddp_kwargs],\n        )\n        logging.basicConfig(\n            format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n            datefmt=\"%m/%d/%Y %H:%M:%S\",\n            level=logging.INFO,\n        )\n        logger.info(accelerator.state, main_process_only=False)\n\n        set_seed(args.seed + accelerator.process_index)\n\n        demo_cfg = DemoParams(args.scene_name).get_cfg(\n            args.demo_name,\n            args.model_id,\n            args.eval_ys,\n            args.force_id,\n            args.force_mag,\n            args.velo_scaling,\n            args.point_id,\n            args.cam_id,\n            args.apply_force,\n        )\n        self.args.dataset_dir = demo_cfg[\"dataset_dir\"]\n        self.demo_cfg = demo_cfg\n\n        # setup the dataset\n        dataset, test_dataset = create_dataset(args)\n        # will be used when synthesize camera trajectory\n        self.test_dataset = test_dataset\n        self.dataset = dataset\n        dataset_dir = test_dataset.data_dir\n\n        gaussian_path = os.path.join(dataset_dir, \"point_cloud.ply\")\n        self.setup_render(\n            args,\n            gaussian_path,\n            white_background=True,\n        )\n        self.args.substep = demo_cfg[\"substep\"]\n        self.args.grid_size = demo_cfg[\"grid_size\"]\n        self.args.checkpoint_path = demo_cfg[\"model_path\"]\n        self.demo_cfg = demo_cfg\n\n        self.num_frames = int(args.num_frames)\n\n        dataloader = torch.utils.data.DataLoader(\n            dataset,\n            batch_size=1,\n            shuffle=False,\n            drop_last=False,\n            num_workers=0,\n            # collate_fn=camera_dataset_collate_fn,\n            collate_fn=camera_dataset_collate_fn_img,\n        )\n        dataloader = accelerator.prepare(dataloader)\n        # why be used in self.compute_metric\n        self.dataloader = cycle(dataloader)\n        self.accelerator = accelerator\n\n        # init traiable params\n        E_nu_list = self.init_trainable_params()\n        for p in E_nu_list:\n            p.requires_grad = False\n        self.E_nu_list = E_nu_list\n\n        # init simulation enviroment\n        self.setup_simulation(dataset_dir, grid_size=args.grid_size)\n\n        if args.checkpoint_path == \"None\":\n            args.checkpoint_path = None\n        if args.checkpoint_path is not None:\n            self.load(args.checkpoint_path)\n        self.sim_fields, self.velo_fields = accelerator.prepare(\n            self.sim_fields, self.velo_fields\n        )\n\n    def init_trainable_params(\n        self,\n    ):\n        # init young modulus and poisson ratio\n        # from pre-optimized;  gres32 step 128.  300 epoch. lr 10.0  psnr: 27.72028086735652.  Stop at 100epoch\n        young_numpy = np.array([self.demo_cfg[\"init_youngs\"]]).astype(np.float32)\n        young_modulus = torch.tensor(young_numpy, dtype=torch.float32).to(\n            self.accelerator.device\n        )\n        poisson_numpy = np.random.uniform(0.1, 0.4)\n        poisson_ratio = torch.tensor(poisson_numpy, dtype=torch.float32).to(\n            self.accelerator.device\n        )\n        trainable_params = [young_modulus, poisson_ratio]\n        print(\n            \"init young modulus: \",\n            young_modulus.item(),\n            \"poisson ratio: \",\n            poisson_ratio.item(),\n        )\n        return trainable_params\n\n    def setup_simulation(self, dataset_dir, grid_size=100):\n        \"\"\"\n        1. load internal filled points.\n        2. pointcloud downsample with KMeans\n        3. Setup MPM simulation environment\n        \"\"\"\n\n        device = \"cuda:{}\".format(self.accelerator.process_index)\n\n        xyzs = self.render_params.gaussians.get_xyz.detach().clone()\n        sim_xyzs = xyzs[self.sim_mask_in_raw_gaussian, :]\n\n        # scale, and shift\n        pos_max = sim_xyzs.max()\n        pos_min = sim_xyzs.min()\n        scale = (pos_max - pos_min) * 1.8\n        shift = -pos_min + (pos_max - pos_min) * 0.25\n        self.scale, self.shift = scale, shift\n        print(\"scale, shift\", scale, shift)\n\n        # load internal filled points.\n        #   if exists, we will use it to fill in the internal points, but not for rendering\n        #   we keep track of render_mask_in_sim_pts, to distinguish the orignal points from the internal filled points\n        filled_in_points_path = os.path.join(dataset_dir, \"internal_filled_points.ply\")\n        if os.path.exists(filled_in_points_path):\n            fill_xyzs = pcu.load_mesh_v(filled_in_points_path)  # [n, 3]\n            fill_xyzs = fill_xyzs[\n                np.random.choice(\n                    fill_xyzs.shape[0], int(fill_xyzs.shape[0] * 1.0), replace=False\n                )\n            ]\n            fill_xyzs = torch.from_numpy(fill_xyzs).float().to(\"cuda\")\n            self.fill_xyzs = fill_xyzs\n            print(\n                \"loaded {} internal filled points from: \".format(fill_xyzs.shape[0]),\n                filled_in_points_path,\n            )\n            render_mask_in_sim_pts = torch.cat(\n                [\n                    torch.ones_like(sim_xyzs[:, 0]).bool(),\n                    torch.zeros_like(fill_xyzs[:, 0]).bool(),\n                ],\n                dim=0,\n            ).to(device)\n            sim_xyzs = torch.cat([sim_xyzs, fill_xyzs], dim=0)\n            self.render_mask = render_mask_in_sim_pts\n        else:\n            self.fill_xyzs = None\n            self.render_mask = torch.ones_like(sim_xyzs[:, 0]).bool().to(device)\n\n        sim_xyzs = (sim_xyzs + shift) / scale\n        sim_aabb = torch.stack(\n            [torch.min(sim_xyzs, dim=0)[0], torch.max(sim_xyzs, dim=0)[0]], dim=0\n        )\n        # This AABB is used to constraint the material fields and velocity fields.\n        sim_aabb = (\n            sim_aabb - torch.mean(sim_aabb, dim=0, keepdim=True)\n        ) * 1.2 + torch.mean(sim_aabb, dim=0, keepdim=True)\n\n        print(\"simulation aabb: \", sim_aabb)\n\n        # point cloud resample with kmeans\n        if \"downsample_scale\" in self.demo_cfg:\n            downsample_scale = self.demo_cfg[\"downsample_scale\"]\n        else:\n            downsample_scale = args.downsample_scale\n        if downsample_scale > 0 and downsample_scale < 1.0:\n            print(\"Downsample with ratio: \", downsample_scale)\n            num_cluster = int(sim_xyzs.shape[0] * downsample_scale)\n\n            # WARNING: this is a GPU implementation, and will be OOM if the number of points is large\n            # you might want to use a CPU implementation if the number of points is large\n            # For CPU implementation: uncomment the following lines\n            # from local_utils import downsample_with_kmeans\n            # sim_xyzs = downsample_with_kmeans(sim_xyzs.detach().cpu().numpy(), num_cluster)\n            # sim_xyzs = torch.from_numpy(sim_xyzs).float().to(device)\n\n            sim_xyzs = downsample_with_kmeans_gpu(sim_xyzs, num_cluster)\n\n            sim_gaussian_pos = self.render_params.gaussians.get_xyz.detach().clone()[\n                self.sim_mask_in_raw_gaussian, :\n            ]\n            sim_gaussian_pos = (sim_gaussian_pos + shift) / scale\n\n            # record top k index for each point, to interpolate positions and rotations later\n            cdist = torch.cdist(sim_gaussian_pos, sim_xyzs) * -1.0\n            _, top_k_index = torch.topk(cdist, self.args.top_k, dim=-1)\n            self.top_k_index = top_k_index\n\n            print(\"Downsampled to: \", sim_xyzs.shape[0], \"by\", downsample_scale)\n\n        # Compute the volume of each particle\n        points_volume = get_volume(sim_xyzs.detach().cpu().numpy())\n\n        num_particles = sim_xyzs.shape[0]\n        sim_aabb = torch.stack(\n            [torch.min(sim_xyzs, dim=0)[0], torch.max(sim_xyzs, dim=0)[0]], dim=0\n        )\n        sim_aabb = (\n            sim_aabb - torch.mean(sim_aabb, dim=0, keepdim=True)\n        ) * 1.2 + torch.mean(sim_aabb, dim=0, keepdim=True)\n\n        # Initialize MPM state and model\n        wp.init()\n        wp.config.mode = \"debug\"\n        wp.config.verify_cuda = True\n\n        mpm_state = MPMStateStruct()\n        mpm_state.init(num_particles, device=device, requires_grad=False)\n\n        self.particle_init_position = sim_xyzs.clone()\n\n        mpm_state.from_torch(\n            self.particle_init_position.clone(),\n            torch.from_numpy(points_volume).float().to(device).clone(),\n            None,\n            device=device,\n            requires_grad=False,\n            n_grid=grid_size,\n            grid_lim=1.0,\n        )\n        mpm_model = MPMModelStruct()\n        mpm_model.init(num_particles, device=device, requires_grad=False)\n        # grid from [0.0 - 1.0]\n        mpm_model.init_other_params(n_grid=grid_size, grid_lim=1.0, device=device)\n\n        material_params = {\n            # select from jel\n            \"material\": \"jelly\",  # \"jelly\", \"metal\", \"sand\", \"foam\", \"snow\", \"plasticine\", \"neo-hookean\"\n            \"g\": [0.0, 0.0, 0.0],\n            \"density\": 2000,  # kg / m^3\n            \"grid_v_damping_scale\": 1.1,  # no damping if > 1.0\n        }\n        self.material_name = material_params[\"material\"]\n        mpm_solver = MPMWARPDiff(\n            num_particles, n_grid=grid_size, grid_lim=1.0, device=device\n        )\n        mpm_solver.set_parameters_dict(mpm_model, mpm_state, material_params)\n\n        self.mpm_state, self.mpm_model, self.mpm_solver = (\n            mpm_state,\n            mpm_model,\n            mpm_solver,\n        )\n\n        # setup boundary condition:\n        moving_pts_path = os.path.join(dataset_dir, \"moving_part_points.ply\")\n        assert os.path.exists(\n            moving_pts_path\n        ), \"We need to segment out the moving part to initialize the boundary condition\"\n\n        moving_pts = pcu.load_mesh_v(moving_pts_path)\n        moving_pts = torch.from_numpy(moving_pts).float().to(device)\n        moving_pts = (moving_pts + shift) / scale\n        freeze_mask = find_far_points(\n            sim_xyzs, moving_pts, thres=0.5 / grid_size\n        ).bool()\n        freeze_pts = sim_xyzs[freeze_mask, :]\n\n        grid_freeze_mask = apply_grid_bc_w_freeze_pts(\n            grid_size, 1.0, freeze_pts, mpm_solver\n        )\n        self.freeze_mask = freeze_mask\n\n        num_freeze_pts = self.freeze_mask.sum()\n        print(\n            \"num freeze pts in total\",\n            num_freeze_pts.item(),\n            \"num moving pts\",\n            num_particles - num_freeze_pts.item(),\n        )\n\n        # init fields for simulation, e.g. density, external force, etc.\n        # padd init density, youngs,\n        density = (\n            torch.ones_like(self.particle_init_position[..., 0])\n            * material_params[\"density\"]\n        )\n        youngs_modulus = (\n            torch.ones_like(self.particle_init_position[..., 0])\n            * self.E_nu_list[0].detach()\n        )\n        poisson_ratio = torch.ones_like(self.particle_init_position[..., 0]) * 0.3\n        self.density = density\n        self.young_modulus = youngs_modulus\n        self.poisson_ratio = poisson_ratio\n\n        # set density, youngs, poisson\n        mpm_state.reset_density(\n            density.clone(),\n            torch.ones_like(density).type(torch.int),\n            device,\n            update_mass=True,\n        )\n        mpm_solver.set_E_nu_from_torch(\n            mpm_model, youngs_modulus.clone(), poisson_ratio.clone(), device\n        )\n        mpm_solver.prepare_mu_lam(mpm_model, mpm_state, device)\n\n        self.sim_fields = create_spatial_fields(self.args, 1, sim_aabb)\n        self.sim_fields.train()\n\n        self.args.sim_res = 24\n        # self.velo_fields = create_velocity_model(self.args, sim_aabb)\n        self.velo_fields = create_spatial_fields(\n            self.args, 3, sim_aabb, add_entropy=False\n        )\n        self.velo_fields.train()\n\n    def add_constant_force(self, center_point, radius, force, dt, start_time, end_time):\n        xyzs = self.particle_init_position.clone() * self.scale - self.shift\n\n        device = \"cuda:{}\".format(self.accelerator.process_index)\n        add_constant_force(\n            self.mpm_solver,\n            self.mpm_state,\n            xyzs,\n            center_point,\n            radius,\n            force,\n            dt,\n            start_time,\n            end_time,\n            device=device,\n        )\n\n    def get_simulation_input(self, device):\n        \"\"\"\n        Outs: All padded\n            density: [N]\n            young_modulus: [N]\n            poisson_ratio: [N]\n            velocity: [N, 3]\n            query_mask: [N]\n            particle_F: [N, 3, 3]\n            particle_C: [N, 3, 3]\n        \"\"\"\n\n        density, youngs_modulus, ret_poisson = self.get_material_params(device)\n        initial_position_time0 = self.particle_init_position.clone()\n\n        query_mask = torch.logical_not(self.freeze_mask)\n        query_pts = initial_position_time0[query_mask, :]\n\n        velocity = self.velo_fields(query_pts)[..., :3]\n\n        # scaling lr is similar to scaling the learning rate of velocity fields.\n        velocity = velocity * 0.1  # not padded yet\n        ret_velocity = torch.zeros_like(initial_position_time0)\n        ret_velocity[query_mask, :] = velocity\n\n        # init F as Idensity Matrix, and C and Zero Matrix\n        I_mat = torch.eye(3, dtype=torch.float32).to(device)\n        particle_F = torch.repeat_interleave(\n            I_mat[None, ...], initial_position_time0.shape[0], dim=0\n        )\n        particle_C = torch.zeros_like(particle_F)\n\n        return (\n            density,\n            youngs_modulus,\n            ret_poisson,\n            ret_velocity,\n            query_mask,\n            particle_F,\n            particle_C,\n        )\n\n    def get_material_params(self, device):\n        \"\"\"\n        Outs:\n            density: [N]\n            young_modulus: [N]\n            poisson_ratio: [N]\n        \"\"\"\n\n        initial_position_time0 = self.particle_init_position.detach()\n\n        # query the materials params of all particles\n        query_pts = initial_position_time0\n\n        sim_params = self.sim_fields(query_pts)\n\n        # scale the output of the network, similar to scale the learning rate\n        sim_params = sim_params * 1000\n        youngs_modulus = self.young_modulus.detach().clone()\n        youngs_modulus += sim_params[..., 0]\n\n        # clamp youngs modulus\n        youngs_modulus = torch.clamp(youngs_modulus, 1.0, 5e8)\n\n        density = self.density.detach().clone()\n        ret_poisson = self.poisson_ratio.detach().clone()\n\n        return density, youngs_modulus, ret_poisson\n\n    def load(self, checkpoint_dir):\n        name_list = [\n            \"velo_fields\",\n            \"sim_fields\",\n        ]\n        for i, model in enumerate([self.velo_fields, self.sim_fields]):\n            model_name = name_list[i]\n            model_path = os.path.join(checkpoint_dir, model_name + \".pt\")\n            if os.path.exists(model_path):\n                print(\"=> loading: \", model_path)\n                model.load_state_dict(torch.load(model_path))\n            else:\n                print(\"=> not found: \", model_path)\n\n    def setup_render(self, args, gaussian_path, white_background=True):\n        \"\"\"\n        1. Load 3D Gaussians in gaussian_path\n        2. Prepare rendering params in self.render_params\n        3. Load foreground points stored in the same directory as gaussian_path, with name \"clean_object_points.ply\"\n               Only foreground points is used for simulation.\n               We will track foreground points with mask: self.sim_mask_in_raw_gaussian\n        \"\"\"\n\n        # setup gaussians\n        class RenderPipe(NamedTuple):\n            convert_SHs_python = False\n            compute_cov3D_python = False\n            debug = False\n\n        class RenderParams(NamedTuple):\n            render_pipe: RenderPipe\n            bg_color: bool\n            gaussians: GaussianModel\n            camera_list: list\n\n        gaussians = GaussianModel(3)\n        camera_list = self.dataset.test_camera_list\n\n        gaussians.load_ply(gaussian_path)\n        gaussians.detach_grad()\n        print(\n            \"load gaussians from: {}\".format(gaussian_path),\n            \"... num gaussians: \",\n            gaussians._xyz.shape[0],\n        )\n        bg_color = [1, 1, 1] if white_background else [0, 0, 0]\n        background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n        render_pipe = RenderPipe()\n\n        render_params = RenderParams(\n            render_pipe=render_pipe,\n            bg_color=background,\n            gaussians=gaussians,\n            camera_list=camera_list,\n        )\n        self.render_params = render_params\n\n        # segment foreground objects. Foreground points is stored in \"clean_object_points.ply\",\n        #    only foreground points is used for simulation\n        #    we will track foreground points with mask: self.sim_mask_in_raw_gaussian\n        gaussian_dir = os.path.dirname(gaussian_path)\n\n        clean_points_path = os.path.join(gaussian_dir, \"clean_object_points.ply\")\n\n        assert os.path.exists(\n            clean_points_path\n        ), \"We need to segment out the forground points to initialize the simulation\"\n\n        clean_xyzs = pcu.load_mesh_v(clean_points_path)\n        clean_xyzs = torch.from_numpy(clean_xyzs).float().to(\"cuda\")\n        self.clean_xyzs = clean_xyzs\n        print(\n            \"loaded {} clean points from: \".format(clean_xyzs.shape[0]),\n            clean_points_path,\n        )\n        not_sim_maks = find_far_points(gaussians._xyz, clean_xyzs, thres=0.01).bool()\n        sim_mask_in_raw_gaussian = torch.logical_not(not_sim_maks)\n        # [N]\n        self.sim_mask_in_raw_gaussian = sim_mask_in_raw_gaussian\n\n    @torch.no_grad()\n    def demo(\n        self,\n        velo_scaling=5.0,\n        num_sec=3.0,\n        eval_ys=1.0,\n        static_camera=False,\n        apply_force=False,\n        save_name=\"demo\",\n    ):\n\n        result_dir = self.demo_cfg[\"result_dir\"]\n        if \"eval_ys\" in self.demo_cfg:\n            eval_ys = self.demo_cfg[\"eval_ys\"]\n        if \"velo_scaling\" in self.demo_cfg:\n            velo_scaling = self.demo_cfg[\"velo_scaling\"]\n\n        save_name = self.demo_cfg[\"name\"]\n\n        if save_name.startswith(\"baseline\"):\n            self.compute_metric(save_name, result_dir)\n            return\n\n        # avoid re-run for experiment with the same name\n        os.makedirs(result_dir, exist_ok=True)\n        pos_path = os.path.join(result_dir, save_name + \"_pos.npy\")\n        if os.path.exists(pos_path):\n            pos_array = np.load(pos_path)\n        else:\n            pos_array = None\n\n        device = \"cuda:0\"\n        data = next(self.dataloader)\n        cam = data[\"cam\"][0]\n\n        substep = self.args.substep  # 1e-4\n\n        youngs_modulus = None\n\n        self.sim_fields.eval()\n        self.velo_fields.eval()\n\n        (\n            density,\n            youngs_modulus_,\n            poisson,\n            init_velocity,\n            query_mask,\n            particle_F,\n            particle_C,\n        ) = self.get_simulation_input(device)\n\n        poisson = self.E_nu_list[1].detach().clone()  # override poisson\n\n        if eval_ys < 10:\n            youngs_modulus = youngs_modulus_\n        else:\n            # assign eval_ys to all particles\n            youngs_modulus = torch.ones_like(youngs_modulus_) * eval_ys\n\n        # step-1 Setup simulation parameters. External force, or initial velocity.\n        #   if --apply_force, we will apply a constant force to points close to the force center\n        #   otherwise, we will load the initial velocity from pretrained models, and scale it by velo_scaling.\n\n        delta_time = 1.0 / 30  # 30 fps\n        substep_size = delta_time / substep\n        num_substeps = int(substep)\n\n        init_xyzs = self.particle_init_position.clone()\n\n        init_velocity[query_mask, :] = init_velocity[query_mask, :] * velo_scaling\n        if apply_force:\n            init_velocity = torch.zeros_like(init_velocity)\n\n            center_point = (\n                torch.from_numpy(self.demo_cfg[\"center_point\"]).to(device).float()\n            )\n            force = torch.from_numpy(self.demo_cfg[\"force\"]).to(device).float()\n\n            force_duration = self.demo_cfg[\"force_duration\"]  # sec\n            force_duration_steps = int(force_duration / delta_time)\n\n            # apply force to points within the radius of the center point\n            force_radius = self.demo_cfg[\"force_radius\"]\n\n            self.add_constant_force(\n                center_point, force_radius, force, delta_time, 0.0, force_duration\n            )\n\n            # prepare to render force in simulated videos:\n            #   find the closest point to the force center, and will use it to render the force\n            xyzs = self.render_params.gaussians.get_xyz.detach().clone()\n            dist = torch.norm(xyzs - center_point.unsqueeze(dim=0), dim=-1)\n            closest_idx = torch.argmin(dist)\n            closest_xyz = xyzs[closest_idx, :]\n            render_force = force / force.norm() * 0.1\n            do_render_force = True\n        else:\n            do_render_force = False\n\n        # step-3: simulation or load the simulated sequence computed before\n        #   with the same scene_name and demo_name\n        if pos_array is None or save_name == \"debug\":\n            self.mpm_state.reset_density(\n                density.clone(), query_mask, device, update_mass=True\n            )\n            self.mpm_solver.set_E_nu_from_torch(\n                self.mpm_model, youngs_modulus.clone(), poisson.clone(), device\n            )\n            self.mpm_solver.prepare_mu_lam(self.mpm_model, self.mpm_state, device)\n\n            self.mpm_state.continue_from_torch(\n                init_xyzs,\n                init_velocity,\n                particle_F,\n                particle_C,\n                device=device,\n                requires_grad=False,\n            )\n\n            # record drive points sequence\n            render_pos_list = [(init_xyzs.clone() * self.scale) - self.shift]\n            prev_state = self.mpm_state\n            for i in tqdm(range(int((30) * num_sec))):\n                # iterate over substeps for each frame\n                for substep_local in range(num_substeps):\n                    next_state = prev_state.partial_clone(requires_grad=False)\n                    self.mpm_solver.p2g2p_differentiable(\n                        self.mpm_model,\n                        prev_state,\n                        next_state,\n                        substep_size,\n                        device=device,\n                    )\n                    prev_state = next_state\n\n                pos = wp.to_torch(next_state.particle_x).clone()\n                # undo scaling and shifting\n                pos = (pos * self.scale) - self.shift\n                render_pos_list.append(pos)\n\n            # save the sequence of drive points\n            numpy_pos = torch.stack(render_pos_list, dim=0).detach().cpu().numpy()\n            np.save(pos_path, numpy_pos)\n        else:\n            render_pos_list = []\n            for i in range(pos_array.shape[0]):\n                pos = pos_array[i, ...]\n                render_pos_list.append(torch.from_numpy(pos).to(device))\n\n        num_pos = len(render_pos_list)\n        init_pos = render_pos_list[0].clone()\n        pos_diff_list = [_ - init_pos for _ in render_pos_list]\n\n        if not static_camera:\n            interpolated_cameras = get_camera_trajectory(\n                cam, num_pos, self.demo_cfg[\"camera_cfg\"], self.test_dataset\n            )\n        else:\n            interpolated_cameras = [cam] * num_pos\n\n        if not do_render_force:\n            video_array, moving_part_video = (\n                render_gaussian_seq_w_mask_with_disp_for_figure(\n                    interpolated_cameras,\n                    self.render_params,\n                    init_pos,\n                    self.top_k_index,\n                    pos_diff_list,\n                    self.sim_mask_in_raw_gaussian,\n                )\n            )\n            video_numpy = video_array.detach().cpu().numpy() * 255\n            video_numpy = np.clip(video_numpy, 0, 255).astype(np.uint8)\n            video_numpy = np.transpose(video_numpy, [0, 2, 3, 1])\n\n            moving_part_video = moving_part_video.detach().cpu().numpy() * 255\n            moving_part_video = np.clip(moving_part_video, 0, 255).astype(np.uint8)\n            moving_part_video = np.transpose(moving_part_video, [0, 2, 3, 1])\n        else:\n            video_numpy = render_gaussian_seq_w_mask_cam_seq_with_force_with_disp(\n                interpolated_cameras,\n                self.render_params,\n                init_pos,\n                self.top_k_index,\n                pos_diff_list,\n                self.sim_mask_in_raw_gaussian,\n                closest_idx,\n                render_force,\n                force_duration_steps,\n            )\n            video_numpy = np.transpose(video_numpy, [0, 2, 3, 1])\n\n        if not static_camera:\n            save_name = (\n                save_name\n                + \"_movingcamera\"\n                + \"_camid_{}\".format(self.demo_cfg[\"cam_id\"])\n            )\n\n        save_name = save_name + \"_\" + self.demo_cfg[\"name\"]\n        save_path = os.path.join(result_dir, save_name + \".mp4\")\n\n        print(\"save video to \", save_path)\n        save_video_mediapy(video_numpy, save_path, fps=30)\n\n        # save_path = save_path.replace(\".mp4\", \"_moving_part.mp4\")\n        # save_video_mediapy(moving_part_video, save_path, fps=30)\n\n    def compute_metric(self, exp_name, result_dir):\n\n        data = next(self.dataloader)\n        cam = data[\"cam\"][0]\n\n        # step-2 simulation part\n        substep = self.args.substep  # 1e-4\n        self.sim_fields.eval()\n        self.velo_fields.eval()\n        device = \"cuda:{}\".format(self.accelerator.process_index)\n\n        (\n            density,\n            youngs_modulus,\n            poisson,\n            init_velocity,\n            query_mask,\n            particle_F,\n            particle_C,\n        ) = self.get_simulation_input(device)\n\n        poisson = self.E_nu_list[1].detach().clone()  # override poisson\n        # delta_time = 1.0 / (self.num_frames - 1)\n        delta_time = 1.0 / 30  # 30 fps\n        substep_size = delta_time / substep\n        num_substeps = int(delta_time / substep_size)\n\n        init_xyzs = self.particle_init_position.clone()\n        init_velocity[query_mask, :] = init_velocity[query_mask, :]\n\n        self.mpm_state.reset_density(\n            density.clone(), query_mask, device, update_mass=True\n        )\n        self.mpm_solver.set_E_nu_from_torch(\n            self.mpm_model, youngs_modulus.clone(), poisson.clone(), device\n        )\n        self.mpm_solver.prepare_mu_lam(self.mpm_model, self.mpm_state, device)\n\n        self.mpm_state.continue_from_torch(\n            init_xyzs,\n            init_velocity,\n            particle_F,\n            particle_C,\n            device=device,\n            requires_grad=False,\n        )\n\n        pos_list = [(init_xyzs.clone() * self.scale) - self.shift]\n\n        prev_state = self.mpm_state\n        for i in tqdm(range(self.args.num_frames - 1)):\n            for substep_local in range(num_substeps):\n                next_state = prev_state.partial_clone(requires_grad=False)\n                self.mpm_solver.p2g2p_differentiable(\n                    self.mpm_model,\n                    prev_state,\n                    next_state,\n                    substep_size,\n                    device=device,\n                )\n                prev_state = next_state\n\n            pos = wp.to_torch(next_state.particle_x).clone()\n\n            # pos = self.mpm_solver.export_particle_x_to_torch().clone()\n            pos = (pos * self.scale) - self.shift\n            pos_list.append(pos)\n        # setup the camera trajectories (copy the static camera for n frames)\n        init_pos = pos_list[0].clone()\n        pos_diff_list = [_ - init_pos for _ in pos_list]\n\n        interpolated_cameras = [cam] * len(pos_list)\n\n        video_array = render_gaussian_seq_w_mask_with_disp(\n            interpolated_cameras,\n            self.render_params,\n            init_pos,\n            self.top_k_index,\n            pos_diff_list,\n            self.sim_mask_in_raw_gaussian,\n        )\n        video_numpy = video_array.detach().cpu().numpy() * 255\n        video_numpy = np.clip(video_numpy, 0, 255).astype(np.uint8)\n        video_numpy = np.transpose(video_numpy, [0, 2, 3, 1])\n        os.makedirs(result_dir, exist_ok=True)\n        save_path = os.path.join(\n            result_dir,\n            exp_name\n            + \"_jelly_densi2k_video_substep_{}_grid_{}\".format(\n                substep, self.args.grid_size\n            )\n            + \".mp4\",\n        )\n        save_path = save_path.replace(\".gif\", \".mp4\")\n        save_video_mediapy(video_numpy, save_path, fps=25)\n\n        gt_videos = data[\"video_clip\"][0, 0 : self.num_frames, ...]\n        ssim = compute_ssim(video_array, gt_videos)\n        psnr = compute_psnr(video_array, gt_videos)\n\n        print(\"psnr for each frame: \", psnr)\n        mean_psnr = psnr.mean().item()\n        print(\"mean psnr: \", mean_psnr, \"mean ssim: \", ssim.item())\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, default=\"se3_field\")\n    parser.add_argument(\"--feat_dim\", type=int, default=64)\n    parser.add_argument(\"--num_decoder_layers\", type=int, default=3)\n    parser.add_argument(\"--decoder_hidden_size\", type=int, default=64)\n    # resolution of velocity fields\n    parser.add_argument(\"--spatial_res\", type=int, default=32)\n    parser.add_argument(\"--zero_init\", type=bool, default=True)\n\n    parser.add_argument(\"--num_frames\", type=str, default=14)\n\n    # resolution of material fields\n    parser.add_argument(\"--sim_res\", type=int, default=8)\n    parser.add_argument(\"--sim_output_dim\", type=int, default=1)\n\n    parser.add_argument(\"--downsample_scale\", type=float, default=0.1)\n    parser.add_argument(\"--top_k\", type=int, default=8)\n\n    # Logging and checkpointing\n    parser.add_argument(\"--output_dir\", type=str, default=\"../../output/inverse_sim\")\n    parser.add_argument(\"--seed\", type=int, default=0)\n\n    # demo parameters. related to parameters specified in configs/{scene_name}.py\n    parser.add_argument(\"--scene_name\", type=str, default=\"carnation\")\n    parser.add_argument(\"--demo_name\", type=str, default=\"inference_demo\")\n    parser.add_argument(\"--model_id\", type=int, default=0)\n\n    # if eval_ys > 10. Then all the youngs modulus is set to eval_ys homogeneously\n    parser.add_argument(\"--eval_ys\", type=float, default=1.0)\n    parser.add_argument(\"--force_id\", type=int, default=1)\n    parser.add_argument(\"--force_mag\", type=float, default=1.0)\n    parser.add_argument(\"--velo_scaling\", type=float, default=5.0)\n    parser.add_argument(\"--point_id\", type=int, default=0)\n    parser.add_argument(\"--apply_force\", action=\"store_true\", default=False)\n    parser.add_argument(\"--cam_id\", type=int, default=0)\n    parser.add_argument(\"--static_camera\", action=\"store_true\", default=False)\n\n    args, extra_args = parser.parse_known_args()\n\n    return args\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n\n    trainer = Trainer(args)\n\n    trainer.demo(\n        velo_scaling=args.velo_scaling,\n        eval_ys=args.eval_ys,\n        static_camera=args.static_camera,\n        apply_force=args.apply_force,\n        save_name=args.demo_name,\n    )\n"
  },
  {
    "path": "projects/inference/local_utils.py",
    "content": "import os\nimport torch\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor\nfrom time import time\nfrom omegaconf import OmegaConf\nfrom physdreamer.fields.se3_field import TemporalKplanesSE3fields\nfrom physdreamer.fields.triplane_field import TriplaneFields, TriplaneFieldsWithEntropy\n\nfrom physdreamer.gaussian_3d.gaussian_renderer.render import (\n    render_gaussian,\n    render_arrow_in_screen,\n)\nfrom physdreamer.gaussian_3d.gaussian_renderer.flow_depth_render import (\n    render_flow_depth_w_gaussian,\n)\nimport cv2\nimport numpy as np\nfrom sklearn.cluster import KMeans\nfrom time import time\n\nfrom physdreamer.gaussian_3d.utils.rigid_body_utils import (\n    get_rigid_transform,\n    matrix_to_quaternion,\n    quaternion_multiply,\n)\n\n\ndef cycle(dl: torch.utils.data.DataLoader):\n    while True:\n        for data in dl:\n            yield data\n\n\ndef load_motion_model(model, checkpoint_path):\n    model_path = os.path.join(checkpoint_path, \"model.pt\")\n    model.load_state_dict(torch.load(model_path))\n    print(\"load model from: \", model_path)\n    return model\n\n\ndef create_spatial_fields(\n    args, output_dim, aabb: Float[Tensor, \"2 3\"], add_entropy=True\n):\n\n    sp_res = args.sim_res\n    resolutions = [sp_res, sp_res, sp_res]\n    reduce = \"sum\"\n\n    model = TriplaneFields(\n        aabb,\n        resolutions,\n        feat_dim=32,\n        init_a=0.1,\n        init_b=0.5,\n        reduce=reduce,\n        num_decoder_layers=2,\n        decoder_hidden_size=32,\n        output_dim=output_dim,\n        zero_init=args.zero_init,\n    )\n    if args.zero_init:\n        print(\"=> zero init the last layer for Spatial MLP\")\n\n    return model\n\n\ndef create_motion_model(\n    args,\n    aabb: Float[Tensor, \"2 3\"],\n    num_frames=None,\n):\n    assert args.model in [\"se3_field\"]\n\n    sp_res = args.spatial_res\n    if num_frames is None:\n        num_frames = args.num_frames\n    resolutions = [sp_res, sp_res, sp_res, (num_frames) // 2 + 1]\n    # resolutions = [64, 64, 64, num_frames // 2 + 1]\n    reduce = \"sum\"\n\n    model = TemporalKplanesSE3fields(\n        aabb,\n        resolutions,\n        feat_dim=args.feat_dim,\n        init_a=0.1,\n        init_b=0.5,\n        reduce=reduce,\n        num_decoder_layers=args.num_decoder_layers,\n        decoder_hidden_size=args.decoder_hidden_size,\n        zero_init=args.zero_init,\n    )\n    if args.zero_init:\n        print(\"=> zero init the last layer for MLP\")\n\n    return model\n\n\ndef create_velocity_model(\n    args,\n    aabb: Float[Tensor, \"2 3\"],\n):\n\n    from physdreamer.fields.offset_field import TemporalKplanesOffsetfields\n\n    sp_res = args.sim_res\n    resolutions = [sp_res, sp_res, sp_res, (args.num_frames) // 2 + 1]\n    reduce = \"sum\"\n    model = TemporalKplanesOffsetfields(\n        aabb,\n        resolutions,\n        feat_dim=32,\n        init_a=0.1,\n        init_b=0.5,\n        reduce=reduce,\n        num_decoder_layers=2,\n        decoder_hidden_size=32,\n        zero_init=args.zero_init,\n    )\n    if args.zero_init:\n        print(\"=> zero init the last layer for velocity MLP\")\n    return model\n\n\ndef create_svd_model(model_name=\"svd_full\", ckpt_path=None):\n    state = dict()\n    cfg_path_dict = {\n        \"svd_full\": \"svd_configs/svd_full_decoder.yaml\",\n    }\n    config = cfg_path_dict[model_name]\n\n    config = OmegaConf.load(config)\n\n    if ckpt_path is not None:\n        # overwrite config.\n        config.model.params.ckpt_path = ckpt_path\n\n    s_time = time()\n    # model will automatically load when create\n    from physdreamer.utils.svd_helpper import load_model_from_config\n\n    model, msg = load_model_from_config(config, None)\n\n    state[\"config\"] = config\n\n    print(f\"Loading svd model takes {time() - s_time} seconds\")\n\n    return model, state\n\n\nclass LinearStepAnneal(object):\n    # def __init__(self, total_iters, start_state=[0.02, 0.98], end_state=[0.50, 0.98]):\n    def __init__(\n        self,\n        total_iters,\n        start_state=[0.02, 0.98],\n        end_state=[0.02, 0.98],\n        plateau_iters=-1,\n        warmup_step=300,\n    ):\n        self.total_iters = total_iters\n\n        if plateau_iters < 0:\n            plateau_iters = int(total_iters * 0.2)\n\n        if warmup_step <= 0:\n            warmup_step = 0\n\n        self.total_iters = max(total_iters - plateau_iters - warmup_step, 10)\n\n        self.start_state = start_state\n        self.end_state = end_state\n        self.warmup_step = warmup_step\n\n    def compute_state(self, cur_iter):\n\n        if self.warmup_step > 0:\n            cur_iter = max(0, cur_iter - self.warmup_step)\n        if cur_iter >= self.total_iters:\n            return self.end_state\n        ret = []\n        for s, e in zip(self.start_state, self.end_state):\n            ret.append(s + (e - s) * cur_iter / self.total_iters)\n        return ret\n\n\ndef setup_boundary_condition(\n    xyzs_over_time: torch.Tensor, mpm_solver, mpm_state, num_filled=0\n):\n\n    init_velocity = xyzs_over_time[1] - xyzs_over_time[0]\n    init_velocity_mag = torch.norm(init_velocity, dim=-1)\n\n    # 10% of the velocity\n    velocity_thres = torch.quantile(init_velocity_mag, 0.1, dim=0)\n\n    # [n_particles]. 1 for freeze, 0 for moving\n    freeze_mask = init_velocity_mag < velocity_thres\n    freeze_mask = freeze_mask.type(torch.int)\n    if num_filled > 0:\n        freeze_mask = torch.cat(\n            [freeze_mask, freeze_mask.new_zeros(num_filled).type(torch.int)], dim=0\n        )\n    num_freeze_pts = freeze_mask.sum()\n    print(\"num freeze pts from static points\", num_freeze_pts.item())\n\n    free_velocity = torch.zeros_like(init_velocity[0])  # [3] in device\n\n    mpm_solver.enforce_particle_velocity_by_mask(\n        mpm_state, freeze_mask, free_velocity, start_time=0, end_time=100000\n    )\n\n    return freeze_mask\n\n\ndef setup_plannar_boundary_condition(\n    xyzs_over_time: torch.Tensor,\n    mpm_solver,\n    mpm_state,\n    gaussian_xyz,\n    plane_mean,\n    plane_normal,\n    thres=0.2,\n):\n    \"\"\"\n    plane_mean and plane_normal are in original coordinate, not being normalized\n    Args:\n        xyzs_over_time: [T, N, 3]\n        gaussian_xyz: [N, 3] torch.Tensor\n        plane_mean: [3]\n        plane_normal: [3]\n        thres: float\n\n    \"\"\"\n\n    plane_normal = plane_normal / torch.norm(plane_normal)\n    # [n_particles]\n    plane_dist = torch.abs(\n        torch.sum(\n            (gaussian_xyz - plane_mean.unsqueeze(0)) * plane_normal.unsqueeze(0), dim=-1\n        )\n    )\n    # [n_particles]\n    freeze_mask = plane_dist < thres\n    freeze_mask = freeze_mask.type(torch.int)\n\n    num_freeze_pts = freeze_mask.sum()\n    print(\"num freeze pts from plannar boundary\", num_freeze_pts.item())\n    free_velocity = xyzs_over_time.new_zeros(3)\n    # print(\"free velocity\", free_velocity.shape, freeze_mask.shape)\n\n    mpm_solver.enforce_particle_velocity_by_mask(\n        mpm_state, freeze_mask, free_velocity, start_time=0, end_time=100000\n    )\n\n    return freeze_mask\n\n\ndef find_far_points(xyzs, selected_points, thres=0.05):\n    \"\"\"\n    Args:\n        xyzs: [N, 3]\n        selected_points: [M, 3]\n    Outs:\n        freeze_mask: [N], 1 for points that are far away, 0 for points that are close\n                    dtype=torch.int\n    \"\"\"\n    chunk_size = 10000\n\n    freeze_mask_list = []\n    for i in range(0, xyzs.shape[0], chunk_size):\n\n        end_index = min(i + chunk_size, xyzs.shape[0])\n        xyzs_chunk = xyzs[i:end_index]\n        # [M, N]\n        cdist = torch.cdist(xyzs_chunk, selected_points)\n\n        min_dist, _ = torch.min(cdist, dim=-1)\n        freeze_mask = min_dist > thres\n        freeze_mask = freeze_mask.type(torch.int)\n        freeze_mask_list.append(freeze_mask)\n\n    freeze_mask = torch.cat(freeze_mask_list, dim=0)\n\n    # 1 for points that are far away, 0 for points that are close\n    return freeze_mask\n\n\ndef setup_boundary_condition_with_points(\n    xyzs, selected_points, mpm_solver, mpm_state, thres=0.05\n):\n    \"\"\"\n    Args:\n        xyzs: [N, 3]\n        selected_points: [M, 3]\n    \"\"\"\n\n    freeze_mask = find_far_points(xyzs, selected_points, thres=thres)\n    num_freeze_pts = freeze_mask.sum()\n    print(\"num freeze pts from static points\", num_freeze_pts.item())\n\n    free_velocity = torch.zeros_like(xyzs[0])  # [3] in device\n\n    mpm_solver.enforce_particle_velocity_by_mask(\n        mpm_state, freeze_mask, free_velocity, start_time=0, end_time=1000000\n    )\n\n    return freeze_mask\n\n\ndef setup_bottom_boundary_condition(xyzs, mpm_solver, mpm_state, percentile=0.05):\n    \"\"\"\n    Args:\n        xyzs: [N, 3]\n        selected_points: [M, 3]\n    \"\"\"\n    max_z, min_z = torch.max(xyzs[:, 2]), torch.min(xyzs[:, 2])\n    thres = min_z + (max_z - min_z) * percentile\n    freeze_mask = xyzs[:, 2] < thres\n\n    freeze_mask = freeze_mask.type(torch.int)\n    num_freeze_pts = freeze_mask.sum()\n    print(\"num freeze pts from bottom points\", num_freeze_pts.item())\n\n    free_velocity = torch.zeros_like(xyzs[0])  # [3] in device\n\n    mpm_solver.enforce_particle_velocity_by_mask(\n        mpm_state, freeze_mask, free_velocity, start_time=0, end_time=1000000\n    )\n\n    return freeze_mask\n\n\ndef render_single_view_video(\n    cam,\n    render_params,\n    motion_model,\n    time_stamps,\n    rand_bg=False,\n    render_flow=False,\n    query_mask=None,\n):\n    \"\"\"\n    Args:\n        cam:\n        motion_model: Callable function, f(x, t) => translation, rotation\n        time_stamps: [T]\n        query_mask: Tensor of [N], 0 for freeze points, 1 for moving points\n    Outs:\n        ret_video: [T, 3, H, W] value in [0, 1]\n    \"\"\"\n\n    if rand_bg:\n        bg_color = torch.rand(3, device=\"cuda\")\n    else:\n        bg_color = render_params.bg_color\n\n    ret_img_list = []\n    for time_stamp in time_stamps:\n        if not render_flow:\n            new_gaussians = render_params.gaussians.apply_se3_fields(\n                motion_model, time_stamp\n            )\n            if query_mask is not None:\n                new_gaussians._xyz = new_gaussians._xyz * query_mask.unsqueeze(\n                    -1\n                ) + render_params.gaussians._xyz * (1 - query_mask.unsqueeze(-1))\n                new_gaussians._rotation = (\n                    new_gaussians._rotation * query_mask.unsqueeze(-1)\n                    + render_params.gaussians._rotation * (1 - query_mask.unsqueeze(-1))\n                )\n            # [3, H, W]\n            img = render_gaussian(\n                cam,\n                new_gaussians,\n                render_params.render_pipe,\n                bg_color,\n            )[\n                \"render\"\n            ]  # value in [0, 1]\n        else:\n            inp_time = (\n                torch.ones_like(render_params.gaussians._xyz[:, 0:1]) * time_stamp\n            )\n            inp = torch.cat([render_params.gaussians._xyz, inp_time], dim=-1)\n            # [bs, 3, 3]. [bs, 3]\n            R, point_disp = motion_model(inp)\n\n            img = render_flow_depth_w_gaussian(\n                cam,\n                render_params.gaussians,\n                render_params.render_pipe,\n                point_disp,\n                bg_color,\n            )[\"render\"]\n\n        ret_img_list.append(img[None, ...])\n\n    ret_video = torch.cat(ret_img_list, dim=0)  # [T, 3, H, W]\n    return ret_video\n\n\ndef render_gaussian_seq(cam, render_params, gaussian_pos_list, gaussian_cov_list):\n\n    ret_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(gaussian_pos_list)):\n\n        xyz = gaussian_pos_list[i]\n        gaussians._xyz = xyz\n        # TODO, how to deal with cov\n        img = render_gaussian(\n            cam,\n            gaussians,\n            render_params.render_pipe,\n            background,\n        )[\"render\"]\n\n        ret_img_list.append(img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    # [T, C, H, W], in [0, 1]\n    rendered_video = torch.cat(ret_img_list, dim=0)\n\n    return rendered_video\n\n\ndef render_gaussian_seq_w_mask(\n    cam, render_params, gaussian_pos_list, gaussian_cov_list, update_mask\n):\n\n    ret_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    old_cov = gaussians.get_covariance().clone()\n\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(gaussian_pos_list)):\n\n        xyz = gaussian_pos_list[i]\n        gaussians._xyz[update_mask, ...] = xyz\n\n        if gaussian_cov_list is not None:\n            cov = gaussian_cov_list[i]\n            old_cov[update_mask, ...] = cov\n            cov3D_precomp = old_cov\n\n        else:\n            cov3D_precomp = None\n\n        img = render_gaussian(\n            cam,\n            gaussians,\n            render_params.render_pipe,\n            background,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        ret_img_list.append(img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    # [T, C, H, W], in [0, 1]\n    rendered_video = torch.cat(ret_img_list, dim=0)\n\n    return rendered_video\n\n\ndef render_gaussian_seq_w_mask_with_disp(\n    cam, render_params, orign_points, top_k_index, disp_list, update_mask\n):\n    \"\"\"\n    Args:\n        cam: Camera or list of Camera\n        orign_points: [m, 3]\n        disp_list: List[m, 3]\n        top_k_index: [n, top_k]\n\n    \"\"\"\n\n    ret_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    old_rotation = gaussians._rotation.clone()\n\n    query_pts = old_xyz[update_mask, ...]\n    query_rotation = old_rotation[update_mask, ...]\n\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(disp_list)):\n\n        if isinstance(cam, list):\n            render_cam = cam[i]\n        else:\n            render_cam = cam\n        disp = disp_list[i]\n        new_xyz, new_rotation = interpolate_points_w_R(\n            query_pts, query_rotation, orign_points, disp, top_k_index\n        )\n        gaussians._xyz[update_mask, ...] = new_xyz\n        gaussians._rotation[update_mask, ...] = new_rotation\n\n        cov3D_precomp = None\n\n        img = render_gaussian(\n            render_cam,\n            gaussians,\n            render_params.render_pipe,\n            background,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        ret_img_list.append(img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    gaussians._rotation = old_rotation\n    # [T, C, H, W], in [0, 1]\n    rendered_video = torch.cat(ret_img_list, dim=0)\n\n    return rendered_video\n\n\ndef render_gaussian_seq_w_mask_with_disp_for_figure(\n    cam, render_params, orign_points, top_k_index, disp_list, update_mask\n):\n    \"\"\"\n    Args:\n        cam: Camera or list of Camera\n        orign_points: [m, 3]\n        disp_list: List[m, 3]\n        top_k_index: [n, top_k]\n\n    \"\"\"\n\n    ret_img_list = []\n    moving_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    old_rotation = gaussians._rotation.clone()\n\n    query_pts = old_xyz[update_mask, ...]\n    query_rotation = old_rotation[update_mask, ...]\n\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    background_black = torch.tensor([0, 0, 0], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(disp_list)):\n\n        if isinstance(cam, list):\n            render_cam = cam[i]\n        else:\n            render_cam = cam\n        disp = disp_list[i]\n        new_xyz, new_rotation = interpolate_points_w_R(\n            query_pts, query_rotation, orign_points, disp, top_k_index\n        )\n        gaussians._xyz[update_mask, ...] = new_xyz\n        gaussians._rotation[update_mask, ...] = new_rotation\n\n        cov3D_precomp = None\n\n        img = render_gaussian(\n            render_cam,\n            gaussians,\n            render_params.render_pipe,\n            background,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        masked_gaussians = gaussians.apply_mask(update_mask)\n        moving_img = render_gaussian(\n            render_cam,\n            masked_gaussians,\n            render_params.render_pipe,\n            background_black,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        ret_img_list.append(img[None, ...])\n        moving_img_list.append(moving_img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    gaussians._rotation = old_rotation\n    # [T, C, H, W], in [0, 1]\n    rendered_video = torch.cat(ret_img_list, dim=0)\n    moving_part_video = torch.cat(moving_img_list, dim=0)\n\n    return rendered_video, moving_part_video\n\n\ndef render_gaussian_seq_w_mask_cam_seq(\n    cam_list, render_params, gaussian_pos_list, gaussian_cov_list, update_mask\n):\n\n    ret_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    old_cov = gaussians.get_covariance().clone()\n\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(gaussian_pos_list)):\n\n        xyz = gaussian_pos_list[i]\n        gaussians._xyz[update_mask, ...] = xyz\n\n        if gaussian_cov_list is not None:\n            cov = gaussian_cov_list[i]\n            old_cov[update_mask, ...] = cov\n            cov3D_precomp = old_cov\n\n        else:\n            cov3D_precomp = None\n\n        img = render_gaussian(\n            cam_list[i],\n            gaussians,\n            render_params.render_pipe,\n            background,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        ret_img_list.append(img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    # [T, C, H, W], in [0, 1]\n    rendered_video = torch.cat(ret_img_list, dim=0)\n\n    return rendered_video\n\n\ndef apply_grid_bc_w_freeze_pts(grid_size, grid_lim, freeze_pts, mpm_solver):\n\n    device = freeze_pts.device\n\n    grid_pts_cnt = torch.zeros(\n        (grid_size, grid_size, grid_size), dtype=torch.int32, device=device\n    )\n\n    dx = grid_lim / grid_size\n    inv_dx = 1.0 / dx\n\n    freeze_pts = (freeze_pts * inv_dx).long()\n\n    for x, y, z in freeze_pts:\n        grid_pts_cnt[x, y, z] += 1\n\n    freeze_grid_mask = grid_pts_cnt >= 1\n\n    freeze_grid_mask_int = freeze_grid_mask.type(torch.int32)\n\n    number_freeze_grid = freeze_grid_mask_int.sum().item()\n    print(\"number of freeze grid\", number_freeze_grid)\n\n    mpm_solver.enforce_grid_velocity_by_mask(freeze_grid_mask_int)\n\n    # add debug section:\n\n    return freeze_grid_mask\n\n\ndef add_constant_force(\n    mpm_sovler,\n    mpm_state,\n    xyzs,\n    center_point,\n    radius,\n    force,\n    dt,\n    start_time,\n    end_time,\n    device,\n):\n    \"\"\"\n    Args:\n        xyzs: [N, 3]\n        center_point: [3]\n        radius: float\n        force: [3]\n\n    \"\"\"\n\n    # compute distance from xyzs to center_point\n    # [N]\n    dist = torch.norm(xyzs - center_point.unsqueeze(0), dim=-1)\n\n    apply_force_mask = dist < radius\n    apply_force_mask = apply_force_mask.type(torch.int)\n\n    print(apply_force_mask.shape, apply_force_mask.sum().item(), \"apply force mask\")\n\n    mpm_sovler.add_impulse_on_particles_with_mask(\n        mpm_state,\n        force,\n        dt,\n        apply_force_mask,\n        start_time=start_time,\n        end_time=end_time,\n        device=device,\n    )\n\n\n@torch.no_grad()\ndef render_force_2d(cam, render_params, center_point, force):\n\n    force_in_2d_scale = 80  # unit as pixel\n    two_points = torch.stack([center_point, center_point + force], dim=0)\n\n    gaussians = render_params.gaussians\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n\n    # [3, H, W]\n    img = render_gaussian(\n        cam,\n        gaussians,\n        render_params.render_pipe,\n        background,\n    )[\"render\"]\n    img = img.detach().contiguous()\n    img = img.cpu().numpy().transpose(1, 2, 0)\n    img = img * 255\n    img = img.astype(np.uint8).copy()\n\n    # two_points.  [2, 3]\n    # arrow_2d: [2, 2]\n    arrow_2d = render_arrow_in_screen(cam, two_points)\n\n    arrow_2d = arrow_2d.cpu().numpy()\n\n    start, vec_2d = arrow_2d[0], arrow_2d[1] - arrow_2d[0]\n    vec_2d = vec_2d / np.linalg.norm(vec_2d)\n\n    start = start  # + np.array([540.0, 288.0])  # [W, H] / 2\n    # debug here.\n    # 1. unit in pixel?\n    # 2. use cv2 to add arrow?\n    # draw cirrcle at start in img\n\n    # img = img.transpose(2, 0, 1)\n    img = cv2.circle(img, (int(start[0]), int(start[1])), 40, (255, 255, 255), 8)\n\n    # draw arrow in img\n    end = start + vec_2d * force_in_2d_scale\n    end = end.astype(np.int32)\n    start = start.astype(np.int32)\n    img = cv2.arrowedLine(img, (start[0], start[1]), (end[0], end[1]), (0, 255, 255), 8)\n\n    return img\n\n\ndef render_gaussian_seq_w_mask_cam_seq_with_force(\n    cam_list,\n    render_params,\n    gaussian_pos_list,\n    gaussian_cov_list,\n    update_mask,\n    pts_index,\n    force,\n    force_steps,\n):\n\n    force_in_2d_scale = 80  # unit as pixel\n    ret_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    old_cov = gaussians.get_covariance().clone()\n\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(gaussian_pos_list)):\n\n        xyz = gaussian_pos_list[i]\n        gaussians._xyz[update_mask, ...] = xyz\n\n        if gaussian_cov_list is not None:\n            cov = gaussian_cov_list[i]\n            old_cov[update_mask, ...] = cov\n            cov3D_precomp = old_cov\n\n        else:\n            cov3D_precomp = None\n\n        img = render_gaussian(\n            cam_list[i],\n            gaussians,\n            render_params.render_pipe,\n            background,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        # to [H, W, 3]\n        img = img.detach().contiguous().cpu().numpy().transpose(1, 2, 0)\n        img = np.clip((img * 255), 0, 255).astype(np.uint8).copy()\n\n        if i < force_steps:\n            center_point = gaussians._xyz[pts_index]\n            two_points = torch.stack([center_point, center_point + force], dim=0)\n\n            arrow_2d = render_arrow_in_screen(cam_list[i], two_points)\n\n            arrow_2d = arrow_2d.cpu().numpy()\n\n            start, vec_2d = arrow_2d[0], arrow_2d[1] - arrow_2d[0]\n            vec_2d = vec_2d / np.linalg.norm(vec_2d)\n\n            start = start  # + np.array([540.0, 288.0])\n\n            img = cv2.circle(\n                img, (int(start[0]), int(start[1])), 40, (255, 255, 255), 8\n            )\n\n            # draw arrow in img\n            end = start + vec_2d * force_in_2d_scale\n            end = end.astype(np.int32)\n            start = start.astype(np.int32)\n            img = cv2.arrowedLine(\n                img, (start[0], start[1]), (end[0], end[1]), (0, 255, 255), 8\n            )\n\n        img = img.transpose(2, 0, 1)\n        ret_img_list.append(img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    # [T, C, H, W], in [0, 1]\n    rendered_video = np.concatenate(ret_img_list, axis=0)\n\n    return rendered_video\n\n\ndef render_gaussian_seq_w_mask_cam_seq_with_force_with_disp(\n    cam_list,\n    render_params,\n    orign_points,\n    top_k_index,\n    disp_list,\n    update_mask,\n    pts_index,\n    force,\n    force_steps,\n):\n\n    force_in_2d_scale = 80  # unit as pixel\n    ret_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    old_rotation = gaussians._rotation.clone()\n\n    query_pts = old_xyz[update_mask, ...]\n    query_rotation = old_rotation[update_mask, ...]\n\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(disp_list)):\n\n        disp = disp_list[i]\n        new_xyz, new_rotation = interpolate_points_w_R(\n            query_pts, query_rotation, orign_points, disp, top_k_index\n        )\n        gaussians._xyz[update_mask, ...] = new_xyz\n        gaussians._rotation[update_mask, ...] = new_rotation\n\n        cov3D_precomp = None\n\n        img = render_gaussian(\n            cam_list[i],\n            gaussians,\n            render_params.render_pipe,\n            background,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        # to [H, W, 3]\n        img = img.detach().contiguous().cpu().numpy().transpose(1, 2, 0)\n        img = np.clip((img * 255), 0, 255).astype(np.uint8).copy()\n\n        if i < force_steps:\n            center_point = gaussians._xyz[pts_index]\n            two_points = torch.stack([center_point, center_point + force], dim=0)\n\n            arrow_2d = render_arrow_in_screen(cam_list[i], two_points)\n\n            arrow_2d = arrow_2d.cpu().numpy()\n\n            start, vec_2d = arrow_2d[0], arrow_2d[1] - arrow_2d[0]\n            vec_2d = vec_2d / np.linalg.norm(vec_2d)\n\n            start = start  # + np.array([540.0, 288.0])\n\n            img = cv2.circle(\n                img, (int(start[0]), int(start[1])), 40, (255, 255, 255), 8\n            )\n\n            # draw arrow in img\n            end = start + vec_2d * force_in_2d_scale\n            end = end.astype(np.int32)\n            start = start.astype(np.int32)\n            img = cv2.arrowedLine(\n                img, (start[0], start[1]), (end[0], end[1]), (0, 255, 255), 8\n            )\n\n        img = img.transpose(2, 0, 1)\n        ret_img_list.append(img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    gaussians._rotation = old_rotation\n    # [T, C, H, W], in [0, 1]\n    rendered_video = np.concatenate(ret_img_list, axis=0)\n\n    return rendered_video\n\n\ndef downsample_with_kmeans(points_array: np.ndarray, num_points: int):\n    \"\"\"\n    Args:\n        points_array: [N, 3]\n        num_points: int\n    Outs:\n        downsampled_points: [num_points, 3]\n    \"\"\"\n\n    print(\n        \"=> staring downsample with kmeans from \",\n        points_array.shape[0],\n        \" points to \",\n        num_points,\n        \" points\",\n    )\n    s_time = time()\n    kmeans = KMeans(n_clusters=num_points, random_state=0).fit(points_array)\n    cluster_centers = kmeans.cluster_centers_\n    e_time = time()\n\n    print(\"=> downsample with kmeans takes \", e_time - s_time, \" seconds\")\n    return cluster_centers\n\n\n@torch.no_grad()\ndef downsample_with_kmeans_gpu(points_array: torch.Tensor, num_points: int):\n\n    from kmeans_gpu import KMeans\n\n    kmeans = KMeans(\n        n_clusters=num_points,\n        max_iter=100,\n        tolerance=1e-4,\n        distance=\"euclidean\",\n        sub_sampling=None,\n        max_neighbors=15,\n    )\n\n    features = torch.ones(1, 1, points_array.shape[0], device=points_array.device)\n    points_array = points_array.unsqueeze(0)\n    # Forward\n\n    print(\n        \"=> staring downsample with kmeans from \",\n        points_array.shape[1],\n        \" points to \",\n        num_points,\n        \" points\",\n    )\n    s_time = time()\n    centroids, features = kmeans(points_array, features)\n\n    ret_points = centroids.squeeze(0)\n    e_time = time()\n    print(\"=> downsample with kmeans takes \", e_time - s_time, \" seconds\")\n\n    # [np_subsample, 3]\n    return ret_points\n\n\n@torch.no_grad()\ndef downsample_with_kmeans_gpu_with_chunk(points_array: torch.Tensor, num_points: int):\n    # split the points_array into chunks, and then do kmeans on each chunk\n    #   to save memory.\n\n    from kmeans_gpu import KMeans\n\n    points_array_sum = points_array.sum(dim=1)\n    arg_idx = torch.argsort(points_array_sum, descending=True)\n    points_array = points_array[arg_idx, :]\n\n    features = torch.ones(1, 1, points_array.shape[0], device=points_array.device)\n    points_array = points_array.unsqueeze(0)\n    # Forward\n\n    print(\n        \"=> staring downsample with kmeans from \",\n        points_array.shape[1],\n        \" points to \",\n        num_points,\n        \" points\",\n        points_array.shape,\n    )\n    s_time = time()\n\n    num_raw_points = points_array.shape[1]\n    chunk_size = 150000\n\n    num_chunks = num_raw_points // chunk_size + 1\n\n    ret_list = []\n    for i in range(num_chunks):\n\n        start = i * chunk_size\n        end = min((i + 1) * chunk_size, num_raw_points)\n        points_chunk = points_array[:, start:end, :]\n        features_chunk = features[:, :, start:end]\n\n        num_target_points = min(chunk_size, num_points // num_chunks)\n\n        kmeans = KMeans(\n            n_clusters=num_target_points,\n            max_iter=100,\n            tolerance=1e-4,\n            distance=\"euclidean\",\n            sub_sampling=None,\n            max_neighbors=15,\n        )\n        centroids, _ = kmeans(points_chunk, features_chunk)\n        ret_list.append(centroids.squeeze(0))\n\n    ret_points = torch.cat(ret_list, dim=0)\n    e_time = time()\n    print(\"=> downsample with kmeans takes \", e_time - s_time, \" seconds\")\n\n    # [np_subsample, 3]\n    return ret_points\n\n\ndef interpolate_points(query_points, drive_displacement, top_k_index):\n    \"\"\"\n    Args:\n        query_points: [n, 3]\n        drive_displacement: [m, 3]\n        top_k_index: [n, top_k] < m\n    \"\"\"\n\n    top_k_disp = drive_displacement[top_k_index]\n\n    t = top_k_disp.mean(dim=1)\n\n    ret_points = query_points + t\n\n    return ret_points\n\n\ndef interpolate_points_w_R(\n    query_points, query_rotation, drive_origin_pts, drive_displacement, top_k_index\n):\n    \"\"\"\n    Args:\n        query_points: [n, 3]\n        drive_origin_pts: [m, 3]\n        drive_displacement: [m, 3]\n        top_k_index: [n, top_k] < m\n\n    Or directly call: apply_discrete_offset_filds_with_R(self, origin_points, offsets, topk=6):\n        Args:\n            origin_points: (N_r, 3)\n            offsets: (N_r, 3)\n        in rendering\n    \"\"\"\n\n    # [n, topk, 3]\n    top_k_disp = drive_displacement[top_k_index]\n    source_points = drive_origin_pts[top_k_index]\n\n    R, t = get_rigid_transform(source_points, source_points + top_k_disp)\n\n    avg_offsets = top_k_disp.mean(dim=1)\n\n    ret_points = query_points + avg_offsets\n\n    new_rotation = quaternion_multiply(matrix_to_quaternion(R), query_rotation)\n\n    return ret_points, new_rotation\n\n\ndef create_camera_path(\n    cam,\n    radius: float,\n    focus_pt: np.ndarray = np.array([0, 0, 0]),\n    up: np.ndarray = np.array([0, 0, 1]),\n    n_frames: int = 60,\n    n_rots: int = 1,\n    y_scale: float = 1.0,\n):\n\n    R, T = cam.R, cam.T\n    # R, T = R.cpu().numpy(), T.cpu().numpy()\n\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = R.transpose()\n    Rt[:3, 3] = T\n    Rt[3, 3] = 1.0\n    C2W = np.linalg.inv(Rt)\n    C2W[:3, 1:3] *= -1\n\n    import copy\n    from physdreamer.utils.camera_utils import generate_spiral_path\n    from physdreamer.data.cameras import Camera\n\n    lookat_pt = focus_pt\n    render_poses = generate_spiral_path(\n        C2W, radius, lookat_pt, up, n_frames, n_rots, y_scale\n    )\n\n    FoVy, FoVx = cam.FoVy, cam.FoVx\n    height, width = cam.image_height, cam.image_width\n\n    ret_cam_list = []\n    for i in range(n_frames):\n        c2w_opengl = render_poses[i]\n        c2w = copy.deepcopy(c2w_opengl)\n        c2w[:3, 1:3] *= -1\n\n        # get the world-to-camera transform and set R, T\n        w2c = np.linalg.inv(c2w)\n        R = np.transpose(\n            w2c[:3, :3]\n        )  # R is stored transposed due to 'glm' in CUDA code\n        T = w2c[:3, 3]\n        cam = Camera(\n            R=R,\n            T=T,\n            FoVy=FoVy,\n            FoVx=FoVx,\n            img_path=None,\n            img_hw=(height, width),\n            timestamp=None,\n            data_device=\"cuda\",\n        )\n        ret_cam_list.append(cam)\n\n    return ret_cam_list\n\n\ndef get_camera_trajectory(cam, num_pos, camera_cfg: dict, dataset):\n    if camera_cfg[\"type\"] == \"spiral\":\n        interpolated_cameras = create_camera_path(\n            cam,\n            radius=camera_cfg[\"radius\"],\n            focus_pt=camera_cfg[\"focus_point\"],\n            up=camera_cfg[\"up\"],\n            n_frames=num_pos,\n        )\n    elif camera_cfg[\"type\"] == \"interpolation\":\n        if \"start_frame\" in camera_cfg and \"end_frame\" in camera_cfg:\n            interpolated_cameras = dataset.interpolate_camera(\n                camera_cfg[\"start_frame\"], camera_cfg[\"end_frame\"], num_pos\n            )\n        else:\n            interpolated_cameras = dataset.interpolate_camera(\n                camera_cfg[\"start_frame\"], camera_cfg[\"start_frame\"], num_pos\n            )\n\n    print(\n        \"number of simulated frames: \",\n        num_pos,\n        \"num camera viewpoints: \",\n        len(interpolated_cameras),\n    )\n    return interpolated_cameras\n"
  },
  {
    "path": "projects/inference/run.sh",
    "content": "# python3 demo.py --scene_name carnation --apply_force --force_id 1  --point_id 0 --force_mag 2.0 --cam_id 0\n\n# python3 demo.py --scene_name hat --apply_force --force_id 0  --point_id 0 --force_mag 3.0 --cam_id 0\n\npython3 demo.py --scene_name telephone --apply_force --force_id 0  --point_id 0 --force_mag 0.1 --cam_id 0\n\npython3 demo.py --scene_name alocasia --apply_force --force_id 0 --point_id 0 --force_mag 3.0 --cam_id 0\n"
  },
  {
    "path": "projects/uncleaned_train/.gitignore",
    "content": "img_data/\ntmp/\n./data/\ndataset/\nmodels/\nmodel\noutput/\noutputs/\n*.sh\nexp_motion/*.sh\n__pycache__\n*__pycache__/\n*/__pycache__/\n*/wandb/*\nwandb\n*/*.pyc\n*.sh.log*\n*.gif\n*.mp4\n*.pt\n*.ipynb\n\n"
  },
  {
    "path": "projects/uncleaned_train/README.md",
    "content": "This folder contains the original uncleaned training code. This folder can be viewed as an independent folder, it did not use code in physdreamer/ and projects/inference\n\n`exp_motion/train` contains code for velocity and material training. \n\nVelocity train and material train is slightly different:\n1. How many frames is used for training.\n2. How many frames the backprop needs to be passed. \n3. Velocity train typically use smaller spatial resolution(grid_size) and temporal resolution(num of substeps). \n\nTwo major difference for this code with the inference code is that:\n1. All the helper functions here are all installed in a folder called \"motionrep\". The inference code uses \"physdreamer\". the physdreamer/ and motionrep/ folder should share most of the code\n2. The config.yaml file has different contents and format\n"
  },
  {
    "path": "projects/uncleaned_train/exp_motion/train/config.yml",
    "content": "dataset_dir: \n# optimization\nwarmup_step: 10\nmax_grad_norm: 10.0\n\nrand_bg: False\n\nvelo_dir: [\n  \"../../data/physics_dreamer/alocasia_nerfstudio/mul_videos/velopretrain_models/frame_00037_mb-8_fps-30_8\",\n]\n"
  },
  {
    "path": "projects/uncleaned_train/exp_motion/train/config_demo.py",
    "content": "import numpy as np\n\nfrom model_config import (\n    model_list,\n    camera_cfg_list,\n    points_list,\n    force_directions,\n    simulate_cfg,\n    dataset_dir,\n    result_dir,\n    exp_name,\n)\n\n\nclass DemoParams(object):\n    def __init__(self):\n\n        self.demo_dict = {\n            \"baseline\": {\n                \"model_path\": model_list[0],\n                \"substep\": 768,\n                \"grid_size\": 64,\n                \"name\": \"baseline\",\n                \"camera_cfg\": camera_cfg_list[0],\n                \"cam_id\": 0,\n            },\n            \"demo_dummy\": {\n                \"model_path\": model_list[0],\n                \"center_point\": points_list[0],\n                \"force\": np.array([0.15, 0, 0]),\n                \"camera_cfg\": camera_cfg_list[0],\n                \"force_duration\": 0.75,\n                \"force_radius\": 0.1,\n                \"substep\": 256,\n                \"grid_size\": 96,\n                \"total_time\": 5,\n                \"name\": \"alocasia_sv_gres96_substep256_force_top_of_flower\",\n            },\n        }\n\n    def get_cfg(\n        self,\n        demo_name=None,\n        model_id: int = 0,\n        eval_ys: float = 1.0,\n        force_id: int = 0,\n        force_mag: float = 1.0,\n        velo_scaling: float = 3.0,\n        point_id: int = 0,\n        cam_id: int = 0,\n        apply_force: bool = False,\n    ):\n        if demo_name == \"None\":\n            demo_name = None\n        if (demo_name is not None) and (demo_name in self.demo_dict):\n            cfg = self.demo_dict[demo_name]\n        else:\n            cfg = {}\n            cfg[\"model_path\"] = model_list[model_id]\n            cfg[\"center_point\"] = points_list[point_id]\n            cfg[\"force\"] = force_directions[force_id] * force_mag\n            cfg[\"camera_cfg\"] = camera_cfg_list[cam_id]\n            cfg[\"cam_id\"] = cam_id\n            cfg[\"force_duration\"] = 0.75\n            cfg[\"force_radius\"] = 0.1\n            cfg[\"substep\"] = simulate_cfg[\"substep\"]\n            cfg[\"grid_size\"] = simulate_cfg[\"grid_size\"]\n            cfg[\"total_time\"] = 5\n            cfg[\"eval_ys\"] = eval_ys\n            cfg[\"velo_scaling\"] = velo_scaling\n\n            if demo_name is None:\n                name = \"\"\n            else:\n                name = demo_name + \"_\"\n            name = (\n                name + f\"{exp_name}_sv_gres{cfg['grid_size']}_substep{cfg['substep']}\"\n            )\n            if eval_ys > 10:\n                name += f\"_eval_ys_{eval_ys}\"\n            else:\n                name += f\"_model_{model_id}\"\n\n            if apply_force:\n                name += f\"_force_{force_id}_mag_{force_mag}_point_{point_id}\"\n            else:\n                name += f\"_no_force_velo_{velo_scaling}\"\n            cfg[\"name\"] = name\n\n        cfg[\"dataset_dir\"] = dataset_dir\n        cfg[\"result_dir\"] = result_dir\n\n        return cfg\n"
  },
  {
    "path": "projects/uncleaned_train/exp_motion/train/convert_gaussian_to_mesh.py",
    "content": "import os\nfrom random import gauss\nfrom fire import Fire\nfrom motionrep.gaussian_3d.scene import GaussianModel\nimport numpy as np\nimport torch\n\n\ndef convert_gaussian_to_mesh(gaussian_path, thresh=0.1, save_path=None):\n    if save_path is None:\n        dir_path = os.path.dirname(gaussian_path)\n        save_path = os.path.join(dir_path, \"gaussian_to_mesh_thres_{}.obj\".format(thresh))\n\n    gaussian_path = os.path.join(gaussian_path)\n\n    gaussians = GaussianModel(3)\n\n    gaussians.load_ply(gaussian_path)\n    gaussians.detach_grad()\n    print(\n        \"load gaussians from: {}\".format(gaussian_path),\n        \"... num gaussians: \",\n        gaussians._xyz.shape[0],\n    )\n\n    mesh = gaussians.extract_mesh(\n        save_path, density_thresh=thresh, resolution=128, decimate_target=1e5\n    )\n\n    mesh.write(save_path)\n\n\ndef internal_filling(gaussian_path, thresh=2.0,  save_path=None, resolution=256, \n                     num_pts=4):\n    if save_path is None:\n        dir_path = os.path.dirname(gaussian_path)\n        save_path = os.path.join(dir_path, \"gaussian_internal_fill.ply\")\n\n    gaussians = GaussianModel(3)\n\n    gaussians.load_ply(gaussian_path)\n    gaussians.detach_grad()\n\n    print(\n        \"load gaussians from: {}\".format(gaussian_path),\n        \"... num gaussians: \",\n        gaussians._xyz.shape[0],\n    )\n\n    # [res, res, res]\n    # torch.linspace(-1, 1, resolution) for the coords\n    # x[0] => -1,  x[resolution-1] = 1\n    # x[i] = -1 + i * 2 / (resolution - 1)\n    # index_x = (x[i] + 1) / 2 * (resolution - 1)\n    occ = (\n        gaussians.extract_fields(resolution=resolution, num_blocks=16, relax_ratio=1.5)\n        .detach()\n        .cpu()\n        .numpy()\n    )\n\n    xyzs = gaussians._xyz.detach().cpu().numpy()\n\n    center = gaussians.center.detach().cpu().numpy()\n    scale = gaussians.scale # float\n    xyzs = (xyzs - center) * scale # [-1, 1]?\n\n    percentile = [95, 97, 99][1]\n\n    # from IPython import embed\n    # embed()\n\n    thres_ = np.percentile(occ, percentile)\n    print(\"density threshold: {:.5f} -- in percentile: {:.1f} \".format(thres_, percentile))\n    occ_large_thres = occ > thresh\n    # get the xyz of the occupied voxels\n    # xyz = np.argwhere(occ)\n    # normalize to [-1, 1]\n    # xyz = xyz / (resolution - 1) * 2 - 1\n\n    voxel_counts = np.zeros((resolution, resolution, resolution))\n\n    points_xyzindex = ((xyzs + 1) / 2 * (resolution - 1)).astype(np.uint32)\n\n    for x, y, z in points_xyzindex:\n        voxel_counts[x, y, z] += 1\n    \n    add_points = np.logical_and(occ_large_thres, voxel_counts <= 1)\n\n    add_xyz = np.argwhere(add_points).astype(np.float32)\n    add_xyz = add_xyz / (resolution - 1) * 2 - 1  # [x,y,z]_min of the unit cell.  randomly add points in the unit cell\n\n    cell_width = 2.0 / (resolution - 1)\n\n    # copy add_xyz \"num_pts\" times\n    add_xyz = np.repeat(add_xyz, num_pts, axis=0)\n\n    random_offset_within_cell = np.random.uniform(-cell_width / 2, cell_width / 2, size=add_xyz.shape)\n    add_xyz += random_offset_within_cell\n\n    all_xyz = np.concatenate([xyzs, add_xyz], axis=0)\n\n    print(\"added points: \", add_xyz.shape[0])\n    \n    # save to ply\n    import point_cloud_utils as pcu\n\n    # pcu.save_mesh_vf(save_path, all_xyz, np.zeros((0, 3), dtype=np.int32))\n\n    add_path = os.path.join(os.path.dirname(save_path), \"extra_filled_points_thresh_{}.ply\".format(thresh))\n    pcu.save_mesh_v(add_path, add_xyz)\n\n    \n\nif __name__ == \"__main__\":\n    Fire(convert_gaussian_to_mesh)\n    # Fire(internal_filling)\n"
  },
  {
    "path": "projects/uncleaned_train/exp_motion/train/fast_train_velocity.py",
    "content": "import argparse\nimport os\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom torch import Tensor\nfrom jaxtyping import Float, Int, Shaped\nfrom typing import List\n\nimport point_cloud_utils as pcu\n\nfrom accelerate.utils import ProjectConfiguration\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import set_seed\nfrom accelerate import Accelerator, DistributedDataParallelKwargs\n\nimport numpy as np\nimport logging\nimport argparse\nimport shutil\nimport wandb\nimport torch\nimport os\nfrom motionrep.utils.config import create_config\nfrom motionrep.utils.optimizer import get_linear_schedule_with_warmup\nfrom time import time\nfrom omegaconf import OmegaConf\nimport numpy as np\n\n# from motionrep.utils.torch_utils import get_sync_time\nfrom einops import rearrange, repeat\n\nfrom motionrep.gaussian_3d.gaussian_renderer.feat_render import render_feat_gaussian\nfrom motionrep.gaussian_3d.scene import GaussianModel\n\nfrom motionrep.data.datasets.multiview_dataset import MultiviewImageDataset\nfrom motionrep.data.datasets.multiview_video_dataset import (\n    MultiviewVideoDataset,\n    camera_dataset_collate_fn,\n)\n\nfrom motionrep.data.datasets.multiview_dataset import (\n    camera_dataset_collate_fn as camera_dataset_collate_fn_img,\n)\n\nfrom typing import NamedTuple\nimport torch.nn.functional as F\n\nfrom motionrep.utils.img_utils import compute_psnr, compute_ssim\nfrom thirdparty_code.warp_mpm.mpm_data_structure import (\n    MPMStateStruct,\n    MPMModelStruct,\n    get_float_array_product,\n)\nfrom thirdparty_code.warp_mpm.mpm_solver_diff import MPMWARPDiff\nfrom thirdparty_code.warp_mpm.warp_utils import from_torch_safe\nfrom thirdparty_code.warp_mpm.gaussian_sim_utils import get_volume\nimport warp as wp\nimport random\n\nfrom local_utils import (\n    cycle,\n    create_spatial_fields,\n    find_far_points,\n    LinearStepAnneal,\n    apply_grid_bc_w_freeze_pts,\n    render_gaussian_seq_w_mask_with_disp,\n    downsample_with_kmeans_gpu,\n)\nfrom interface import MPMDifferentiableSimulation\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef create_dataset(args):\n    assert args.dataset_res in [\"middle\", \"small\", \"large\"]\n    if args.dataset_res == \"middle\":\n        res = [320, 576]\n    elif args.dataset_res == \"small\":\n        res = [192, 320]\n    elif args.dataset_res == \"large\":\n        res = [576, 1024]\n    else:\n        raise NotImplementedError\n\n    video_dir_name = \"videos_2\"\n    dataset = MultiviewVideoDataset(\n        args.dataset_dir,\n        use_white_background=False,\n        resolution=res,\n        scale_x_angle=1.0,\n        video_dir_name=video_dir_name,\n    )\n\n    test_dataset = MultiviewImageDataset(\n        args.dataset_dir,\n        use_white_background=False,\n        resolution=res,\n        # use_index=list(range(0, 30, 4)),\n        # use_index=[0],\n        scale_x_angle=1.0,\n        fitler_with_renderd=True,\n        load_imgs=False,\n    )\n    print(\"len of test dataset\", len(test_dataset))\n    return dataset, test_dataset\n\n\nclass Trainer:\n    def __init__(self, args):\n        self.args = args\n\n        self.ssim = args.ssim\n        args.warmup_step = int(args.warmup_step * args.gradient_accumulation_steps)\n        args.train_iters = int(args.train_iters * args.gradient_accumulation_steps)\n        os.environ[\"WANDB__SERVICE_WAIT\"] = \"600\"\n        args.wandb_name += (\n            \"decay_{}_substep_{}_{}_lr_{}_tv_{}_iters_{}_sw_{}_cw_{}\".format(\n                args.loss_decay,\n                args.substep,\n                args.model,\n                args.lr,\n                args.tv_loss_weight,\n                args.train_iters,\n                args.start_window_size,\n                args.compute_window,\n            )\n        )\n\n        logging_dir = os.path.join(args.output_dir, args.wandb_name)\n        accelerator_project_config = ProjectConfiguration(logging_dir=logging_dir)\n        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n        accelerator = Accelerator(\n            gradient_accumulation_steps=1,  # args.gradient_accumulation_steps,\n            mixed_precision=\"no\",\n            log_with=\"wandb\",\n            project_config=accelerator_project_config,\n            kwargs_handlers=[ddp_kwargs],\n        )\n        self.gradient_accumulation_steps = args.gradient_accumulation_steps\n        logging.basicConfig(\n            format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n            datefmt=\"%m/%d/%Y %H:%M:%S\",\n            level=logging.INFO,\n        )\n        logger.info(accelerator.state, main_process_only=False)\n\n        set_seed(args.seed + accelerator.process_index)\n        print(\"process index\", accelerator.process_index)\n        if accelerator.is_main_process:\n            output_path = os.path.join(logging_dir, f\"seed{args.seed}\")\n            os.makedirs(output_path, exist_ok=True)\n            self.output_path = output_path\n\n        self.rand_bg = args.rand_bg\n        # setup the dataset\n        dataset, test_dataset = create_dataset(args)\n        self.test_dataset = test_dataset\n\n        dataset_dir = test_dataset.data_dir\n        self.dataset = dataset\n\n        gaussian_path = os.path.join(dataset_dir, \"point_cloud.ply\")\n        aabb = self.setup_eval(\n            args,\n            gaussian_path,\n            white_background=True,\n        )\n        self.aabb = aabb\n\n        self.num_frames = int(args.num_frames)\n        self.window_size_schduler = LinearStepAnneal(\n            args.train_iters,\n            start_state=[args.start_window_size],\n            end_state=[13],\n            plateau_iters=0,\n            warmup_step=300,\n        )\n\n        test_dataloader = torch.utils.data.DataLoader(\n            test_dataset,\n            batch_size=args.batch_size,\n            shuffle=False,\n            drop_last=True,\n            num_workers=0,\n            collate_fn=camera_dataset_collate_fn_img,\n        )\n        # why prepare here again?\n        test_dataloader = accelerator.prepare(test_dataloader)\n        self.test_dataloader = cycle(test_dataloader)\n\n        dataloader = torch.utils.data.DataLoader(\n            dataset,\n            batch_size=args.batch_size,\n            shuffle=False,\n            drop_last=False,\n            num_workers=0,\n            collate_fn=camera_dataset_collate_fn,\n        )\n        # why prepare here again?\n        dataloader = accelerator.prepare(dataloader)\n        self.dataloader = cycle(dataloader)\n\n        self.train_iters = args.train_iters\n        self.accelerator = accelerator\n        # init traiable params\n        E_nu_list = self.init_trainable_params()\n        for p in E_nu_list:\n            p.requires_grad = True\n        self.E_nu_list = E_nu_list\n\n        self.setup_simulation(dataset_dir, grid_size=args.grid_size)\n\n        if args.checkpoint_path == \"None\":\n            args.checkpoint_path = None\n        if args.checkpoint_path is not None:\n            self.load(args.checkpoint_path)\n            trainable_params = list(self.sim_fields.parameters()) + self.E_nu_list\n            optim_list = [\n                {\"params\": self.E_nu_list, \"lr\": args.lr * 1e-10},\n                {\n                    \"params\": self.sim_fields.parameters(),\n                    \"lr\": args.lr,\n                    \"weight_decay\": 1e-4,\n                },\n                # {\"params\": self.velo_fields.parameters(), \"lr\": args.lr * 1e-3, \"weight_decay\": 1e-4},\n            ]\n            self.freeze_velo = True\n            self.velo_optimizer = None\n        else:\n            trainable_params = list(self.sim_fields.parameters()) + self.E_nu_list\n            optim_list = [\n                {\"params\": self.E_nu_list, \"lr\": args.lr * 1e-10},\n                {\n                    \"params\": self.sim_fields.parameters(),\n                    \"lr\": args.lr * 1e-10,\n                    \"weight_decay\": 1e-4,\n                },\n            ]\n            self.freeze_velo = False\n            self.window_size_schduler.warmup_step = 800\n\n            velo_optim = [\n                {\n                    \"params\": self.velo_fields.parameters(),\n                    \"lr\": args.lr * 0.1,\n                    \"weight_decay\": 1e-4,\n                },\n            ]\n            self.velo_optimizer = torch.optim.AdamW(\n                velo_optim,\n                lr=args.lr,\n                weight_decay=0.0,\n            )\n            self.velo_scheduler = get_linear_schedule_with_warmup(\n                optimizer=self.velo_optimizer,\n                num_warmup_steps=args.warmup_step,\n                num_training_steps=args.train_iters,\n            )\n            self.velo_optimizer, self.velo_scheduler = accelerator.prepare(\n                self.velo_optimizer, self.velo_scheduler\n            )\n\n        self.optimizer = torch.optim.AdamW(\n            optim_list,\n            lr=args.lr,\n            weight_decay=0.0,\n        )\n        self.trainable_params = trainable_params\n        self.scheduler = get_linear_schedule_with_warmup(\n            optimizer=self.optimizer,\n            num_warmup_steps=args.warmup_step,\n            num_training_steps=args.train_iters,\n        )\n        self.sim_fields, self.optimizer, self.scheduler = accelerator.prepare(\n            self.sim_fields, self.optimizer, self.scheduler\n        )\n        self.velo_fields = accelerator.prepare(self.velo_fields)\n\n        # setup train info\n        self.step = 0\n        self.batch_size = args.batch_size\n        self.tv_loss_weight = args.tv_loss_weight\n\n        self.log_iters = args.log_iters\n        self.wandb_iters = args.wandb_iters\n        self.max_grad_norm = args.max_grad_norm\n\n        self.use_wandb = args.use_wandb\n        if self.accelerator.is_main_process:\n            if args.use_wandb:\n                run = wandb.init(\n                    config=dict(args),\n                    dir=self.output_path,\n                    **{\n                        \"mode\": \"online\",\n                        \"entity\": args.wandb_entity,\n                        \"project\": args.wandb_project,\n                    },\n                )\n                wandb.run.log_code(\".\")\n                wandb.run.name = args.wandb_name\n                print(f\"run dir: {run.dir}\")\n                self.wandb_folder = run.dir\n                os.makedirs(self.wandb_folder, exist_ok=True)\n\n    def init_trainable_params(\n        self,\n    ):\n\n        # init young modulus and poisson ratio\n\n        young_numpy = np.exp(np.random.uniform(np.log(1e-3), np.log(1e3))).astype(\n            np.float32\n        )\n        young_numpy = 1e6 * 1.0\n\n        young_modulus = torch.tensor(young_numpy, dtype=torch.float32).to(\n            self.accelerator.device\n        )\n\n        poisson_numpy = np.random.uniform(0.1, 0.4)\n        poisson_ratio = torch.tensor(poisson_numpy, dtype=torch.float32).to(\n            self.accelerator.device\n        )\n\n        trainable_params = [young_modulus, poisson_ratio]\n\n        print(\n            \"init young modulus: \",\n            young_modulus.item(),\n            \"poisson ratio: \",\n            poisson_ratio.item(),\n        )\n        return trainable_params\n\n    def setup_simulation(self, dataset_dir, grid_size=100):\n\n        device = \"cuda:{}\".format(self.accelerator.process_index)\n\n        xyzs = self.render_params.gaussians.get_xyz.detach().clone()\n        sim_xyzs = xyzs[self.sim_mask_in_raw_gaussian, :]\n        sim_cov = (\n            self.render_params.gaussians.get_covariance()[\n                self.sim_mask_in_raw_gaussian, :\n            ]\n            .detach()\n            .clone()\n        )\n\n        # scale, and shift\n        pos_max = sim_xyzs.max()\n        pos_min = sim_xyzs.min()\n        scale = (pos_max - pos_min) * 1.8\n        shift = -pos_min + (pos_max - pos_min) * 0.25\n        self.scale, self.shift = scale, shift\n        print(\"scale, shift\", scale, shift)\n\n        # filled\n        filled_in_points_path = os.path.join(dataset_dir, \"internal_filled_points.ply\")\n\n        if os.path.exists(filled_in_points_path):\n            fill_xyzs = pcu.load_mesh_v(filled_in_points_path)  # [n, 3]\n            fill_xyzs = fill_xyzs[\n                np.random.choice(\n                    fill_xyzs.shape[0], int(fill_xyzs.shape[0] * 0.25), replace=False\n                )\n            ]\n            fill_xyzs = torch.from_numpy(fill_xyzs).float().to(\"cuda\")\n            self.fill_xyzs = fill_xyzs\n            print(\n                \"loaded {} internal filled points from: \".format(fill_xyzs.shape[0]),\n                filled_in_points_path,\n            )\n        else:\n            self.fill_xyzs = None\n\n        if self.fill_xyzs is not None:\n            render_mask_in_sim_pts = torch.cat(\n                [\n                    torch.ones_like(sim_xyzs[:, 0]).bool(),\n                    torch.zeros_like(fill_xyzs[:, 0]).bool(),\n                ],\n                dim=0,\n            ).to(device)\n            sim_xyzs = torch.cat([sim_xyzs, fill_xyzs], dim=0)\n            sim_cov = torch.cat(\n                [sim_cov, sim_cov.new_ones((fill_xyzs.shape[0], sim_cov.shape[-1]))],\n                dim=0,\n            )\n            self.render_mask = render_mask_in_sim_pts\n        else:\n            self.render_mask = torch.ones_like(sim_xyzs[:, 0]).bool().to(device)\n\n        sim_xyzs = (sim_xyzs + shift) / scale\n\n        sim_aabb = torch.stack(\n            [torch.min(sim_xyzs, dim=0)[0], torch.max(sim_xyzs, dim=0)[0]], dim=0\n        )\n        sim_aabb = (\n            sim_aabb - torch.mean(sim_aabb, dim=0, keepdim=True)\n        ) * 1.2 + torch.mean(sim_aabb, dim=0, keepdim=True)\n\n        print(\"simulation aabb: \", sim_aabb)\n\n        # point cloud resample with kmeans\n        downsample_scale = self.args.downsample_scale\n        num_cluster = int(sim_xyzs.shape[0] * downsample_scale)\n        sim_xyzs = downsample_with_kmeans_gpu(sim_xyzs, num_cluster)\n\n        sim_gaussian_pos = self.render_params.gaussians.get_xyz.detach().clone()[\n            self.sim_mask_in_raw_gaussian, :\n        ]\n        sim_gaussian_pos = (sim_gaussian_pos + shift) / scale\n\n        cdist = torch.cdist(sim_gaussian_pos, sim_xyzs) * -1.0\n        _, top_k_index = torch.topk(cdist, self.args.top_k, dim=-1)\n        self.top_k_index = top_k_index\n\n        print(\"Downsampled to: \", sim_xyzs.shape[0], \"by\", downsample_scale)\n\n        # compute volue for each point.\n        points_volume = get_volume(sim_xyzs.detach().cpu().numpy())\n\n        num_particles = sim_xyzs.shape[0]\n\n        wp.init()\n        wp.config.mode = \"debug\"\n        wp.config.verify_cuda = True\n\n        mpm_state = MPMStateStruct()\n        mpm_state.init(num_particles, device=device, requires_grad=True)\n\n        self.particle_init_position = sim_xyzs.clone()\n\n        mpm_state.from_torch(\n            self.particle_init_position.clone(),\n            torch.from_numpy(points_volume).float().to(device).clone(),\n            None,  # set cov to None, since it is not used.\n            device=device,\n            requires_grad=True,\n            n_grid=grid_size,\n            grid_lim=1.0,\n        )\n        mpm_model = MPMModelStruct()\n        mpm_model.init(num_particles, device=device, requires_grad=True)\n        mpm_model.init_other_params(n_grid=grid_size, grid_lim=1.0, device=device)\n\n        material_params = {\n            \"material\": \"jelly\",  # \"jelly\", \"metal\", \"sand\", \"foam\", \"snow\", \"plasticine\", \"neo-hookean\"\n            \"g\": [0.0, 0.0, 0.0],\n            \"density\": 2000,  # kg / m^3\n            \"grid_v_damping_scale\": 1.1,  # 0.999,\n        }\n\n        self.v_damping = material_params[\"grid_v_damping_scale\"]\n        self.material_name = material_params[\"material\"]\n        mpm_solver = MPMWARPDiff(\n            num_particles, n_grid=grid_size, grid_lim=1.0, device=device\n        )\n        mpm_solver.set_parameters_dict(mpm_model, mpm_state, material_params)\n\n        self.mpm_state, self.mpm_model, self.mpm_solver = (\n            mpm_state,\n            mpm_model,\n            mpm_solver,\n        )\n\n        # setup boundary condition:\n        moving_pts_path = os.path.join(dataset_dir, \"moving_part_points.ply\")\n        if os.path.exists(moving_pts_path):\n            moving_pts = pcu.load_mesh_v(moving_pts_path)\n            moving_pts = torch.from_numpy(moving_pts).float().to(device)\n            moving_pts = (moving_pts + shift) / scale\n            freeze_mask = find_far_points(\n                sim_xyzs, moving_pts, thres=0.25 / grid_size\n            ).bool()\n            freeze_pts = sim_xyzs[freeze_mask, :]\n\n            grid_freeze_mask = apply_grid_bc_w_freeze_pts(\n                grid_size, 1.0, freeze_pts, mpm_solver\n            )\n            self.freeze_mask = freeze_mask\n\n            # does not prefer boundary condition on particle\n            # freeze_mask_select = setup_boundary_condition_with_points(sim_xyzs, moving_pts,\n            #                                                         self.mpm_solver, self.mpm_state, thres=0.5 / grid_size)\n            # self.freeze_mask = freeze_mask_select.bool()\n        else:\n            raise NotImplementedError\n\n        num_freeze_pts = self.freeze_mask.sum()\n        print(\n            \"num freeze pts in total\",\n            num_freeze_pts.item(),\n            \"num moving pts\",\n            num_particles - num_freeze_pts.item(),\n        )\n\n        # init fields for simulation, e.g. density, external force, etc.\n\n        # padd init density, youngs,\n        density = (\n            torch.ones_like(self.particle_init_position[..., 0])\n            * material_params[\"density\"]\n        )\n        youngs_modulus = (\n            torch.ones_like(self.particle_init_position[..., 0])\n            * self.E_nu_list[0].detach()\n        )\n        poisson_ratio = torch.ones_like(self.particle_init_position[..., 0]) * 0.3\n\n        # load stem for higher density\n        stem_pts_path = os.path.join(dataset_dir, \"stem_points.ply\")\n        if os.path.exists(stem_pts_path):\n            stem_pts = pcu.load_mesh_v(stem_pts_path)\n            stem_pts = torch.from_numpy(stem_pts).float().to(device)\n            stem_pts = (stem_pts + shift) / scale\n            no_stem_mask = find_far_points(\n                sim_xyzs, stem_pts, thres=2.0 / grid_size\n            ).bool()\n            stem_mask = torch.logical_not(no_stem_mask)\n            density[stem_mask] = 2000\n            print(\"num stem pts\", stem_mask.sum().item())\n\n        self.density = density\n        self.young_modulus = youngs_modulus\n        self.poisson_ratio = poisson_ratio\n\n        # set density, youngs, poisson\n        mpm_state.reset_density(\n            density.clone(),\n            torch.ones_like(density).type(torch.int),\n            device,\n            update_mass=True,\n        )\n        mpm_solver.set_E_nu_from_torch(\n            mpm_model, youngs_modulus.clone(), poisson_ratio.clone(), device\n        )\n        mpm_solver.prepare_mu_lam(mpm_model, mpm_state, device)\n\n        self.sim_fields = create_spatial_fields(self.args, 1, sim_aabb)\n        self.sim_fields.train()\n\n        self.args.sim_res = 24\n        # self.velo_fields = create_velocity_model(self.args, sim_aabb)\n        self.velo_fields = create_spatial_fields(\n            self.args, 3, sim_aabb, add_entropy=False\n        )\n        self.velo_fields.train()\n\n    def set_simulation_state(\n        self,\n        init_xyzs,\n        init_velocity,\n        device,\n        requires_grad=False,\n        use_precompute_F=False,\n        use_density=True,\n    ):\n\n        initial_position_time0 = self.particle_init_position.clone()\n\n        if use_precompute_F:\n            self.mpm_state.reset_state(\n                initial_position_time0,\n                None,\n                init_velocity.clone(),\n                device=device,\n                requires_grad=True,\n            )\n\n            init_xyzs_wp = from_torch_safe(\n                init_xyzs.clone().detach().contiguous(),\n                dtype=wp.vec3,\n                requires_grad=True,\n            )\n            self.mpm_solver.restart_and_compute_F_C(\n                self.mpm_model, self.mpm_state, init_xyzs_wp, device=device\n            )\n        else:\n            self.mpm_state.reset_state(\n                init_xyzs.clone(),\n                None,\n                init_velocity.clone(),\n                device=device,\n                requires_grad=True,\n            )\n\n    def get_density_velocity(self, time_stamp: float, device, requires_grad=True):\n\n        initial_position_time0 = self.particle_init_position.clone()\n\n        query_mask = torch.logical_not(self.freeze_mask)\n        query_pts = initial_position_time0[query_mask, :]\n        sim_params = self.sim_fields(query_pts)\n        # density = sim_params[..., 0]\n\n        # 0.1\n        young_modulus = sim_params[..., 0]\n        # young_modulus = torch.exp(sim_params[..., 0]) + init_young\n        young_modulus = torch.clamp(young_modulus, 1e-3, 1e8)\n\n        # young_padded = torch.ones_like(initial_position_time0[..., 0]) * init_young\n        young_padded = self.young_modulus.detach().clone()\n        young_padded[query_mask] = young_padded[query_mask] + young_modulus * 1\n\n        density = self.density.detach().clone()\n\n        velocity = self.velo_fields(query_pts)[..., :3]\n\n        # scaling.\n        velocity = velocity * 0.1\n\n        return density, young_padded, velocity, query_mask\n\n    def train_one_step(self):\n\n        self.sim_fields.train()\n        self.velo_fields.train()\n        accelerator = self.accelerator\n        device = \"cuda:{}\".format(accelerator.process_index)\n        data = next(self.dataloader)\n        cam = data[\"cam\"][0]\n\n        time_stamps = np.linspace(0, 1, self.num_frames).astype(np.float32)[1:]\n\n        gt_videos = data[\"video_clip\"][0, 1 : self.num_frames, ...]\n\n        window_size = int(self.window_size_schduler.compute_state(self.step)[0])\n        print(\"window size\", window_size)\n        stop_velo_opt_thres = 4\n        do_velo_opt = not self.freeze_velo\n        if not do_velo_opt:\n            stop_velo_opt_thres = (\n                0  # stop velocity optimization if we are loading from checkpoint\n            )\n        if window_size >= stop_velo_opt_thres:\n            self.velo_fields.eval()\n            do_velo_opt = False\n\n        rendered_video_list = []\n        log_loss_dict = {\n            \"loss\": [],\n            \"l2_loss\": [],\n            \"psnr\": [],\n            \"ssim\": [],\n        }\n\n        init_xyzs = self.particle_init_position.clone()\n        num_particles = init_xyzs.shape[0]\n        # delta_time = 1.0 / (self.num_frames - 1)\n        delta_time = 1.0 / 30\n        substep_size = delta_time / self.args.substep\n        num_substeps = int(delta_time / substep_size)\n\n        start_time_idx = max(0, window_size - self.args.compute_window)\n        for time_idx in range(start_time_idx, window_size):\n            # time_stamp = time_stamps[time_idx]\n            time_stamp = time_stamps[0]  # fix to begining.. Start at the begining\n\n            density, youngs_padded, init_velocity, query_mask = (\n                self.get_density_velocity(time_stamp, device)\n            )\n\n            if not do_velo_opt:\n                init_velocity = init_velocity.detach()\n            padded_velocity = torch.zeros_like(init_xyzs)\n            padded_velocity[query_mask, :] = init_velocity\n\n            gt_frame = gt_videos[[time_idx]]\n\n            extra_no_grad_step = max(\n                0, (time_idx - self.args.grad_window + 1) * num_substeps\n            )\n            if do_velo_opt:\n                extra_no_grad_step = 0\n\n            num_step_with_grad = num_substeps * (time_idx + 1) - extra_no_grad_step\n\n            particle_pos = MPMDifferentiableSimulation.apply(\n                self.mpm_solver,\n                self.mpm_state,\n                self.mpm_model,\n                0,\n                substep_size,\n                num_step_with_grad,\n                init_xyzs,\n                padded_velocity,\n                youngs_padded,\n                self.E_nu_list[1],\n                density,\n                query_mask,\n                None,\n                device,\n                True,\n                extra_no_grad_step,\n            )\n\n            gaussian_pos = particle_pos * self.scale - self.shift\n            undeformed_gaussian_pos = (\n                self.particle_init_position * self.scale - self.shift\n            )\n            disp_offset = gaussian_pos - undeformed_gaussian_pos.detach()\n            # gaussian_pos.requires_grad = True\n\n            simulated_video = render_gaussian_seq_w_mask_with_disp(\n                cam,\n                self.render_params,\n                undeformed_gaussian_pos.detach(),\n                self.top_k_index,\n                [disp_offset],\n                self.sim_mask_in_raw_gaussian,\n            )\n\n            # print(\"debug\", simulated_video.shape, gt_frame.shape, gaussian_pos.shape, init_xyzs.shape, density.shape, query_mask.sum().item())\n            rendered_video_list.append(simulated_video.detach())\n\n            l2_loss = 0.5 * F.mse_loss(simulated_video, gt_frame, reduction=\"mean\")\n            ssim_loss = compute_ssim(simulated_video, gt_frame)\n            loss = l2_loss * (1.0 - self.ssim) + (1.0 - ssim_loss) * self.ssim\n\n            sm_velo_loss = self.velo_fields.compute_smoothess_loss()\n            if time_idx > 2 or window_size > stop_velo_opt_thres:\n                sm_velo_loss = sm_velo_loss.detach()\n            sm_spatial_loss = self.sim_fields.compute_smoothess_loss()\n\n            sm_loss = sm_velo_loss + sm_spatial_loss\n            loss = (\n                loss * (self.args.loss_decay**time_idx) + sm_loss * self.tv_loss_weight\n            )\n            loss = loss / self.args.compute_window\n            loss.backward()\n\n            with torch.no_grad():\n                psnr = compute_psnr(simulated_video, gt_frame).mean()\n                log_loss_dict[\"loss\"].append(loss.item())\n                log_loss_dict[\"l2_loss\"].append(l2_loss.item())\n                log_loss_dict[\"psnr\"].append(psnr.item())\n                log_loss_dict[\"ssim\"].append(ssim_loss.item())\n\n            # subtep-4: pass gradients to mpm solver\n\n        nu_grad_norm = self.E_nu_list[1].grad.norm(2).item()\n        spatial_grad_norm = 0\n        for p in self.sim_fields.parameters():\n            if p.grad is not None:\n                spatial_grad_norm += p.grad.norm(2).item()\n        velo_grad_norm = 0\n        for p in self.velo_fields.parameters():\n            if p.grad is not None:\n                velo_grad_norm += p.grad.norm(2).item()\n\n        renderd_video = torch.cat(rendered_video_list, dim=0)\n        renderd_video = torch.clamp(renderd_video, 0.0, 1.0)\n        visual_video = (renderd_video.detach().cpu().numpy() * 255.0).astype(np.uint8)\n        gt_video = (gt_videos.detach().cpu().numpy() * 255.0).astype(np.uint8)\n\n        if (\n            self.step % self.gradient_accumulation_steps == 0\n            or self.step == (self.train_iters - 1)\n            or (self.step % self.log_iters == self.log_iters - 1)\n        ):\n\n            torch.nn.utils.clip_grad_norm_(\n                self.trainable_params,\n                self.max_grad_norm,\n                error_if_nonfinite=False,\n            )  # error if nonfinite is false\n\n            self.optimizer.step()\n            self.optimizer.zero_grad()\n            if do_velo_opt:\n                assert self.velo_optimizer is not None\n                torch.nn.utils.clip_grad_norm_(\n                    self.velo_fields.parameters(),\n                    self.max_grad_norm * 10,\n                    error_if_nonfinite=False,\n                )  # error if nonfinite is false\n                self.velo_optimizer.step()\n                self.velo_optimizer.zero_grad()\n                self.velo_scheduler.step()\n            with torch.no_grad():\n                self.E_nu_list[0].data.clamp_(1e-3, 2000)\n                self.E_nu_list[1].data.clamp_(1e-2, 0.449)\n        self.scheduler.step()\n\n        for k, v in log_loss_dict.items():\n            log_loss_dict[k] = np.mean(v)\n\n        print(log_loss_dict)\n        print(\n            \"nu: \",\n            self.E_nu_list[1].item(),\n            nu_grad_norm,\n            spatial_grad_norm,\n            velo_grad_norm,\n            \"young_mean, max:\",\n            youngs_padded.mean().item(),\n            youngs_padded.max().item(),\n            do_velo_opt,\n        )\n\n        if accelerator.is_main_process and (self.step % self.wandb_iters == 0):\n            with torch.no_grad():\n                wandb_dict = {\n                    \"nu_grad_norm\": nu_grad_norm,\n                    \"spatial_grad_norm\": spatial_grad_norm,\n                    \"velo_grad_norm\": velo_grad_norm,\n                    \"nu\": self.E_nu_list[1].item(),\n                    # \"mean_density\": density.mean().item(),\n                    \"mean_E\": youngs_padded.mean().item(),\n                    \"max_E\": youngs_padded.max().item(),\n                    \"min_E\": youngs_padded.min().item(),\n                    \"smoothness_loss\": sm_loss.item(),\n                    \"window_size\": window_size,\n                    \"velo_mean\": init_velocity.mean().item(),\n                    \"velo_max\": init_velocity.max().item(),\n                }\n\n                simulated_video = self.inference(cam)\n                sim_video_torch = (\n                    torch.from_numpy(simulated_video).float().to(device) / 255.0\n                )\n                gt_video_torch = torch.from_numpy(gt_video).float().to(device) / 255.0\n\n                full_psnr = compute_psnr(sim_video_torch[1:], gt_video_torch)\n\n                first_psnr = full_psnr[:6].mean().item()\n                last_psnr = full_psnr[-6:].mean().item()\n                full_psnr = full_psnr.mean().item()\n                wandb_dict[\"full_psnr\"] = full_psnr\n                wandb_dict[\"first_psnr\"] = first_psnr\n                wandb_dict[\"last_psnr\"] = last_psnr\n                wandb_dict.update(log_loss_dict)\n\n                if self.step % int(5 * self.wandb_iters) == 0:\n\n                    wandb_dict[\"rendered_video\"] = wandb.Video(\n                        visual_video, fps=visual_video.shape[0]\n                    )\n\n                    wandb_dict[\"gt_video\"] = wandb.Video(\n                        gt_video,\n                        fps=gt_video.shape[0],\n                    )\n\n                    wandb_dict[\"inference_video\"] = wandb.Video(\n                        simulated_video,\n                        fps=simulated_video.shape[0],\n                    )\n\n                    simulated_video = self.inference(\n                        cam, num_sec=3, substep=self.args.substep\n                    )\n                    wandb_dict[\"inference_video_t3\"] = wandb.Video(\n                        simulated_video,\n                        fps=simulated_video.shape[0] // 3,\n                    )\n\n                    simulated_video = self.inference(\n                        cam, velo_scaling=5.0, num_sec=3, substep=self.args.substep\n                    )\n                    wandb_dict[\"inference_video_v5_t3\"] = wandb.Video(\n                        simulated_video,\n                        fps=simulated_video.shape[0] // 3,\n                    )\n\n                if self.use_wandb:\n                    wandb.log(wandb_dict, step=self.step)\n\n        self.accelerator.wait_for_everyone()\n\n    def train(self):\n        # might remove tqdm when multiple node\n        for index in tqdm(range(self.step, self.train_iters), desc=\"Training progress\"):\n            self.train_one_step()\n            if self.step % self.log_iters == self.log_iters - 1:\n                if self.accelerator.is_main_process:\n                    self.save()\n                    # self.test()\n            # self.accelerator.wait_for_everyone()\n            self.step += 1\n        if self.accelerator.is_main_process:\n            self.save()\n\n    @torch.no_grad()\n    def inference(\n        self, cam, velo_scaling=1.0, num_sec=1, nu=None, young_scaling=1.0, substep=20\n    ):\n\n        self.sim_fields.eval()\n        self.velo_fields.eval()\n\n        device = \"cuda:{}\".format(self.accelerator.process_index)\n\n        time_stamps = np.linspace(0, 1, self.num_frames).astype(np.float32)[1:]\n        time_idx = 0\n        time_stamp = time_stamps[time_idx]\n\n        density, youngs_padded, init_velocity, query_mask = self.get_density_velocity(\n            time_stamp, device\n        )\n        youngs_padded = youngs_padded * young_scaling\n        init_xyzs = self.particle_init_position\n\n        padded_velocity = torch.zeros_like(init_xyzs)\n        padded_velocity[query_mask, :] = init_velocity * velo_scaling\n\n        num_particles = init_xyzs.shape[0]\n\n        delta_time = 1.0 / (self.num_frames - 1)\n        delta_time = 1.0 / 30\n        substep_size = delta_time / substep\n        num_substeps = int(delta_time / substep_size)\n        # reset state\n        self.set_simulation_state(\n            init_xyzs,\n            padded_velocity,\n            device,\n            requires_grad=True,\n            use_precompute_F=False,\n            use_density=False,\n        )\n\n        if nu is None:\n            E, nu = self.E_nu_list[0].item(), self.E_nu_list[1].item()\n        E_wp = from_torch_safe(youngs_padded, dtype=wp.float32, requires_grad=False)\n        self.mpm_solver.set_E_nu(self.mpm_model, E_wp, nu, device=device)\n        self.mpm_solver.prepare_mu_lam(self.mpm_model, self.mpm_state, device=device)\n\n        wp.launch(\n            kernel=get_float_array_product,\n            dim=num_particles,\n            inputs=[\n                self.mpm_state.particle_density,\n                self.mpm_state.particle_vol,\n                self.mpm_state.particle_mass,\n            ],\n            device=device,\n        )\n\n        pos_list = [self.particle_init_position.clone() * self.scale - self.shift]\n\n        for i in tqdm(range((self.num_frames - 1) * num_sec)):\n            for substep in range(num_substeps):\n                self.mpm_solver.p2g2p(\n                    self.mpm_model,\n                    self.mpm_state,\n                    substep,\n                    substep_size,\n                    device=\"cuda:0\",\n                )\n\n            pos = wp.to_torch(self.mpm_state.particle_x).clone()\n            pos = (pos * self.scale) - self.shift\n            pos_list.append(pos)\n\n        init_pos = pos_list[0].clone()\n        pos_diff_list = [_ - init_pos for _ in pos_list]\n\n        video_array = render_gaussian_seq_w_mask_with_disp(\n            cam,\n            self.render_params,\n            init_pos,\n            self.top_k_index,\n            pos_diff_list,\n            self.sim_mask_in_raw_gaussian,\n        )\n\n        video_numpy = video_array.detach().cpu().numpy() * 255\n        video_numpy = np.clip(video_numpy, 0, 255).astype(np.uint8)\n\n        return video_numpy\n\n    def save(\n        self,\n    ):\n        # training states\n        output_path = os.path.join(\n            self.output_path, f\"checkpoint_model_{self.step:06d}\"\n        )\n        os.makedirs(output_path, exist_ok=True)\n\n        name_list = [\n            \"velo_fields\",\n            \"sim_fields\",\n        ]\n        for i, model in enumerate(\n            [\n                self.accelerator.unwrap_model(self.velo_fields, keep_fp32_wrapper=True),\n                self.accelerator.unwrap_model(self.sim_fields, keep_fp32_wrapper=True),\n            ]\n        ):\n            model_name = name_list[i]\n            model_path = os.path.join(output_path, model_name + \".pt\")\n            torch.save(model.state_dict(), model_path)\n\n    def load(self, checkpoint_dir):\n        name_list = [\n            \"velo_fields\",\n            \"sim_fields\",\n        ]\n        for i, model in enumerate([self.velo_fields, self.sim_fields]):\n            model_name = name_list[i]\n            if model_name == \"sim_fields\" and (not self.args.run_eval):\n                continue\n            model_path = os.path.join(checkpoint_dir, model_name + \".pt\")\n            model.load_state_dict(torch.load(model_path))\n            print(\"=> loaded: \", model_path)\n\n    def setup_eval(self, args, gaussian_path, white_background=True):\n        # setup gaussians\n        class RenderPipe(NamedTuple):\n            convert_SHs_python = False\n            compute_cov3D_python = False\n            debug = False\n\n        class RenderParams(NamedTuple):\n            render_pipe: RenderPipe\n            bg_color: bool\n            gaussians: GaussianModel\n            camera_list: list\n\n        gaussians = GaussianModel(3)\n        camera_list = self.dataset.test_camera_list\n\n        gaussians.load_ply(gaussian_path)\n        gaussians.detach_grad()\n        print(\n            \"load gaussians from: {}\".format(gaussian_path),\n            \"... num gaussians: \",\n            gaussians._xyz.shape[0],\n        )\n        bg_color = [1, 1, 1] if white_background else [0, 0, 0]\n        background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n        render_pipe = RenderPipe()\n\n        render_params = RenderParams(\n            render_pipe=render_pipe,\n            bg_color=background,\n            gaussians=gaussians,\n            camera_list=camera_list,\n        )\n        self.render_params = render_params\n\n        # get_gaussian scene box\n        scaler = 1.1\n        points = gaussians._xyz\n\n        min_xyz = torch.min(points, dim=0)[0]\n        max_xyz = torch.max(points, dim=0)[0]\n\n        center = (min_xyz + max_xyz) / 2\n\n        scaled_min_xyz = (min_xyz - center) * scaler + center\n        scaled_max_xyz = (max_xyz - center) * scaler + center\n\n        aabb = torch.stack([scaled_min_xyz, scaled_max_xyz], dim=0)\n\n        # add filled in points\n        gaussian_dir = os.path.dirname(gaussian_path)\n\n        clean_points_path = os.path.join(gaussian_dir, \"clean_object_points.ply\")\n        if os.path.exists(clean_points_path):\n            clean_xyzs = pcu.load_mesh_v(clean_points_path)\n            clean_xyzs = torch.from_numpy(clean_xyzs).float().to(\"cuda\")\n            self.clean_xyzs = clean_xyzs\n            print(\n                \"loaded {} clean points from: \".format(clean_xyzs.shape[0]),\n                clean_points_path,\n            )\n            # we can use tight threshold here\n            not_sim_maks = find_far_points(\n                gaussians._xyz, clean_xyzs, thres=0.01\n            ).bool()\n            sim_mask_in_raw_gaussian = torch.logical_not(not_sim_maks)\n            # [N]\n            self.sim_mask_in_raw_gaussian = sim_mask_in_raw_gaussian\n        else:\n            self.clean_xyzs = None\n            self.sim_mask_in_raw_gaussian = torch.ones_like(gaussians._xyz[:, 0]).bool()\n\n        return aabb\n\n    def eval(\n        self,\n    ):\n\n        accelerator = self.accelerator\n        device = \"cuda:{}\".format(accelerator.process_index)\n        data = next(self.dataloader)\n        cam = data[\"cam\"][0]\n\n        nu = 0.1\n        young_scaling = 5000.0\n        substep = 800  # 1e-4\n        video_numpy = self.inference(\n            cam,\n            velo_scaling=5.0,\n            num_sec=3,\n            nu=nu,\n            young_scaling=young_scaling,\n            substep=substep,\n        )\n\n        video_numpy = np.transpose(video_numpy, [0, 2, 3, 1])\n        from motionrep.utils.io_utils import save_video_imageio, save_gif_imageio\n\n        # output_dir = os.path.join(self.output_path, \"simulation\")\n        output_dir = \"./\"\n\n        save_path = os.path.join(\n            output_dir,\n            \"eval_fill2k_video_nu_{}_ys_{}_substep_{}_grid_{}\".format(\n                nu, young_scaling, substep, self.args.grid_size\n            )\n            + \".gif\",\n        )\n        print(\"save video to \", save_path)\n        # save_video_imageio(save_path, video_numpy, fps=12)\n        save_gif_imageio(save_path, video_numpy, fps=12)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--config\", type=str, default=\"config.yml\")\n\n    # dataset params\n    parser.add_argument(\n        \"--dataset_dir\",\n        type=str,\n        default=\"../../data/physics_dreamer/alocasia_nerfstudio\",\n    )\n    parser.add_argument(\n        \"--dataset_res\",\n        type=str,\n        default=\"large\",  # [\"middle\", \"small\", \"large\"]\n    )\n\n    parser.add_argument(\"--model\", type=str, default=\"se3_field\")\n    parser.add_argument(\"--feat_dim\", type=int, default=64)\n    parser.add_argument(\"--num_decoder_layers\", type=int, default=3)\n    parser.add_argument(\"--decoder_hidden_size\", type=int, default=64)\n    parser.add_argument(\"--spatial_res\", type=int, default=32)\n    parser.add_argument(\"--zero_init\", type=bool, default=True)\n    parser.add_argument(\"--entropy_cls\", type=int, default=0)\n\n    parser.add_argument(\"--num_frames\", type=str, default=14)\n\n    parser.add_argument(\"--grid_size\", type=int, default=32)\n    parser.add_argument(\"--sim_res\", type=int, default=24)\n    parser.add_argument(\"--sim_output_dim\", type=int, default=1)\n    parser.add_argument(\"--substep\", type=int, default=96)\n    parser.add_argument(\"--loss_decay\", type=float, default=1.0)\n    parser.add_argument(\"--start_window_size\", type=int, default=2)\n    parser.add_argument(\"--compute_window\", type=int, default=2)\n    parser.add_argument(\"--grad_window\", type=int, default=14)\n\n    parser.add_argument(\"--downsample_scale\", type=float, default=0.1)\n    parser.add_argument(\"--top_k\", type=int, default=8)\n\n    # loss parameters\n    parser.add_argument(\"--tv_loss_weight\", type=float, default=1e-2)\n    parser.add_argument(\"--ssim\", type=float, default=0.5)\n\n    # Logging and checkpointing\n    parser.add_argument(\"--output_dir\", type=str, default=\"../../output/inverse_sim\")\n    parser.add_argument(\"--log_iters\", type=int, default=100)\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\n        \"--checkpoint_path\", type=str, default=None, help=\"path to load checkpoint from\"\n    )\n    # training parameters\n    parser.add_argument(\"--train_iters\", type=int, default=300)\n    parser.add_argument(\"--batch_size\", type=int, default=1)\n    parser.add_argument(\"--lr\", type=float, default=1e-2)\n    parser.add_argument(\"--max_grad_norm\", type=float, default=1.0)\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n    )\n\n    # wandb parameters\n    parser.add_argument(\"--use_wandb\", action=\"store_true\", default=False)\n    parser.add_argument(\"--wandb_entity\", type=str, default=\"mit-cv\")\n    parser.add_argument(\"--wandb_project\", type=str, default=\"inverse_sim\")\n    parser.add_argument(\"--wandb_iters\", type=int, default=20)\n    parser.add_argument(\"--wandb_name\", type=str, required=True)\n    parser.add_argument(\"--run_eval\", action=\"store_true\", default=False)\n\n    # distributed training args\n    parser.add_argument(\n        \"--local_rank\",\n        type=int,\n        default=-1,\n        help=\"For distributed training: local_rank\",\n    )\n\n    args, extra_args = parser.parse_known_args()\n    cfg = create_config(args.config, args, extra_args)\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n    print(args.local_rank, \"local rank\")\n\n    return cfg\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n\n    # torch.backends.cuda.matmul.allow_tf32 = True\n\n    trainer = Trainer(args)\n\n    if args.run_eval:\n        trainer.eval()\n    else:\n        # trainer.debug()\n        trainer.train()\n"
  },
  {
    "path": "projects/uncleaned_train/exp_motion/train/interface.py",
    "content": "from typing import Optional, Tuple\nfrom jaxtyping import Float, Int, Shaped\nimport torch\nimport torch.autograd as autograd\nimport torch.nn as nn\nfrom torch import Tensor\n\nimport warp as wp\n\nfrom thirdparty_code.warp_mpm.warp_utils import from_torch_safe, MyTape, CondTape\nfrom thirdparty_code.warp_mpm.mpm_solver_diff import MPMWARPDiff\nfrom thirdparty_code.warp_mpm.mpm_utils import compute_position_l2_loss, aggregate_grad, compute_posloss_with_grad\nfrom thirdparty_code.warp_mpm.mpm_data_structure import MPMStateStruct, MPMModelStruct, get_float_array_product\nfrom thirdparty_code.warp_mpm.mpm_utils import (compute_Closs_with_grad, compute_Floss_with_grad, \n                                                compute_posloss_with_grad, compute_veloloss_with_grad)\n\n\nclass MPMDifferentiableSimulation(autograd.Function):\n\n    @staticmethod\n    def forward(\n        ctx: autograd.function.FunctionCtx,\n        mpm_solver: MPMWARPDiff,\n        mpm_state: MPMStateStruct,\n        mpm_model: MPMModelStruct,\n        substep: int, \n        substep_size: float, \n        num_substeps: int,\n        init_pos: Float[Tensor, \"n 3\"],\n        init_velocity: Float[Tensor, \"n 3\"],\n        E: Float[Tensor, \"n\"] | Float[Tensor, \"1\"],\n        nu: Float[Tensor, \"n\"] | Float[Tensor, \"1\"],\n        particle_density: Optional[Float[Tensor, \"n\"] | Float[Tensor, \"1\"]]=None,\n        density_change_mask: Optional[Int[Tensor, \"n\"]] = None,\n        static_pos: Optional[Float[Tensor, \"n 3\"]] = None,\n        device: str=\"cuda:0\",\n        requires_grad: bool=True,\n        extra_no_grad_steps: int=0,\n    ) -> Float[Tensor, \"n 3\"]:\n        \"\"\"\n        Args:\n            density_change_mask: [n] 0 or 1.  1 means the density of this particle can change.\n        \"\"\"\n        \n        num_particles = init_pos.shape[0]\n        if static_pos is None:\n            \n            mpm_state.reset_state(\n                init_pos.clone(),\n                None,\n                init_velocity, #.clone(),\n                device=device,\n                requires_grad=requires_grad,\n            )\n        else:\n            mpm_state.reset_state(\n                static_pos.clone(),\n                None,\n                init_velocity, #.clone(),\n                device=device,\n                requires_grad=requires_grad,\n\n            )\n            init_xyzs_wp = from_torch_safe(init_pos.clone().detach().contiguous(), dtype=wp.vec3, requires_grad=requires_grad)\n            mpm_solver.restart_and_compute_F_C(mpm_model, mpm_state, init_xyzs_wp, device=device)\n        \n        if E.ndim == 0:\n            E_inp = E.item() # float\n            ctx.aggregating_E = True\n        else:\n            E_inp = from_torch_safe(E, dtype=wp.float32, requires_grad=requires_grad)\n            ctx.aggregating_E = False\n        if nu.ndim == 0:\n            nu_inp = nu.item() # float\n            ctx.aggregating_nu = True\n        else:\n            nu_inp = from_torch_safe(nu, dtype=wp.float32, requires_grad=requires_grad)\n            ctx.aggregating_nu = False\n            \n        mpm_solver.set_E_nu(mpm_model, E_inp, nu_inp, device=device)\n\n        mpm_state.reset_density(\n            tensor_density=particle_density,\n            selection_mask=density_change_mask,\n            device=device,\n            requires_grad=requires_grad)\n        \n        prev_state = mpm_state\n\n        if extra_no_grad_steps > 0:\n            with torch.no_grad():\n                wp.launch(\n                    kernel=get_float_array_product,\n                    dim=num_particles,\n                    inputs=[\n                        mpm_state.particle_density,\n                        mpm_state.particle_vol,\n                        mpm_state.particle_mass,\n                    ],\n                    device=device,\n                )\n                mpm_solver.prepare_mu_lam(mpm_model, mpm_state, device=device)\n\n                for i in range(extra_no_grad_steps):\n                    next_state = prev_state.partial_clone(requires_grad=requires_grad)\n                    mpm_solver.p2g2p_differentiable(mpm_model, prev_state, next_state, substep_size, device=device)\n                    prev_state = next_state\n        else:\n            prev_state = mpm_state\n\n        wp_tape = MyTape()\n        cond_tape: CondTape = CondTape(wp_tape, requires_grad)\n\n        next_state_list = [] \n        \n        with cond_tape:\n            wp.launch(\n                kernel=get_float_array_product,\n                dim=num_particles,\n                inputs=[\n                    prev_state.particle_density,\n                    prev_state.particle_vol,\n                    prev_state.particle_mass,\n                ],\n                device=device,\n            )\n            mpm_solver.prepare_mu_lam(mpm_model, prev_state, device=device)\n\n            for substep_local in range(num_substeps):\n                next_state = prev_state.partial_clone(requires_grad=requires_grad)\n                mpm_solver.p2g2p_differentiable(mpm_model, prev_state, next_state, substep_size, device=device)\n\n                # next_state = mpm_solver.p2g2p_differentiable(mpm_model, prev_state, substep_size, device=device)\n                next_state_list.append(next_state)\n                prev_state = next_state\n        \n        ctx.mpm_solver = mpm_solver\n        ctx.mpm_state = mpm_state\n        ctx.mpm_model = mpm_model\n        ctx.tape = cond_tape.tape\n        ctx.device = device\n        ctx.num_particles = num_particles\n\n        ctx.next_state_list = next_state_list\n\n        ctx.save_for_backward(density_change_mask)\n\n        last_state = next_state_list[-1]\n        particle_pos = wp.to_torch(last_state.particle_x).detach().clone()\n\n        return particle_pos\n    \n\n    @staticmethod\n    def backward(ctx, out_pos_grad: Float[Tensor, \"n 3\"]):\n        \n        num_particles = ctx.num_particles\n        tape, device = ctx.tape, ctx.device\n        mpm_solver, mpm_state, mpm_model = ctx.mpm_solver, ctx.mpm_state, ctx.mpm_model\n        last_state = ctx.next_state_list[-1]\n        density_change_mask = ctx.saved_tensors[0]\n\n        grad_pos_wp = from_torch_safe(out_pos_grad, dtype=wp.vec3, requires_grad=False)\n        target_pos_detach = wp.clone(last_state.particle_x, device=device, requires_grad=False)\n\n        with tape:\n            loss_wp = torch.zeros(1, device=device)\n            loss_wp = wp.from_torch(loss_wp, requires_grad=True)\n            wp.launch(\n                compute_posloss_with_grad, \n                dim=num_particles,\n                inputs=[\n                    last_state,\n                    target_pos_detach,\n                    grad_pos_wp,\n                    0.5,\n                    loss_wp,\n                ],\n                device=device,\n            )\n\n        tape.backward(loss_wp)\n\n        pos_grad = None\n        if mpm_state.particle_v.grad is None:\n            velo_grad = None\n        else:\n            velo_grad = wp.to_torch(mpm_state.particle_v.grad).detach().clone()\n\n        # print(\"debug back\", velo_grad)\n\n        # grad for E, nu. TODO: add spatially varying E, nu later\n        if ctx.aggregating_E:\n            E_grad = wp.from_torch(torch.zeros(1, device=device), requires_grad=False)\n            wp.launch(\n                aggregate_grad,\n                dim=num_particles,\n                inputs=[\n                    E_grad,\n                    mpm_model.E.grad,\n                ],\n                device=device,\n            )\n            E_grad = wp.to_torch(E_grad)[0] / num_particles\n        else:\n            E_grad = wp.to_torch(mpm_model.E.grad).detach().clone()\n\n        if ctx.aggregating_nu:\n            nu_grad = wp.from_torch(torch.zeros(1, device=device), requires_grad=False)\n            wp.launch(\n                aggregate_grad,\n                dim=num_particles,\n                inputs=[nu_grad, mpm_model.nu.grad],\n                device=device,\n            )\n            nu_grad = wp.to_torch(nu_grad)[0] / num_particles   \n        else:\n            nu_grad = wp.to_torch(mpm_model.nu.grad).detach().clone()\n\n        # grad for density\n        if mpm_state.particle_density.grad is None:\n            density_grad = None\n        else:\n            density_grad = wp.to_torch(mpm_state.particle_density.grad).detach()\n            density_grad = density_grad[density_change_mask.type(torch.bool)]\n        \n        density_mask_grad = None\n        static_pos_grad = None \n\n        # from IPython import embed; embed()\n        tape.zero()\n        # print(density_grad.abs().sum(), velo_grad.abs().sum(), E_grad.abs().item(), nu_grad.abs().item(), \"in sim func\")\n\n        return (None, None, None, None, None, None, \n                pos_grad, velo_grad, E_grad, nu_grad,\n                 density_grad, density_mask_grad, \n                 static_pos_grad, None, None, None)\n    \n\n\nclass MPMDifferentiableSimulationWCheckpoint(autograd.Function):\n    \"\"\"\n    Current version does not support grad for density. \n    Please set vol, mass before calling this function.\n    \"\"\"\n\n    @staticmethod\n    @torch.no_grad()\n    def forward(\n        ctx: autograd.function.FunctionCtx,\n        mpm_solver: MPMWARPDiff,\n        mpm_state: MPMStateStruct,\n        mpm_model: MPMModelStruct,\n        substep_size: float, \n        num_substeps: int,\n        particle_x: Float[Tensor, \"n 3\"], \n        particle_v: Float[Tensor, \"n 3\"],\n        particle_F: Float[Tensor, \"n 3 3\"],\n        particle_C: Float[Tensor, \"n 3 3\"],\n        E: Float[Tensor, \"n\"] | Float[Tensor, \"1\"],\n        nu: Float[Tensor, \"n\"] | Float[Tensor, \"1\"],\n        particle_density: Optional[Float[Tensor, \"n\"] | Float[Tensor, \"1\"]]=None,\n        query_mask: Optional[Int[Tensor, \"n\"]] = None,\n        device: str=\"cuda:0\",\n        requires_grad: bool=True,\n        extra_no_grad_steps: int=0,\n    ) -> Tuple[Float[Tensor, \"n 3\"], Float[Tensor, \"n 3\"], Float[Tensor, \"n 9\"], Float[Tensor, \"n 9\"]]:\n        \"\"\"\n        Args:\n            query_mask: [n] 0 or 1.  1 means the density or young's modulus, or poisson'ratio of this particle can change.\n        \"\"\"\n        \n        # initialization work is done before calling forward! \n\n        num_particles = particle_x.shape[0]\n\n        mpm_state.continue_from_torch(\n            particle_x, particle_v, particle_F, particle_C, device=device, requires_grad=True\n        )\n        # set x, v, F, C.\n\n        if E.ndim == 0:\n            E_inp = E.item() # float\n            ctx.aggregating_E = True\n        else:\n            E_inp = from_torch_safe(E, dtype=wp.float32, requires_grad=True)\n            ctx.aggregating_E = False\n        if nu.ndim == 0:\n            nu_inp = nu.item() # float\n            ctx.aggregating_nu = True\n        else:\n            nu_inp = from_torch_safe(nu, dtype=wp.float32, requires_grad=True)\n            ctx.aggregating_nu = False\n            \n        mpm_solver.set_E_nu(mpm_model, E_inp, nu_inp, device=device)\n        mpm_solver.prepare_mu_lam(mpm_model, mpm_state, device=device)\n\n        mpm_state.reset_density(\n            tensor_density=particle_density,\n            selection_mask=query_mask,\n            device=device,\n            requires_grad=True,\n            update_mass=True)\n        \n        prev_state = mpm_state\n\n        if extra_no_grad_steps > 0:\n            with torch.no_grad():\n                for i in range(extra_no_grad_steps):\n                    next_state = prev_state.partial_clone(requires_grad=True)\n                    mpm_solver.p2g2p_differentiable(mpm_model, prev_state, next_state, substep_size, device=device)\n                    prev_state = next_state\n\n        # following steps will be checkpointed. then replayed in backward\n        ctx.prev_state = prev_state\n        \n        for substep_local in range(num_substeps):\n            next_state = prev_state.partial_clone(requires_grad=True)\n            mpm_solver.p2g2p_differentiable(mpm_model, prev_state, next_state, substep_size, device=device)\n            prev_state = next_state\n        \n        \n        ctx.mpm_solver = mpm_solver\n        ctx.mpm_state = mpm_state # state at the begining of this function; TODO: drop it?\n        ctx.mpm_model = mpm_model\n        ctx.device = device\n        ctx.num_particles = num_particles\n\n        ctx.num_substeps = num_substeps\n        ctx.substep_size = substep_size\n        \n        ctx.save_for_backward(E, nu, particle_density, query_mask)\n\n        last_state = next_state\n        particle_pos = wp.to_torch(last_state.particle_x).detach().clone()\n        particle_velo = wp.to_torch(last_state.particle_v).detach().clone()\n        particle_F = wp.to_torch(last_state.particle_F_trial).detach().clone()\n        particle_C = wp.to_torch(last_state.particle_C).detach().clone()\n\n        return particle_pos, particle_velo, particle_F, particle_C\n    \n\n    @staticmethod\n    def backward(ctx, out_pos_grad: Float[Tensor, \"n 3\"], out_velo_grad: Float[Tensor, \"n 3\"], \n                 out_F_grad: Float[Tensor, \"n 9\"], out_C_grad: Float[Tensor, \"n 9\"]):\n        \n        num_particles = ctx.num_particles\n        device = ctx.device\n        mpm_solver, mpm_model = ctx.mpm_solver, ctx.mpm_model\n        prev_state = ctx.prev_state\n        starting_state = ctx.prev_state \n\n        E, nu, particle_density, query_mask = ctx.saved_tensors\n\n        num_substeps, substep_size = ctx.num_substeps, ctx.substep_size\n\n        # rolling back\n        # setting initial param first: \n        if E.ndim == 0:\n            E_inp = E.item() # float\n            ctx.aggregating_E = True\n        else:\n            E_inp = from_torch_safe(E, dtype=wp.float32, requires_grad=True)\n            ctx.aggregating_E = False\n        if nu.ndim == 0:\n            nu_inp = nu.item() # float\n            ctx.aggregating_nu = True\n        else:\n            nu_inp = from_torch_safe(nu, dtype=wp.float32, requires_grad=True)\n            ctx.aggregating_nu = False\n            \n        mpm_solver.set_E_nu(mpm_model, E_inp, nu_inp, device=device)\n\n        starting_state.reset_density(\n            tensor_density=particle_density,\n            selection_mask=query_mask,\n            device=device,\n            requires_grad=True)\n        \n        next_state_list = []\n\n        with wp.ScopedDevice(device):\n            tape = MyTape()\n\n            # handle it later\n            grad_pos_wp = from_torch_safe(out_pos_grad, dtype=wp.vec3, requires_grad=False)\n            if out_velo_grad is not None:\n                grad_velo_wp = from_torch_safe(out_velo_grad, dtype=wp.vec3, requires_grad=False)\n            else:\n                grad_velo_wp = None\n            \n            if out_F_grad is not None:\n                grad_F_wp = from_torch_safe(out_F_grad, dtype=wp.mat33, requires_grad=False)\n            else:\n                grad_F_wp = None\n            \n            if out_C_grad is not None:\n                grad_C_wp = from_torch_safe(out_C_grad, dtype=wp.mat33, requires_grad=False)\n            else:\n                grad_C_wp = None\n\n            with tape:\n\n                wp.launch(\n                    kernel=get_float_array_product,\n                    dim=num_particles,\n                    inputs=[\n                        prev_state.particle_density,\n                        prev_state.particle_vol,\n                        prev_state.particle_mass,\n                    ],\n                    device=device,\n                )\n                mpm_solver.prepare_mu_lam(mpm_model, prev_state, device=device)\n\n                for substep_local in range(num_substeps):\n                    next_state = prev_state.partial_clone(requires_grad=True)\n                    mpm_solver.p2g2p_differentiable(mpm_model, prev_state, next_state, substep_size, device=device)\n\n                    # next_state = mpm_solver.p2g2p_differentiable(mpm_model, prev_state, substep_size, device=device)\n                    next_state_list.append(next_state)\n                    prev_state = next_state\n\n                # simulation done. Compute loss:\n                \n                loss_wp = torch.zeros(1, device=device)\n                loss_wp = wp.from_torch(loss_wp, requires_grad=True)\n                target_pos_detach = wp.clone(next_state.particle_x, device=device, requires_grad=False)\n                wp.launch(\n                    compute_posloss_with_grad, \n                    dim=num_particles,\n                    inputs=[\n                        next_state,\n                        target_pos_detach,\n                        grad_pos_wp,\n                        0.5,\n                        loss_wp,\n                    ],\n                    device=device,\n                )\n                if grad_velo_wp is not None:\n                    target_velo_detach = wp.clone(next_state.particle_v, device=device, requires_grad=False)\n                    wp.launch(\n                        compute_veloloss_with_grad, \n                        dim=num_particles,\n                        inputs=[\n                            next_state,\n                            target_velo_detach,\n                            grad_velo_wp,\n                            0.5,\n                            loss_wp,\n                        ],\n                        device=device,\n                    )\n                \n                if grad_F_wp is not None:\n                    target_F_detach = wp.clone(next_state.particle_F_trial, device=device, requires_grad=False)\n                    wp.launch(\n                        compute_Floss_with_grad, \n                        dim=num_particles,\n                        inputs=[\n                            next_state,\n                            target_F_detach,\n                            grad_F_wp,\n                            0.5,\n                            loss_wp,\n                        ],\n                        device=device,\n                    )\n                if grad_C_wp is not None:\n                    target_C_detach = wp.clone(next_state.particle_C, device=device, requires_grad=False)\n                    wp.launch(\n                        compute_Closs_with_grad, \n                        dim=num_particles,\n                        inputs=[\n                            next_state,\n                            target_C_detach,\n                            grad_C_wp,\n                            0.5,\n                            loss_wp,\n                        ],\n                        device=device,)\n\n            # wp.synchronize_device(device)            \n            tape.backward(loss_wp)\n            # from IPython import embed; embed()\n\n        pos_grad = wp.to_torch(starting_state.particle_x.grad).detach().clone()\n        velo_grad = wp.to_torch(starting_state.particle_v.grad).detach().clone()\n        F_grad = wp.to_torch(starting_state.particle_F_trial.grad).detach().clone()\n        C_grad = wp.to_torch(starting_state.particle_C.grad).detach().clone()\n        # print(\"debug back\", velo_grad)\n\n        # grad for E, nu. TODO: add spatially varying E, nu later\n        if ctx.aggregating_E:\n            E_grad = wp.from_torch(torch.zeros(1, device=device), requires_grad=False)\n            wp.launch(\n                aggregate_grad,\n                dim=num_particles,\n                inputs=[\n                    E_grad,\n                    mpm_model.E.grad,\n                ],\n                device=device,\n            )\n            E_grad = wp.to_torch(E_grad)[0] / num_particles\n        else:\n            E_grad = wp.to_torch(mpm_model.E.grad).detach().clone()\n\n        if ctx.aggregating_nu:\n            nu_grad = wp.from_torch(torch.zeros(1, device=device), requires_grad=False)\n            wp.launch(\n                aggregate_grad,\n                dim=num_particles,\n                inputs=[nu_grad, mpm_model.nu.grad],\n                device=device,\n            )\n            nu_grad = wp.to_torch(nu_grad)[0] / num_particles   \n        else:\n            nu_grad = wp.to_torch(mpm_model.nu.grad).detach().clone()\n\n        # grad for density\n        if starting_state.particle_density.grad is None:\n            density_grad = None\n        else:\n            density_grad = wp.to_torch(starting_state.particle_density.grad).detach()\n\n        \n        density_mask_grad = None\n        static_pos_grad = None \n\n        tape.zero()\n        # print(density_grad.abs().sum(), velo_grad.abs().sum(), E_grad.abs().item(), nu_grad.abs().item(), \"in sim func\")\n        # from IPython import embed; embed()\n        \n        return (None, None, None, None, None,\n                pos_grad, velo_grad, F_grad, C_grad, \n                E_grad, nu_grad,\n                density_grad, density_mask_grad, \n                None, None, None)\n\n\nclass MPMDifferentiableSimulationClean(autograd.Function):\n    \"\"\"\n    Current version does not support grad for density. \n    Please set vol, mass before calling this function.\n    \"\"\"\n\n    @staticmethod\n    @torch.no_grad()\n    def forward(\n        ctx: autograd.function.FunctionCtx,\n        mpm_solver: MPMWARPDiff,\n        mpm_state: MPMStateStruct,\n        mpm_model: MPMModelStruct,\n        substep_size: float, \n        num_substeps: int,\n        particle_x: Float[Tensor, \"n 3\"], \n        particle_v: Float[Tensor, \"n 3\"],\n        particle_F: Float[Tensor, \"n 3 3\"],\n        particle_C: Float[Tensor, \"n 3 3\"],\n        E: Float[Tensor, \"n\"] | Float[Tensor, \"1\"],\n        nu: Float[Tensor, \"n\"] | Float[Tensor, \"1\"],\n        particle_density: Optional[Float[Tensor, \"n\"] | Float[Tensor, \"1\"]]=None,\n        query_mask: Optional[Int[Tensor, \"n\"]] = None,\n        device: str=\"cuda:0\",\n        requires_grad: bool=True,\n        extra_no_grad_steps: int=0,\n    ) -> Tuple[Float[Tensor, \"n 3\"], Float[Tensor, \"n 3\"], Float[Tensor, \"n 9\"], Float[Tensor, \"n 9\"], Float[Tensor, \"n 6\"]]:\n        \"\"\"\n        Args:\n            query_mask: [n] 0 or 1.  1 means the density or young's modulus, or poisson'ratio of this particle can change.\n        \"\"\"\n        \n        # initialization work is done before calling forward! \n\n        num_particles = particle_x.shape[0]\n\n        mpm_state.continue_from_torch(\n            particle_x, particle_v, particle_F, particle_C, device=device, requires_grad=True\n        )\n        # set x, v, F, C.\n\n        if E.ndim == 0:\n            E_inp = E.item() # float\n            ctx.aggregating_E = True\n        else:\n            E_inp = from_torch_safe(E, dtype=wp.float32, requires_grad=True)\n            ctx.aggregating_E = False\n        if nu.ndim == 0:\n            nu_inp = nu.item() # float\n            ctx.aggregating_nu = True\n        else:\n            nu_inp = from_torch_safe(nu, dtype=wp.float32, requires_grad=True)\n            ctx.aggregating_nu = False\n            \n        mpm_solver.set_E_nu(mpm_model, E_inp, nu_inp, device=device)\n        mpm_solver.prepare_mu_lam(mpm_model, mpm_state, device=device)\n\n        mpm_state.reset_density(\n            tensor_density=particle_density,\n            selection_mask=query_mask,\n            device=device,\n            requires_grad=True,\n            update_mass=True)\n        \n        prev_state = mpm_state\n\n        if extra_no_grad_steps > 0:\n            with torch.no_grad():\n                for i in range(extra_no_grad_steps):\n                    next_state = prev_state.partial_clone(requires_grad=True)\n                    mpm_solver.p2g2p_differentiable(mpm_model, prev_state, next_state, substep_size, device=device)\n                    prev_state = next_state\n\n        # following steps will be checkpointed. then replayed in backward\n        ctx.prev_state = prev_state\n\n        wp_tape = MyTape()\n        cond_tape: CondTape = CondTape(wp_tape, requires_grad)\n        next_state_list = [] \n\n        with cond_tape:\n            wp.launch(\n                kernel=get_float_array_product,\n                dim=num_particles,\n                inputs=[\n                    prev_state.particle_density,\n                    prev_state.particle_vol,\n                    prev_state.particle_mass,\n                ],\n                device=device,\n            )\n            mpm_solver.prepare_mu_lam(mpm_model, prev_state, device=device)\n\n            for substep_local in range(num_substeps):\n                next_state = prev_state.partial_clone(requires_grad=True)\n                mpm_solver.p2g2p_differentiable(mpm_model, prev_state, next_state, substep_size, device=device)\n                next_state_list.append(next_state)\n                prev_state = next_state\n        \n        ctx.mpm_solver = mpm_solver\n        ctx.mpm_model = mpm_model\n        ctx.next_state_list = next_state_list\n        ctx.device = device\n        ctx.num_particles = num_particles\n        ctx.tape = cond_tape.tape\n\n        ctx.save_for_backward(query_mask)\n\n        last_state = next_state\n        particle_pos = wp.to_torch(last_state.particle_x).detach().clone()\n        particle_velo = wp.to_torch(last_state.particle_v).detach().clone()\n        particle_F = wp.to_torch(last_state.particle_F_trial).detach().clone()\n        particle_C = wp.to_torch(last_state.particle_C).detach().clone()\n        # [N * 6, ]\n        particle_cov = wp.to_torch(last_state.particle_cov).detach().clone()\n\n        particle_cov = particle_cov.view(-1, 6)\n\n        return particle_pos, particle_velo, particle_F, particle_C, particle_cov\n    \n\n    @staticmethod\n    def backward(ctx, out_pos_grad: Float[Tensor, \"n 3\"], out_velo_grad: Float[Tensor, \"n 3\"], \n                 out_F_grad: Float[Tensor, \"n 9\"], out_C_grad: Float[Tensor, \"n 9\"], out_cov_grad: Float[Tensor, \"n 6\"]):\n        \n        num_particles = ctx.num_particles\n        device = ctx.device\n        mpm_solver, mpm_model = ctx.mpm_solver, ctx.mpm_model\n        tape = ctx.tape\n        starting_state = ctx.prev_state\n        \n        next_state_list = ctx.next_state_list\n        next_state = next_state_list[-1]\n\n        query_mask = ctx.saved_tensors\n    \n        with wp.ScopedDevice(device):\n            \n            grad_pos_wp = from_torch_safe(out_pos_grad, dtype=wp.vec3, requires_grad=False)\n            \n            with tape:\n                loss_wp = torch.zeros(1, device=device)\n                loss_wp = wp.from_torch(loss_wp, requires_grad=True)\n                target_pos_detach = wp.clone(next_state.particle_x, device=device, requires_grad=False)\n                wp.launch(\n                    compute_posloss_with_grad, \n                    dim=num_particles,\n                    inputs=[\n                        next_state,\n                        target_pos_detach,\n                        grad_pos_wp,\n                        0.5,\n                        loss_wp,\n                    ],\n                    device=device,\n                )\n\n            # wp.synchronize_device(device)            \n            tape.backward(loss_wp)\n            # from IPython import embed; embed()\n\n        pos_grad = wp.to_torch(starting_state.particle_x.grad).detach().clone()\n        velo_grad = wp.to_torch(starting_state.particle_v.grad).detach().clone()\n        F_grad = wp.to_torch(starting_state.particle_F_trial.grad).detach().clone()\n        C_grad = wp.to_torch(starting_state.particle_C.grad).detach().clone()\n        # print(\"debug back\", velo_grad)\n\n        # grad for E, nu. TODO: add spatially varying E, nu later\n        if ctx.aggregating_E:\n            E_grad = wp.from_torch(torch.zeros(1, device=device), requires_grad=False)\n            wp.launch(\n                aggregate_grad,\n                dim=num_particles,\n                inputs=[\n                    E_grad,\n                    mpm_model.E.grad,\n                ],\n                device=device,\n            )\n            E_grad = wp.to_torch(E_grad)[0] / num_particles\n        else:\n            E_grad = wp.to_torch(mpm_model.E.grad).detach().clone()\n\n        if ctx.aggregating_nu:\n            nu_grad = wp.from_torch(torch.zeros(1, device=device), requires_grad=False)\n            wp.launch(\n                aggregate_grad,\n                dim=num_particles,\n                inputs=[nu_grad, mpm_model.nu.grad],\n                device=device,\n            )\n            nu_grad = wp.to_torch(nu_grad)[0] / num_particles   \n        else:\n            nu_grad = wp.to_torch(mpm_model.nu.grad).detach().clone()\n\n        # grad for density\n        if starting_state.particle_density.grad is None:\n            density_grad = None\n        else:\n            density_grad = wp.to_torch(starting_state.particle_density.grad).detach()\n        density_mask_grad = None\n\n        tape.zero()\n        # print(density_grad.abs().sum(), velo_grad.abs().sum(), E_grad.abs().item(), nu_grad.abs().item(), \"in sim func\")\n        # from IPython import embed; embed()\n        \n        return (None, None, None, None, None,\n                pos_grad, velo_grad, F_grad, C_grad, \n                E_grad, nu_grad,\n                density_grad, density_mask_grad, \n                None, None, None)"
  },
  {
    "path": "projects/uncleaned_train/exp_motion/train/local_utils.py",
    "content": "import os\nimport torch\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor\nfrom time import time\nfrom omegaconf import OmegaConf\nfrom motionrep.fields.se3_field import TemporalKplanesSE3fields\nfrom motionrep.fields.triplane_field import TriplaneFields, TriplaneFieldsWithEntropy\nfrom motionrep.utils.svd_helpper import load_model_from_config\n\nfrom motionrep.gaussian_3d.gaussian_renderer.render import (\n    render_gaussian,\n    render_arrow_in_screen,\n)\nfrom motionrep.gaussian_3d.gaussian_renderer.flow_depth_render import (\n    render_flow_depth_w_gaussian,\n)\nimport cv2\nimport numpy as np\nfrom sklearn.cluster import KMeans\nfrom time import time\n\nfrom motionrep.gaussian_3d.utils.rigid_body_utils import (\n    get_rigid_transform,\n    matrix_to_quaternion,\n    quaternion_multiply,\n)\n\n\ndef cycle(dl: torch.utils.data.DataLoader):\n    while True:\n        for data in dl:\n            yield data\n\n\ndef load_motion_model(model, checkpoint_path):\n    model_path = os.path.join(checkpoint_path, \"model.pt\")\n    model.load_state_dict(torch.load(model_path))\n    print(\"load model from: \", model_path)\n    return model\n\n\ndef create_spatial_fields(\n    args, output_dim, aabb: Float[Tensor, \"2 3\"], add_entropy=True\n):\n\n    sp_res = args.sim_res\n\n    resolutions = [sp_res, sp_res, sp_res]\n    reduce = \"sum\"\n\n    if args.entropy_cls > 0 and add_entropy:\n        model = TriplaneFieldsWithEntropy(\n            aabb,\n            resolutions,\n            feat_dim=32,\n            init_a=0.1,\n            init_b=0.5,\n            reduce=reduce,\n            num_decoder_layers=2,\n            decoder_hidden_size=32,\n            output_dim=output_dim,\n            zero_init=args.zero_init,\n            num_cls=args.entropy_cls,\n        )\n    else:\n        model = TriplaneFields(\n            aabb,\n            resolutions,\n            feat_dim=32,\n            init_a=0.1,\n            init_b=0.5,\n            reduce=reduce,\n            num_decoder_layers=2,\n            decoder_hidden_size=32,\n            output_dim=output_dim,\n            zero_init=args.zero_init,\n        )\n    if args.zero_init:\n        print(\"=> zero init the last layer for Spatial MLP\")\n\n    return model\n\n\ndef create_motion_model(\n    args,\n    aabb: Float[Tensor, \"2 3\"],\n    num_frames=None,\n):\n    assert args.model in [\"se3_field\"]\n\n    sp_res = args.spatial_res\n    if num_frames is None:\n        num_frames = args.num_frames\n    resolutions = [sp_res, sp_res, sp_res, (num_frames) // 2 + 1]\n    # resolutions = [64, 64, 64, num_frames // 2 + 1]\n    reduce = \"sum\"\n\n    model = TemporalKplanesSE3fields(\n        aabb,\n        resolutions,\n        feat_dim=args.feat_dim,\n        init_a=0.1,\n        init_b=0.5,\n        reduce=reduce,\n        num_decoder_layers=args.num_decoder_layers,\n        decoder_hidden_size=args.decoder_hidden_size,\n        zero_init=args.zero_init,\n    )\n    if args.zero_init:\n        print(\"=> zero init the last layer for MLP\")\n\n    return model\n\n\ndef create_velocity_model(\n    args,\n    aabb: Float[Tensor, \"2 3\"],\n):\n\n    from motionrep.fields.offset_field import TemporalKplanesOffsetfields\n\n    sp_res = args.sim_res\n    resolutions = [sp_res, sp_res, sp_res, (args.num_frames) // 2 + 1]\n    reduce = \"sum\"\n    model = TemporalKplanesOffsetfields(\n        aabb,\n        resolutions,\n        feat_dim=32,\n        init_a=0.1,\n        init_b=0.5,\n        reduce=reduce,\n        num_decoder_layers=2,\n        decoder_hidden_size=32,\n        zero_init=args.zero_init,\n    )\n    if args.zero_init:\n        print(\"=> zero init the last layer for velocity MLP\")\n    return model\n\n\ndef create_svd_model(model_name=\"svd_full\", ckpt_path=None):\n    state = dict()\n    cfg_path_dict = {\n        \"svd_full\": \"svd_configs/svd_full_decoder.yaml\",\n    }\n    config = cfg_path_dict[model_name]\n\n    config = OmegaConf.load(config)\n\n    if ckpt_path is not None:\n        # overwrite config.\n        config.model.params.ckpt_path = ckpt_path\n\n    s_time = time()\n    # model will automatically load when create\n    model, msg = load_model_from_config(config, None)\n\n    state[\"config\"] = config\n\n    print(f\"Loading svd model takes {time() - s_time} seconds\")\n\n    return model, state\n\n\nclass LinearStepAnneal(object):\n    # def __init__(self, total_iters, start_state=[0.02, 0.98], end_state=[0.50, 0.98]):\n    def __init__(\n        self,\n        total_iters,\n        start_state=[0.02, 0.98],\n        end_state=[0.02, 0.98],\n        plateau_iters=-1,\n        warmup_step=300,\n    ):\n        self.total_iters = total_iters\n\n        if plateau_iters < 0:\n            plateau_iters = int(total_iters * 0.2)\n\n        if warmup_step <= 0:\n            warmup_step = 0\n\n        self.total_iters = max(total_iters - plateau_iters - warmup_step, 10)\n\n        self.start_state = start_state\n        self.end_state = end_state\n        self.warmup_step = warmup_step\n\n    def compute_state(self, cur_iter):\n\n        if self.warmup_step > 0:\n            cur_iter = max(0, cur_iter - self.warmup_step)\n        if cur_iter >= self.total_iters:\n            return self.end_state\n        ret = []\n        for s, e in zip(self.start_state, self.end_state):\n            ret.append(s + (e - s) * cur_iter / self.total_iters)\n        return ret\n\n\ndef setup_boundary_condition(\n    xyzs_over_time: torch.Tensor, mpm_solver, mpm_state, num_filled=0\n):\n\n    init_velocity = xyzs_over_time[1] - xyzs_over_time[0]\n    init_velocity_mag = torch.norm(init_velocity, dim=-1)\n\n    # 10% of the velocity\n    velocity_thres = torch.quantile(init_velocity_mag, 0.1, dim=0)\n\n    # [n_particles]. 1 for freeze, 0 for moving\n    freeze_mask = init_velocity_mag < velocity_thres\n    freeze_mask = freeze_mask.type(torch.int)\n    if num_filled > 0:\n        freeze_mask = torch.cat(\n            [freeze_mask, freeze_mask.new_zeros(num_filled).type(torch.int)], dim=0\n        )\n    num_freeze_pts = freeze_mask.sum()\n    print(\"num freeze pts from static points\", num_freeze_pts.item())\n\n    free_velocity = torch.zeros_like(init_velocity[0])  # [3] in device\n\n    mpm_solver.enforce_particle_velocity_by_mask(\n        mpm_state, freeze_mask, free_velocity, start_time=0, end_time=100000\n    )\n\n    return freeze_mask\n\n\ndef setup_plannar_boundary_condition(\n    xyzs_over_time: torch.Tensor,\n    mpm_solver,\n    mpm_state,\n    gaussian_xyz,\n    plane_mean,\n    plane_normal,\n    thres=0.2,\n):\n    \"\"\"\n    plane_mean and plane_normal are in original coordinate, not being normalized\n    Args:\n        xyzs_over_time: [T, N, 3]\n        gaussian_xyz: [N, 3] torch.Tensor\n        plane_mean: [3]\n        plane_normal: [3]\n        thres: float\n\n    \"\"\"\n\n    plane_normal = plane_normal / torch.norm(plane_normal)\n    # [n_particles]\n    plane_dist = torch.abs(\n        torch.sum(\n            (gaussian_xyz - plane_mean.unsqueeze(0)) * plane_normal.unsqueeze(0), dim=-1\n        )\n    )\n    # [n_particles]\n    freeze_mask = plane_dist < thres\n    freeze_mask = freeze_mask.type(torch.int)\n\n    num_freeze_pts = freeze_mask.sum()\n    print(\"num freeze pts from plannar boundary\", num_freeze_pts.item())\n    free_velocity = xyzs_over_time.new_zeros(3)\n    # print(\"free velocity\", free_velocity.shape, freeze_mask.shape)\n\n    mpm_solver.enforce_particle_velocity_by_mask(\n        mpm_state, freeze_mask, free_velocity, start_time=0, end_time=100000\n    )\n\n    return freeze_mask\n\n\ndef find_far_points(xyzs, selected_points, thres=0.05):\n    \"\"\"\n    Args:\n        xyzs: [N, 3]\n        selected_points: [M, 3]\n    Outs:\n        freeze_mask: [N], 1 for points that are far away, 0 for points that are close\n                    dtype=torch.int\n    \"\"\"\n    chunk_size = 10000\n\n    freeze_mask_list = []\n    for i in range(0, xyzs.shape[0], chunk_size):\n\n        end_index = min(i + chunk_size, xyzs.shape[0])\n        xyzs_chunk = xyzs[i:end_index]\n        # [M, N]\n        cdist = torch.cdist(xyzs_chunk, selected_points)\n\n        min_dist, _ = torch.min(cdist, dim=-1)\n        freeze_mask = min_dist > thres\n        freeze_mask = freeze_mask.type(torch.int)\n        freeze_mask_list.append(freeze_mask)\n\n    freeze_mask = torch.cat(freeze_mask_list, dim=0)\n\n    # 1 for points that are far away, 0 for points that are close\n    return freeze_mask\n\n\ndef setup_boundary_condition_with_points(\n    xyzs, selected_points, mpm_solver, mpm_state, thres=0.05\n):\n    \"\"\"\n    Args:\n        xyzs: [N, 3]\n        selected_points: [M, 3]\n    \"\"\"\n\n    freeze_mask = find_far_points(xyzs, selected_points, thres=thres)\n    num_freeze_pts = freeze_mask.sum()\n    print(\"num freeze pts from static points\", num_freeze_pts.item())\n\n    free_velocity = torch.zeros_like(xyzs[0])  # [3] in device\n\n    mpm_solver.enforce_particle_velocity_by_mask(\n        mpm_state, freeze_mask, free_velocity, start_time=0, end_time=1000000\n    )\n\n    return freeze_mask\n\n\ndef setup_bottom_boundary_condition(xyzs, mpm_solver, mpm_state, percentile=0.05):\n    \"\"\"\n    Args:\n        xyzs: [N, 3]\n        selected_points: [M, 3]\n    \"\"\"\n    max_z, min_z = torch.max(xyzs[:, 2]), torch.min(xyzs[:, 2])\n    thres = min_z + (max_z - min_z) * percentile\n    freeze_mask = xyzs[:, 2] < thres\n\n    freeze_mask = freeze_mask.type(torch.int)\n    num_freeze_pts = freeze_mask.sum()\n    print(\"num freeze pts from bottom points\", num_freeze_pts.item())\n\n    free_velocity = torch.zeros_like(xyzs[0])  # [3] in device\n\n    mpm_solver.enforce_particle_velocity_by_mask(\n        mpm_state, freeze_mask, free_velocity, start_time=0, end_time=1000000\n    )\n\n    return freeze_mask\n\n\ndef render_single_view_video(\n    cam,\n    render_params,\n    motion_model,\n    time_stamps,\n    rand_bg=False,\n    render_flow=False,\n    query_mask=None,\n):\n    \"\"\"\n    Args:\n        cam:\n        motion_model: Callable function, f(x, t) => translation, rotation\n        time_stamps: [T]\n        query_mask: Tensor of [N], 0 for freeze points, 1 for moving points\n    Outs:\n        ret_video: [T, 3, H, W] value in [0, 1]\n    \"\"\"\n\n    if rand_bg:\n        bg_color = torch.rand(3, device=\"cuda\")\n    else:\n        bg_color = render_params.bg_color\n\n    ret_img_list = []\n    for time_stamp in time_stamps:\n        if not render_flow:\n            new_gaussians = render_params.gaussians.apply_se3_fields(\n                motion_model, time_stamp\n            )\n            if query_mask is not None:\n                new_gaussians._xyz = new_gaussians._xyz * query_mask.unsqueeze(\n                    -1\n                ) + render_params.gaussians._xyz * (1 - query_mask.unsqueeze(-1))\n                new_gaussians._rotation = (\n                    new_gaussians._rotation * query_mask.unsqueeze(-1)\n                    + render_params.gaussians._rotation * (1 - query_mask.unsqueeze(-1))\n                )\n            # [3, H, W]\n            img = render_gaussian(\n                cam,\n                new_gaussians,\n                render_params.render_pipe,\n                bg_color,\n            )[\n                \"render\"\n            ]  # value in [0, 1]\n        else:\n            inp_time = (\n                torch.ones_like(render_params.gaussians._xyz[:, 0:1]) * time_stamp\n            )\n            inp = torch.cat([render_params.gaussians._xyz, inp_time], dim=-1)\n            # [bs, 3, 3]. [bs, 3]\n            R, point_disp = motion_model(inp)\n\n            img = render_flow_depth_w_gaussian(\n                cam,\n                render_params.gaussians,\n                render_params.render_pipe,\n                point_disp,\n                bg_color,\n            )[\"render\"]\n\n        ret_img_list.append(img[None, ...])\n\n    ret_video = torch.cat(ret_img_list, dim=0)  # [T, 3, H, W]\n    return ret_video\n\n\ndef render_gaussian_seq(cam, render_params, gaussian_pos_list, gaussian_cov_list):\n\n    ret_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(gaussian_pos_list)):\n\n        xyz = gaussian_pos_list[i]\n        gaussians._xyz = xyz\n        # TODO, how to deal with cov\n        img = render_gaussian(\n            cam,\n            gaussians,\n            render_params.render_pipe,\n            background,\n        )[\"render\"]\n\n        ret_img_list.append(img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    # [T, C, H, W], in [0, 1]\n    rendered_video = torch.cat(ret_img_list, dim=0)\n\n    return rendered_video\n\n\ndef render_gaussian_seq_w_mask(\n    cam, render_params, gaussian_pos_list, gaussian_cov_list, update_mask\n):\n\n    ret_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    old_cov = gaussians.get_covariance().clone()\n\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(gaussian_pos_list)):\n\n        xyz = gaussian_pos_list[i]\n        gaussians._xyz[update_mask, ...] = xyz\n\n        if gaussian_cov_list is not None:\n            cov = gaussian_cov_list[i]\n            old_cov[update_mask, ...] = cov\n            cov3D_precomp = old_cov\n\n        else:\n            cov3D_precomp = None\n\n        img = render_gaussian(\n            cam,\n            gaussians,\n            render_params.render_pipe,\n            background,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        ret_img_list.append(img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    # [T, C, H, W], in [0, 1]\n    rendered_video = torch.cat(ret_img_list, dim=0)\n\n    return rendered_video\n\n\ndef render_gaussian_seq_w_mask_with_disp(\n    cam, render_params, orign_points, top_k_index, disp_list, update_mask\n):\n    \"\"\"\n    Args:\n        cam: Camera or list of Camera\n        orign_points: [m, 3]\n        disp_list: List[m, 3]\n        top_k_index: [n, top_k]\n\n    \"\"\"\n\n    ret_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    old_rotation = gaussians._rotation.clone()\n\n    query_pts = old_xyz[update_mask, ...]\n    query_rotation = old_rotation[update_mask, ...]\n\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(disp_list)):\n\n        if isinstance(cam, list):\n            render_cam = cam[i]\n        else:\n            render_cam = cam\n        disp = disp_list[i]\n        new_xyz, new_rotation = interpolate_points_w_R(\n            query_pts, query_rotation, orign_points, disp, top_k_index\n        )\n        gaussians._xyz[update_mask, ...] = new_xyz\n        gaussians._rotation[update_mask, ...] = new_rotation\n\n        cov3D_precomp = None\n\n        img = render_gaussian(\n            render_cam,\n            gaussians,\n            render_params.render_pipe,\n            background,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        ret_img_list.append(img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    gaussians._rotation = old_rotation\n    # [T, C, H, W], in [0, 1]\n    rendered_video = torch.cat(ret_img_list, dim=0)\n\n    return rendered_video\n\n\ndef render_gaussian_seq_w_mask_with_disp_for_figure(\n    cam, render_params, orign_points, top_k_index, disp_list, update_mask\n):\n    \"\"\"\n    Args:\n        cam: Camera or list of Camera\n        orign_points: [m, 3]\n        disp_list: List[m, 3]\n        top_k_index: [n, top_k]\n\n    \"\"\"\n\n    ret_img_list = []\n    moving_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    old_rotation = gaussians._rotation.clone()\n\n    query_pts = old_xyz[update_mask, ...]\n    query_rotation = old_rotation[update_mask, ...]\n\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    background_black = torch.tensor([0, 0, 0], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(disp_list)):\n\n        if isinstance(cam, list):\n            render_cam = cam[i]\n        else:\n            render_cam = cam\n        disp = disp_list[i]\n        new_xyz, new_rotation = interpolate_points_w_R(\n            query_pts, query_rotation, orign_points, disp, top_k_index\n        )\n        gaussians._xyz[update_mask, ...] = new_xyz\n        gaussians._rotation[update_mask, ...] = new_rotation\n\n        cov3D_precomp = None\n\n        img = render_gaussian(\n            render_cam,\n            gaussians,\n            render_params.render_pipe,\n            background,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        masked_gaussians = gaussians.apply_mask(update_mask)\n        moving_img = render_gaussian(\n            render_cam,\n            masked_gaussians,\n            render_params.render_pipe,\n            background_black,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        ret_img_list.append(img[None, ...])\n        moving_img_list.append(moving_img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    gaussians._rotation = old_rotation\n    # [T, C, H, W], in [0, 1]\n    rendered_video = torch.cat(ret_img_list, dim=0)\n    moving_part_video = torch.cat(moving_img_list, dim=0)\n\n    return rendered_video, moving_part_video\n\n\ndef render_gaussian_seq_w_mask_cam_seq(\n    cam_list, render_params, gaussian_pos_list, gaussian_cov_list, update_mask\n):\n\n    ret_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    old_cov = gaussians.get_covariance().clone()\n\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(gaussian_pos_list)):\n\n        xyz = gaussian_pos_list[i]\n        gaussians._xyz[update_mask, ...] = xyz\n\n        if gaussian_cov_list is not None:\n            cov = gaussian_cov_list[i]\n            old_cov[update_mask, ...] = cov\n            cov3D_precomp = old_cov\n\n        else:\n            cov3D_precomp = None\n\n        img = render_gaussian(\n            cam_list[i],\n            gaussians,\n            render_params.render_pipe,\n            background,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        ret_img_list.append(img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    # [T, C, H, W], in [0, 1]\n    rendered_video = torch.cat(ret_img_list, dim=0)\n\n    return rendered_video\n\n\ndef apply_grid_bc_w_freeze_pts(grid_size, grid_lim, freeze_pts, mpm_solver):\n\n    device = freeze_pts.device\n\n    grid_pts_cnt = torch.zeros(\n        (grid_size, grid_size, grid_size), dtype=torch.int32, device=device\n    )\n\n    dx = grid_lim / grid_size\n    inv_dx = 1.0 / dx\n\n    freeze_pts = (freeze_pts * inv_dx).long()\n\n    for x, y, z in freeze_pts:\n        grid_pts_cnt[x, y, z] += 1\n\n    freeze_grid_mask = grid_pts_cnt >= 1\n\n    freeze_grid_mask_int = freeze_grid_mask.type(torch.int32)\n\n    number_freeze_grid = freeze_grid_mask_int.sum().item()\n    print(\"number of freeze grid\", number_freeze_grid)\n\n    mpm_solver.enforce_grid_velocity_by_mask(freeze_grid_mask_int)\n\n    # add debug section:\n\n    return freeze_grid_mask\n\n\ndef add_constant_force(\n    mpm_sovler,\n    mpm_state,\n    xyzs,\n    center_point,\n    radius,\n    force,\n    dt,\n    start_time,\n    end_time,\n    device,\n):\n    \"\"\"\n    Args:\n        xyzs: [N, 3]\n        center_point: [3]\n        radius: float\n        force: [3]\n\n    \"\"\"\n\n    # compute distance from xyzs to center_point\n    # [N]\n    dist = torch.norm(xyzs - center_point.unsqueeze(0), dim=-1)\n\n    apply_force_mask = dist < radius\n    apply_force_mask = apply_force_mask.type(torch.int)\n\n    print(apply_force_mask.shape, apply_force_mask.sum().item(), \"apply force mask\")\n\n    mpm_sovler.add_impulse_on_particles_with_mask(\n        mpm_state,\n        force,\n        dt,\n        apply_force_mask,\n        start_time=start_time,\n        end_time=end_time,\n        device=device,\n    )\n\n\n@torch.no_grad()\ndef render_force_2d(cam, render_params, center_point, force):\n\n    force_in_2d_scale = 80  # unit as pixel\n    two_points = torch.stack([center_point, center_point + force], dim=0)\n\n    gaussians = render_params.gaussians\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n\n    # [3, H, W]\n    img = render_gaussian(\n        cam,\n        gaussians,\n        render_params.render_pipe,\n        background,\n    )[\"render\"]\n    img = img.detach().contiguous()\n    img = img.cpu().numpy().transpose(1, 2, 0)\n    img = img * 255\n    img = img.astype(np.uint8).copy()\n\n    # two_points.  [2, 3]\n    # arrow_2d: [2, 2]\n    arrow_2d = render_arrow_in_screen(cam, two_points)\n\n    arrow_2d = arrow_2d.cpu().numpy()\n\n    start, vec_2d = arrow_2d[0], arrow_2d[1] - arrow_2d[0]\n    vec_2d = vec_2d / np.linalg.norm(vec_2d)\n\n    start = start  # + np.array([540.0, 288.0])  # [W, H] / 2\n    # debug here.\n    # 1. unit in pixel?\n    # 2. use cv2 to add arrow?\n    # draw cirrcle at start in img\n\n    # img = img.transpose(2, 0, 1)\n    img = cv2.circle(img, (int(start[0]), int(start[1])), 40, (255, 255, 255), 8)\n\n    # draw arrow in img\n    end = start + vec_2d * force_in_2d_scale\n    end = end.astype(np.int32)\n    start = start.astype(np.int32)\n    img = cv2.arrowedLine(img, (start[0], start[1]), (end[0], end[1]), (0, 255, 255), 8)\n\n    return img\n\n\ndef render_gaussian_seq_w_mask_cam_seq_with_force(\n    cam_list,\n    render_params,\n    gaussian_pos_list,\n    gaussian_cov_list,\n    update_mask,\n    pts_index,\n    force,\n    force_steps,\n):\n\n    force_in_2d_scale = 80  # unit as pixel\n    ret_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    old_cov = gaussians.get_covariance().clone()\n\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(gaussian_pos_list)):\n\n        xyz = gaussian_pos_list[i]\n        gaussians._xyz[update_mask, ...] = xyz\n\n        if gaussian_cov_list is not None:\n            cov = gaussian_cov_list[i]\n            old_cov[update_mask, ...] = cov\n            cov3D_precomp = old_cov\n\n        else:\n            cov3D_precomp = None\n\n        img = render_gaussian(\n            cam_list[i],\n            gaussians,\n            render_params.render_pipe,\n            background,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        # to [H, W, 3]\n        img = img.detach().contiguous().cpu().numpy().transpose(1, 2, 0)\n        img = np.clip((img * 255), 0, 255).astype(np.uint8).copy()\n\n        if i < force_steps:\n            center_point = gaussians._xyz[pts_index]\n            two_points = torch.stack([center_point, center_point + force], dim=0)\n\n            arrow_2d = render_arrow_in_screen(cam_list[i], two_points)\n\n            arrow_2d = arrow_2d.cpu().numpy()\n\n            start, vec_2d = arrow_2d[0], arrow_2d[1] - arrow_2d[0]\n            vec_2d = vec_2d / np.linalg.norm(vec_2d)\n\n            start = start  # + np.array([540.0, 288.0])\n\n            img = cv2.circle(\n                img, (int(start[0]), int(start[1])), 40, (255, 255, 255), 8\n            )\n\n            # draw arrow in img\n            end = start + vec_2d * force_in_2d_scale\n            end = end.astype(np.int32)\n            start = start.astype(np.int32)\n            img = cv2.arrowedLine(\n                img, (start[0], start[1]), (end[0], end[1]), (0, 255, 255), 8\n            )\n\n        img = img.transpose(2, 0, 1)\n        ret_img_list.append(img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    # [T, C, H, W], in [0, 1]\n    rendered_video = np.concatenate(ret_img_list, axis=0)\n\n    return rendered_video\n\n\ndef render_gaussian_seq_w_mask_cam_seq_with_force_with_disp(\n    cam_list,\n    render_params,\n    orign_points,\n    top_k_index,\n    disp_list,\n    update_mask,\n    pts_index,\n    force,\n    force_steps,\n):\n\n    force_in_2d_scale = 80  # unit as pixel\n    ret_img_list = []\n    gaussians = render_params.gaussians\n    old_xyz = gaussians._xyz.clone()\n    old_rotation = gaussians._rotation.clone()\n\n    query_pts = old_xyz[update_mask, ...]\n    query_rotation = old_rotation[update_mask, ...]\n\n    background = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n    for i in range(len(disp_list)):\n\n        disp = disp_list[i]\n        new_xyz, new_rotation = interpolate_points_w_R(\n            query_pts, query_rotation, orign_points, disp, top_k_index\n        )\n        gaussians._xyz[update_mask, ...] = new_xyz\n        gaussians._rotation[update_mask, ...] = new_rotation\n\n        cov3D_precomp = None\n\n        img = render_gaussian(\n            cam_list[i],\n            gaussians,\n            render_params.render_pipe,\n            background,\n            cov3D_precomp=cov3D_precomp,\n        )[\"render\"]\n\n        # to [H, W, 3]\n        img = img.detach().contiguous().cpu().numpy().transpose(1, 2, 0)\n        img = np.clip((img * 255), 0, 255).astype(np.uint8).copy()\n\n        if i < force_steps:\n            center_point = gaussians._xyz[pts_index]\n            two_points = torch.stack([center_point, center_point + force], dim=0)\n\n            arrow_2d = render_arrow_in_screen(cam_list[i], two_points)\n\n            arrow_2d = arrow_2d.cpu().numpy()\n\n            start, vec_2d = arrow_2d[0], arrow_2d[1] - arrow_2d[0]\n            vec_2d = vec_2d / np.linalg.norm(vec_2d)\n\n            start = start  # + np.array([540.0, 288.0])\n\n            img = cv2.circle(\n                img, (int(start[0]), int(start[1])), 40, (255, 255, 255), 5\n            )\n\n            # draw arrow in img\n            end = start + vec_2d * force_in_2d_scale\n            end = end.astype(np.int32)\n            start = start.astype(np.int32)\n            img = cv2.arrowedLine(\n                img, (start[0], start[1]), (end[0], end[1]), (255, 255, 0), 4\n            )\n\n        img = img.transpose(2, 0, 1)\n        ret_img_list.append(img[None, ...])\n\n    gaussians._xyz = old_xyz  # set back\n    gaussians._rotation = old_rotation\n    # [T, C, H, W], in [0, 1]\n    rendered_video = np.concatenate(ret_img_list, axis=0)\n\n    return rendered_video\n\n\ndef downsample_with_kmeans(points_array: np.ndarray, num_points: int):\n    \"\"\"\n    Args:\n        points_array: [N, 3]\n        num_points: int\n    Outs:\n        downsampled_points: [num_points, 3]\n    \"\"\"\n\n    print(\n        \"=> staring downsample with kmeans from \",\n        points_array.shape[0],\n        \" points to \",\n        num_points,\n        \" points\",\n    )\n    s_time = time()\n    kmeans = KMeans(n_clusters=num_points, random_state=0).fit(points_array)\n    cluster_centers = kmeans.cluster_centers_\n    e_time = time()\n\n    print(\"=> downsample with kmeans takes \", e_time - s_time, \" seconds\")\n    return cluster_centers\n\n\n@torch.no_grad()\ndef downsample_with_kmeans_gpu(points_array: torch.Tensor, num_points: int):\n\n    from kmeans_gpu import KMeans\n\n    kmeans = KMeans(\n        n_clusters=num_points,\n        max_iter=100,\n        tolerance=1e-4,\n        distance=\"euclidean\",\n        sub_sampling=None,\n        max_neighbors=15,\n    )\n\n    features = torch.ones(1, 1, points_array.shape[0], device=points_array.device)\n    points_array = points_array.unsqueeze(0)\n    # Forward\n\n    print(\n        \"=> staring downsample with kmeans from \",\n        points_array.shape[1],\n        \" points to \",\n        num_points,\n        \" points\",\n    )\n    s_time = time()\n    centroids, features = kmeans(points_array, features)\n\n    ret_points = centroids.squeeze(0)\n    e_time = time()\n    print(\"=> downsample with kmeans takes \", e_time - s_time, \" seconds\")\n\n    # [np_subsample, 3]\n    return ret_points\n\n\ndef interpolate_points(query_points, drive_displacement, top_k_index):\n    \"\"\"\n    Args:\n        query_points: [n, 3]\n        drive_displacement: [m, 3]\n        top_k_index: [n, top_k] < m\n    \"\"\"\n\n    top_k_disp = drive_displacement[top_k_index]\n\n    t = top_k_disp.mean(dim=1)\n\n    ret_points = query_points + t\n\n    return ret_points\n\n\ndef interpolate_points_w_R(\n    query_points, query_rotation, drive_origin_pts, drive_displacement, top_k_index\n):\n    \"\"\"\n    Args:\n        query_points: [n, 3]\n        drive_origin_pts: [m, 3]\n        drive_displacement: [m, 3]\n        top_k_index: [n, top_k] < m\n\n    Or directly call: apply_discrete_offset_filds_with_R(self, origin_points, offsets, topk=6):\n        Args:\n            origin_points: (N_r, 3)\n            offsets: (N_r, 3)\n        in rendering\n    \"\"\"\n\n    # [n, topk, 3]\n    top_k_disp = drive_displacement[top_k_index]\n    source_points = drive_origin_pts[top_k_index]\n\n    R, t = get_rigid_transform(source_points, source_points + top_k_disp)\n\n    avg_offsets = top_k_disp.mean(dim=1)\n\n    ret_points = query_points + avg_offsets\n\n    new_rotation = quaternion_multiply(matrix_to_quaternion(R), query_rotation)\n\n    return ret_points, new_rotation\n\n\ndef create_camera_path(\n    cam,\n    radius: float,\n    focus_pt: np.ndarray = np.array([0, 0, 0]),\n    up: np.ndarray = np.array([0, 0, 1]),\n    n_frames: int = 60,\n    n_rots: int = 1,\n    y_scale: float = 1.0,\n):\n\n    R, T = cam.R, cam.T\n    # R, T = R.cpu().numpy(), T.cpu().numpy()\n\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = R.transpose()\n    Rt[:3, 3] = T\n    Rt[3, 3] = 1.0\n    C2W = np.linalg.inv(Rt)\n    C2W[:3, 1:3] *= -1\n\n    import copy\n    from motionrep.utils.camera_utils import generate_spiral_path\n    from motionrep.data.cameras import Camera\n\n    lookat_pt = focus_pt\n    render_poses = generate_spiral_path(\n        C2W, radius, lookat_pt, up, n_frames, n_rots, y_scale\n    )\n\n    FoVy, FoVx = cam.FoVy, cam.FoVx\n    height, width = cam.image_height, cam.image_width\n\n    ret_cam_list = []\n    for i in range(n_frames):\n        c2w_opengl = render_poses[i]\n        c2w = copy.deepcopy(c2w_opengl)\n        c2w[:3, 1:3] *= -1\n\n        # get the world-to-camera transform and set R, T\n        w2c = np.linalg.inv(c2w)\n        R = np.transpose(\n            w2c[:3, :3]\n        )  # R is stored transposed due to 'glm' in CUDA code\n        T = w2c[:3, 3]\n        cam = Camera(\n            R=R,\n            T=T,\n            FoVy=FoVy,\n            FoVx=FoVx,\n            img_path=None,\n            img_hw=(height, width),\n            timestamp=None,\n            data_device=\"cuda\",\n        )\n        ret_cam_list.append(cam)\n\n    return ret_cam_list\n\n\ndef get_camera_trajectory(cam, num_pos, camera_cfg: dict, dataset):\n    if camera_cfg[\"type\"] == \"spiral\":\n        interpolated_cameras = create_camera_path(\n            cam,\n            radius=camera_cfg[\"radius\"],\n            focus_pt=camera_cfg[\"focus_point\"],\n            up=camera_cfg[\"up\"],\n            n_frames=num_pos,\n        )\n    elif camera_cfg[\"type\"] == \"interpolation\":\n        if \"start_frame\" in camera_cfg and \"end_frame\" in camera_cfg:\n            interpolated_cameras = dataset.interpolate_camera(\n                camera_cfg[\"start_frame\"], camera_cfg[\"end_frame\"], num_pos\n            )\n        else:\n            interpolated_cameras = dataset.interpolate_camera(\n                camera_cfg[\"start_frame\"], camera_cfg[\"start_frame\"], num_pos\n            )\n\n    print(\n        \"number of simulated frames: \",\n        num_pos,\n        \"num camera viewpoints: \",\n        len(interpolated_cameras),\n    )\n    return interpolated_cameras\n"
  },
  {
    "path": "projects/uncleaned_train/exp_motion/train/model_config.py",
    "content": "import numpy as np\n\ndataset_dir = \"../../data/physics_dreamer/hat_nerfstudio/\"\nresult_dir = \"output/hat/results_force\"\nexp_name = \"hat\"\n\nmodel_list = [\n    # multiview 64 364\n    \"../../output/inverse_sim/fast_hat_videos2_sv64-384_init1e5decay_1.0_substep_384_se3_field_lr_0.03_tv_0.0001_iters_200_sw_6_cw_1/seed0/checkpoint_model_000019\",\n]\n\nfocus_point_list = [\n    np.array([-0.467188, 0.067178, 0.044333]),  # botton of the background\n]\n\ncamera_cfg_list = [\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00001.png\",\n        \"end_frame\": \"frame_00187.png\",  # or 91\n    },\n    # real captured viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00217.png\",\n    },\n    # other selected viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00001.png\",\n    },\n    # other selected viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00001.png\",\n    },\n    # other selected viewpoint\n    {\n        \"type\": \"interpolation\",\n        \"start_frame\": \"frame_00079.png\",\n    },\n]\n\nsimulate_cfg = {\n    \"substep\": 384,\n    \"grid_size\": 64,\n}\n\n\npoints_list = [\n    np.array([-0.390069, 0.139051, -0.182607]),  # bottom of the hat\n    np.array([-0.404391, 0.184975, -0.001585]),  # middle of the hat\n    np.array([-0.289375, 0.034581, 0.062010]),  # left of the hat\n    np.array([-0.352060, 0.105737, 0.009359]),  # center of the hat\n]\n\nforce_directions = [\n    np.array([1.0, 0.0, 0]),\n    np.array([0.0, 1.0, 0.0]),\n    np.array([1.0, 0.0, 1.0]),\n    np.array([1.0, 1.0, 0.0]),\n    np.array([1.0, 0.0, 1.0]),\n    np.array([0.0, 1.0, 1.0]),\n    np.array([1.0, 1.0, 1.0]),\n]\n\nforce_directions = np.array(force_directions)\nforce_directions = force_directions / np.linalg.norm(force_directions, axis=1)[:, None]\n"
  },
  {
    "path": "projects/uncleaned_train/exp_motion/train/train_material.py",
    "content": "import argparse\nimport os\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom torch import Tensor\nfrom jaxtyping import Float, Int, Shaped\nfrom typing import List\n\nimport point_cloud_utils as pcu\n\nfrom accelerate.utils import ProjectConfiguration\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import set_seed\nfrom accelerate import Accelerator, DistributedDataParallelKwargs\n\nimport numpy as np\nimport logging\nimport argparse\nimport shutil\nimport wandb\nimport torch\nimport os\nfrom motionrep.utils.config import create_config\nfrom motionrep.utils.optimizer import get_linear_schedule_with_warmup\nfrom time import time\nfrom omegaconf import OmegaConf\nfrom PIL import Image\nimport imageio\nimport numpy as np\n\n# from motionrep.utils.torch_utils import get_sync_time\nfrom einops import rearrange, repeat\n\nfrom motionrep.gaussian_3d.gaussian_renderer.feat_render import render_feat_gaussian\nfrom motionrep.gaussian_3d.scene import GaussianModel\nfrom motionrep.fields.se3_field import TemporalKplanesSE3fields\n\nfrom motionrep.data.datasets.multiview_dataset import MultiviewImageDataset\nfrom motionrep.data.datasets.multiview_video_dataset import (\n    MultiviewVideoDataset,\n    camera_dataset_collate_fn,\n)\n\nfrom motionrep.data.datasets.multiview_dataset import (\n    camera_dataset_collate_fn as camera_dataset_collate_fn_img,\n)\n\nfrom typing import NamedTuple\nimport torch.nn.functional as F\n\nfrom motionrep.utils.img_utils import compute_psnr, compute_ssim\nfrom thirdparty_code.warp_mpm.mpm_data_structure import (\n    MPMStateStruct,\n    MPMModelStruct,\n    get_float_array_product,\n)\nfrom thirdparty_code.warp_mpm.mpm_solver_diff import MPMWARPDiff\nfrom thirdparty_code.warp_mpm.warp_utils import from_torch_safe\nfrom thirdparty_code.warp_mpm.gaussian_sim_utils import get_volume\nimport warp as wp\nimport random\n\nfrom local_utils import (\n    cycle,\n    load_motion_model,\n    create_motion_model,\n    create_spatial_fields,\n    find_far_points,\n    LinearStepAnneal,\n    apply_grid_bc_w_freeze_pts,\n    render_gaussian_seq_w_mask_cam_seq,\n    downsample_with_kmeans_gpu,\n    render_gaussian_seq_w_mask_with_disp,\n)\nfrom interface import (\n    MPMDifferentiableSimulationWCheckpoint,\n    MPMDifferentiableSimulationClean,\n)\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\nmodel_dict = {\n    # psnr: 29.9\n    # \"videos\": \"../../output/inverse_sim/fast_hat_velopretraindecay_1.0_substep_96_se3_field_lr_0.001_tv_0.01_iters_300_sw_2_cw_2/seed0/checkpoint_model_000299\",\n    # psnr: 30.25\n    \"videos\": \"../../output/inverse_sim/fast_hat_velopretrain_g48-192decay_1.0_substep_192_se3_field_lr_0.003_tv_0.01_iters_300_sw_2_cw_2/seed0/checkpoint_model_000199\",\n    # psnr: 30.52\n    \"videos_2\": \"../../output/inverse_sim/fast_hat_videos2_velopretraindecay_1.0_substep_96_se3_field_lr_0.003_tv_0.01_iters_300_sw_2_cw_2/seed0/checkpoint_model_000199\",\n}\n\n\ndef create_dataset(args):\n    assert args.dataset_res in [\"middle\", \"small\", \"large\"]\n    if args.dataset_res == \"middle\":\n        res = [320, 576]\n    elif args.dataset_res == \"small\":\n        res = [192, 320]\n    elif args.dataset_res == \"large\":\n        res = [576, 1024]\n    else:\n        raise NotImplementedError\n\n    video_dir_name = \"videos\"\n    video_dir_name = args.video_dir_name\n\n    if args.test_convergence:\n        video_dir_name = \"simulated_videos\"\n    dataset = MultiviewVideoDataset(\n        args.dataset_dir,\n        use_white_background=False,\n        resolution=res,\n        scale_x_angle=1.0,\n        video_dir_name=video_dir_name,\n    )\n\n    test_dataset = MultiviewImageDataset(\n        args.dataset_dir,\n        use_white_background=False,\n        resolution=res,\n        # use_index=list(range(0, 30, 4)),\n        # use_index=[0],\n        scale_x_angle=1.0,\n        fitler_with_renderd=False,\n        load_imgs=False,\n    )\n    print(\"len of test dataset\", len(test_dataset))\n    return dataset, test_dataset\n\n\nclass Trainer:\n    def __init__(self, args):\n        self.args = args\n\n        self.ssim = args.ssim\n        args.warmup_step = int(args.warmup_step * args.gradient_accumulation_steps)\n        args.train_iters = int(args.train_iters * args.gradient_accumulation_steps)\n        os.environ[\"WANDB__SERVICE_WAIT\"] = \"600\"\n        args.wandb_name += (\n            \"decay_{}_substep_{}_{}_lr_{}_tv_{}_iters_{}_sw_{}_cw_{}\".format(\n                args.loss_decay,\n                args.substep,\n                args.model,\n                args.lr,\n                args.tv_loss_weight,\n                args.train_iters,\n                args.start_window_size,\n                args.compute_window,\n            )\n        )\n\n        logging_dir = os.path.join(args.output_dir, args.wandb_name)\n        accelerator_project_config = ProjectConfiguration(logging_dir=logging_dir)\n        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n        accelerator = Accelerator(\n            gradient_accumulation_steps=1,  # args.gradient_accumulation_steps,\n            mixed_precision=\"no\",\n            log_with=\"wandb\",\n            project_config=accelerator_project_config,\n            kwargs_handlers=[ddp_kwargs],\n        )\n        self.gradient_accumulation_steps = args.gradient_accumulation_steps\n        logging.basicConfig(\n            format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n            datefmt=\"%m/%d/%Y %H:%M:%S\",\n            level=logging.INFO,\n        )\n        logger.info(accelerator.state, main_process_only=False)\n\n        set_seed(args.seed + accelerator.process_index)\n        print(\"process index\", accelerator.process_index)\n        if accelerator.is_main_process:\n            output_path = os.path.join(logging_dir, f\"seed{args.seed}\")\n            os.makedirs(output_path, exist_ok=True)\n            self.output_path = output_path\n\n        self.rand_bg = args.rand_bg\n        # setup the dataset\n        dataset, test_dataset = create_dataset(args)\n        self.test_dataset = test_dataset\n\n        dataset_dir = test_dataset.data_dir\n        self.dataset = dataset\n\n        gaussian_path = os.path.join(dataset_dir, \"point_cloud.ply\")\n        aabb = self.setup_eval(\n            args,\n            gaussian_path,\n            white_background=True,\n        )\n        self.aabb = aabb\n        self.model = create_motion_model(\n            args,\n            aabb=aabb,\n            num_frames=9,\n        )\n        if args.motion_model_path is not None:\n            self.model = load_motion_model(self.model, args.motion_model_path)\n        self.model.eval()\n\n        self.num_frames = int(args.num_frames)\n        self.window_size_schduler = LinearStepAnneal(\n            args.train_iters,\n            start_state=[args.start_window_size],\n            end_state=[13],\n            plateau_iters=-1,\n            warmup_step=20,\n        )\n\n        test_dataloader = torch.utils.data.DataLoader(\n            test_dataset,\n            batch_size=args.batch_size,\n            shuffle=False,\n            drop_last=True,\n            num_workers=0,\n            collate_fn=camera_dataset_collate_fn_img,\n        )\n        # why prepare here again?\n        test_dataloader = accelerator.prepare(test_dataloader)\n        self.test_dataloader = cycle(test_dataloader)\n\n        dataloader = torch.utils.data.DataLoader(\n            dataset,\n            batch_size=args.batch_size,\n            shuffle=False,\n            drop_last=False,\n            num_workers=0,\n            collate_fn=camera_dataset_collate_fn,\n        )\n        # why prepare here again?\n        dataloader = accelerator.prepare(dataloader)\n        self.dataloader = cycle(dataloader)\n\n        self.train_iters = args.train_iters\n        self.accelerator = accelerator\n        # init traiable params\n        E_nu_list = self.init_trainable_params()\n        for p in E_nu_list:\n            p.requires_grad = True\n        self.E_nu_list = E_nu_list\n\n        self.model = accelerator.prepare(self.model)\n        self.setup_simulation(dataset_dir, grid_size=args.grid_size)\n\n        if args.checkpoint_path == \"None\":\n            args.checkpoint_path = None\n        if args.checkpoint_path is not None:\n\n            if args.video_dir_name in model_dict:\n                args.checkpoint_path = model_dict[args.video_dir_name]\n            self.load(args.checkpoint_path)\n            trainable_params = list(self.sim_fields.parameters()) + self.E_nu_list\n            optim_list = [\n                {\"params\": self.E_nu_list, \"lr\": args.lr * 1e-10},\n                {\n                    \"params\": self.sim_fields.parameters(),\n                    \"lr\": args.lr,\n                    \"weight_decay\": 1e-4,\n                },\n                # {\"params\": self.velo_fields.parameters(), \"lr\": args.lr * 1e-3, \"weight_decay\": 1e-4},\n            ]\n\n            if args.update_velo:\n                self.freeze_velo = False\n                velo_optim = [\n                    {\n                        \"params\": self.velo_fields.parameters(),\n                        \"lr\": args.lr * 1e-4,\n                        \"weight_decay\": 1e-4,\n                    },\n                ]\n                self.velo_optimizer = torch.optim.AdamW(\n                    velo_optim,\n                    lr=args.lr,\n                    weight_decay=0.0,\n                )\n                self.velo_scheduler = get_linear_schedule_with_warmup(\n                    optimizer=self.velo_optimizer,\n                    num_warmup_steps=args.warmup_step,\n                    num_training_steps=args.train_iters,\n                )\n            else:\n                self.freeze_velo = True\n                self.velo_optimizer = None\n        else:\n            trainable_params = list(self.sim_fields.parameters()) + self.E_nu_list\n            optim_list = [\n                {\"params\": self.E_nu_list, \"lr\": args.lr * 1e-10},\n                {\n                    \"params\": self.sim_fields.parameters(),\n                    \"lr\": args.lr,\n                    \"weight_decay\": 1e-4,\n                },\n            ]\n            self.freeze_velo = False\n            self.window_size_schduler.warmup_step = 800\n\n            velo_optim = [\n                {\n                    \"params\": self.velo_fields.parameters(),\n                    \"lr\": args.lr,\n                    \"weight_decay\": 1e-4,\n                },\n            ]\n            self.velo_optimizer = torch.optim.AdamW(\n                velo_optim,\n                lr=args.lr,\n                weight_decay=0.0,\n            )\n            self.velo_scheduler = get_linear_schedule_with_warmup(\n                optimizer=self.velo_optimizer,\n                num_warmup_steps=args.warmup_step,\n                num_training_steps=args.train_iters // 3,\n            )\n            self.velo_optimizer, self.velo_scheduler = accelerator.prepare(\n                self.velo_optimizer, self.velo_scheduler\n            )\n\n        self.optimizer = torch.optim.AdamW(\n            optim_list,\n            lr=args.lr,\n            weight_decay=0.0,\n        )\n        self.trainable_params = trainable_params\n        self.scheduler = get_linear_schedule_with_warmup(\n            optimizer=self.optimizer,\n            num_warmup_steps=args.warmup_step,\n            num_training_steps=args.train_iters,\n        )\n        self.sim_fields, self.optimizer, self.scheduler = accelerator.prepare(\n            self.sim_fields, self.optimizer, self.scheduler\n        )\n        self.velo_fields = accelerator.prepare(self.velo_fields)\n\n        # setup train info\n        self.step = 0\n        self.batch_size = args.batch_size\n        self.tv_loss_weight = args.tv_loss_weight\n\n        self.log_iters = args.log_iters\n        self.wandb_iters = args.wandb_iters\n        self.max_grad_norm = args.max_grad_norm\n\n        self.use_wandb = args.use_wandb\n        if self.accelerator.is_main_process:\n            if args.use_wandb:\n                run = wandb.init(\n                    config=dict(args),\n                    dir=self.output_path,\n                    **{\n                        \"mode\": \"online\",\n                        \"entity\": args.wandb_entity,\n                        \"project\": args.wandb_project,\n                    },\n                )\n                wandb.run.log_code(\".\")\n                wandb.run.name = args.wandb_name\n                print(f\"run dir: {run.dir}\")\n                self.wandb_folder = run.dir\n                os.makedirs(self.wandb_folder, exist_ok=True)\n\n    def init_trainable_params(\n        self,\n    ):\n\n        # init young modulus and poisson ratio\n\n        young_numpy = np.exp(np.random.uniform(np.log(1e-3), np.log(1e3))).astype(\n            np.float32\n        )\n\n        young_numpy = np.array([1e5]).astype(np.float32)\n\n        young_modulus = torch.tensor(young_numpy, dtype=torch.float32).to(\n            self.accelerator.device\n        )\n\n        poisson_numpy = np.random.uniform(0.1, 0.4)\n        poisson_ratio = torch.tensor(poisson_numpy, dtype=torch.float32).to(\n            self.accelerator.device\n        )\n\n        trainable_params = [young_modulus, poisson_ratio]\n\n        print(\n            \"init young modulus: \",\n            young_modulus.item(),\n            \"poisson ratio: \",\n            poisson_ratio.item(),\n        )\n        return trainable_params\n\n    def setup_simulation(self, dataset_dir, grid_size=100):\n\n        device = \"cuda:{}\".format(self.accelerator.process_index)\n\n        xyzs = self.render_params.gaussians.get_xyz.detach().clone()\n        sim_xyzs = xyzs[self.sim_mask_in_raw_gaussian, :]\n        sim_cov = (\n            self.render_params.gaussians.get_covariance()[\n                self.sim_mask_in_raw_gaussian, :\n            ]\n            .detach()\n            .clone()\n        )\n\n        # scale, and shift\n        pos_max = sim_xyzs.max()\n        pos_min = sim_xyzs.min()\n        scale = (pos_max - pos_min) * 1.8\n        shift = -pos_min + (pos_max - pos_min) * 0.25\n        self.scale, self.shift = scale, shift\n        print(\"scale, shift\", scale, shift)\n\n        # filled\n        filled_in_points_path = os.path.join(dataset_dir, \"internal_filled_points.ply\")\n\n        if os.path.exists(filled_in_points_path):\n            fill_xyzs = pcu.load_mesh_v(filled_in_points_path)  # [n, 3]\n            fill_xyzs = fill_xyzs[\n                np.random.choice(\n                    fill_xyzs.shape[0], int(fill_xyzs.shape[0] * 0.25), replace=False\n                )\n            ]\n            fill_xyzs = torch.from_numpy(fill_xyzs).float().to(\"cuda\")\n            self.fill_xyzs = fill_xyzs\n            print(\n                \"loaded {} internal filled points from: \".format(fill_xyzs.shape[0]),\n                filled_in_points_path,\n            )\n        else:\n            self.fill_xyzs = None\n\n        if self.fill_xyzs is not None:\n            render_mask_in_sim_pts = torch.cat(\n                [\n                    torch.ones_like(sim_xyzs[:, 0]).bool(),\n                    torch.zeros_like(fill_xyzs[:, 0]).bool(),\n                ],\n                dim=0,\n            ).to(device)\n            sim_xyzs = torch.cat([sim_xyzs, fill_xyzs], dim=0)\n            sim_cov = torch.cat(\n                [sim_cov, sim_cov.new_ones((fill_xyzs.shape[0], sim_cov.shape[-1]))],\n                dim=0,\n            )\n            self.render_mask = render_mask_in_sim_pts\n        else:\n            self.render_mask = torch.ones_like(sim_xyzs[:, 0]).bool().to(device)\n\n        sim_xyzs = (sim_xyzs + shift) / scale\n\n        sim_aabb = torch.stack(\n            [torch.min(sim_xyzs, dim=0)[0], torch.max(sim_xyzs, dim=0)[0]], dim=0\n        )\n        sim_aabb = (\n            sim_aabb - torch.mean(sim_aabb, dim=0, keepdim=True)\n        ) * 1.2 + torch.mean(sim_aabb, dim=0, keepdim=True)\n\n        print(\"simulation aabb: \", sim_aabb)\n\n        # point cloud resample with kmeans\n\n        downsample_scale = self.args.downsample_scale\n        num_cluster = int(sim_xyzs.shape[0] * downsample_scale)\n        sim_xyzs = downsample_with_kmeans_gpu(sim_xyzs, num_cluster)\n\n        sim_gaussian_pos = self.render_params.gaussians.get_xyz.detach().clone()[\n            self.sim_mask_in_raw_gaussian, :\n        ]\n        sim_gaussian_pos = (sim_gaussian_pos + shift) / scale\n\n        cdist = torch.cdist(sim_gaussian_pos, sim_xyzs) * -1.0\n        _, top_k_index = torch.topk(cdist, self.args.top_k, dim=-1)\n        self.top_k_index = top_k_index\n\n        print(\"Downsampled to: \", sim_xyzs.shape[0], \"by\", downsample_scale)\n\n        points_volume = get_volume(sim_xyzs.detach().cpu().numpy())\n\n        num_particles = sim_xyzs.shape[0]\n\n        sim_aabb = torch.stack(\n            [torch.min(sim_xyzs, dim=0)[0], torch.max(sim_xyzs, dim=0)[0]], dim=0\n        )\n        sim_aabb = (\n            sim_aabb - torch.mean(sim_aabb, dim=0, keepdim=True)\n        ) * 1.2 + torch.mean(sim_aabb, dim=0, keepdim=True)\n\n        print(\"simulation aabb: \", sim_aabb)\n\n        wp.init()\n        wp.config.mode = \"debug\"\n        wp.config.verify_cuda = True\n\n        mpm_state = MPMStateStruct()\n        mpm_state.init(num_particles, device=device, requires_grad=True)\n\n        self.particle_init_position = sim_xyzs.clone()\n\n        mpm_state.from_torch(\n            self.particle_init_position.clone(),\n            torch.from_numpy(points_volume).float().to(device).clone(),\n            sim_cov,\n            device=device,\n            requires_grad=True,\n            n_grid=grid_size,\n            grid_lim=1.0,\n        )\n        mpm_model = MPMModelStruct()\n        mpm_model.init(num_particles, device=device, requires_grad=True)\n        mpm_model.init_other_params(n_grid=grid_size, grid_lim=1.0, device=device)\n\n        material_params = {\n            \"material\": \"jelly\",  # \"jelly\", \"metal\", \"sand\", \"foam\", \"snow\", \"plasticine\", \"neo-hookean\"\n            \"g\": [0.0, 0.0, 0.0],\n            \"density\": 2000,  # kg / m^3\n            \"grid_v_damping_scale\": 1.1,  # 0.999,\n        }\n\n        self.v_damping = material_params[\"grid_v_damping_scale\"]\n        self.material_name = material_params[\"material\"]\n        mpm_solver = MPMWARPDiff(\n            num_particles, n_grid=grid_size, grid_lim=1.0, device=device\n        )\n        mpm_solver.set_parameters_dict(mpm_model, mpm_state, material_params)\n\n        self.mpm_state, self.mpm_model, self.mpm_solver = (\n            mpm_state,\n            mpm_model,\n            mpm_solver,\n        )\n\n        # setup boundary condition:\n        moving_pts_path = os.path.join(dataset_dir, \"moving_part_points.ply\")\n        if os.path.exists(moving_pts_path):\n            moving_pts = pcu.load_mesh_v(moving_pts_path)\n            moving_pts = torch.from_numpy(moving_pts).float().to(device)\n            moving_pts = (moving_pts + shift) / scale\n            freeze_mask = find_far_points(\n                sim_xyzs, moving_pts, thres=0.5 / grid_size\n            ).bool()\n            freeze_pts = sim_xyzs[freeze_mask, :]\n\n            grid_freeze_mask = apply_grid_bc_w_freeze_pts(\n                grid_size, 1.0, freeze_pts, mpm_solver\n            )\n            self.freeze_mask = freeze_mask\n\n            # does not prefer boundary condition on particle\n            # freeze_mask_select = setup_boundary_condition_with_points(sim_xyzs, moving_pts,\n            #                                                         self.mpm_solver, self.mpm_state, thres=0.5 / grid_size)\n            # self.freeze_mask = freeze_mask_select.bool()\n        else:\n            raise NotImplementedError\n\n        num_freeze_pts = self.freeze_mask.sum()\n        print(\n            \"num freeze pts in total\",\n            num_freeze_pts.item(),\n            \"num moving pts\",\n            num_particles - num_freeze_pts.item(),\n        )\n\n        # init fields for simulation, e.g. density, external force, etc.\n\n        # padd init density, youngs,\n        density = (\n            torch.ones_like(self.particle_init_position[..., 0])\n            * material_params[\"density\"]\n        )\n        youngs_modulus = (\n            torch.ones_like(self.particle_init_position[..., 0])\n            * self.E_nu_list[0].detach()\n        )\n        poisson_ratio = torch.ones_like(self.particle_init_position[..., 0]) * 0.3\n\n        # load stem for higher density\n        stem_pts_path = os.path.join(dataset_dir, \"stem_points.ply\")\n        if os.path.exists(stem_pts_path):\n            stem_pts = pcu.load_mesh_v(stem_pts_path)\n            stem_pts = torch.from_numpy(stem_pts).float().to(device)\n            stem_pts = (stem_pts + shift) / scale\n            no_stem_mask = find_far_points(\n                sim_xyzs, stem_pts, thres=2.0 / grid_size\n            ).bool()\n            stem_mask = torch.logical_not(no_stem_mask)\n            density[stem_mask] = 2000\n            print(\"num stem pts\", stem_mask.sum().item())\n\n        self.density = density\n        self.young_modulus = youngs_modulus\n        self.poisson_ratio = poisson_ratio\n\n        # set density, youngs, poisson\n        mpm_state.reset_density(\n            density.clone(),\n            torch.ones_like(density).type(torch.int),\n            device,\n            update_mass=True,\n        )\n        mpm_solver.set_E_nu_from_torch(\n            mpm_model, youngs_modulus.clone(), poisson_ratio.clone(), device\n        )\n        mpm_solver.prepare_mu_lam(mpm_model, mpm_state, device)\n\n        self.sim_fields = create_spatial_fields(self.args, 1, sim_aabb)\n        self.sim_fields.train()\n\n        self.args.sim_res = 24\n        # self.velo_fields = create_velocity_model(self.args, sim_aabb)\n        self.velo_fields = create_spatial_fields(\n            self.args, 3, sim_aabb, add_entropy=False\n        )\n        self.velo_fields.train()\n\n    def get_simulation_input(self, device):\n        \"\"\"\n        Outs: All padded\n            density: [N]\n            young_modulus: [N]\n            poisson_ratio: [N]\n            velocity: [N, 3]\n            query_mask: [N]\n        \"\"\"\n\n        density, youngs_modulus, ret_poisson, entropy = self.get_material_params(device)\n        initial_position_time0 = self.particle_init_position.clone()\n\n        query_mask = torch.logical_not(self.freeze_mask)\n        query_pts = initial_position_time0[query_mask, :]\n\n        # velocity = self.velo_fields(torch.cat([query_pts, time_array.unsqueeze(-1)], dim=-1))[..., :3]\n        velocity = self.velo_fields(query_pts)[..., :3]\n\n        # scaling\n        velocity = velocity * 0.1  # not padded yet\n        ret_velocity = torch.zeros_like(initial_position_time0)\n        ret_velocity[query_mask, :] = velocity\n\n        # init F, and C\n\n        I_mat = torch.eye(3, dtype=torch.float32).to(device)\n        particle_F = torch.repeat_interleave(\n            I_mat[None, ...], initial_position_time0.shape[0], dim=0\n        )\n        particle_C = torch.zeros_like(particle_F)\n\n        return (\n            density,\n            youngs_modulus,\n            ret_poisson,\n            ret_velocity,\n            query_mask,\n            particle_F,\n            particle_C,\n            entropy,\n        )\n\n    def get_material_params(self, device):\n\n        initial_position_time0 = self.particle_init_position.detach()\n\n        # query_mask = torch.logical_not(self.freeze_mask)\n        query_mask = torch.ones_like(self.freeze_mask).bool()\n        query_pts = initial_position_time0[query_mask, :]\n        if self.args.entropy_cls > 0:\n            sim_params, entropy = self.sim_fields(query_pts)\n        else:\n            sim_params = self.sim_fields(query_pts)\n            entropy = torch.zeros(1).to(sim_params.device)\n\n        sim_params = sim_params * 1000\n        # sim_params = torch.exp(self.sim_fields(query_pts))\n\n        # density = sim_params[..., 0]\n\n        youngs_modulus = self.young_modulus.detach().clone()\n        youngs_modulus[query_mask] += sim_params[..., 0]\n\n        # young_modulus = torch.exp(sim_params[..., 0]) + init_young\n        youngs_modulus = torch.clamp(youngs_modulus, 1000.0, 5e8)\n\n        density = self.density.detach().clone()\n        # density[self.freeze_mask] = 100000\n        ret_poisson = self.poisson_ratio.detach().clone()\n\n        return density, youngs_modulus, ret_poisson, entropy\n\n    def train_one_step(self):\n\n        self.sim_fields.train()\n        self.velo_fields.train()\n        self.model.eval()\n        accelerator = self.accelerator\n        device = \"cuda:{}\".format(accelerator.process_index)\n        data = next(self.dataloader)\n        cam = data[\"cam\"][0]\n\n        gt_videos = data[\"video_clip\"][0, 1 : self.num_frames, ...]\n\n        window_size = int(self.window_size_schduler.compute_state(self.step)[0])\n        stop_velo_opt_thres = 15\n        do_velo_opt = not self.freeze_velo\n        if not do_velo_opt:\n            stop_velo_opt_thres = (\n                0  # stop velocity optimization if we are loading from checkpoint\n            )\n            self.velo_fields.eval()\n\n        rendered_video_list = []\n        log_loss_dict = {\n            \"loss\": [],\n            \"l2_loss\": [],\n            \"psnr\": [],\n            \"ssim\": [],\n            \"entropy\": [],\n        }\n        log_psnr_dict = {}\n\n        particle_pos = self.particle_init_position.clone()\n        # clean grid, stress, F, C and rest initial position\n        self.mpm_state.reset_state(\n            particle_pos.clone(),\n            None,\n            None,  # .clone(),\n            device=device,\n            requires_grad=True,\n        )\n        self.mpm_state.set_require_grad(True)\n\n        (\n            density,\n            youngs_modulus,\n            poisson,\n            particle_velo,\n            query_mask,\n            particle_F,\n            particle_C,\n            entropy,\n        ) = self.get_simulation_input(device)\n\n        init_velo_mean = particle_velo[query_mask, :].mean().item()\n        init_velo_max = particle_velo[query_mask, :].max().item()\n\n        if not do_velo_opt:\n            particle_velo = particle_velo.detach()\n        # print(\"does do velo opt\": do_velo_opt)\n\n        num_particles = particle_pos.shape[0]\n\n        delta_time = 1.0 / 30  # 30 fps\n        substep_size = delta_time / self.args.substep\n        num_substeps = int(delta_time / substep_size)\n\n        checkpoint_steps = self.args.checkpoint_steps\n\n        start_time_idx = max(0, window_size - self.args.compute_window)\n\n        temporal_stride = self.args.stride\n\n        if temporal_stride < 0 or temporal_stride > window_size:\n            temporal_stride = window_size\n\n        for start_time_idx in range(0, window_size, temporal_stride):\n\n            end_time_idx = min(start_time_idx + temporal_stride, window_size)\n\n            num_step_with_grad = num_substeps * (end_time_idx - start_time_idx)\n\n            gt_frame = gt_videos[[end_time_idx - 1]]\n\n            if start_time_idx != 0:\n                density, youngs_modulus, poisson, entropy = self.get_material_params(\n                    device\n                )\n\n            if checkpoint_steps > 0 and checkpoint_steps < num_step_with_grad:\n                for time_step in range(0, num_step_with_grad, checkpoint_steps):\n                    num_step = min(num_step_with_grad - time_step, checkpoint_steps)\n                    if num_step == 0:\n                        break\n                    particle_pos, particle_velo, particle_F, particle_C = (\n                        MPMDifferentiableSimulationWCheckpoint.apply(\n                            self.mpm_solver,\n                            self.mpm_state,\n                            self.mpm_model,\n                            substep_size,\n                            num_step,\n                            particle_pos,\n                            particle_velo,\n                            particle_F,\n                            particle_C,\n                            youngs_modulus,\n                            self.E_nu_list[1],\n                            density,\n                            query_mask,\n                            device,\n                            True,\n                            0,\n                        )\n                    )\n            else:\n                particle_pos, particle_velo, particle_F, particle_C, particle_cov = (\n                    MPMDifferentiableSimulationClean.apply(\n                        self.mpm_solver,\n                        self.mpm_state,\n                        self.mpm_model,\n                        substep_size,\n                        num_step_with_grad,\n                        particle_pos,\n                        particle_velo,\n                        particle_F,\n                        particle_C,\n                        youngs_modulus,\n                        self.E_nu_list[1],\n                        density,\n                        query_mask,\n                        device,\n                        True,\n                        0,\n                    )\n                )\n\n            # substep-3: render gaussian\n\n            gaussian_pos = particle_pos * self.scale - self.shift\n            undeformed_gaussian_pos = (\n                self.particle_init_position * self.scale - self.shift\n            )\n            disp_offset = gaussian_pos - undeformed_gaussian_pos.detach()\n            # gaussian_pos.requires_grad = True\n\n            simulated_video = render_gaussian_seq_w_mask_with_disp(\n                cam,\n                self.render_params,\n                undeformed_gaussian_pos.detach(),\n                self.top_k_index,\n                [disp_offset],\n                self.sim_mask_in_raw_gaussian,\n            )\n\n            # print(\"debug\", simulated_video.shape, gt_frame.shape, gaussian_pos.shape, init_xyzs.shape, density.shape, query_mask.sum().item())\n            rendered_video_list.append(simulated_video.detach())\n\n            l2_loss = 0.5 * F.mse_loss(simulated_video, gt_frame, reduction=\"mean\")\n            ssim_loss = compute_ssim(simulated_video, gt_frame)\n            loss = l2_loss * (1.0 - self.ssim) + (1.0 - ssim_loss) * self.ssim\n\n            loss = loss * (self.args.loss_decay**end_time_idx)\n            sm_velo_loss = self.velo_fields.compute_smoothess_loss() * 10.0\n            if not (do_velo_opt and start_time_idx == 0):\n                sm_velo_loss = sm_velo_loss.detach()\n\n            sm_spatial_loss = self.sim_fields.compute_smoothess_loss()\n\n            sm_loss = (\n                sm_velo_loss + sm_spatial_loss\n            )  # typically 20 times larger than rendering loss\n\n            loss = loss + sm_loss * self.tv_loss_weight\n            loss = loss + entropy * self.args.entropy_reg\n            loss = loss / self.args.compute_window\n            loss.backward()\n\n            # from IPython import embed; embed()\n            # print(self.E_nu_list[1].grad)\n\n            particle_pos, particle_velo, particle_F, particle_C = (\n                particle_pos.detach(),\n                particle_velo.detach(),\n                particle_F.detach(),\n                particle_C.detach(),\n            )\n\n            with torch.no_grad():\n                psnr = compute_psnr(simulated_video, gt_frame).mean()\n                log_loss_dict[\"loss\"].append(loss.item())\n                log_loss_dict[\"l2_loss\"].append(l2_loss.item())\n                log_loss_dict[\"psnr\"].append(psnr.item())\n                log_loss_dict[\"ssim\"].append(ssim_loss.item())\n                log_loss_dict[\"entropy\"].append(entropy.item())\n\n                print(\n                    psnr.item(),\n                    end_time_idx,\n                    youngs_modulus.max().item(),\n                    density.max().item(),\n                )\n                log_psnr_dict[\"psnr_frame_{}\".format(end_time_idx)] = psnr.item()\n                # print(psnr.item(), end_time_idx, youngs_modulus.max().item(), density.max().item())\n\n        nu_grad_norm = self.E_nu_list[1].grad.norm(2).item()\n        spatial_grad_norm = 0\n        for p in self.sim_fields.parameters():\n            if p.grad is not None:\n                spatial_grad_norm += p.grad.norm(2).item()\n        velo_grad_norm = 0\n        for p in self.velo_fields.parameters():\n            if p.grad is not None:\n                velo_grad_norm += p.grad.norm(2).item()\n\n        renderd_video = torch.cat(rendered_video_list, dim=0)\n        renderd_video = torch.clamp(renderd_video, 0.0, 1.0)\n        visual_video = (renderd_video.detach().cpu().numpy() * 255.0).astype(np.uint8)\n        gt_video = (gt_videos.detach().cpu().numpy() * 255.0).astype(np.uint8)\n\n        if (\n            self.step % self.gradient_accumulation_steps == 0\n            or self.step == (self.train_iters - 1)\n            or (self.step % self.log_iters == self.log_iters - 1)\n        ):\n\n            torch.nn.utils.clip_grad_norm_(\n                self.trainable_params,\n                self.max_grad_norm,\n                error_if_nonfinite=False,\n            )  # error if nonfinite is false\n\n            self.optimizer.step()\n            self.optimizer.zero_grad()\n            if do_velo_opt:\n                assert self.velo_optimizer is not None\n                torch.nn.utils.clip_grad_norm_(\n                    self.velo_fields.parameters(),\n                    self.max_grad_norm,\n                    error_if_nonfinite=False,\n                )  # error if nonfinite is false\n                self.velo_optimizer.step()\n                self.velo_optimizer.zero_grad()\n                self.velo_scheduler.step()\n            with torch.no_grad():\n                self.E_nu_list[0].data.clamp_(1e-1, 1e8)\n                self.E_nu_list[1].data.clamp_(1e-2, 0.449)\n        self.scheduler.step()\n\n        for k, v in log_loss_dict.items():\n            log_loss_dict[k] = np.mean(v)\n\n        print(log_loss_dict)\n        print(\n            \"nu: \",\n            self.E_nu_list[1].item(),\n            nu_grad_norm,\n            spatial_grad_norm,\n            velo_grad_norm,\n            \"young_mean, max:\",\n            youngs_modulus.mean().item(),\n            youngs_modulus.max().item(),\n            do_velo_opt,\n            \"init_velo_mean:\",\n            init_velo_mean,\n        )\n\n        if accelerator.is_main_process and (self.step % self.wandb_iters == 0):\n            with torch.no_grad():\n                wandb_dict = {\n                    \"nu_grad_norm\": nu_grad_norm,\n                    \"spatial_grad_norm\": spatial_grad_norm,\n                    \"velo_grad_norm\": velo_grad_norm,\n                    \"nu\": self.E_nu_list[1].item(),\n                    # \"mean_density\": density.mean().item(),\n                    \"mean_E\": youngs_modulus.mean().item(),\n                    \"max_E\": youngs_modulus.max().item(),\n                    \"min_E\": youngs_modulus.min().item(),\n                    \"smoothness_loss\": sm_loss.item(),\n                    \"window_size\": window_size,\n                    \"max_particle_velo\": particle_velo.max().item(),\n                    \"init_velo_mean\": init_velo_mean,\n                    \"init_velo_max\": init_velo_max,\n                }\n\n                wandb_dict.update(log_psnr_dict)\n                simulated_video = self.inference(cam, substep=num_substeps)\n                sim_video_torch = (\n                    torch.from_numpy(simulated_video).float().to(device) / 255.0\n                )\n                gt_video_torch = torch.from_numpy(gt_video).float().to(device) / 255.0\n\n                full_psnr = compute_psnr(sim_video_torch[1:], gt_video_torch)\n\n                first_psnr = full_psnr[:6].mean().item()\n                last_psnr = full_psnr[-6:].mean().item()\n                full_psnr = full_psnr.mean().item()\n                wandb_dict[\"full_psnr\"] = full_psnr\n                wandb_dict[\"first_psnr\"] = first_psnr\n                wandb_dict[\"last_psnr\"] = last_psnr\n                wandb_dict.update(log_loss_dict)\n\n                # add young render\n\n                youngs_norm = youngs_modulus - youngs_modulus.min() + 1e-2\n                young_color = youngs_norm / torch.quantile(youngs_norm, 0.99)\n                young_color = torch.clamp(young_color, 0.0, 1.0)\n                young_color[self.freeze_mask] = 0.0\n                queryed_young_color = young_color[self.top_k_index]  # [n_raw, topk]\n                young_color = queryed_young_color.mean(dim=-1)\n\n                young_color_full = torch.ones_like(\n                    self.render_params.gaussians._xyz[:, 0]\n                )\n\n                young_color_full[self.sim_mask_in_raw_gaussian] = young_color\n                young_color = torch.stack(\n                    [young_color_full, young_color_full, young_color_full], dim=-1\n                )\n\n                young_img = render_feat_gaussian(\n                    cam,\n                    self.render_params.gaussians,\n                    self.render_params.render_pipe,\n                    self.render_params.bg_color,\n                    young_color,\n                )[\"render\"]\n                young_img = (\n                    (young_img.detach().cpu().numpy() * 255.0)\n                    .astype(np.uint8)\n                    .transpose(1, 2, 0)\n                )\n                wandb_dict[\"young_img\"] = wandb.Image(young_img)\n\n                if self.step % int(10 * self.wandb_iters) == 0:\n\n                    wandb_dict[\"rendered_video\"] = wandb.Video(\n                        visual_video, fps=visual_video.shape[0]\n                    )\n\n                    wandb_dict[\"gt_video\"] = wandb.Video(\n                        gt_video,\n                        fps=gt_video.shape[0],\n                    )\n\n                    wandb_dict[\"inference_video\"] = wandb.Video(\n                        simulated_video,\n                        fps=simulated_video.shape[0],\n                    )\n\n                    simulated_video = self.inference(\n                        cam, velo_scaling=5.0, num_sec=3, substep=num_substeps\n                    )\n                    wandb_dict[\"inference_video_v5_t3\"] = wandb.Video(\n                        simulated_video,\n                        fps=30,\n                    )\n\n                if self.use_wandb:\n                    wandb.log(wandb_dict, step=self.step)\n\n        self.accelerator.wait_for_everyone()\n\n    def train(self):\n        # might remove tqdm when multiple node\n        for index in tqdm(range(self.step, self.train_iters), desc=\"Training progress\"):\n            self.train_one_step()\n            if self.step % self.log_iters == self.log_iters - 1:\n                if self.accelerator.is_main_process:\n                    self.save()\n                    # self.test()\n            # self.accelerator.wait_for_everyone()\n            self.step += 1\n        if self.accelerator.is_main_process:\n            self.save()\n\n    @torch.no_grad()\n    def inference(\n        self,\n        cam,\n        velo_scaling=1.0,\n        num_sec=1,\n        nu=None,\n        young_scaling=1.0,\n        substep=64,\n        youngs_modulus=None,\n    ):\n\n        self.sim_fields.eval()\n        self.velo_fields.eval()\n\n        device = \"cuda:{}\".format(self.accelerator.process_index)\n\n        (\n            density,\n            youngs_modulus_,\n            poisson,\n            init_velocity,\n            query_mask,\n            particle_F,\n            particle_C,\n            entropy,\n        ) = self.get_simulation_input(device)\n\n        poisson = self.E_nu_list[1].detach().clone()  # override poisson\n\n        if youngs_modulus is None:\n            youngs_modulus = youngs_modulus_ * young_scaling\n        init_xyzs = self.particle_init_position.clone()\n\n        init_velocity[query_mask, :] = init_velocity[query_mask, :] * velo_scaling\n\n        num_particles = init_xyzs.shape[0]\n\n        # delta_time = 1.0 / (self.num_frames - 1)\n        delta_time = 1.0 / 30  # 30 fps\n        substep_size = delta_time / substep\n        num_substeps = int(delta_time / substep_size)\n        # reset state\n\n        self.mpm_state.reset_density(\n            density.clone(), query_mask, device, update_mass=True\n        )\n        self.mpm_solver.set_E_nu_from_torch(\n            self.mpm_model, youngs_modulus.clone(), poisson.clone(), device\n        )\n        self.mpm_solver.prepare_mu_lam(self.mpm_model, self.mpm_state, device)\n\n        self.mpm_state.continue_from_torch(\n            init_xyzs,\n            init_velocity,\n            particle_F,\n            particle_C,\n            device=device,\n            requires_grad=False,\n        )\n\n        pos_list = [self.particle_init_position.clone() * self.scale - self.shift]\n\n        prev_state = self.mpm_state\n        for i in tqdm(range((self.num_frames - 1) * num_sec)):\n            # for substep in range(num_substeps):\n            #     self.mpm_solver.p2g2p(self.mpm_model, self.mpm_state, substep, substep_size, device=\"cuda:0\")\n            # pos = wp.to_torch(self.mpm_state.particle_x).clone()\n\n            for substep_local in range(num_substeps):\n                next_state = prev_state.partial_clone(requires_grad=False)\n                self.mpm_solver.p2g2p_differentiable(\n                    self.mpm_model, prev_state, next_state, substep_size, device=device\n                )\n                prev_state = next_state\n\n            pos = wp.to_torch(next_state.particle_x).clone()\n            pos = (pos * self.scale) - self.shift\n            pos_list.append(pos)\n\n        init_pos = pos_list[0].clone()\n        pos_diff_list = [_ - init_pos for _ in pos_list]\n\n        video_array = render_gaussian_seq_w_mask_with_disp(\n            cam,\n            self.render_params,\n            init_pos,\n            self.top_k_index,\n            pos_diff_list,\n            self.sim_mask_in_raw_gaussian,\n        )\n\n        video_numpy = video_array.detach().cpu().numpy() * 255\n        video_numpy = np.clip(video_numpy, 0, 255).astype(np.uint8)\n\n        return video_numpy\n\n    def save(\n        self,\n    ):\n        # training states\n        output_path = os.path.join(\n            self.output_path, f\"checkpoint_model_{self.step:06d}\"\n        )\n        os.makedirs(output_path, exist_ok=True)\n\n        name_list = [\n            \"velo_fields\",\n            \"sim_fields\",\n        ]\n        for i, model in enumerate(\n            [\n                self.accelerator.unwrap_model(self.velo_fields, keep_fp32_wrapper=True),\n                self.accelerator.unwrap_model(self.sim_fields, keep_fp32_wrapper=True),\n            ]\n        ):\n            model_name = name_list[i]\n            model_path = os.path.join(output_path, model_name + \".pt\")\n            torch.save(model.state_dict(), model_path)\n\n    def load(self, checkpoint_dir):\n        name_list = [\n            \"velo_fields\",\n            \"sim_fields\",\n        ]\n        for i, model in enumerate([self.velo_fields, self.sim_fields]):\n            model_name = name_list[i]\n            if model_name == \"sim_fields\" and (not self.args.load_sim):\n                continue\n            model_path = os.path.join(checkpoint_dir, model_name + \".pt\")\n            print(\"=> loading: \", model_path)\n            model.load_state_dict(torch.load(model_path))\n\n    def setup_eval(self, args, gaussian_path, white_background=True):\n        # setup gaussians\n        class RenderPipe(NamedTuple):\n            convert_SHs_python = False\n            compute_cov3D_python = False\n            debug = False\n\n        class RenderParams(NamedTuple):\n            render_pipe: RenderPipe\n            bg_color: bool\n            gaussians: GaussianModel\n            camera_list: list\n\n        gaussians = GaussianModel(3)\n        camera_list = self.dataset.test_camera_list\n\n        gaussians.load_ply(gaussian_path)\n        gaussians.detach_grad()\n        print(\n            \"load gaussians from: {}\".format(gaussian_path),\n            \"... num gaussians: \",\n            gaussians._xyz.shape[0],\n        )\n        bg_color = [1, 1, 1] if white_background else [0, 0, 0]\n        background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n        render_pipe = RenderPipe()\n\n        render_params = RenderParams(\n            render_pipe=render_pipe,\n            bg_color=background,\n            gaussians=gaussians,\n            camera_list=camera_list,\n        )\n        self.render_params = render_params\n\n        # get_gaussian scene box\n        scaler = 1.1\n        points = gaussians._xyz\n\n        min_xyz = torch.min(points, dim=0)[0]\n        max_xyz = torch.max(points, dim=0)[0]\n\n        center = (min_xyz + max_xyz) / 2\n\n        scaled_min_xyz = (min_xyz - center) * scaler + center\n        scaled_max_xyz = (max_xyz - center) * scaler + center\n\n        aabb = torch.stack([scaled_min_xyz, scaled_max_xyz], dim=0)\n\n        # add filled in points\n        gaussian_dir = os.path.dirname(gaussian_path)\n\n        clean_points_path = os.path.join(gaussian_dir, \"clean_object_points.ply\")\n        if os.path.exists(clean_points_path):\n            clean_xyzs = pcu.load_mesh_v(clean_points_path)\n            clean_xyzs = torch.from_numpy(clean_xyzs).float().to(\"cuda\")\n            self.clean_xyzs = clean_xyzs\n            print(\n                \"loaded {} clean points from: \".format(clean_xyzs.shape[0]),\n                clean_points_path,\n            )\n            # we can use tight threshold here\n            not_sim_maks = find_far_points(\n                gaussians._xyz, clean_xyzs, thres=0.01\n            ).bool()\n            sim_mask_in_raw_gaussian = torch.logical_not(not_sim_maks)\n            # [N]\n            self.sim_mask_in_raw_gaussian = sim_mask_in_raw_gaussian\n        else:\n            self.clean_xyzs = None\n            self.sim_mask_in_raw_gaussian = torch.ones_like(gaussians._xyz[:, 0]).bool()\n\n        return aabb\n\n    def demo(\n        self,\n        velo_scaling=5.0,\n        num_sec=8.0,\n        eval_ys=1.0,\n        static_camera=False,\n        save_name=\"demo_3sec\",\n    ):\n\n        result_dir = \"output/alocasia/results\"\n        pos_path = os.path.join(result_dir, save_name + \"_pos.npy\")\n\n        if os.path.exists(pos_path):\n            pos_array = np.load(pos_path)\n        else:\n            pos_array = None\n        pos_array = None\n        accelerator = self.accelerator\n        data = next(self.dataloader)\n        cam = data[\"cam\"][0]\n\n        for i in range(10):\n            next_data = next(self.test_dataloader)\n        next_cam = next_data[\"cam\"][0]\n\n        substep = self.args.substep  # 1e-4\n\n        youngs_modulus = None\n\n        self.sim_fields.eval()\n        self.velo_fields.eval()\n\n        device = \"cuda:{}\".format(self.accelerator.process_index)\n\n        (\n            density,\n            youngs_modulus_,\n            poisson,\n            init_velocity,\n            query_mask,\n            particle_F,\n            particle_C,\n            entropy,\n        ) = self.get_simulation_input(device)\n\n        poisson = self.E_nu_list[1].detach().clone()  # override poisson\n\n        if eval_ys < 10:\n            youngs_modulus = youngs_modulus_\n        else:\n            youngs_modulus = torch.ones_like(youngs_modulus_) * eval_ys\n\n        # from IPython import embed; embed()\n\n        if pos_array is None:\n            init_xyzs = self.particle_init_position.clone()\n\n            init_velocity[query_mask, :] = init_velocity[query_mask, :] * velo_scaling\n\n            num_particles = init_xyzs.shape[0]\n\n            # delta_time = 1.0 / (self.num_frames - 1)\n            delta_time = 1.0 / 30  # 30 fps\n            substep_size = delta_time / substep\n            num_substeps = int(delta_time / substep_size)\n            # reset state\n\n            self.mpm_state.reset_density(\n                density.clone(), query_mask, device, update_mass=True\n            )\n            self.mpm_solver.set_E_nu_from_torch(\n                self.mpm_model, youngs_modulus.clone(), poisson.clone(), device\n            )\n            self.mpm_solver.prepare_mu_lam(self.mpm_model, self.mpm_state, device)\n\n            self.mpm_state.continue_from_torch(\n                init_xyzs,\n                init_velocity,\n                particle_F,\n                particle_C,\n                device=device,\n                requires_grad=False,\n            )\n\n            pos_list = [self.particle_init_position.clone() * self.scale - self.shift]\n\n            prev_state = self.mpm_state\n            for i in tqdm(range(int((self.num_frames - 1) * num_sec))):\n                # for substep in range(num_substeps):\n                #     self.mpm_solver.p2g2p(self.mpm_model, self.mpm_state, substep, substep_size, device=\"cuda:0\")\n                # pos = wp.to_torch(self.mpm_state.particle_x).clone()\n\n                for substep_local in range(num_substeps):\n                    next_state = prev_state.partial_clone(requires_grad=False)\n                    self.mpm_solver.p2g2p_differentiable(\n                        self.mpm_model,\n                        prev_state,\n                        next_state,\n                        substep_size,\n                        device=device,\n                    )\n                    prev_state = next_state\n\n                pos = wp.to_torch(next_state.particle_x).clone()\n                pos = (pos * self.scale) - self.shift\n                pos_list.append(pos)\n\n            numpy_pos = torch.stack(pos_list, dim=0).detach().cpu().numpy()\n\n            np.save(pos_path, numpy_pos)\n        else:\n            pos_list = []\n            for i in range(pos_array.shape[0]):\n                pos = pos_array[i, ...]\n                pos_list.append(torch.from_numpy(pos).to(device))\n\n        init_pos = pos_list[0].clone()\n        pos_diff_list = [_ - init_pos for _ in pos_list]\n\n        video_array = render_gaussian_seq_w_mask_with_disp(\n            cam,\n            self.render_params,\n            init_pos,\n            self.top_k_index,\n            pos_diff_list,\n            self.sim_mask_in_raw_gaussian,\n        )\n\n        video_numpy = video_array.detach().cpu().numpy() * 255\n        video_numpy = np.clip(video_numpy, 0, 255).astype(np.uint8)\n        video_numpy = np.transpose(video_numpy, [0, 2, 3, 1])\n\n        from motionrep.utils.io_utils import save_video_imageio, save_gif_imageio\n\n        save_path = os.path.join(\n            save_name\n            + \"_jelly_video_substep_{}_grid_{}_evalys_{}\".format(\n                substep, self.args.grid_size, eval_ys\n            )\n            + \".gif\"\n        )\n        print(\"save video to \", save_path)\n        save_gif_imageio(save_path, video_numpy, fps=30)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--config\", type=str, default=\"config.yml\")\n\n    # dataset params\n    parser.add_argument(\n        \"--dataset_dir\",\n        type=str,\n        default=\"../../data/physics_dreamer/hat_nerfstudio/\",\n    )\n    parser.add_argument(\"--video_dir_name\", type=str, default=\"videos\")\n    parser.add_argument(\n        \"--dataset_res\",\n        type=str,\n        default=\"large\",  # [\"middle\", \"small\", \"large\"]\n    )\n    parser.add_argument(\n        \"--motion_model_path\",\n        type=str,\n        default=None,  # not used\n        help=\"path to load the pretrained motion model from\",\n    )\n\n    parser.add_argument(\"--model\", type=str, default=\"se3_field\")\n    parser.add_argument(\"--feat_dim\", type=int, default=64)\n    parser.add_argument(\"--num_decoder_layers\", type=int, default=3)\n    parser.add_argument(\"--decoder_hidden_size\", type=int, default=64)\n    parser.add_argument(\"--spatial_res\", type=int, default=32)\n    parser.add_argument(\"--zero_init\", type=bool, default=True)\n\n    parser.add_argument(\"--entropy_cls\", type=int, default=-1)\n    parser.add_argument(\"--entropy_reg\", type=float, default=1e-2)\n\n    parser.add_argument(\"--num_frames\", type=str, default=14)\n\n    parser.add_argument(\"--grid_size\", type=int, default=64)\n    parser.add_argument(\"--sim_res\", type=int, default=8)\n    parser.add_argument(\"--sim_output_dim\", type=int, default=1)\n    parser.add_argument(\"--substep\", type=int, default=768)\n    parser.add_argument(\"--loss_decay\", type=float, default=1.0)\n    parser.add_argument(\"--start_window_size\", type=int, default=6)\n    parser.add_argument(\"--compute_window\", type=int, default=1)\n    parser.add_argument(\"--grad_window\", type=int, default=14)\n    # -1 means no gradient checkpointing\n    parser.add_argument(\"--checkpoint_steps\", type=int, default=-1)\n    parser.add_argument(\"--stride\", type=int, default=1)\n\n    parser.add_argument(\"--downsample_scale\", type=float, default=0.04)\n    parser.add_argument(\"--top_k\", type=int, default=8)\n\n    # loss parameters\n    parser.add_argument(\"--tv_loss_weight\", type=float, default=1e-4)\n    parser.add_argument(\"--ssim\", type=float, default=0.9)\n\n    # Logging and checkpointing\n    parser.add_argument(\"--output_dir\", type=str, default=\"../../output/inverse_sim\")\n    parser.add_argument(\"--log_iters\", type=int, default=10)\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\n        \"--checkpoint_path\",\n        type=str,\n        # psnr 29.0\n        default=\"../../output/inverse_sim/fast_alocasia_velopretrain_cleandecay_1.0_substep_96_se3_field_lr_0.01_tv_0.01_iters_300_sw_2_cw_2/seed0/checkpoint_model_000299\",\n        help=\"path to load velocity pretrain checkpoint from\",\n    )\n    # training parameters\n    parser.add_argument(\"--train_iters\", type=int, default=200)\n    parser.add_argument(\"--batch_size\", type=int, default=1)\n    parser.add_argument(\"--lr\", type=float, default=1e-3)\n    parser.add_argument(\"--max_grad_norm\", type=float, default=1.0)\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n    )\n\n    # wandb parameters\n    parser.add_argument(\"--use_wandb\", action=\"store_true\", default=False)\n    parser.add_argument(\"--wandb_entity\", type=str, default=\"mit-cv\")\n    parser.add_argument(\"--wandb_project\", type=str, default=\"inverse_sim\")\n    parser.add_argument(\"--wandb_iters\", type=int, default=10)\n    parser.add_argument(\"--wandb_name\", type=str, required=True)\n    parser.add_argument(\"--run_eval\", action=\"store_true\", default=False)\n    parser.add_argument(\"--load_sim\", action=\"store_true\", default=False)\n    parser.add_argument(\"--test_convergence\", action=\"store_true\", default=False)\n    parser.add_argument(\"--update_velo\", action=\"store_true\", default=False)\n    parser.add_argument(\"--eval_iters\", type=int, default=8)\n    parser.add_argument(\"--eval_ys\", type=float, default=1e6)\n    parser.add_argument(\"--demo_name\", type=str, default=\"demo_3sec_sv_gres48_lr1e-2\")\n    parser.add_argument(\"--velo_scaling\", type=float, default=5.0)\n\n    # distributed training args\n    parser.add_argument(\n        \"--local_rank\",\n        type=int,\n        default=-1,\n        help=\"For distributed training: local_rank\",\n    )\n\n    args, extra_args = parser.parse_known_args()\n    cfg = create_config(args.config, args, extra_args)\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n    print(args.local_rank, \"local rank\")\n\n    return cfg\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n\n    # torch.backends.cuda.matmul.allow_tf32 = True\n\n    trainer = Trainer(args)\n\n    if args.run_eval:\n        trainer.demo(\n            velo_scaling=args.velo_scaling,\n            eval_ys=args.eval_ys,\n            save_name=args.demo_name,\n        )\n    else:\n        # trainer.debug()\n        trainer.train()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/_convert_fbx_to_mesh.py",
    "content": "import bpy\nimport numpy as np\nimport sys\nimport point_cloud_utils as pcu\nimport os\n\n\ndef convert(fbx_path):\n    bpy.ops.wm.read_factory_settings(use_empty=True)\n    # 1. Import the FBX file\n    bpy.ops.import_scene.fbx(filepath=fbx_path)\n    print(\"loaded fbx from: \", fbx_path)\n\n    # Assuming the FBX file has one main mesh object, get it\n    mesh_obj = bpy.context.selected_objects[1]\n\n    # 2. Duplicate the mesh for the first frame\n    bpy.context.view_layer.objects.active = mesh_obj\n    mesh_obj.select_set(True)\n    bpy.ops.object.duplicate()\n    static_mesh = bpy.context.object\n\n    # Apply the first frame's pose to the static mesh\n    bpy.context.scene.frame_set(1)\n    bpy.ops.object.modifier_apply({\"object\": static_mesh}, modifier=\"Armature\")\n\n    # 3. Calculate and store vertex offsets for each subsequent frame\n    vertex_traj_list = []\n    num_frames = bpy.context.scene.frame_end\n\n    for frame in range(1, num_frames + 1):\n        bpy.context.scene.frame_set(frame)\n\n        bpy.context.view_layer.update()\n        # Update the mesh to the current frame's pose\n        # mesh_obj.data.update()\n\n        all_pts_3d = []\n        for v1, v2 in zip(static_mesh.data.vertices, mesh_obj.data.vertices):\n            pts_3d = v2.co\n            all_pts_3d.append(pts_3d)\n\n        vertex_traj_list.append(np.array(all_pts_3d))\n\n    vertex_traj_list = np.stack(vertex_traj_list, axis=0)\n\n    # Now, frame_offsets contains the vertex offsets for each frame\n    vertex_array = vertex_traj_list[0]  # first frame\n\n    # get face indx\n    bpy.context.view_layer.objects.active = static_mesh\n    bpy.ops.object.mode_set(mode=\"EDIT\")\n    bpy.ops.mesh.select_all(action=\"SELECT\")\n    bpy.ops.mesh.quads_convert_to_tris(quad_method=\"BEAUTY\", ngon_method=\"BEAUTY\")\n    bpy.ops.object.mode_set(mode=\"OBJECT\")\n    faces_list = [list(face.vertices) for face in static_mesh.data.polygons]\n\n    faces_array = np.array(faces_list, dtype=np.int32)\n    vertices = np.array([v.co for v in static_mesh.data.vertices])\n    print(\"vertices shape: \", vertices.shape)\n\n    print(\n        \"num_frames: \",\n        num_frames,\n        \"offsets shape\",\n        vertex_traj_list.shape,\n        \"num_faces\",\n        faces_array.shape,\n        \"max offset\",\n        np.max(vertex_traj_list - vertex_array[np.newaxis, :, :]),\n        np.min(vertex_traj_list - vertex_array[np.newaxis, :, :]),\n    )\n\n    mean = np.mean(vertices, axis=0)\n    max_range = np.max(np.max(vertices, axis=0) - np.min(vertices, axis=0))\n    print(\"max_range: \", max_range, \"mean: \", mean)\n\n    # normalize\n    # vertex_array = (vertex_array - mean[np.newaxis, :]) / max_range\n    # vertex_traj_list = (vertex_traj_list - mean[np.newaxis, np.newaxis, :]) / max_range\n\n    return faces_array, vertex_array, vertex_traj_list\n\n\ndef convert2(fbx_path):\n    bpy.ops.import_scene.fbx(filepath=fbx_path)\n\n    # Assuming the imported object is the active object\n    obj = bpy.context.active_object\n    for obj in bpy.context.selected_objects:\n        print(\"obj: \", obj.name, obj.type)\n    mesh_objects = [obj for obj in bpy.context.selected_objects if obj.type == \"MESH\"]\n\n    # Ensure it's in object mode\n    bpy.ops.object.mode_set(mode=\"OBJECT\")\n\n    # Get the total number of frames in the scene\n    start_frame = bpy.context.scene.frame_start\n    end_frame = bpy.context.scene.frame_end\n\n    # Create a dictionary to store vertex positions for each frame\n    vertex_data_list = []\n\n    # Get the dependency graph\n    depsgraph = bpy.context.evaluated_depsgraph_get()\n\n    # Iterate over each frame\n    for frame in range(start_frame, end_frame + 1):\n        bpy.context.scene.frame_set(frame)\n\n        # Update the scene to reflect changes\n        bpy.context.view_layer.update()\n\n        ret_list = []\n        for mesh_obj in mesh_objects:\n            # deformed_mesh = mesh_obj.to_mesh()\n            # Extract vertex positions for the current frame\n            # vertex_positions = [vertex.co for vertex in deformed_mesh.vertices]\n            # vertex_positions = [vertex.co.copy() for vertex in deformed_mesh.vertices]\n\n            duplicated_obj = mesh_obj.copy()\n            duplicated_obj.data = mesh_obj.data.copy()\n            bpy.context.collection.objects.link(duplicated_obj)\n\n            # Make the duplicated object the active object\n            bpy.context.view_layer.objects.active = duplicated_obj\n            duplicated_obj.select_set(True)\n            print(\"duplicated_obj.modifiers\", duplicated_obj.modifiers)\n\n            for mod in duplicated_obj.modifiers:\n                bpy.ops.object.modifier_apply(\n                    {\"object\": duplicated_obj}, modifier=mod.name\n                )\n\n            # if \"Armature\" in duplicated_obj.modifiers:\n            #     bpy.ops.object.modifier_apply(\n            #         {\"object\": duplicated_obj}, modifier=\"Armature\"\n            #     )\n\n            # Extract vertex positions from the duplicated object\n            vertex_positions = [\n                vertex.co.copy() for vertex in duplicated_obj.data.vertices\n            ]\n\n            ret_list += vertex_positions\n\n        # Convert to numpy array and store in the dictionary\n        vertex_data_list.append(np.array(ret_list))\n\n    vertex_traj_list = np.stack(vertex_data_list, axis=0)\n    print(\n        \"offsets shape\",\n        vertex_traj_list.shape,\n        \"max offset\",\n        np.max(vertex_traj_list - vertex_traj_list[0:1, :, :]),\n        np.min(vertex_traj_list - vertex_traj_list[0:1, :, :]),\n    )\n\n    # bpy.ops.object.mode_set(mode=\"EDIT\")\n    # bpy.ops.mesh.select_all(action=\"SELECT\")\n    # bpy.ops.mesh.quads_convert_to_tris(quad_method=\"BEAUTY\", ngon_method=\"BEAUTY\")\n    # bpy.ops.object.mode_set(mode=\"OBJECT\")\n    if bpy.context.active_object.type == \"MESH\":\n        obj = bpy.context.active_object\n\n        # Set the mode to 'EDIT'\n        bpy.ops.object.mode_set(mode=\"EDIT\")\n\n        # Ensure the mesh is the active object and is in edit mode\n        if bpy.context.mode == \"EDIT_MESH\" and bpy.context.object == obj:\n            bpy.ops.mesh.select_all(action=\"SELECT\")\n            bpy.ops.mesh.quads_convert_to_tris(\n                quad_method=\"BEAUTY\", ngon_method=\"BEAUTY\"\n            )\n            bpy.ops.object.mode_set(mode=\"OBJECT\")\n        else:\n            print(\"Failed to set the correct context.\")\n    else:\n        print(\"Active object is not a mesh.\")\n    faces_list = []\n    for mesh_obj in mesh_objects:\n        _fl = [list(face.vertices) for face in obj.data.polygons]\n        faces_list += _fl\n\n    faces_array = np.array(faces_list, dtype=np.int32)\n    vertex_array = vertex_traj_list[0]  # first frame\n    print(\"face shape\", faces_array.shape)\n    return faces_array, vertex_array, vertex_traj_list\n\n\ndef main():\n    argv = sys.argv\n    argv = argv[argv.index(\"--\") + 1 :]  # get all args after \"--\"\n    print(argv)\n    fbx_path = argv[0]  # input mesh path\n    output_dir = argv[1]  # output dir\n    # num_frames = int(argv[2])\n\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    faces_array, vertex_array, vertex_traj_array = convert2(fbx_path)\n\n    save_mesh_path = os.path.join(output_dir, \"mesh0.obj\")\n    pcu.save_mesh_vf(save_mesh_path, vertex_array, faces_array)\n\n    save_traj_path = os.path.join(output_dir, \"traj.npy\")\n    np.save(save_traj_path, vertex_traj_array)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/blender_deforming_things4d.py",
    "content": "import sys\nimport numpy\nimport os\n\nimport PIL\nimport mathutils\n\n\ndef anime_read(filename):\n    \"\"\"\n    filename: path of .anime file\n    return:\n        nf: number of frames in the animation\n        nv: number of vertices in the mesh (mesh topology fixed through frames)\n        nt: number of triangle face in the mesh\n        vert_data: vertice data of the 1st frame (3D positions in x-y-z-order)\n        face_data: riangle face data of the 1st frame\n        offset_data: 3D offset data from the 2nd to the last frame\n    \"\"\"\n    f = open(filename, \"rb\")\n    nf = np.fromfile(f, dtype=np.int32, count=1)[0]\n    nv = np.fromfile(f, dtype=np.int32, count=1)[0]\n    nt = np.fromfile(f, dtype=np.int32, count=1)[0]\n    vert_data = np.fromfile(f, dtype=np.float32, count=nv * 3)\n    face_data = np.fromfile(f, dtype=np.int32, count=nt * 3)\n    offset_data = np.fromfile(f, dtype=np.float32, count=-1)\n    \"\"\"check data consistency\"\"\"\n    if len(offset_data) != (nf - 1) * nv * 3:\n        raise (\"data inconsistent error!\", filename)\n    vert_data = vert_data.reshape((-1, 3))\n    face_data = face_data.reshape((-1, 3))\n    offset_data = offset_data.reshape((nf - 1, nv, 3))\n    return nf, nv, nt, vert_data, face_data, offset_data\n\n\ndef \n\n\nargv = sys.argv\nargv = argv[argv.index(\"--\") + 1 :]  # get all args after \"--\"\nprint(argv)  # ['arg1', 'arg2', 'arg3']\n\n\n\n# for package install\n# import site\n# import pip\n# pip.main([\"install\", \"Pillow\", \"--target\", site.USER_SITE])\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/blender_install_packages.py",
    "content": "import site\nimport pip\n\n# pip.main([\"install\", \"point-cloud-utils\", \"--target\", site.USER_SITE])\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/blender_render_imgs.py",
    "content": "import bpy\nimport os\nimport numpy as np\nimport math\nimport sys\nimport struct\nimport collections\nfrom mathutils import Matrix, Quaternion\nfrom scipy.spatial.transform import Rotation\n\n\ndef focal2fov(focal, pixels):\n    return 2 * math.atan(pixels / (2 * focal))\n\n\ndef create_camera(location, rotation):\n    # Create a new camera\n    bpy.ops.object.camera_add(location=location, rotation=rotation)\n    return bpy.context.active_object\n\n\ndef set_camera_look_at(camera, target_point):\n    # Compute the direction vector from the camera to the target point\n    direction = target_point - camera.location\n    # Compute the rotation matrix to align the camera's -Z axis to this direction\n    rot_quat = direction.to_track_quat(\"-Z\", \"Y\")\n    camera.rotation_euler = rot_quat.to_euler()\n\n    return rot_quat\n\n\ndef setup_alpha_mask(obj_name, pass_index=1):\n    # Set the object's pass index\n    obj = bpy.data.objects[obj_name]\n    obj.pass_index = pass_index\n\n    # Enable the Object Index pass for the active render layer\n    bpy.context.view_layer.use_pass_object_index = True\n\n    # Enable 'Use Nodes':\n    bpy.context.scene.use_nodes = True\n    tree = bpy.context.scene.node_tree\n\n    # Clear default nodes\n    for node in tree.nodes:\n        tree.nodes.remove(node)\n\n    # Add Render Layers node\n    render_layers = tree.nodes.new(\"CompositorNodeRLayers\")\n\n    # Add Composite node (output)\n    composite = tree.nodes.new(\"CompositorNodeComposite\")\n\n    # Add ID Mask node\n    id_mask = tree.nodes.new(\"CompositorNodeIDMask\")\n    id_mask.index = pass_index\n\n    # Add Set Alpha node\n    set_alpha = tree.nodes.new(\"CompositorNodeSetAlpha\")\n\n    # Connect nodes\n    tree.links.new(render_layers.outputs[\"Image\"], set_alpha.inputs[\"Image\"])\n    tree.links.new(render_layers.outputs[\"IndexOB\"], id_mask.inputs[0])\n    tree.links.new(id_mask.outputs[0], set_alpha.inputs[\"Alpha\"])\n    tree.links.new(set_alpha.outputs[\"Image\"], composite.inputs[\"Image\"])\n\n\ndef render_scene(camera, output_path):\n    bpy.context.scene.render.film_transparent = True\n\n    setup_alpha_mask(\"MyMeshObject\", 1)\n    # Set the active camera\n    bpy.context.scene.render.image_settings.color_mode = \"RGBA\"\n\n    bpy.context.scene.camera = camera\n\n    # Set the output path for the render\n    bpy.context.scene.render.filepath = output_path\n\n    # Render the scene\n    bpy.ops.render.render(write_still=True)\n\n\ndef setup_light():\n    # Add first directional light (Sun lamp)\n    light_data_1 = bpy.data.lights.new(name=\"Directional_Light_1\", type=\"SUN\")\n    light_data_1.energy = 3  # Adjust energy as needed\n    light_1 = bpy.data.objects.new(name=\"Directional_Light_1\", object_data=light_data_1)\n    bpy.context.collection.objects.link(light_1)\n    light_1.location = (10, 10, 10)  # Adjust location as needed\n    light_1.rotation_euler = (\n        np.radians(45),\n        np.radians(0),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n    # Add second directional light (Sun lamp)\n    light_data_2 = bpy.data.lights.new(name=\"Directional_Light_2\", type=\"SUN\")\n    light_data_2.energy = 5  # Adjust energy as needed\n    light_2 = bpy.data.objects.new(name=\"Directional_Light_2\", object_data=light_data_2)\n    bpy.context.collection.objects.link(light_2)\n    light_2.location = (10, -10, 10)  # Adjust location as needed\n    light_2.rotation_euler = (\n        np.radians(45),\n        np.radians(180),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n\ndef create_mesh_from_data(vertices, faces):\n    # Clear existing mesh objects in the scene\n    bpy.ops.object.select_all(action=\"DESELECT\")\n    bpy.ops.object.select_by_type(type=\"MESH\")\n    bpy.ops.object.delete()\n\n    vertices_list = vertices.tolist()\n    faces_list = faces.tolist()\n\n    # Create a new mesh\n    mesh_name = \"MyMesh\"\n    mesh = bpy.data.meshes.new(name=mesh_name)\n    obj = bpy.data.objects.new(\"MyMeshObject\", mesh)\n\n    # Link it to the scene\n    bpy.context.collection.objects.link(obj)\n    bpy.context.view_layer.objects.active = obj\n    obj.select_set(True)\n\n    # Load the mesh data\n    mesh.from_pydata(vertices_list, [], faces_list)\n    mesh.update()\n\n    # mesh_data = bpy.data.meshes.new(mesh_name)\n    # mesh_data.from_pydata(vertices_list, [], faces_list)\n    # mesh_data.update()\n    # the_mesh = bpy.data.objects.new(mesh_name, mesh_data)\n    # the_mesh.data.vertex_colors.new()  # init color\n    # bpy.context.collection.objects.link(the_mesh)\n\n    # UV unwrap the mesh\n    bpy.ops.object.select_all(action=\"DESELECT\")\n    obj.select_set(True)\n    bpy.context.view_layer.objects.active = obj\n    bpy.ops.object.mode_set(mode=\"EDIT\")\n    bpy.ops.mesh.select_all(action=\"SELECT\")\n    bpy.ops.uv.smart_project()\n    bpy.ops.object.mode_set(mode=\"OBJECT\")\n\n    # Texture the mesh based on its normals\n    mat = bpy.data.materials.new(name=\"NormalMaterial\")\n    mat.use_nodes = True\n    bsdf = mat.node_tree.nodes[\"Principled BSDF\"]\n    normal_node = mat.node_tree.nodes.new(type=\"ShaderNodeNormal\")\n    geometry = mat.node_tree.nodes.new(type=\"ShaderNodeNewGeometry\")\n\n    # mat.node_tree.links.new(geometry.outputs[\"Normal\"], normal_node.inputs[\"Normal\"])\n    # mat.node_tree.links.new(normal_node.outputs[\"Dot\"], bsdf.inputs[\"Base Color\"])\n    mat.node_tree.links.new(geometry.outputs[\"Normal\"], bsdf.inputs[\"Base Color\"])\n\n    obj.data.materials.append(mat)\n\n    return None\n\n\nCameraModel = collections.namedtuple(\n    \"CameraModel\", [\"model_id\", \"model_name\", \"num_params\"]\n)\nCamera = collections.namedtuple(\"Camera\", [\"id\", \"model\", \"width\", \"height\", \"params\"])\nBaseImage = collections.namedtuple(\n    \"Image\", [\"id\", \"qvec\", \"tvec\", \"camera_id\", \"name\", \"xys\", \"point3D_ids\"]\n)\nPoint3D = collections.namedtuple(\n    \"Point3D\", [\"id\", \"xyz\", \"rgb\", \"error\", \"image_ids\", \"point2D_idxs\"]\n)\n\n\nCAMERA_MODELS = {\n    CameraModel(model_id=0, model_name=\"SIMPLE_PINHOLE\", num_params=3),\n    CameraModel(model_id=1, model_name=\"PINHOLE\", num_params=4),\n    CameraModel(model_id=2, model_name=\"SIMPLE_RADIAL\", num_params=4),\n    CameraModel(model_id=3, model_name=\"RADIAL\", num_params=5),\n    CameraModel(model_id=4, model_name=\"OPENCV\", num_params=8),\n    CameraModel(model_id=5, model_name=\"OPENCV_FISHEYE\", num_params=8),\n    CameraModel(model_id=6, model_name=\"FULL_OPENCV\", num_params=12),\n    CameraModel(model_id=7, model_name=\"FOV\", num_params=5),\n    CameraModel(model_id=8, model_name=\"SIMPLE_RADIAL_FISHEYE\", num_params=4),\n    CameraModel(model_id=9, model_name=\"RADIAL_FISHEYE\", num_params=5),\n    CameraModel(model_id=10, model_name=\"THIN_PRISM_FISHEYE\", num_params=12),\n}\nCAMERA_MODEL_IDS = dict(\n    [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]\n)\nCAMERA_MODEL_NAMES = dict(\n    [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]\n)\n\n\ndef write_next_bytes(fid, data, format_char_sequence, endian_character=\"<\"):\n    \"\"\"pack and write to a binary file.\n    :param fid:\n    :param data: data to send, if multiple elements are sent at the same time,\n    they should be encapsuled either in a list or a tuple\n    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.\n    should be the same length as the data list or tuple\n    :param endian_character: Any of {@, =, <, >, !}\n    \"\"\"\n    if isinstance(data, (list, tuple)):\n        bytes = struct.pack(endian_character + format_char_sequence, *data)\n    else:\n        bytes = struct.pack(endian_character + format_char_sequence, data)\n    fid.write(bytes)\n\n\ndef write_cameras_binary(cameras, path_to_model_file):\n    \"\"\"\n    see: src/colmap/scene/reconstruction.cc\n        void Reconstruction::WriteCamerasBinary(const std::string& path)\n        void Reconstruction::ReadCamerasBinary(const std::string& path)\n    \"\"\"\n    with open(path_to_model_file, \"wb\") as fid:\n        write_next_bytes(fid, len(cameras), \"Q\")\n        for _, cam in cameras.items():\n            model_id = CAMERA_MODEL_NAMES[cam.model].model_id\n            camera_properties = [cam.id, model_id, cam.width, cam.height]\n            write_next_bytes(fid, camera_properties, \"iiQQ\")\n            for p in cam.params:\n                write_next_bytes(fid, float(p), \"d\")\n    return cameras\n\n\ndef write_images_binary(images, path_to_model_file):\n    \"\"\"\n    see: src/colmap/scene/reconstruction.cc\n        void Reconstruction::ReadImagesBinary(const std::string& path)\n        void Reconstruction::WriteImagesBinary(const std::string& path)\n    \"\"\"\n    with open(path_to_model_file, \"wb\") as fid:\n        write_next_bytes(fid, len(images), \"Q\")\n        for _, img in images.items():\n            write_next_bytes(fid, img.id, \"i\")\n            write_next_bytes(fid, img.qvec.tolist(), \"dddd\")\n            write_next_bytes(fid, img.tvec.tolist(), \"ddd\")\n            write_next_bytes(fid, img.camera_id, \"i\")\n            for char in img.name:\n                write_next_bytes(fid, char.encode(\"utf-8\"), \"c\")\n            write_next_bytes(fid, b\"\\x00\", \"c\")\n            write_next_bytes(fid, len(img.point3D_ids), \"Q\")\n            for xy, p3d_id in zip(img.xys, img.point3D_ids):\n                write_next_bytes(fid, [*xy, p3d_id], \"ddq\")\n\n\ndef write_points3D_binary(points3D, path_to_model_file):\n    \"\"\"\n    see: src/colmap/scene/reconstruction.cc\n        void Reconstruction::ReadPoints3DBinary(const std::string& path)\n        void Reconstruction::WritePoints3DBinary(const std::string& path)\n    \"\"\"\n    with open(path_to_model_file, \"wb\") as fid:\n        write_next_bytes(fid, len(points3D), \"Q\")\n        for _, pt in points3D.items():\n            write_next_bytes(fid, pt.id, \"Q\")\n            write_next_bytes(fid, pt.xyz.tolist(), \"ddd\")\n            write_next_bytes(fid, pt.rgb.tolist(), \"BBB\")\n            write_next_bytes(fid, pt.error, \"d\")\n            track_length = pt.image_ids.shape[0]\n            write_next_bytes(fid, track_length, \"Q\")\n            for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):\n                write_next_bytes(fid, [image_id, point2D_id], \"ii\")\n\n\ndef get_colmap_camera(camera_obj, render_resolution):\n    \"\"\"\n    Extract the intrinsic matrix from a Blender camera.\n\n    Args:\n    - camera_obj: The Blender camera object.\n    - render_resolution: Tuple of (width, height) indicating the render resolution.\n\n    Returns:\n    - colmap_camera: dict of [\"id\", \"model\", \"width\", \"height\", \"params\"]\n    \"\"\"\n\n    # Get the camera data\n    cam = camera_obj.data\n\n    # Ensure it's a perspective camera\n    if cam.type != \"PERSP\":\n        raise ValueError(\"Only 'PERSP' camera type is supported.\")\n\n    # Image resolution\n    width, height = render_resolution\n\n    # Sensor width and height in millimeters\n    sensor_width_mm = cam.sensor_width\n    sensor_height_mm = cam.sensor_height\n\n    # Calculate the focal length in pixels\n    fx = (cam.lens / sensor_width_mm) * width\n    fy = (cam.lens / sensor_height_mm) * height\n\n    # Principal point, usually at the center of the image\n    cx = width / 2.0\n    cy = height / 2.0\n\n    _cam_dict = {\n        \"id\": 0,\n        \"model\": \"PINHOLE\",  # PINHOLE\n        \"width\": width,\n        \"height\": height,\n        \"params\": [fx, fy, cx, cy],\n    }\n\n    colmap_cameras = {0: Camera(**_cam_dict)}\n\n    print(\"focal\", fx, fy, cx, cy)\n\n    return colmap_cameras\n\n\ndef main():\n    import point_cloud_utils as pcu\n\n    argv = sys.argv\n    argv = argv[argv.index(\"--\") + 1 :]  # get all args after \"--\"\n    print(argv)\n    inp_mesh_path = argv[0]  # input mesh path\n    output_dir = argv[1]  # output dir\n    # num_frames = int(argv[2])\n\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    img_output_dir = os.path.join(output_dir, \"images\")\n    if not os.path.exists(img_output_dir):\n        os.makedirs(img_output_dir)\n\n    vertices, faces = pcu.load_mesh_vf(inp_mesh_path)\n\n    # normalize\n    verices_center = np.mean(vertices, axis=0)\n    max_range = np.max(np.max(vertices, axis=0) - np.min(vertices, axis=0))\n    print(\n        max_range.shape, max_range, verices_center.shape, verices_center, vertices.shape\n    )\n    vertices = (vertices - verices_center[np.newaxis, :]) / max_range\n\n    # Create the 3D mesh in Blender from your data.\n    obj = create_mesh_from_data(vertices, faces)\n\n    object_center = bpy.context.scene.objects[\"MyMeshObject\"].location\n\n    # Number of viewpoints\n    num_views = 180  # 180\n    radius = 6  # Distance of the camera from the object center\n\n    setup_light()\n    # Set up rendering parameters\n    bpy.context.scene.render.image_settings.file_format = \"PNG\"\n    bpy.context.scene.render.resolution_x = 1080\n    bpy.context.scene.render.resolution_y = 720\n\n    camera = create_camera((1, 1, 1), (0, 0, 0))\n    colmap_camera_dict = get_colmap_camera(\n        camera,\n        (bpy.context.scene.render.resolution_x, bpy.context.scene.render.resolution_y),\n    )\n\n    transform_dict = {\n        \"frames\": [],\n        \"camera_angle_x\": focal2fov(\n            colmap_camera_dict[0].params[0], colmap_camera_dict[0].width\n        ),\n    }\n    img_indx = 0\n    num_elevations = 6\n    colmap_images_dict = {}\n    for j in range(num_elevations):\n        num_imgs = num_views // num_elevations\n        for i in range(num_imgs):\n            angle = 2 * math.pi * i / num_imgs\n            x = object_center.x + radius * math.cos(angle)\n            y = object_center.y + radius * math.sin(angle)\n            z = (\n                object_center.z + (j - num_elevations / 3.0) * 4.0 / num_elevations\n            )  # Adjust this if you want the camera to be above or below the object's center\n\n            camera = create_camera((x, y, z), (0, 0, 0))\n            rot_quant = set_camera_look_at(camera, object_center)\n            tvec = np.array([x, y, z])\n            bpy.context.view_layer.update()\n\n            # plan-1\n            # w2c = np.array(camera.matrix_world.inverted())\n            # w2c[1:3, :] *= -1.0\n            # rotation_matrix = w2c[:3, :3]\n            # tvec = w2c[:3, 3]\n            # plan-1 end\n\n            # plan-2\n            camera_to_world_matrix = camera.matrix_world\n            # [4, 4]\n            camera_to_world_matrix = np.array(camera_to_world_matrix).copy()\n            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)\n            camera_to_world_matrix[:3, 1:3] *= -1.0\n            w2c = np.linalg.inv(camera_to_world_matrix)\n            rotation_matrix = w2c[:3, :3]\n            tvec = w2c[:3, 3]\n\n            # c2w rotation\n            # rotation_matrix = rot_quant.to_matrix()  # .to_4x4()\n            # # w2c rotation\n            # rotation_matrix = np.array(rotation_matrix)\n            # # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)\n            # rotation_matrix[:3, 1:3] *= -1.0\n            # rotation_matrix = rotation_matrix.transpose()\n            # tvec = (rotation_matrix @ tvec[:, np.newaxis]).squeeze(axis=-1) * -1.0\n\n            rot_quant = Rotation.from_matrix(rotation_matrix).as_quat()\n            # print(\"r shape\", rotation_matrix.shape, tvec.shape)\n\n            img_dict = {\n                \"id\": img_indx,\n                \"qvec\": rot_quant,\n                \"tvec\": tvec,\n                \"camera_id\": 0,\n                \"name\": f\"img_{img_indx}.png\",\n                \"xys\": [[k, k] for k in range(i, i + 10)],  # placeholder\n                \"point3D_ids\": list(range(i, i + 10)),  # placeholder\n            }\n            colmap_images_dict[img_indx] = BaseImage(**img_dict)\n\n            # also prepare transforms.json\n            fname = f\"images/img_{img_indx}\"\n            cam2world = np.array(camera.matrix_world)\n            transform_dict[\"frames\"].append(\n                {\"file_path\": fname, \"transform_matrix\": cam2world.tolist()}\n            )\n\n            # render_scene(camera, os.path.join(img_output_dir, f\"img_{img_indx}.png\"))\n            img_indx += 1\n\n    # sample 3D points\n    num_3d_points = 10000\n\n    # sample 3D points\n    sampled_points = vertices[np.random.choice(vertices.shape[0], num_3d_points), :]\n\n    print(\n        \"samping {} points out of {} vertices\".format(num_3d_points, vertices.shape[0])\n    )\n\n    # save into ply\n    pcu.save_mesh_v(\n        os.path.join(output_dir, \"sampled_point_cloud.ply\"),\n        sampled_points,\n    )\n\n    # format into colmap points format\n    colmap_points_dict = {}\n    for i in range(num_3d_points):\n        pnt_dict = {\n            \"id\": i,\n            \"xyz\": sampled_points[i, :],\n            \"rgb\": np.array([100, 100, 100]),  # place holder , need integers\n            \"error\": 0.0,  # place holder\n            \"image_ids\": np.array(list(range(i, i + 10))),  # placeholder\n            \"point2D_idxs\": np.array(list(range(i, i + 10))),  # placeholder\n        }\n\n        colmap_points_dict[i] = Point3D(**pnt_dict)\n\n    trans_fpath = os.path.join(output_dir, \"transforms_train.json\")\n    import json\n\n    with open(trans_fpath, \"w\") as f:\n        json.dump(transform_dict, f)\n\n    output_dir = os.path.join(output_dir, \"sparse/0\")\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n    # Write to binary\n    write_cameras_binary(colmap_camera_dict, os.path.join(output_dir, \"cameras.bin\"))\n    write_images_binary(colmap_images_dict, os.path.join(output_dir, \"images.bin\"))\n    write_points3D_binary(colmap_points_dict, os.path.join(output_dir, \"points3D.bin\"))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/deforming_things4d.py",
    "content": "import os\nimport numpy as np\n\n\ndef anime_read(filename):\n    \"\"\"\n    filename: path of .anime file\n    return:\n        nf: number of frames in the animation\n        nv: number of vertices in the mesh (mesh topology fixed through frames)\n        nt: number of triangle face in the mesh\n        vert_data: vertice data of the 1st frame (3D positions in x-y-z-order). [nv, 3]\n        face_data: riangle face data of the 1st frame. [nt, 3] dtype = int32\n        offset_data: 3D offset data from the 2nd to the last frame. [nf, nv, 3]\n    \"\"\"\n    f = open(filename, \"rb\")\n    nf = np.fromfile(f, dtype=np.int32, count=1)[0]\n    nv = np.fromfile(f, dtype=np.int32, count=1)[0]\n    nt = np.fromfile(f, dtype=np.int32, count=1)[0]\n    vert_data = np.fromfile(f, dtype=np.float32, count=nv * 3)\n    face_data = np.fromfile(f, dtype=np.int32, count=nt * 3)\n    offset_data = np.fromfile(f, dtype=np.float32, count=-1)\n    \"\"\"check data consistency\"\"\"\n    if len(offset_data) != (nf - 1) * nv * 3:\n        raise (\"data inconsistent error!\", filename)\n    vert_data = vert_data.reshape((-1, 3))\n    face_data = face_data.reshape((-1, 3))\n    offset_data = offset_data.reshape((nf - 1, nv, 3))\n    return nf, nv, nt, vert_data, face_data, offset_data\n\n\ndef extract_trajectory(\n    trajectory_array: np.ndarray,\n    topk_freq: int = 8,\n):\n    \"\"\"\n    Args:\n        trajectory_array: [nf, nv, 3]. The 3D position of each point in each frame.\n        topk_freq: int. FFT frequency.\n    \"\"\"\n\n    # doing fft on trajectory_array\n    # [nf, nv, 3]\n    trajectory_fft = np.fft.fft(trajectory_array, axis=0)\n    # only keep topk_freq\n    # [topk_freq, nv, 3]\n    trajectory_fft = trajectory_fft[:topk_freq, :, :]\n    trajectory_fft[topk_freq:-topk_freq, :, :] = 0\n\n    # doing ifft on trajectory_fft\n    # [nf, nv, 3]\n    trajectory_array = np.fft.ifft(trajectory_fft, axis=0).real\n\n\ndef main():\n    import argparse\n    import point_cloud_utils as pcu\n\n    parser = argparse.ArgumentParser(description=\"None description\")\n\n    parser.add_argument(\"--input\", type=str, help=\"input path\")\n    parser.add_argument(\"--output_dir\", type=str, help=\"output path\")\n    parser.add_argument(\n        \"--skip\",\n        type=int,\n        default=-1,\n        help=\"skipping between frame saving. -1 indicates only save first frame\",\n    )\n\n    args = parser.parse_args()\n\n    if not os.path.exists(args.output_dir):\n        os.makedirs(args.output_dir)\n\n    inp_ani_path = args.input\n\n    nf, nv, nt, vert_data, face_data, offset_data = anime_read(inp_ani_path)\n    #  face_data:  offset_data [nf, nv, 3]\n\n    # normalize\n    verices_center = np.mean(vert_data, axis=0)\n    max_range = np.max(np.max(vert_data, axis=0) - np.min(vert_data, axis=0))\n    print(\n        max_range.shape,\n        max_range,\n        verices_center.shape,\n        verices_center,\n        vert_data.shape,\n    )\n    vert_data = (vert_data - verices_center[np.newaxis, :]) / max_range\n    offset_data = offset_data / max_range\n\n    # save trajectory as numpy array\n\n    # [nf, nv, 3]\n    trajectory_array = offset_data + vert_data[None, :, :]\n    trajectory_array = np.concatenate([vert_data[None, :, :], trajectory_array], axis=0)\n    out_traj_path = os.path.join(args.output_dir, \"trajectory.npy\")\n    print(\"trajectory array of shape [nf, nv, 3]. key: data\", trajectory_array.shape)\n    save_dict = {\n        \"help\": \"trajectory array of shape [nf, nv, 3]. key: data\",\n        \"data\": trajectory_array,\n    }\n\n    # np.savez(out_traj_path, save_dict)\n    # np.save(out_traj_path, trajectory_array)\n\n    if args.skip == -1:\n        # save mesh as .obj\n        out_obj_path = os.path.join(args.output_dir, \"mesh0.obj\")\n        pcu.save_mesh_vf(out_obj_path, vert_data, face_data)\n\n        return\n\n    for i in range(nf):\n        if i % args.skip != 0:\n            continue\n        out_obj_path = os.path.join(args.output_dir, \"mesh{}.obj\".format(i))\n        pcu.save_mesh_vf(out_obj_path, trajectory_array[i], face_data)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/dragon_animation.py",
    "content": "import bpy\n\n# Clear existing data\nbpy.ops.wm.read_factory_settings(use_empty=True)\n\n# 1. Import the FBX file\nfbx_path = \"/local/cg/rundi/data/New-FBX-BVH_Z-OO/Truebone_Z-OO/Dragon/Wyvern-Fly.fbx\"\n# fbx_path = \"../../../data/motion_dataset/pirate-flag-animated/source/pirate_flag.fbx\"\nbpy.ops.import_scene.fbx(filepath=fbx_path)\n\n\n# 2. Set up the camera and lighting (assuming they aren't in the FBX)\n# Add a camera\nbpy.ops.object.camera_add(location=(0, -14, 7))\ncamera = bpy.context.object\ncamera.rotation_euler = (1.5708, 0, 0)  # Point the camera towards the origin\n\n# Ensure the camera is in the scene and set it as the active camera\nif \"Camera\" in bpy.data.objects:\n    bpy.context.scene.camera = camera\nelse:\n    print(\"Camera not added!\")\n\n# Add a light\nbpy.ops.object.light_add(type=\"SUN\", location=(15, -15, 15))\n\n# 3. Set up the render settings\nbpy.context.scene.render.engine = \"CYCLES\"  # or 'EEVEE'\nbpy.context.scene.render.image_settings.file_format = \"FFMPEG\"\nbpy.context.scene.render.ffmpeg.format = \"MPEG4\"\nbpy.context.scene.render.ffmpeg.codec = \"H264\"\nbpy.context.scene.render.ffmpeg.constant_rate_factor = \"MEDIUM\"\nbpy.context.scene.render.filepath = \".data/dragon/dragon.mp4\"\nbpy.context.scene.frame_start = 1\nbpy.context.scene.frame_end = 250  # Adjust this based on your needs\n\n# 4. Render the animation\nbpy.ops.render.render(animation=True)\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/fbx_to_mesh.py",
    "content": "import bpy\nimport os\nimport sys\n\n\ndef convert_to_mesh(fbx_path, output_dir):\n    bpy.ops.import_scene.fbx(filepath=fbx_path)\n\n    # Assuming the imported object is the active object\n    obj = bpy.context.active_object\n    for obj in bpy.context.selected_objects:\n        print(\"obj: \", obj.name, obj.type)\n    mesh_objects = [obj for obj in bpy.context.selected_objects if obj.type == \"MESH\"]\n\n    # Get the active object (assuming it's the imported FBX mesh)\n    obj = bpy.context.active_object\n\n    for obj in mesh_objects:\n        # Add the subdivision modifier\n        mod = obj.modifiers.new(name=\"Subdivision\", type=\"SUBSURF\")\n        mod.levels = 1  # Set this to the desired subdivision level\n        mod.render_levels = 2  # Set this to the desired subdivision level for rendering\n\n        # Apply the modifier\n        bpy.context.view_layer.objects.active = obj\n        bpy.ops.object.modifier_apply(modifier=mod.name)\n\n    # Set the start and end frames (modify these values if needed)\n    start_frame = bpy.context.scene.frame_start\n    end_frame = bpy.context.scene.frame_end\n\n    # Iterate through each frame and export\n    for frame in range(start_frame, end_frame + 1):\n        bpy.context.scene.frame_set(frame)\n\n        # Construct the filename\n        filename = os.path.join(\n            output_dir, f\"frame_{frame:04}.obj\"\n        )  # Change to .blend for Blender format\n\n        # Export to OBJ format\n        bpy.ops.export_scene.obj(filepath=filename, use_selection=True)\n\n        # Uncomment the line below and comment the above line if you want to export in Blender format\n        # bpy.ops.wm.save_as_mainfile(filepath=filename, copy=True)\n\n    print(\"Export complete!\")\n\n\ndef convert_obj_to_traj(meshes_dir):\n    import glob\n    import numpy as np\n    import point_cloud_utils as pcu\n\n    meshes = sorted(glob.glob(os.path.join(meshes_dir, \"*.obj\")))\n    print(\"total of {} meshes: \".format(len(meshes)), meshes[:5], \"....\")\n    traj = []\n    R_mat = np.array(\n        [[1.0, 0, 0], [0, 0, 1.0], [0, 1.0, 0]],\n    )\n    for mesh in meshes:\n        print(mesh)\n        verts, faces = pcu.load_mesh_vf(mesh)\n        verts = R_mat[np.newaxis, :, :] @ verts[:, :, np.newaxis]\n        verts = verts.squeeze(axis=-1)\n        traj.append(verts)\n    traj = np.array(traj)\n\n    print(\"final traj shape\", traj.shape)\n\n    save_path = os.path.join(meshes_dir, \"traj.npy\")\n    np.save(save_path, traj)\n\n    save_path = os.path.join(meshes_dir, \"../\", \"traj.npy\")\n    np.save(save_path, traj)\n\n\ndef main():\n    argv = sys.argv\n    argv = argv[argv.index(\"--\") + 1 :]  # get all args after \"--\"\n    print(argv)\n    inp_fpx_path = argv[0]  # input mesh path\n    output_dir = argv[1]  # output dir\n    # num_frames = int(argv[2])\n\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    output_dir = os.path.join(output_dir, \"meshes\")\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    convert_to_mesh(inp_fpx_path, output_dir)\n    convert_obj_to_traj(output_dir)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/fbx_to_mesh_flag.py",
    "content": "import bpy\nimport os\nimport sys\n\n\ndef convert_to_mesh(fbx_path, output_dir):\n    bpy.ops.import_scene.fbx(filepath=fbx_path)\n\n    # Assuming the imported object is the active object\n    original_obj = bpy.context.active_object\n\n    for obj in bpy.context.selected_objects:\n        print(\"obj: \", obj.name, obj.type)\n\n    # Duplicate the original object\n    bpy.ops.object.duplicate()\n    duplicate_obj = bpy.context.active_object\n\n    # Remove shape keys from the duplicate\n    if duplicate_obj.data.shape_keys:\n        bpy.context.view_layer.objects.active = duplicate_obj\n        bpy.ops.object.shape_key_remove(all=True)\n\n    # Add and apply the subdivision modifier to the duplicate\n    mod = duplicate_obj.modifiers.new(name=\"Subdivision\", type=\"SUBSURF\")\n    mod.levels = 1\n    mod.render_levels = 1\n    bpy.context.view_layer.objects.active = duplicate_obj\n    bpy.ops.object.modifier_apply(modifier=mod.name)\n\n    # Set the start and end frames\n    start_frame = bpy.context.scene.frame_start\n    end_frame = bpy.context.scene.frame_end\n\n    # Iterate through each frame, set the shape of the duplicate to match the original, and export\n    for frame in range(start_frame, end_frame + 1):\n        bpy.context.scene.frame_set(frame)\n\n        # Transfer shape from original to duplicate (this assumes the original animation uses shape keys)\n        if original_obj.data.shape_keys:\n            for key_block in original_obj.data.shape_keys.key_blocks:\n                duplicate_obj.data.vertices.foreach_set(\"co\", key_block.data[:])\n\n        # Construct the filename\n        filename = os.path.join(output_dir, f\"frame_{frame:04}.obj\")\n\n        # Export to OBJ format\n        bpy.ops.export_scene.obj(filepath=filename, use_selection=True)\n\n    print(\"Export complete!\")\n\n\ndef subdivde_mesh(mesh_directory, output_directory):\n    # Ensure the output directory exists\n    if not os.path.exists(output_directory):\n        os.makedirs(output_directory)\n\n    # List all files in the directory\n    all_files = os.listdir(mesh_directory)\n\n    # Filter for .obj files (or change to the format you're using)\n    obj_files = [f for f in all_files if f.endswith(\".obj\")]\n\n    for obj_file in obj_files:\n        # Construct full path\n        full_path = os.path.join(mesh_directory, obj_file)\n\n        # Clear existing mesh data\n        bpy.ops.object.select_all(action=\"DESELECT\")\n        bpy.ops.object.select_by_type(type=\"MESH\")\n        bpy.ops.object.delete()\n\n        # Import the mesh\n        bpy.ops.import_scene.obj(filepath=full_path)\n\n        # Select all imported objects (assuming they are meshes)\n        bpy.ops.object.select_all(action=\"SELECT\")\n\n        # Apply subdivision\n        for obj in bpy.context.selected_objects:\n            if obj.type == \"MESH\":\n                print(\"apply subdivide to: \", obj.name)\n                mod = obj.modifiers.new(name=\"Subdivision\", type=\"SUBSURF\")\n                mod.levels = 1\n                mod.render_levels = 1\n                bpy.context.view_layer.objects.active = obj\n                bpy.ops.object.modifier_apply(modifier=mod.name)\n\n        # Export the mesh with subdivision\n        output_path = os.path.join(output_directory, obj_file)\n        bpy.ops.export_scene.obj(filepath=output_path, use_selection=True)\n\n    print(\"Processing complete!\")\n\n\ndef convert_obj_to_traj(meshes_dir):\n    import glob\n    import numpy as np\n    import point_cloud_utils as pcu\n\n    meshes = sorted(glob.glob(os.path.join(meshes_dir, \"*.obj\")))\n    print(\"total of {} meshes: \".format(len(meshes)), meshes[:5], \"....\")\n    traj = []\n    R_mat = np.array(\n        [[1.0, 0, 0], [0, 0, 1.0], [0, 1.0, 0]],\n    )\n    for mesh in meshes:\n        print(mesh)\n        verts, faces = pcu.load_mesh_vf(mesh)\n        verts = R_mat[np.newaxis, :, :] @ verts[:, :, np.newaxis]\n        verts = verts.squeeze(axis=-1)\n        traj.append(verts)\n    traj = np.array(traj)\n\n    print(\"final traj shape\", traj.shape)\n\n    save_path = os.path.join(meshes_dir, \"traj.npy\")\n    np.save(save_path, traj)\n\n    save_path = os.path.join(meshes_dir, \"../\", \"traj.npy\")\n    np.save(save_path, traj)\n\n\ndef main():\n    argv = sys.argv\n    argv = argv[argv.index(\"--\") + 1 :]  # get all args after \"--\"\n    print(argv)\n    inp_fpx_path = argv[0]  # input mesh path\n    output_dir = argv[1]  # output dir\n    # num_frames = int(argv[2])\n\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    output_dir = os.path.join(output_dir, \"meshes\")\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    # convert_to_mesh(inp_fpx_path, output_dir)\n    dense_output_dir = os.path.join(output_dir, \"denser_mesh\")\n    os.makedirs(dense_output_dir, exist_ok=True)\n    subdivde_mesh(output_dir, dense_output_dir)\n    convert_obj_to_traj(dense_output_dir)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/render_blender_annimations.py",
    "content": "import bpy\nimport os\nimport numpy as np\nimport math\nimport sys\nimport struct\nimport collections\nfrom mathutils import Matrix, Quaternion\nfrom scipy.spatial.transform import Rotation\n\n\ndef focal2fov(focal, pixels):\n    return 2 * math.atan(pixels / (2 * focal))\n\n\ndef create_camera(location, rotation):\n    # Create a new camera\n    bpy.ops.object.camera_add(location=location, rotation=rotation)\n    return bpy.context.active_object\n\n\ndef set_camera_look_at(camera, target_point):\n    # Compute the direction vector from the camera to the target point\n    direction = target_point - camera.location\n    # Compute the rotation matrix to align the camera's -Z axis to this direction\n    rot_quat = direction.to_track_quat(\"-Z\", \"Y\")\n    camera.rotation_euler = rot_quat.to_euler()\n\n    return rot_quat\n\n\ndef setup_alpha_mask(obj_name, pass_index=1):\n    # Set the object's pass index\n    obj = bpy.data.objects[obj_name]\n    obj.pass_index = pass_index\n\n    # Enable the Object Index pass for the active render layer\n    bpy.context.view_layer.use_pass_object_index = True\n\n    # Enable 'Use Nodes':\n    bpy.context.scene.use_nodes = True\n    tree = bpy.context.scene.node_tree\n\n    # Clear default nodes\n    for node in tree.nodes:\n        tree.nodes.remove(node)\n\n    # Add Render Layers node\n    render_layers = tree.nodes.new(\"CompositorNodeRLayers\")\n\n    # Add Composite node (output)\n    composite = tree.nodes.new(\"CompositorNodeComposite\")\n\n    # Add ID Mask node\n    id_mask = tree.nodes.new(\"CompositorNodeIDMask\")\n    id_mask.index = pass_index\n\n    # Add Set Alpha node\n    set_alpha = tree.nodes.new(\"CompositorNodeSetAlpha\")\n\n    # Connect nodes\n    tree.links.new(render_layers.outputs[\"Image\"], set_alpha.inputs[\"Image\"])\n    tree.links.new(render_layers.outputs[\"IndexOB\"], id_mask.inputs[0])\n    tree.links.new(id_mask.outputs[0], set_alpha.inputs[\"Alpha\"])\n    tree.links.new(set_alpha.outputs[\"Image\"], composite.inputs[\"Image\"])\n\n\ndef render_scene(camera, output_path):\n    bpy.context.scene.render.film_transparent = True\n\n    setup_alpha_mask(\"MyMeshObject\", 1)\n    # Set the active camera\n    bpy.context.scene.render.image_settings.color_mode = \"RGBA\"\n\n    bpy.context.scene.camera = camera\n\n    # Set the output path for the render\n    bpy.context.scene.render.filepath = output_path\n\n    # Render the scene\n    bpy.ops.render.render(write_still=True)\n\n\ndef setup_light():\n    # Add first directional light (Sun lamp)\n    light_data_1 = bpy.data.lights.new(name=\"Directional_Light_1\", type=\"SUN\")\n    light_data_1.energy = 3  # Adjust energy as needed\n    light_1 = bpy.data.objects.new(name=\"Directional_Light_1\", object_data=light_data_1)\n    bpy.context.collection.objects.link(light_1)\n    light_1.location = (10, 10, 10)  # Adjust location as needed\n    light_1.rotation_euler = (\n        np.radians(45),\n        np.radians(0),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n    # Add second directional light (Sun lamp)\n    light_data_2 = bpy.data.lights.new(name=\"Directional_Light_2\", type=\"SUN\")\n    light_data_2.energy = 5  # Adjust energy as needed\n    light_2 = bpy.data.objects.new(name=\"Directional_Light_2\", object_data=light_data_2)\n    bpy.context.collection.objects.link(light_2)\n    light_2.location = (10, -10, 10)  # Adjust location as needed\n    light_2.rotation_euler = (\n        np.radians(45),\n        np.radians(180),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n\ndef create_mesh_from_data(vertices, faces):\n    # Clear existing mesh objects in the scene\n    bpy.ops.object.select_all(action=\"DESELECT\")\n    bpy.ops.object.select_by_type(type=\"MESH\")\n    bpy.ops.object.delete()\n\n    vertices_list = vertices.tolist()\n    faces_list = faces.tolist()\n\n    # Create a new mesh\n    mesh_name = \"MyMesh\"\n    mesh = bpy.data.meshes.new(name=mesh_name)\n    obj = bpy.data.objects.new(\"MyMeshObject\", mesh)\n\n    # Link it to the scene\n    bpy.context.collection.objects.link(obj)\n    bpy.context.view_layer.objects.active = obj\n    obj.select_set(True)\n\n    # Load the mesh data\n    mesh.from_pydata(vertices_list, [], faces_list)\n    mesh.update()\n\n    # mesh_data = bpy.data.meshes.new(mesh_name)\n    # mesh_data.from_pydata(vertices_list, [], faces_list)\n    # mesh_data.update()\n    # the_mesh = bpy.data.objects.new(mesh_name, mesh_data)\n    # the_mesh.data.vertex_colors.new()  # init color\n    # bpy.context.collection.objects.link(the_mesh)\n\n    # UV unwrap the mesh\n    bpy.ops.object.select_all(action=\"DESELECT\")\n    obj.select_set(True)\n    bpy.context.view_layer.objects.active = obj\n    bpy.ops.object.mode_set(mode=\"EDIT\")\n    bpy.ops.mesh.select_all(action=\"SELECT\")\n    bpy.ops.uv.smart_project()\n    bpy.ops.object.mode_set(mode=\"OBJECT\")\n\n    # Texture the mesh based on its normals\n    mat = bpy.data.materials.new(name=\"NormalMaterial\")\n    mat.use_nodes = True\n    bsdf = mat.node_tree.nodes[\"Principled BSDF\"]\n    normal_node = mat.node_tree.nodes.new(type=\"ShaderNodeNormal\")\n    geometry = mat.node_tree.nodes.new(type=\"ShaderNodeNewGeometry\")\n\n    # mat.node_tree.links.new(geometry.outputs[\"Normal\"], normal_node.inputs[\"Normal\"])\n    # mat.node_tree.links.new(normal_node.outputs[\"Dot\"], bsdf.inputs[\"Base Color\"])\n    mat.node_tree.links.new(geometry.outputs[\"Normal\"], bsdf.inputs[\"Base Color\"])\n\n    obj.data.materials.append(mat)\n\n    return None\n\n\nCameraModel = collections.namedtuple(\n    \"CameraModel\", [\"model_id\", \"model_name\", \"num_params\"]\n)\nCamera = collections.namedtuple(\"Camera\", [\"id\", \"model\", \"width\", \"height\", \"params\"])\nBaseImage = collections.namedtuple(\n    \"Image\", [\"id\", \"qvec\", \"tvec\", \"camera_id\", \"name\", \"xys\", \"point3D_ids\"]\n)\nPoint3D = collections.namedtuple(\n    \"Point3D\", [\"id\", \"xyz\", \"rgb\", \"error\", \"image_ids\", \"point2D_idxs\"]\n)\n\n\nCAMERA_MODELS = {\n    CameraModel(model_id=0, model_name=\"SIMPLE_PINHOLE\", num_params=3),\n    CameraModel(model_id=1, model_name=\"PINHOLE\", num_params=4),\n    CameraModel(model_id=2, model_name=\"SIMPLE_RADIAL\", num_params=4),\n    CameraModel(model_id=3, model_name=\"RADIAL\", num_params=5),\n    CameraModel(model_id=4, model_name=\"OPENCV\", num_params=8),\n    CameraModel(model_id=5, model_name=\"OPENCV_FISHEYE\", num_params=8),\n    CameraModel(model_id=6, model_name=\"FULL_OPENCV\", num_params=12),\n    CameraModel(model_id=7, model_name=\"FOV\", num_params=5),\n    CameraModel(model_id=8, model_name=\"SIMPLE_RADIAL_FISHEYE\", num_params=4),\n    CameraModel(model_id=9, model_name=\"RADIAL_FISHEYE\", num_params=5),\n    CameraModel(model_id=10, model_name=\"THIN_PRISM_FISHEYE\", num_params=12),\n}\nCAMERA_MODEL_IDS = dict(\n    [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]\n)\nCAMERA_MODEL_NAMES = dict(\n    [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]\n)\n\n\ndef write_next_bytes(fid, data, format_char_sequence, endian_character=\"<\"):\n    \"\"\"pack and write to a binary file.\n    :param fid:\n    :param data: data to send, if multiple elements are sent at the same time,\n    they should be encapsuled either in a list or a tuple\n    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.\n    should be the same length as the data list or tuple\n    :param endian_character: Any of {@, =, <, >, !}\n    \"\"\"\n    if isinstance(data, (list, tuple)):\n        bytes = struct.pack(endian_character + format_char_sequence, *data)\n    else:\n        bytes = struct.pack(endian_character + format_char_sequence, data)\n    fid.write(bytes)\n\n\ndef write_cameras_binary(cameras, path_to_model_file):\n    \"\"\"\n    see: src/colmap/scene/reconstruction.cc\n        void Reconstruction::WriteCamerasBinary(const std::string& path)\n        void Reconstruction::ReadCamerasBinary(const std::string& path)\n    \"\"\"\n    with open(path_to_model_file, \"wb\") as fid:\n        write_next_bytes(fid, len(cameras), \"Q\")\n        for _, cam in cameras.items():\n            model_id = CAMERA_MODEL_NAMES[cam.model].model_id\n            camera_properties = [cam.id, model_id, cam.width, cam.height]\n            write_next_bytes(fid, camera_properties, \"iiQQ\")\n            for p in cam.params:\n                write_next_bytes(fid, float(p), \"d\")\n    return cameras\n\n\ndef write_images_binary(images, path_to_model_file):\n    \"\"\"\n    see: src/colmap/scene/reconstruction.cc\n        void Reconstruction::ReadImagesBinary(const std::string& path)\n        void Reconstruction::WriteImagesBinary(const std::string& path)\n    \"\"\"\n    with open(path_to_model_file, \"wb\") as fid:\n        write_next_bytes(fid, len(images), \"Q\")\n        for _, img in images.items():\n            write_next_bytes(fid, img.id, \"i\")\n            write_next_bytes(fid, img.qvec.tolist(), \"dddd\")\n            write_next_bytes(fid, img.tvec.tolist(), \"ddd\")\n            write_next_bytes(fid, img.camera_id, \"i\")\n            for char in img.name:\n                write_next_bytes(fid, char.encode(\"utf-8\"), \"c\")\n            write_next_bytes(fid, b\"\\x00\", \"c\")\n            write_next_bytes(fid, len(img.point3D_ids), \"Q\")\n            for xy, p3d_id in zip(img.xys, img.point3D_ids):\n                write_next_bytes(fid, [*xy, p3d_id], \"ddq\")\n\n\ndef write_points3D_binary(points3D, path_to_model_file):\n    \"\"\"\n    see: src/colmap/scene/reconstruction.cc\n        void Reconstruction::ReadPoints3DBinary(const std::string& path)\n        void Reconstruction::WritePoints3DBinary(const std::string& path)\n    \"\"\"\n    with open(path_to_model_file, \"wb\") as fid:\n        write_next_bytes(fid, len(points3D), \"Q\")\n        for _, pt in points3D.items():\n            write_next_bytes(fid, pt.id, \"Q\")\n            write_next_bytes(fid, pt.xyz.tolist(), \"ddd\")\n            write_next_bytes(fid, pt.rgb.tolist(), \"BBB\")\n            write_next_bytes(fid, pt.error, \"d\")\n            track_length = pt.image_ids.shape[0]\n            write_next_bytes(fid, track_length, \"Q\")\n            for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):\n                write_next_bytes(fid, [image_id, point2D_id], \"ii\")\n\n\ndef get_colmap_camera(camera_obj, render_resolution):\n    \"\"\"\n    Extract the intrinsic matrix from a Blender camera.\n\n    Args:\n    - camera_obj: The Blender camera object.\n    - render_resolution: Tuple of (width, height) indicating the render resolution.\n\n    Returns:\n    - colmap_camera: dict of [\"id\", \"model\", \"width\", \"height\", \"params\"]\n    \"\"\"\n\n    # Get the camera data\n    cam = camera_obj.data\n\n    # Ensure it's a perspective camera\n    if cam.type != \"PERSP\":\n        raise ValueError(\"Only 'PERSP' camera type is supported.\")\n\n    # Image resolution\n    width, height = render_resolution\n\n    # Sensor width and height in millimeters\n    sensor_width_mm = cam.sensor_width\n    sensor_height_mm = cam.sensor_height\n\n    # Calculate the focal length in pixels\n    fx = (cam.lens / sensor_width_mm) * width\n    fy = (cam.lens / sensor_height_mm) * height\n\n    # Principal point, usually at the center of the image\n    cx = width / 2.0\n    cy = height / 2.0\n\n    _cam_dict = {\n        \"id\": 0,\n        \"model\": \"PINHOLE\",  # PINHOLE\n        \"width\": width,\n        \"height\": height,\n        \"params\": [fx, fy, cx, cy],\n    }\n\n    colmap_cameras = {0: Camera(**_cam_dict)}\n\n    print(\"focal\", fx, fy, cx, cy)\n\n    return colmap_cameras\n\n\ndef main():\n    import point_cloud_utils as pcu\n\n    argv = sys.argv\n    argv = argv[argv.index(\"--\") + 1 :]  # get all args after \"--\"\n    print(argv)\n    inp_mesh_path = argv[0]  # input mesh path\n    output_dir = argv[1]  # output dir\n    # num_frames = int(argv[2])\n\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    img_output_dir = os.path.join(output_dir, \"images\")\n    if not os.path.exists(img_output_dir):\n        os.makedirs(img_output_dir)\n\n    tmp_mesh_path = \"data/deer_attack/mesh0.obj\"\n\n    vertices, faces = pcu.load_mesh_vf(tmp_mesh_path)\n    # normalize\n    verices_center = np.mean(vertices, axis=0)\n    max_range = np.max(np.max(vertices, axis=0) - np.min(vertices, axis=0))\n    print(\n        max_range.shape, max_range, verices_center.shape, verices_center, vertices.shape\n    )\n\n    vertices, faces = pcu.load_mesh_vf(inp_mesh_path)\n\n    mesh_name = os.path.basename(inp_mesh_path).split(\".\")[0]\n\n    vertices = (vertices - verices_center[np.newaxis, :]) / max_range\n\n    # Create the 3D mesh in Blender from your data.\n    obj = create_mesh_from_data(vertices, faces)\n\n    object_center = bpy.context.scene.objects[\"MyMeshObject\"].location\n\n    # Number of viewpoints\n    num_views = 180  # 180\n    radius = 6  # Distance of the camera from the object center\n\n    setup_light()\n    # Set up rendering parameters\n    bpy.context.scene.render.image_settings.file_format = \"PNG\"\n    bpy.context.scene.render.resolution_x = 1080\n    bpy.context.scene.render.resolution_y = 720\n\n    camera = create_camera((1, 1, 1), (0, 0, 0))\n    colmap_camera_dict = get_colmap_camera(\n        camera,\n        (bpy.context.scene.render.resolution_x, bpy.context.scene.render.resolution_y),\n    )\n\n    transform_dict = {\n        \"frames\": [],\n        \"camera_angle_x\": focal2fov(\n            colmap_camera_dict[0].params[0], colmap_camera_dict[0].width\n        ),\n    }\n    img_indx = 0\n    num_elevations = 6\n    colmap_images_dict = {}\n    for j in range(num_elevations):\n        num_imgs = num_views // num_elevations\n        for i in range(num_imgs):\n            angle = 2 * math.pi * i / num_imgs\n            x = object_center.x + radius * math.cos(angle)\n            y = object_center.y + radius * math.sin(angle)\n            z = (\n                object_center.z + (j - num_elevations / 3.0) * 4.0 / num_elevations\n            )  # Adjust this if you want the camera to be above or below the object's center\n\n            camera = create_camera((x, y, z), (0, 0, 0))\n            rot_quant = set_camera_look_at(camera, object_center)\n            tvec = np.array([x, y, z])\n            bpy.context.view_layer.update()\n\n            # plan-1\n            # w2c = np.array(camera.matrix_world.inverted())\n            # w2c[1:3, :] *= -1.0\n            # rotation_matrix = w2c[:3, :3]\n            # tvec = w2c[:3, 3]\n            # plan-1 end\n\n            # plan-2\n            camera_to_world_matrix = camera.matrix_world\n            # [4, 4]\n            camera_to_world_matrix = np.array(camera_to_world_matrix).copy()\n            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)\n            camera_to_world_matrix[:3, 1:3] *= -1.0\n            w2c = np.linalg.inv(camera_to_world_matrix)\n            rotation_matrix = w2c[:3, :3]\n            tvec = w2c[:3, 3]\n\n            # c2w rotation\n            # rotation_matrix = rot_quant.to_matrix()  # .to_4x4()\n            # # w2c rotation\n            # rotation_matrix = np.array(rotation_matrix)\n            # # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)\n            # rotation_matrix[:3, 1:3] *= -1.0\n            # rotation_matrix = rotation_matrix.transpose()\n            # tvec = (rotation_matrix @ tvec[:, np.newaxis]).squeeze(axis=-1) * -1.0\n\n            rot_quant = Rotation.from_matrix(rotation_matrix).as_quat()\n            # print(\"r shape\", rotation_matrix.shape, tvec.shape)\n\n            img_dict = {\n                \"id\": img_indx,\n                \"qvec\": rot_quant,\n                \"tvec\": tvec,\n                \"camera_id\": 0,\n                \"name\": f\"img_{img_indx}.png\",\n                \"xys\": [[k, k] for k in range(i, i + 10)],  # placeholder\n                \"point3D_ids\": list(range(i, i + 10)),  # placeholder\n            }\n            colmap_images_dict[img_indx] = BaseImage(**img_dict)\n\n            # also prepare transforms.json\n            fname = f\"images/img_{img_indx}\"\n            cam2world = np.array(camera.matrix_world)\n            transform_dict[\"frames\"].append(\n                {\"file_path\": fname, \"transform_matrix\": cam2world.tolist()}\n            )\n\n            render_scene(camera, os.path.join(img_output_dir, f\"img_{mesh_name}.png\"))\n            img_indx += 1\n\n            return\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/render_fbx_first_frame.py",
    "content": "import bpy\nimport os\nimport numpy as np\nimport math\nimport sys\nimport struct\nimport collections\nfrom mathutils import Matrix, Quaternion, Vector\nfrom scipy.spatial.transform import Rotation\n\n\ndef focal2fov(focal, pixels):\n    return 2 * math.atan(pixels / (2 * focal))\n\n\ndef create_camera(location, rotation):\n    # Create a new camera\n    bpy.ops.object.camera_add(location=location, rotation=rotation)\n    return bpy.context.active_object\n\n\ndef set_camera_look_at(camera, target_point):\n    # Compute the direction vector from the camera to the target point\n    direction = target_point - camera.location\n    # Compute the rotation matrix to align the camera's -Z axis to this direction\n    rot_quat = direction.to_track_quat(\"-Z\", \"Y\")\n    camera.rotation_euler = rot_quat.to_euler()\n\n    return rot_quat\n\n\ndef setup_alpha_mask(obj_name, pass_index=1):\n    # Set the object's pass index\n    obj = bpy.data.objects[obj_name]\n    obj.pass_index = pass_index\n\n    # Enable the Object Index pass for the active render layer\n    bpy.context.view_layer.use_pass_object_index = True\n\n    # Enable 'Use Nodes':\n    bpy.context.scene.use_nodes = True\n    tree = bpy.context.scene.node_tree\n\n    # Clear default nodes\n    for node in tree.nodes:\n        tree.nodes.remove(node)\n\n    # Add Render Layers node\n    render_layers = tree.nodes.new(\"CompositorNodeRLayers\")\n\n    # Add Composite node (output)\n    composite = tree.nodes.new(\"CompositorNodeComposite\")\n\n    # Add ID Mask node\n    id_mask = tree.nodes.new(\"CompositorNodeIDMask\")\n    id_mask.index = pass_index\n\n    # Add Set Alpha node\n    set_alpha = tree.nodes.new(\"CompositorNodeSetAlpha\")\n\n    # Connect nodes\n    tree.links.new(render_layers.outputs[\"Image\"], set_alpha.inputs[\"Image\"])\n    tree.links.new(render_layers.outputs[\"IndexOB\"], id_mask.inputs[0])\n    tree.links.new(id_mask.outputs[0], set_alpha.inputs[\"Alpha\"])\n    tree.links.new(set_alpha.outputs[\"Image\"], composite.inputs[\"Image\"])\n\n\ndef render_scene(camera, output_path, mask_name=\"U3DMesh\"):\n    bpy.context.scene.render.film_transparent = True\n\n    setup_alpha_mask(mask_name, 1)\n    # Set the active camera\n    bpy.context.scene.render.image_settings.color_mode = \"RGBA\"\n\n    bpy.context.scene.camera = camera\n\n    # Set the output path for the render\n    bpy.context.scene.render.filepath = output_path\n\n    # Render the scene\n    bpy.ops.render.render(write_still=True)\n\n\ndef normalize_mesh(obj):\n    # Ensure the object is a mesh\n    if obj.type != \"MESH\":\n        print(f\"{obj.name} is not a mesh object.\")\n        return\n\n    # Calculate the mesh's bounding box dimensions\n    bbox = obj.bound_box\n    dimensions = obj.dimensions\n\n    # Calculate the center of the bounding box\n    center = [(bbox[i][0] + bbox[(i + 4) % 8][0]) * 0.5 for i in range(4)]\n\n    # Move the object to the origin based on its bounding box center\n    obj.location = [\n        0.0,\n        0.0,\n        0.0,\n    ]  # [-center[0], -center[1], -center[2]]  #  [0.0, 0.0, 0.0]\n\n    # Calculate the scaling factor based on the largest dimension\n    scale_factor = 1.0 / max(dimensions)\n    print(obj.scale, \"prev\")\n    # Apply the scale to the object\n    obj.scale = [scale_factor] * 3\n\n    print(\"scalar\", obj.scale, scale_factor)\n    # Update the scene (important for getting correct visual updates)\n    bpy.context.view_layer.update()\n\n\ndef setup_light():\n    # Add first directional light (Sun lamp)\n    light_data_1 = bpy.data.lights.new(name=\"Directional_Light_1\", type=\"SUN\")\n    light_data_1.energy = 3  # Adjust energy as needed\n    light_1 = bpy.data.objects.new(name=\"Directional_Light_1\", object_data=light_data_1)\n    bpy.context.collection.objects.link(light_1)\n    light_1.location = (20, 20, 20)  # Adjust location as needed\n    light_1.rotation_euler = (\n        np.radians(45),\n        np.radians(0),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n    # Add second directional light (Sun lamp)\n    light_data_2 = bpy.data.lights.new(name=\"Directional_Light_2\", type=\"SUN\")\n    light_data_2.energy = 5  # Adjust energy as needed\n    light_2 = bpy.data.objects.new(name=\"Directional_Light_2\", object_data=light_data_2)\n    bpy.context.collection.objects.link(light_2)\n    light_2.location = (20, -20, 20)  # Adjust location as needed\n    light_2.rotation_euler = (\n        np.radians(45),\n        np.radians(180),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n    # Add second directional light (Sun lamp)\n    light_data_3 = bpy.data.lights.new(name=\"Directional_Light_3\", type=\"SUN\")\n    light_data_3.energy = 3  # Adjust energy as needed\n    light_3 = bpy.data.objects.new(name=\"Directional_Light_3\", object_data=light_data_2)\n    bpy.context.collection.objects.link(light_3)\n    light_3.location = (-20, 20, 20)  # Adjust location as needed\n    light_3.rotation_euler = (\n        np.radians(-135),\n        np.radians(0),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n\ndef create_mesh_from_fpx(fbx_path):\n    # Clear existing mesh objects in the scene\n    bpy.ops.object.select_all(action=\"DESELECT\")\n    bpy.ops.object.select_by_type(type=\"MESH\")\n    bpy.ops.object.delete()\n\n    bpy.ops.import_scene.fbx(filepath=fbx_path, use_image_search=True)\n\n    # Assuming the imported object is the active object\n    obj = bpy.context.active_object\n    for obj in bpy.context.selected_objects:\n        print(\"obj: \", obj.name, obj.type)\n    mesh_objects = [obj for obj in bpy.context.selected_objects if obj.type == \"MESH\"]\n\n    return None\n\n\nCamera = collections.namedtuple(\"Camera\", [\"id\", \"model\", \"width\", \"height\", \"params\"])\n\n\ndef get_colmap_camera(camera_obj, render_resolution):\n    \"\"\"\n    Extract the intrinsic matrix from a Blender camera.\n\n    Args:\n    - camera_obj: The Blender camera object.\n    - render_resolution: Tuple of (width, height) indicating the render resolution.\n\n    Returns:\n    - colmap_camera: dict of [\"id\", \"model\", \"width\", \"height\", \"params\"]\n    \"\"\"\n\n    # Get the camera data\n    cam = camera_obj.data\n\n    # Ensure it's a perspective camera\n    if cam.type != \"PERSP\":\n        raise ValueError(\"Only 'PERSP' camera type is supported.\")\n\n    # Image resolution\n    width, height = render_resolution\n\n    # Sensor width and height in millimeters\n    sensor_width_mm = cam.sensor_width\n    sensor_height_mm = cam.sensor_height\n\n    # Calculate the focal length in pixels\n    fx = (cam.lens / sensor_width_mm) * width\n    fy = (cam.lens / sensor_height_mm) * height\n\n    # Principal point, usually at the center of the image\n    cx = width / 2.0\n    cy = height / 2.0\n\n    _cam_dict = {\n        \"id\": 0,\n        \"model\": \"PINHOLE\",  # PINHOLE\n        \"width\": width,\n        \"height\": height,\n        \"params\": [fx, fy, cx, cy],\n    }\n\n    colmap_cameras = {0: Camera(**_cam_dict)}\n\n    print(\"focal\", fx, fy, cx, cy)\n\n    return colmap_cameras\n\n\ndef get_textures(\n    texture_dir=\"/local/cg/rundi/data/motion_dataset/pirate-flag-animated/source/textures\",\n):\n    # Ensure the \"flag\" object is selected\n    # bpy.context.view_layer.objects.active = bpy.data.objects[\"flag\"]\n\n    obj = bpy.data.objects[\"flag\"]\n\n    # Create a new material or get the existing one\n    if not obj.data.materials:\n        mat = bpy.data.materials.new(name=\"FBX_Material\")\n        obj.data.materials.append(mat)\n    else:\n        mat = obj.data.materials[0]\n\n    # Use nodes for the material\n    mat.use_nodes = True\n    nodes = mat.node_tree.nodes\n\n    # Clear default nodes\n    for node in nodes:\n        nodes.remove(node)\n\n    # Add a Principled BSDF shader and connect it to the Material Output\n    shader = nodes.new(type=\"ShaderNodeBsdfPrincipled\")\n    shader.location = (0, 0)\n\n    output = nodes.new(type=\"ShaderNodeOutputMaterial\")\n    output.location = (400, 0)\n    mat.node_tree.links.new(shader.outputs[\"BSDF\"], output.inputs[\"Surface\"])\n\n    # Load textures and create the corresponding nodes\n\n    textures = {\n        \"Base Color\": os.path.join(texture_dir, \"pirate_flag_albedo.jpg\"),\n        \"Metallic\": os.path.join(texture_dir, \"pirate_flag_metallic.jpg\"),\n        \"Normal\": os.path.join(texture_dir, \"pirate_flag_normal.png\"),\n        \"Roughness\": os.path.join(texture_dir, \"pirate_flag_roughness.jpg\"),\n    }\n    # ... [rest of the script]\n\n    ao_texture = nodes.new(type=\"ShaderNodeTexImage\")\n    ao_texture.location = (-400, -200)\n    ao_texture.image = bpy.data.images.load(\n        filepath=os.path.join(texture_dir, \"pirate_flag_AO.jpg\")\n    )  # Adjust filepath if needed\n\n    mix_rgb = nodes.new(type=\"ShaderNodeMixRGB\")\n    mix_rgb.location = (-200, 0)\n    mix_rgb.blend_type = \"MULTIPLY\"\n    mix_rgb.inputs[\n        0\n    ].default_value = 1.0  # Factor to 1 to fully use the multiply operation\n\n    mat.node_tree.links.new(ao_texture.outputs[\"Color\"], mix_rgb.inputs[2])\n\n    for i, (input_name, filename) in enumerate(textures.items()):\n        tex_image = nodes.new(type=\"ShaderNodeTexImage\")\n        tex_image.location = (-400, i * 200)\n        tex_image.image = bpy.data.images.load(\n            filepath=filename\n        )  # Adjust filepath if needed\n\n        if input_name == \"Base Color\":\n            mat.node_tree.links.new(tex_image.outputs[\"Color\"], mix_rgb.inputs[1])\n            mat.node_tree.links.new(mix_rgb.outputs[\"Color\"], shader.inputs[input_name])\n        elif input_name == \"Normal\":\n            normal_map_node = nodes.new(type=\"ShaderNodeNormalMap\")\n            normal_map_node.location = (-200, i * 200)\n            mat.node_tree.links.new(\n                tex_image.outputs[\"Color\"], normal_map_node.inputs[\"Color\"]\n            )\n            mat.node_tree.links.new(\n                normal_map_node.outputs[\"Normal\"], shader.inputs[\"Normal\"]\n            )\n        else:\n            mat.node_tree.links.new(\n                tex_image.outputs[\"Color\"], shader.inputs[input_name]\n            )\n\n\ndef main():\n    import point_cloud_utils as pcu\n\n    argv = sys.argv\n    argv = argv[argv.index(\"--\") + 1 :]  # get all args after \"--\"\n    print(argv)\n    inp_fpx_path = argv[0]  # input mesh path\n    output_dir = argv[1]  # output dir\n    # num_frames = int(argv[2])\n\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    img_output_dir = os.path.join(output_dir, \"images\")\n    if not os.path.exists(img_output_dir):\n        os.makedirs(img_output_dir)\n\n    # Create the 3D mesh in Blender from your data. no normalize\n    obj = create_mesh_from_fpx(inp_fpx_path)\n    my_mesh_name = \"U3DMesh\"  #  \"flag\"  # \"U3DMesh\"  # \"flag\"  # \"U3DMesh\" for dragon\n\n    # get_textures()\n    # normalize_mesh(bpy.context.scene.objects[my_mesh_name])\n    object_center = bpy.context.scene.objects[my_mesh_name].location + Vector(\n        (0, 0, 2)\n    )  # (0, 0, 2) for dragon\n\n    print(\"look at object center: \", object_center)\n    # Number of viewpoints\n    num_views = 270  # 180   # 240 for dragon\n    radius = 20  #  16 for dragon\n\n    setup_light()\n    # Set up rendering parameters\n    bpy.context.scene.render.image_settings.file_format = \"PNG\"\n    bpy.context.scene.render.resolution_x = 1080\n    bpy.context.scene.render.resolution_y = 720\n\n    camera = create_camera((1, 1, 1), (0, 0, 0))\n    colmap_camera_dict = get_colmap_camera(\n        camera,\n        (bpy.context.scene.render.resolution_x, bpy.context.scene.render.resolution_y),\n    )\n\n    transform_dict = {\n        \"frames\": [],\n        \"camera_angle_x\": focal2fov(\n            colmap_camera_dict[0].params[0], colmap_camera_dict[0].width\n        ),\n    }\n    img_indx = 0\n    num_elevations = 8\n    for j in range(num_elevations):\n        num_imgs = num_views // num_elevations\n        for i in range(num_imgs):\n            angle = 2 * math.pi * i / num_imgs\n            x = object_center.x + radius * math.cos(angle)\n            y = object_center.y + radius * math.sin(angle)\n            z = (\n                object_center.z\n                + (j - num_elevations / 2.0) * (radius * 2) / num_elevations\n            )  # Adjust this if you want the camera to be above or below the object's center\n\n            camera = create_camera((x, y, z), (0, 0, 0))\n            rot_quant = set_camera_look_at(camera, object_center)\n            tvec = np.array([x, y, z])\n            bpy.context.view_layer.update()\n\n            # also prepare transforms.json\n            fname = f\"images/img_{img_indx}\"\n            cam2world = np.array(camera.matrix_world)\n            transform_dict[\"frames\"].append(\n                {\"file_path\": fname, \"transform_matrix\": cam2world.tolist()}\n            )\n\n            render_scene(\n                camera,\n                os.path.join(img_output_dir, f\"img_{img_indx}.png\"),\n                my_mesh_name,\n            )\n            img_indx += 1\n\n    trans_fpath = os.path.join(output_dir, \"transforms_train.json\")\n    import json\n\n    with open(trans_fpath, \"w\") as f:\n        json.dump(transform_dict, f)\n\n    transform_dict[\"frames\"] = transform_dict[\"frames\"][::10]\n    trans_fpath = os.path.join(output_dir, \"transforms_test.json\")\n\n    with open(trans_fpath, \"w\") as f:\n        json.dump(transform_dict, f)\n\n\ndef find_material():\n    import point_cloud_utils as pcu\n\n    argv = sys.argv\n    argv = argv[argv.index(\"--\") + 1 :]  # get all args after \"--\"\n    print(argv)\n    inp_fpx_path = argv[0]  # input mesh path\n    output_dir = argv[1]  # output dir\n\n    bpy.ops.import_scene.fbx(filepath=inp_fpx_path)\n\n    print(\"inspecting materials\")\n    for material in bpy.data.materials:\n        if material.use_nodes:\n            for node in material.node_tree.nodes:\n                if node.type == \"TEX_IMAGE\":\n                    print(\n                        f\"Material: {material.name}, Image: {node.image.name}, Path: {node.image.filepath}\"\n                    )\n\n\nif __name__ == \"__main__\":\n    main()\n    # find_material()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/render_obj.py",
    "content": "import bpy\nimport os\nimport numpy as np\nimport math\nimport sys\nimport struct\nimport collections\nfrom mathutils import Matrix, Quaternion, Vector\nfrom scipy.spatial.transform import Rotation\n\n\ndef focal2fov(focal, pixels):\n    return 2 * math.atan(pixels / (2 * focal))\n\n\ndef create_camera(location, rotation):\n    # Create a new camera\n    bpy.ops.object.camera_add(location=location, rotation=rotation)\n    return bpy.context.active_object\n\n\ndef set_camera_look_at(camera, target_point):\n    # Compute the direction vector from the camera to the target point\n    direction = target_point - camera.location\n    # Compute the rotation matrix to align the camera's -Z axis to this direction\n    rot_quat = direction.to_track_quat(\"-Z\", \"Y\")\n    camera.rotation_euler = rot_quat.to_euler()\n\n    return rot_quat\n\n\ndef setup_alpha_mask(obj_name, pass_index=1):\n    # Set the object's pass index\n    obj = bpy.data.objects[obj_name]\n    obj.pass_index = pass_index\n\n    # Enable the Object Index pass for the active render layer\n    bpy.context.view_layer.use_pass_object_index = True\n\n    # Enable 'Use Nodes':\n    bpy.context.scene.use_nodes = True\n    tree = bpy.context.scene.node_tree\n\n    # Clear default nodes\n    for node in tree.nodes:\n        tree.nodes.remove(node)\n\n    # Add Render Layers node\n    render_layers = tree.nodes.new(\"CompositorNodeRLayers\")\n\n    # Add Composite node (output)\n    composite = tree.nodes.new(\"CompositorNodeComposite\")\n\n    # Add ID Mask node\n    id_mask = tree.nodes.new(\"CompositorNodeIDMask\")\n    id_mask.index = pass_index\n\n    # Add Set Alpha node\n    set_alpha = tree.nodes.new(\"CompositorNodeSetAlpha\")\n\n    # Connect nodes\n    tree.links.new(render_layers.outputs[\"Image\"], set_alpha.inputs[\"Image\"])\n    tree.links.new(render_layers.outputs[\"IndexOB\"], id_mask.inputs[0])\n    tree.links.new(id_mask.outputs[0], set_alpha.inputs[\"Alpha\"])\n    tree.links.new(set_alpha.outputs[\"Image\"], composite.inputs[\"Image\"])\n\n\ndef render_scene(camera, output_path, mask_name=\"U3DMesh\"):\n    bpy.context.scene.render.film_transparent = True\n\n    setup_alpha_mask(mask_name, 1)\n    # Set the active camera\n    bpy.context.scene.render.image_settings.color_mode = \"RGBA\"\n\n    bpy.context.scene.camera = camera\n\n    # Set the output path for the render\n    bpy.context.scene.render.filepath = output_path\n\n    # Render the scene\n    bpy.ops.render.render(write_still=True)\n\n\ndef setup_light():\n    # Add first directional light (Sun lamp)\n    light_data_1 = bpy.data.lights.new(name=\"Directional_Light_1\", type=\"SUN\")\n    light_data_1.energy = 3  # Adjust energy as needed\n    light_1 = bpy.data.objects.new(name=\"Directional_Light_1\", object_data=light_data_1)\n    bpy.context.collection.objects.link(light_1)\n    light_1.location = (20, 20, 20)  # Adjust location as needed\n    light_1.rotation_euler = (\n        np.radians(45),\n        np.radians(0),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n    # Add second directional light (Sun lamp)\n    light_data_2 = bpy.data.lights.new(name=\"Directional_Light_2\", type=\"SUN\")\n    light_data_2.energy = 5  # Adjust energy as needed\n    light_2 = bpy.data.objects.new(name=\"Directional_Light_2\", object_data=light_data_2)\n    bpy.context.collection.objects.link(light_2)\n    light_2.location = (20, -20, 20)  # Adjust location as needed\n    light_2.rotation_euler = (\n        np.radians(45),\n        np.radians(180),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n    # Add second directional light (Sun lamp)\n    light_data_3 = bpy.data.lights.new(name=\"Directional_Light_3\", type=\"SUN\")\n    light_data_3.energy = 3  # Adjust energy as needed\n    light_3 = bpy.data.objects.new(name=\"Directional_Light_3\", object_data=light_data_2)\n    bpy.context.collection.objects.link(light_3)\n    light_3.location = (-20, 20, 20)  # Adjust location as needed\n    light_3.rotation_euler = (\n        np.radians(-135),\n        np.radians(0),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n\ndef create_mesh_from_obj(obj_file_path):\n    # Clear existing mesh objects in the scene\n    bpy.ops.object.select_all(action=\"DESELECT\")\n    bpy.ops.object.select_by_type(type=\"MESH\")\n    bpy.ops.object.delete()\n\n    bpy.ops.import_scene.obj(filepath=obj_file_path)\n\n    # Assuming the imported object is the active object\n    obj = bpy.context.active_object\n    num_obj = 0\n    for obj in bpy.context.selected_objects:\n        print(\"obj mesh name: \", obj.name, obj.type)\n        num_obj += 1\n    if num_obj > 1:\n        raise ValueError(\"More than one object in the scene.\")\n    mesh_objects = [obj for obj in bpy.context.selected_objects if obj.type == \"MESH\"]\n\n    return obj.name, mesh_objects\n\n\ndef get_focal_length(camera_obj, render_resolution):\n    \"\"\"\n    Extract the intrinsic matrix from a Blender camera.\n\n    Args:\n    - camera_obj: The Blender camera object.\n    - render_resolution: Tuple of (width, height) indicating the render resolution.\n\n    Returns:\n    - colmap_camera: dict of [\"id\", \"model\", \"width\", \"height\", \"params\"]\n    \"\"\"\n\n    # Get the camera data\n    cam = camera_obj.data\n\n    # Ensure it's a perspective camera\n    if cam.type != \"PERSP\":\n        raise ValueError(\"Only 'PERSP' camera type is supported.\")\n\n    # Image resolution\n    width, height = render_resolution\n\n    # Sensor width and height in millimeters\n    sensor_width_mm = cam.sensor_width\n    sensor_height_mm = cam.sensor_height\n\n    # Calculate the focal length in pixels\n    fx = (cam.lens / sensor_width_mm) * width\n    fy = (cam.lens / sensor_height_mm) * height\n\n    return fx, fy\n\n\ndef normalize_mesh(transform_meta_path, mesh_objects):\n    import json\n\n    if os.path.exists(transform_meta_path):\n        with open(transform_meta_path, \"r\") as f:\n            meta_dict = json.load(f)\n    # obj = bpy.context.active_object\n\n    for obj in mesh_objects:\n        # Ensure the object is in object mode\n        # bpy.ops.object.mode_set(mode=\"OBJECT\")\n\n        scale_ = 1.0 / meta_dict[\"scale\"]\n        center = Vector(meta_dict[\"center\"])\n        # Apply the scale\n        print(\"old scale: \", obj.scale)\n        # obj.location -= center\n        obj.scale *= scale_\n\n\ndef apply_rotation(mesh_objects):\n    for obj in mesh_objects:\n        R_np = [[1.0, 0, 0], [0, 0, 1.0], [0, 1.0, 0]]\n        R_blender = Matrix(R_np).transposed()\n\n        # Convert the rotation matrix to a quaternion\n        quaternion = R_blender.to_quaternion()\n\n        # Set the active object's rotation to this quaternion\n        print(\"rotation\", quaternion, obj.rotation_quaternion)\n        obj.rotation_quaternion = obj.rotation_quaternion @ quaternion\n\n\ndef main():\n    argv = sys.argv\n    argv = argv[argv.index(\"--\") + 1 :]  # get all args after \"--\"\n    print(argv)\n    inp_fpx_path = argv[0]  # input mesh path\n    output_dir = argv[1]  # output dir\n    num_views = int(argv[2])\n    radius = 5\n\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    img_output_dir = os.path.join(output_dir, \"images\")\n    if not os.path.exists(img_output_dir):\n        os.makedirs(img_output_dir)\n\n    transform_meta_path = os.path.join(os.path.dirname(inp_fpx_path), \"meta.json\")\n\n    # Create the 3D mesh in Blender from your data. no normalize\n    my_mesh_name, mesh_objects = create_mesh_from_obj(inp_fpx_path)\n\n    normalize_mesh(transform_meta_path, mesh_objects)\n    # apply_rotation(mesh_objects)\n\n    object_center = Vector((0.0, 0.0, 1.0))\n\n    print(\"look at object center: \", object_center)\n\n    setup_light()\n    # Set up rendering parameters\n    bpy.context.scene.render.image_settings.file_format = \"PNG\"\n    # bpy.context.scene.render.resolution_x = 1080\n    # bpy.context.scene.render.resolution_y = 720\n\n    bpy.context.scene.render.resolution_x = 720\n    bpy.context.scene.render.resolution_y = 480\n\n    camera = create_camera((1, 1, 1), (0, 0, 0))\n    fx, fy = get_focal_length(\n        camera,\n        (bpy.context.scene.render.resolution_x, bpy.context.scene.render.resolution_y),\n    )\n\n    transform_dict = {\n        \"frames\": [],\n        \"camera_angle_x\": focal2fov(fx, bpy.context.scene.render.resolution_x),\n    }\n    img_indx = 0\n    num_elevations = 6  # 9 for init gaussians\n    for j in range(num_elevations):\n        num_imgs = num_views // num_elevations\n        for i in range(num_imgs):\n            angle = 2 * math.pi * i / num_imgs\n            x = object_center.x + radius * math.cos(angle)\n            y = object_center.y + radius * math.sin(angle)\n            z = (\n                object_center.z\n                + (j - num_elevations / 2.0) * radius / num_elevations * 1.5\n            )  # Adjust this if you want the camera to be above or below the object's center\n\n            camera = create_camera((x, y, z), (0, 0, 0))\n            rot_quant = set_camera_look_at(camera, object_center)\n            bpy.context.view_layer.update()\n\n            # also prepare transforms.json\n            fname = f\"images/img_{img_indx}\"\n            cam2world = np.array(camera.matrix_world)\n            transform_dict[\"frames\"].append(\n                {\"file_path\": fname, \"transform_matrix\": cam2world.tolist()}\n            )\n\n            render_scene(\n                camera,\n                os.path.join(img_output_dir, f\"img_{img_indx}.png\"),\n                my_mesh_name,\n            )\n            img_indx += 1\n\n    trans_fpath = os.path.join(output_dir, \"transforms_train.json\")\n    import json\n\n    with open(trans_fpath, \"w\") as f:\n        json.dump(transform_dict, f)\n\n    transform_dict[\"frames\"] = transform_dict[\"frames\"][::4]\n    trans_fpath = os.path.join(output_dir, \"transforms_test.json\")\n\n    with open(trans_fpath, \"w\") as f:\n        json.dump(transform_dict, f)\n\n\nif __name__ == \"__main__\":\n    main()\n    # find_material()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/render_obj_external_texture.py",
    "content": "import bpy\nimport os\nimport numpy as np\nimport math\nimport sys\nimport struct\nimport collections\nfrom mathutils import Matrix, Quaternion, Vector\nfrom scipy.spatial.transform import Rotation\n\n\ndef focal2fov(focal, pixels):\n    return 2 * math.atan(pixels / (2 * focal))\n\n\ndef create_camera(location, rotation):\n    # Create a new camera\n    bpy.ops.object.camera_add(location=location, rotation=rotation)\n    return bpy.context.active_object\n\n\ndef set_camera_look_at(camera, target_point):\n    # Compute the direction vector from the camera to the target point\n    direction = target_point - camera.location\n    # Compute the rotation matrix to align the camera's -Z axis to this direction\n    rot_quat = direction.to_track_quat(\"-Z\", \"Y\")\n    camera.rotation_euler = rot_quat.to_euler()\n\n    return rot_quat\n\n\ndef setup_alpha_mask(obj_name, pass_index=1):\n    # Set the object's pass index\n    obj = bpy.data.objects[obj_name]\n    obj.pass_index = pass_index\n\n    # Enable the Object Index pass for the active render layer\n    bpy.context.view_layer.use_pass_object_index = True\n\n    # Enable 'Use Nodes':\n    bpy.context.scene.use_nodes = True\n    tree = bpy.context.scene.node_tree\n\n    # Clear default nodes\n    for node in tree.nodes:\n        tree.nodes.remove(node)\n\n    # Add Render Layers node\n    render_layers = tree.nodes.new(\"CompositorNodeRLayers\")\n\n    # Add Composite node (output)\n    composite = tree.nodes.new(\"CompositorNodeComposite\")\n\n    # Add ID Mask node\n    id_mask = tree.nodes.new(\"CompositorNodeIDMask\")\n    id_mask.index = pass_index\n\n    # Add Set Alpha node\n    set_alpha = tree.nodes.new(\"CompositorNodeSetAlpha\")\n\n    # Connect nodes\n    tree.links.new(render_layers.outputs[\"Image\"], set_alpha.inputs[\"Image\"])\n    tree.links.new(render_layers.outputs[\"IndexOB\"], id_mask.inputs[0])\n    tree.links.new(id_mask.outputs[0], set_alpha.inputs[\"Alpha\"])\n    tree.links.new(set_alpha.outputs[\"Image\"], composite.inputs[\"Image\"])\n\n\ndef render_scene(camera, output_path, mask_name=\"U3DMesh\"):\n    bpy.context.scene.render.film_transparent = True\n\n    setup_alpha_mask(mask_name, 1)\n    # Set the active camera\n    bpy.context.scene.render.image_settings.color_mode = \"RGBA\"\n\n    bpy.context.scene.camera = camera\n\n    # Set the output path for the render\n    bpy.context.scene.render.filepath = output_path\n\n    # Render the scene\n    bpy.ops.render.render(write_still=True)\n\n\ndef setup_light():\n    # Add first directional light (Sun lamp)\n    light_data_1 = bpy.data.lights.new(name=\"Directional_Light_1\", type=\"SUN\")\n    light_data_1.energy = 3  # Adjust energy as needed\n    light_1 = bpy.data.objects.new(name=\"Directional_Light_1\", object_data=light_data_1)\n    bpy.context.collection.objects.link(light_1)\n    light_1.location = (20, 20, 20)  # Adjust location as needed\n    light_1.rotation_euler = (\n        np.radians(45),\n        np.radians(0),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n    # Add second directional light (Sun lamp)\n    light_data_2 = bpy.data.lights.new(name=\"Directional_Light_2\", type=\"SUN\")\n    light_data_2.energy = 5  # Adjust energy as needed\n    light_2 = bpy.data.objects.new(name=\"Directional_Light_2\", object_data=light_data_2)\n    bpy.context.collection.objects.link(light_2)\n    light_2.location = (20, -20, 20)  # Adjust location as needed\n    light_2.rotation_euler = (\n        np.radians(45),\n        np.radians(180),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n    # Add second directional light (Sun lamp)\n    light_data_3 = bpy.data.lights.new(name=\"Directional_Light_3\", type=\"SUN\")\n    light_data_3.energy = 3  # Adjust energy as needed\n    light_3 = bpy.data.objects.new(name=\"Directional_Light_3\", object_data=light_data_2)\n    bpy.context.collection.objects.link(light_3)\n    light_3.location = (-20, 20, 20)  # Adjust location as needed\n    light_3.rotation_euler = (\n        np.radians(-135),\n        np.radians(0),\n        np.radians(45),\n    )  # Adjust rotation for direction\n\n\ndef create_mesh_from_obj(obj_file_path):\n    # Clear existing mesh objects in the scene\n    bpy.ops.object.select_all(action=\"DESELECT\")\n    bpy.ops.object.select_by_type(type=\"MESH\")\n    bpy.ops.object.delete()\n\n    bpy.ops.import_scene.obj(filepath=obj_file_path)\n\n    # Assuming the imported object is the active object\n    obj = bpy.context.active_object\n    num_obj = 0\n    for obj in bpy.context.selected_objects:\n        print(\"obj mesh name: \", obj.name, obj.type)\n        num_obj += 1\n    if num_obj > 2:\n        raise ValueError(\"More than one object in the scene.\")\n    mesh_objects = [obj for obj in bpy.context.selected_objects if obj.type == \"MESH\"]\n\n    return obj.name, mesh_objects\n\n\ndef get_focal_length(camera_obj, render_resolution):\n    \"\"\"\n    Extract the intrinsic matrix from a Blender camera.\n\n    Args:\n    - camera_obj: The Blender camera object.\n    - render_resolution: Tuple of (width, height) indicating the render resolution.\n\n    Returns:\n    - colmap_camera: dict of [\"id\", \"model\", \"width\", \"height\", \"params\"]\n    \"\"\"\n\n    # Get the camera data\n    cam = camera_obj.data\n\n    # Ensure it's a perspective camera\n    if cam.type != \"PERSP\":\n        raise ValueError(\"Only 'PERSP' camera type is supported.\")\n\n    # Image resolution\n    width, height = render_resolution\n\n    # Sensor width and height in millimeters\n    sensor_width_mm = cam.sensor_width\n    sensor_height_mm = cam.sensor_height\n\n    # Calculate the focal length in pixels\n    fx = (cam.lens / sensor_width_mm) * width\n    fy = (cam.lens / sensor_height_mm) * height\n\n    return fx, fy\n\n\ndef normalize_mesh(transform_meta_path, mesh_objects):\n    import json\n\n    if os.path.exists(transform_meta_path):\n        with open(transform_meta_path, \"r\") as f:\n            meta_dict = json.load(f)\n    # obj = bpy.context.active_object\n\n    for obj in mesh_objects:\n        # Ensure the object is in object mode\n        # bpy.ops.object.mode_set(mode=\"OBJECT\")\n\n        scale_ = 1.0 / meta_dict[\"scale\"]\n        center = Vector(meta_dict[\"center\"])\n        # Apply the scale\n        print(\"old scale: \", obj.scale)\n        # obj.location -= center\n        obj.scale *= scale_\n\n\ndef apply_rotation(mesh_objects):\n    for obj in mesh_objects:\n        R_np = [[1.0, 0, 0], [0, 0, 1.0], [0, 1.0, 0]]\n        R_blender = Matrix(R_np).transposed()\n\n        # Convert the rotation matrix to a quaternion\n        quaternion = R_blender.to_quaternion()\n\n        # Set the active object's rotation to this quaternion\n        print(\"rotation\", quaternion, obj.rotation_quaternion)\n        obj.rotation_quaternion = obj.rotation_quaternion @ quaternion\n\n\ndef get_textures(\n    texture_dir=\"/local/cg/rundi/data/motion_dataset/pirate-flag-animated/source/textures\",\n):\n    # Ensure the \"flag\" object is selected\n    # bpy.context.view_layer.objects.active = bpy.data.objects[\"flag\"]\n\n    obj = bpy.data.objects[\"flag.001_Plane.001\"]\n\n    # Create a new material or get the existing one\n    if not obj.data.materials:\n        mat = bpy.data.materials.new(name=\"FBX_Material\")\n        obj.data.materials.append(mat)\n    else:\n        mat = obj.data.materials[0]\n\n    # Use nodes for the material\n    mat.use_nodes = True\n    nodes = mat.node_tree.nodes\n\n    # Clear default nodes\n    for node in nodes:\n        nodes.remove(node)\n\n    # Add a Principled BSDF shader and connect it to the Material Output\n    shader = nodes.new(type=\"ShaderNodeBsdfPrincipled\")\n    shader.location = (0, 0)\n\n    output = nodes.new(type=\"ShaderNodeOutputMaterial\")\n    output.location = (400, 0)\n    mat.node_tree.links.new(shader.outputs[\"BSDF\"], output.inputs[\"Surface\"])\n\n    # Load textures and create the corresponding nodes\n\n    textures = {\n        \"Base Color\": os.path.join(texture_dir, \"pirate_flag_albedo.jpg\"),\n        \"Metallic\": os.path.join(texture_dir, \"pirate_flag_metallic.jpg\"),\n        \"Normal\": os.path.join(texture_dir, \"pirate_flag_normal.png\"),\n        \"Roughness\": os.path.join(texture_dir, \"pirate_flag_roughness.jpg\"),\n    }\n    # ... [rest of the script]\n\n    ao_texture = nodes.new(type=\"ShaderNodeTexImage\")\n    ao_texture.location = (-400, -200)\n    ao_texture.image = bpy.data.images.load(\n        filepath=os.path.join(texture_dir, \"pirate_flag_AO.jpg\")\n    )  # Adjust filepath if needed\n\n    mix_rgb = nodes.new(type=\"ShaderNodeMixRGB\")\n    mix_rgb.location = (-200, 0)\n    mix_rgb.blend_type = \"MULTIPLY\"\n    mix_rgb.inputs[\n        0\n    ].default_value = 1.0  # Factor to 1 to fully use the multiply operation\n\n    mat.node_tree.links.new(ao_texture.outputs[\"Color\"], mix_rgb.inputs[2])\n\n    for i, (input_name, filename) in enumerate(textures.items()):\n        tex_image = nodes.new(type=\"ShaderNodeTexImage\")\n        tex_image.location = (-400, i * 200)\n        tex_image.image = bpy.data.images.load(\n            filepath=filename\n        )  # Adjust filepath if needed\n\n        if input_name == \"Base Color\":\n            mat.node_tree.links.new(tex_image.outputs[\"Color\"], mix_rgb.inputs[1])\n            mat.node_tree.links.new(mix_rgb.outputs[\"Color\"], shader.inputs[input_name])\n        elif input_name == \"Normal\":\n            normal_map_node = nodes.new(type=\"ShaderNodeNormalMap\")\n            normal_map_node.location = (-200, i * 200)\n            mat.node_tree.links.new(\n                tex_image.outputs[\"Color\"], normal_map_node.inputs[\"Color\"]\n            )\n            mat.node_tree.links.new(\n                normal_map_node.outputs[\"Normal\"], shader.inputs[\"Normal\"]\n            )\n        else:\n            mat.node_tree.links.new(\n                tex_image.outputs[\"Color\"], shader.inputs[input_name]\n            )\n\n\ndef main():\n    argv = sys.argv\n    argv = argv[argv.index(\"--\") + 1 :]  # get all args after \"--\"\n    print(argv)\n    inp_fpx_path = argv[0]  # input mesh path\n    output_dir = argv[1]  # output dir\n    num_views = int(argv[2])\n    radius = 3\n\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n\n    img_output_dir = os.path.join(output_dir, \"images\")\n    if not os.path.exists(img_output_dir):\n        os.makedirs(img_output_dir)\n\n    transform_meta_path = os.path.join(os.path.dirname(inp_fpx_path), \"meta.json\")\n\n    # Create the 3D mesh in Blender from your data. no normalize\n    my_mesh_name, mesh_objects = create_mesh_from_obj(inp_fpx_path)\n\n    normalize_mesh(transform_meta_path, mesh_objects)\n    # apply_rotation(mesh_objects)\n\n    get_textures()\n\n    object_center = Vector((0.0, 0.0, 0.5))\n\n    print(\"look at object center: \", object_center)\n\n    setup_light()\n    # Set up rendering parameters\n    bpy.context.scene.render.image_settings.file_format = \"PNG\"\n    bpy.context.scene.render.resolution_x = 1080\n    bpy.context.scene.render.resolution_y = 720\n\n    # bpy.context.scene.render.resolution_x = 720\n    # bpy.context.scene.render.resolution_y = 480\n\n    camera = create_camera((1, 1, 1), (0, 0, 0))\n    fx, fy = get_focal_length(\n        camera,\n        (bpy.context.scene.render.resolution_x, bpy.context.scene.render.resolution_y),\n    )\n\n    transform_dict = {\n        \"frames\": [],\n        \"camera_angle_x\": focal2fov(fx, bpy.context.scene.render.resolution_x),\n    }\n    img_indx = 0\n    num_elevations = 6\n    for j in range(num_elevations):\n        num_imgs = num_views // num_elevations\n        for i in range(num_imgs):\n            angle = 2 * math.pi * i / num_imgs + math.pi / 6.0\n            x = object_center.x + radius * math.cos(angle)\n            y = object_center.y + radius * math.sin(angle)\n            z = (\n                object_center.z + (j - num_elevations / 2.0) * radius / num_elevations\n            )  # Adjust this if you want the camera to be above or below the object's center\n\n            camera = create_camera((x, y, z), (0, 0, 0))\n            rot_quant = set_camera_look_at(camera, object_center)\n            bpy.context.view_layer.update()\n\n            # also prepare transforms.json\n            fname = f\"images/img_{img_indx}\"\n            cam2world = np.array(camera.matrix_world)\n            transform_dict[\"frames\"].append(\n                {\"file_path\": fname, \"transform_matrix\": cam2world.tolist()}\n            )\n\n            render_scene(\n                camera,\n                os.path.join(img_output_dir, f\"img_{img_indx}.png\"),\n                my_mesh_name,\n            )\n            img_indx += 1\n\n    trans_fpath = os.path.join(output_dir, \"transforms_train.json\")\n    import json\n\n    with open(trans_fpath, \"w\") as f:\n        json.dump(transform_dict, f)\n\n    transform_dict[\"frames\"] = transform_dict[\"frames\"][::4]\n    trans_fpath = os.path.join(output_dir, \"transforms_test.json\")\n\n    with open(trans_fpath, \"w\") as f:\n        json.dump(transform_dict, f)\n\n\nif __name__ == \"__main__\":\n    main()\n    # find_material()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/test_colmap_camera.py",
    "content": "import numpy as np\nimport os\nimport sys\nimport argparse\nimport collections\nimport struct\nfrom typing import NamedTuple\nimport math\nimport cv2\n\nCameraModel = collections.namedtuple(\n    \"CameraModel\", [\"model_id\", \"model_name\", \"num_params\"]\n)\nCamera = collections.namedtuple(\"Camera\", [\"id\", \"model\", \"width\", \"height\", \"params\"])\nBaseImage = collections.namedtuple(\n    \"Image\", [\"id\", \"qvec\", \"tvec\", \"camera_id\", \"name\", \"xys\", \"point3D_ids\"]\n)\nPoint3D = collections.namedtuple(\n    \"Point3D\", [\"id\", \"xyz\", \"rgb\", \"error\", \"image_ids\", \"point2D_idxs\"]\n)\nCAMERA_MODELS = {\n    CameraModel(model_id=0, model_name=\"SIMPLE_PINHOLE\", num_params=3),\n    CameraModel(model_id=1, model_name=\"PINHOLE\", num_params=4),\n    CameraModel(model_id=2, model_name=\"SIMPLE_RADIAL\", num_params=4),\n    CameraModel(model_id=3, model_name=\"RADIAL\", num_params=5),\n    CameraModel(model_id=4, model_name=\"OPENCV\", num_params=8),\n    CameraModel(model_id=5, model_name=\"OPENCV_FISHEYE\", num_params=8),\n    CameraModel(model_id=6, model_name=\"FULL_OPENCV\", num_params=12),\n    CameraModel(model_id=7, model_name=\"FOV\", num_params=5),\n    CameraModel(model_id=8, model_name=\"SIMPLE_RADIAL_FISHEYE\", num_params=4),\n    CameraModel(model_id=9, model_name=\"RADIAL_FISHEYE\", num_params=5),\n    CameraModel(model_id=10, model_name=\"THIN_PRISM_FISHEYE\", num_params=12),\n}\nCAMERA_MODEL_IDS = dict(\n    [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]\n)\nCAMERA_MODEL_NAMES = dict(\n    [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]\n)\n\n\ndef qvec2rotmat(qvec):\n    return np.array(\n        [\n            [\n                1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,\n                2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],\n                2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],\n            ],\n            [\n                2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],\n                1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,\n                2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],\n            ],\n            [\n                2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],\n                2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],\n                1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,\n            ],\n        ]\n    )\n\n\ndef rotmat2qvec(R):\n    Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat\n    K = (\n        np.array(\n            [\n                [Rxx - Ryy - Rzz, 0, 0, 0],\n                [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],\n                [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],\n                [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz],\n            ]\n        )\n        / 3.0\n    )\n    eigvals, eigvecs = np.linalg.eigh(K)\n    qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]\n    if qvec[0] < 0:\n        qvec *= -1\n    return qvec\n\n\ndef fov2focal(fov, pixels):\n    return pixels / (2 * math.tan(fov / 2))\n\n\nclass Image(BaseImage):\n    def qvec2rotmat(self):\n        return qvec2rotmat(self.qvec)\n\n\ndef read_next_bytes(fid, num_bytes, format_char_sequence, endian_character=\"<\"):\n    \"\"\"Read and unpack the next bytes from a binary file.\n    :param fid:\n    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.\n    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.\n    :param endian_character: Any of {@, =, <, >, !}\n    :return: Tuple of read and unpacked values.\n    \"\"\"\n    data = fid.read(num_bytes)\n    return struct.unpack(endian_character + format_char_sequence, data)\n\n\ndef read_points3D_text(path):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DText(const std::string& path)\n        void Reconstruction::WritePoints3DText(const std::string& path)\n    \"\"\"\n    xyzs = None\n    rgbs = None\n    errors = None\n    num_points = 0\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                num_points += 1\n\n    xyzs = np.empty((num_points, 3))\n    rgbs = np.empty((num_points, 3))\n    errors = np.empty((num_points, 1))\n    count = 0\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                xyz = np.array(tuple(map(float, elems[1:4])))\n                rgb = np.array(tuple(map(int, elems[4:7])))\n                error = np.array(float(elems[7]))\n                xyzs[count] = xyz\n                rgbs[count] = rgb\n                errors[count] = error\n                count += 1\n\n    return xyzs, rgbs, errors\n\n\ndef read_points3D_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DBinary(const std::string& path)\n        void Reconstruction::WritePoints3DBinary(const std::string& path)\n    \"\"\"\n\n    with open(path_to_model_file, \"rb\") as fid:\n        num_points = read_next_bytes(fid, 8, \"Q\")[0]\n\n        xyzs = np.empty((num_points, 3))\n        rgbs = np.empty((num_points, 3))\n        errors = np.empty((num_points, 1))\n\n        for p_id in range(num_points):\n            binary_point_line_properties = read_next_bytes(\n                fid, num_bytes=43, format_char_sequence=\"QdddBBBd\"\n            )\n            xyz = np.array(binary_point_line_properties[1:4])\n            rgb = np.array(binary_point_line_properties[4:7])\n            error = np.array(binary_point_line_properties[7])\n            track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence=\"Q\")[\n                0\n            ]\n            track_elems = read_next_bytes(\n                fid,\n                num_bytes=8 * track_length,\n                format_char_sequence=\"ii\" * track_length,\n            )\n            xyzs[p_id] = xyz\n            rgbs[p_id] = rgb\n            errors[p_id] = error\n    return xyzs, rgbs, errors\n\n\ndef read_intrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    cameras = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                camera_id = int(elems[0])\n                model = elems[1]\n                assert (\n                    model == \"PINHOLE\"\n                ), \"While the loader support other types, the rest of the code assumes PINHOLE\"\n                width = int(elems[2])\n                height = int(elems[3])\n                params = np.array(tuple(map(float, elems[4:])))\n                cameras[camera_id] = Camera(\n                    id=camera_id, model=model, width=width, height=height, params=params\n                )\n    return cameras\n\n\ndef read_extrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadImagesBinary(const std::string& path)\n        void Reconstruction::WriteImagesBinary(const std::string& path)\n    \"\"\"\n    images = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_reg_images = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_reg_images):\n            binary_image_properties = read_next_bytes(\n                fid, num_bytes=64, format_char_sequence=\"idddddddi\"\n            )\n            image_id = binary_image_properties[0]\n            qvec = np.array(binary_image_properties[1:5])\n            tvec = np.array(binary_image_properties[5:8])\n            camera_id = binary_image_properties[8]\n            image_name = \"\"\n            current_char = read_next_bytes(fid, 1, \"c\")[0]\n            while current_char != b\"\\x00\":  # look for the ASCII 0 entry\n                image_name += current_char.decode(\"utf-8\")\n                current_char = read_next_bytes(fid, 1, \"c\")[0]\n            num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence=\"Q\")[\n                0\n            ]\n            x_y_id_s = read_next_bytes(\n                fid,\n                num_bytes=24 * num_points2D,\n                format_char_sequence=\"ddq\" * num_points2D,\n            )\n            xys = np.column_stack(\n                [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))]\n            )\n            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))\n            images[image_id] = Image(\n                id=image_id,\n                qvec=qvec,\n                tvec=tvec,\n                camera_id=camera_id,\n                name=image_name,\n                xys=xys,\n                point3D_ids=point3D_ids,\n            )\n    return images\n\n\ndef read_intrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::WriteCamerasBinary(const std::string& path)\n        void Reconstruction::ReadCamerasBinary(const std::string& path)\n    \"\"\"\n    cameras = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_cameras = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_cameras):\n            camera_properties = read_next_bytes(\n                fid, num_bytes=24, format_char_sequence=\"iiQQ\"\n            )\n            camera_id = camera_properties[0]\n            model_id = camera_properties[1]\n            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name\n            width = camera_properties[2]\n            height = camera_properties[3]\n            num_params = CAMERA_MODEL_IDS[model_id].num_params\n            params = read_next_bytes(\n                fid, num_bytes=8 * num_params, format_char_sequence=\"d\" * num_params\n            )\n            cameras[camera_id] = Camera(\n                id=camera_id,\n                model=model_name,\n                width=width,\n                height=height,\n                params=np.array(params),\n            )\n        assert len(cameras) == num_cameras\n    return cameras\n\n\ndef read_extrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    images = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                image_id = int(elems[0])\n                qvec = np.array(tuple(map(float, elems[1:5])))\n                tvec = np.array(tuple(map(float, elems[5:8])))\n                camera_id = int(elems[8])\n                image_name = elems[9]\n                elems = fid.readline().split()\n                xys = np.column_stack(\n                    [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))]\n                )\n                point3D_ids = np.array(tuple(map(int, elems[2::3])))\n                images[image_id] = Image(\n                    id=image_id,\n                    qvec=qvec,\n                    tvec=tvec,\n                    camera_id=camera_id,\n                    name=image_name,\n                    xys=xys,\n                    point3D_ids=point3D_ids,\n                )\n    return images\n\n\ndef read_colmap_bin_array(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py\n\n    :param path: path to the colmap binary file.\n    :return: nd array with the floating point values in the value\n    \"\"\"\n    with open(path, \"rb\") as fid:\n        width, height, channels = np.genfromtxt(\n            fid, delimiter=\"&\", max_rows=1, usecols=(0, 1, 2), dtype=int\n        )\n        fid.seek(0)\n        num_delimiter = 0\n        byte = fid.read(1)\n        while True:\n            if byte == b\"&\":\n                num_delimiter += 1\n                if num_delimiter >= 3:\n                    break\n            byte = fid.read(1)\n        array = np.fromfile(fid, np.float32)\n    array = array.reshape((width, height, channels), order=\"F\")\n    return np.transpose(array, (1, 0, 2)).squeeze()\n\n\nclass CameraInfo(NamedTuple):\n    uid: int\n    R: np.array\n    T: np.array\n    FovY: np.array\n    FovX: np.array\n    image: np.array\n    image_path: str\n    image_name: str\n    width: int\n    height: int\n\n\ndef focal2fov(focal, pixels):\n    return 2 * math.atan(pixels / (2 * focal))\n\n\ndef readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):\n    cam_infos = []\n    for idx, key in enumerate(cam_extrinsics):\n        sys.stdout.write(\"\\r\")\n        # the exact output you're looking for:\n        sys.stdout.write(\"Reading camera {}/{}\".format(idx + 1, len(cam_extrinsics)))\n        sys.stdout.flush()\n\n        extr = cam_extrinsics[key]\n        intr = cam_intrinsics[extr.camera_id]\n        height = intr.height\n        width = intr.width\n\n        uid = intr.id\n        R = np.transpose(qvec2rotmat(extr.qvec))\n        T = np.array(extr.tvec)\n\n        if intr.model == \"SIMPLE_PINHOLE\":\n            focal_length_x = intr.params[0]\n            FovY = focal2fov(focal_length_x, height)\n            FovX = focal2fov(focal_length_x, width)\n        elif intr.model == \"PINHOLE\":\n            focal_length_x = intr.params[0]\n            focal_length_y = intr.params[1]\n            FovY = focal2fov(focal_length_y, height)\n            FovX = focal2fov(focal_length_x, width)\n        else:\n            assert (\n                False\n            ), \"Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!\"\n\n        image_path = os.path.join(images_folder, os.path.basename(extr.name))\n        image_name = os.path.basename(image_path).split(\".\")[0]\n        # image = Image.open(image_path)\n\n        cam_info = CameraInfo(\n            uid=uid,\n            R=R,\n            T=T,\n            FovY=FovY,\n            FovX=FovX,\n            image=None,\n            image_path=image_path,\n            image_name=image_name,\n            width=width,\n            height=height,\n        )\n        cam_infos.append(cam_info)\n    sys.stdout.write(\"\\n\")\n    return cam_infos\n\n\ndef read_camera_points(dir_path):\n    cameras_extrinsic_file = os.path.join(dir_path, \"sparse/0\", \"images.bin\")\n    cameras_intrinsic_file = os.path.join(dir_path, \"sparse/0\", \"cameras.bin\")\n    bin_path = os.path.join(dir_path, \"sparse/0\", \"points3D.bin\")\n    cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)\n    cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)\n\n    reading_dir = \"images\"\n    cam_infos_unsorted = readColmapCameras(\n        cam_extrinsics=cam_extrinsics,\n        cam_intrinsics=cam_intrinsics,\n        images_folder=os.path.join(dir_path, reading_dir),\n    )\n\n    xyz, rgb, _ = read_points3D_binary(bin_path)\n\n    return cam_infos_unsorted, xyz, rgb\n\n\ndef extract_projection_matrix(cam_info):\n    \"\"\"\n    Args:\n        cam_info: CameraInfo\n    Returns:\n        P: [3, 4]\n    \"\"\"\n    # change intrinsic to projection matrix\n\n    fovx, fovy = cam_info.FovX, cam_info.FovY\n    R, T = np.transpose(cam_info.R), cam_info.T\n\n    # R = np.transpose(R)\n\n    fx, fy = fov2focal(fovx, cam_info.width), fov2focal(fovy, cam_info.height)\n\n    K = np.array([[fx, 0, cam_info.width / 2], [0, fy, cam_info.height / 2], [0, 0, 1]])\n    # K[:, 1:3] *= -1.0\n\n    P = K @ np.hstack((R, T.reshape(3, 1)))\n\n    # P[1:3, :] *= -1.0\n\n    return P\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"None description\")\n\n    parser.add_argument(\"--input\", type=str, help=\"input dir\")\n\n    parser.add_argument(\"--output\", type=str, help=\"output dir\")\n\n    args = parser.parse_args()\n\n    if not os.path.exists(args.output):\n        os.makedirs(args.output)\n\n    cam_infos_unsorted, xyz, rgb = read_camera_points(args.input)\n\n    # xyz[:, 1:] *= -1.0\n    # xyz *= -1.0\n    print(xyz.shape)\n\n    # sort cam_infos_unsorted by uid\n    # cam_infos = sorted(cam_infos_unsorted, key=lambda x: x.uid)\n    cam_infos = cam_infos_unsorted\n\n    for i in range(10):\n        cam_info = cam_infos[i]\n        print(\"name\", cam_info.image_name)\n\n        projeciton_matrix = extract_projection_matrix(cam_info)\n\n        img = np.zeros((cam_info.height, cam_info.width, 3), dtype=np.uint8)\n\n        points2d = np.matmul(\n            projeciton_matrix[np.newaxis, :, :],\n            np.hstack((xyz, np.ones((xyz.shape[0], 1))))[:, :, np.newaxis],\n        )\n        points2d = points2d[:, :2] / points2d[:, 2:]\n        points2d = np.round(points2d).astype(np.int32).squeeze(axis=-1)\n\n        # filter out points that are out of image\n        valid_mask = (\n            (points2d[:, 0] >= 0)\n            & (points2d[:, 0] < cam_info.width)\n            & (points2d[:, 1] >= 0)\n            & (points2d[:, 1] < cam_info.height)\n        )\n\n        points2d = points2d[valid_mask]\n        valid_rgb = rgb[valid_mask]\n\n        # img[points2d[:, 1], points2d[:, 0]] = valid_rgb\n        # draw circles\n        print(\"num valid points: \", points2d.shape[0])\n        for j in range(points2d.shape[0]):\n            cv2.circle(\n                img,\n                (points2d[j, 0], points2d[j, 1]),\n                3,\n                tuple(valid_rgb[j].astype(np.int32).tolist()),\n                -1,\n            )\n\n        out_img_path = os.path.join(args.output, f\"img{i}.png\")\n\n        cv2.imwrite(out_img_path, img)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/datatools/transform_obj_for_blender.py",
    "content": "import point_cloud_utils as pcu\nimport argparse\nimport os\nimport json\nimport numpy as np\n\n\ndef transform_vertex(vertex: np.ndarray, transform_dict):\n    \"\"\"\n    Args:\n        vertex: shape [n, 3]\n    \"\"\"\n    if transform_dict is not None:\n        center = np.array(transform_dict[\"center\"])\n        scale = transform_dict[\"scale\"]\n\n    else:\n        center = np.mean(vertex, axis=0)\n        scale = np.max(np.abs(vertex - center))\n\n    new_vertex = (vertex - center) / scale\n\n    return new_vertex, center, scale\n\n\ndef colmap_to_blender_transform(vertex: np.ndarray):\n    R_mat = np.array(\n        [[1.0, 0, 0], [0, 0, 1.0], [0, 1.0, 0]],\n    )\n    vertex = R_mat[np.newaxis, :, :] @ vertex[:, :, np.newaxis]\n\n    return vertex.squeeze(axis=-1)\n\n\ndef copy_mtl_file(obj_path, transformed_obj_path):\n    mtl_path = obj_path.replace(\".obj\", \".mtl\")\n\n    dummy_mtl_path = transformed_obj_path + \".mtl\"\n    if os.path.exists(dummy_mtl_path):\n        os.remove(dummy_mtl_path)\n\n    if os.path.exists(mtl_path):\n        os.system(\"cp {} {}\".format(mtl_path, dummy_mtl_path))\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--obj_path\", type=str, required=True)\n\n    parser.add_argument(\"--save_transform\", action=\"store_true\", default=False)\n\n    args = parser.parse_args()\n\n    dir_name = os.path.dirname(args.obj_path)\n    _name = os.path.basename(dir_name)\n\n    dir_name_father = os.path.dirname(dir_name)\n\n    transformed_dir = os.path.join(dir_name_father, \"transformed_{}\".format(_name))\n    if not os.path.exists(transformed_dir):\n        os.makedirs(transformed_dir)\n\n    transformed_obj_path = os.path.join(\n        transformed_dir, os.path.basename(args.obj_path)\n    )\n\n    if os.path.exists(transformed_obj_path):\n        print(\"Transformed object already exists.\")\n        # return\n\n    meta_path = os.path.join(dir_name, \"meta.json\")\n    if os.path.exists(meta_path):\n        with open(meta_path, \"r\") as f:\n            meta_dict = json.load(f)\n    else:\n        print(\"transforming without meta.json\")\n        meta_dict = None\n\n    mesh = pcu.load_triangle_mesh(args.obj_path)\n    vertex = mesh.v\n    vertex, center, scale = transform_vertex(vertex, meta_dict)\n    vertex = colmap_to_blender_transform(vertex)\n\n    mesh.vertex_data.positions = vertex\n\n    mesh.save(transformed_obj_path)\n\n    copy_mtl_file(args.obj_path, transformed_obj_path)\n\n    if args.save_transform:\n        transform_dict = {\"center\": center.tolist(), \"scale\": scale}\n        with open(os.path.join(dir_name, \"meta.json\"), \"w\") as f:\n            json.dump(transform_dict, f)\n\n        print(\"Saved transform dict to {}\".format(os.path.join(dir_name, \"meta.json\")))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/builder.py",
    "content": "from . import gaussian_diffusion as gd\nfrom .respace import SpacedDiffusion, space_timesteps\n\n\ndef create_gaussian_diffusion(\n    *,\n    steps=1000,\n    learn_sigma=False,\n    sigma_small=False,\n    noise_schedule=\"linear\",\n    use_kl=False,\n    predict_xstart=False,\n    rescale_timesteps=False,\n    rescale_learned_sigmas=False,\n    timestep_respacing=\"\",\n):\n    betas = gd.get_named_beta_schedule(noise_schedule, steps)\n    if use_kl:\n        loss_type = gd.LossType.RESCALED_KL\n    elif rescale_learned_sigmas:\n        loss_type = gd.LossType.RESCALED_MSE\n    else:\n        loss_type = gd.LossType.MSE\n    if not timestep_respacing:\n        timestep_respacing = [steps]\n    return SpacedDiffusion(\n        use_timesteps=space_timesteps(steps, timestep_respacing),\n        betas=betas,\n        model_mean_type=(\n            gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X\n        ),\n        model_var_type=(\n            (\n                gd.ModelVarType.FIXED_LARGE  # used this. What is the difference?\n                if not sigma_small\n                else gd.ModelVarType.FIXED_SMALL\n            )\n            if not learn_sigma\n            else gd.ModelVarType.LEARNED_RANGE\n        ),\n        loss_type=loss_type,\n        rescale_timesteps=rescale_timesteps,\n    )\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/discretizer.py",
    "content": "import torch\n\nfrom sgm.modules.diffusionmodules.discretizer import Discretization\n\n\nclass EDMResShiftedDiscretization(Discretization):\n    def __init__(\n        self, sigma_min=0.002, sigma_max=80.0, rho=7.0, scale_shift=1.0\n    ):\n        self.sigma_min = sigma_min\n        self.sigma_max = sigma_max\n        self.rho = rho\n        self.scale_shift = scale_shift\n\n    def get_sigmas(self, n, device=\"cpu\"):\n        ramp = torch.linspace(0, 1, n, device=device)\n        min_inv_rho = self.sigma_min ** (1 / self.rho)\n        max_inv_rho = self.sigma_max ** (1 / self.rho)\n        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho\n        sigmas = sigmas * self.scale_shift\n        return sigmas\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/draft.py",
    "content": "\n\nimport numpy as np\n\ndef latent_sds(input_x, schduler, unet, t_range=[0.02, 0.98]):\n\n    # t_range_annel: [0.02, 0.98] => [0.50, 0.98]\n    # input_x: # [T, 4, H, W] \n\n\n\n    sigma = schduler.sample_sigma(t_range) # scalar\n\n    noise = randn_like(input_x)\n\n    noised_latent = input_x + sigma * noise\n\n    c, uc = None \n    # x0 prediction. \n    denoised_latent_c, denoised_latent_uc = unet(noised_latent, c, uc)\n\n    w = [1.0, 2.0, 3.0]\n    denoised_latent = denoised_latent_uc + w * (denoised_latent_c - denoised_latent_uc)\n\n    sds_grad = (input_x - denoised_latent) / sigma\n\n    loss_sds = MSE(input_x - (input_x - sds_grad).detach())\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/gaussian_diffusion.py",
    "content": "\"\"\"\nThis code started out as a PyTorch port of Ho et al's diffusion models:\nhttps://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py\n\nDocstrings have been added, as well as DDIM sampling and a new collection of beta schedules.\n\"\"\"\n\nimport enum\nimport math\n\nimport numpy as np\nimport torch as th\n\nfrom .losses import normal_kl, discretized_gaussian_log_likelihood\n\n# from utils.triplane_util import decompose_featmaps\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef get_named_beta_schedule(schedule_name, num_diffusion_timesteps):\n    \"\"\"\n    Get a pre-defined beta schedule for the given name.\n\n    The beta schedule library consists of beta schedules which remain similar\n    in the limit of num_diffusion_timesteps.\n    Beta schedules may be added, but should not be removed or changed once\n    they are committed to maintain backwards compatibility.\n    \"\"\"\n    if schedule_name == \"linear\":\n        # Linear schedule from Ho et al, extended to work for any number of\n        # diffusion steps.\n        scale = 1000 / num_diffusion_timesteps\n        beta_start = scale * 0.0001\n        beta_end = scale * 0.02\n        return np.linspace(\n            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64\n        )\n    elif schedule_name == \"cosine\":\n        return betas_for_alpha_bar(\n            num_diffusion_timesteps,\n            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,\n        )\n    else:\n        raise NotImplementedError(f\"unknown beta schedule: {schedule_name}\")\n\n\ndef betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):\n    \"\"\"\n    Create a beta schedule that discretizes the given alpha_t_bar function,\n    which defines the cumulative product of (1-beta) over time from t = [0,1].\n\n    :param num_diffusion_timesteps: the number of betas to produce.\n    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and\n                      produces the cumulative product of (1-beta) up to that\n                      part of the diffusion process.\n    :param max_beta: the maximum beta to use; use values lower than 1 to\n                     prevent singularities.\n    \"\"\"\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))\n    return np.array(betas)\n\n\nclass ModelMeanType(enum.Enum):\n    \"\"\"\n    Which type of output the model predicts.\n    \"\"\"\n\n    PREVIOUS_X = enum.auto()  # the model predicts x_{t-1}\n    START_X = enum.auto()  # the model predicts x_0\n    EPSILON = enum.auto()  # the model predicts epsilon\n\n\nclass ModelVarType(enum.Enum):\n    \"\"\"\n    What is used as the model's output variance.\n\n    The LEARNED_RANGE option has been added to allow the model to predict\n    values between FIXED_SMALL and FIXED_LARGE, making its job easier.\n    \"\"\"\n\n    LEARNED = enum.auto()\n    FIXED_SMALL = enum.auto()\n    FIXED_LARGE = enum.auto()\n    LEARNED_RANGE = enum.auto()\n\n\nclass LossType(enum.Enum):\n    MSE = enum.auto()  # use raw MSE loss (and KL when learning variances)\n    RESCALED_MSE = (\n        enum.auto()\n    )  # use raw MSE loss (with RESCALED_KL when learning variances)\n    KL = enum.auto()  # use the variational lower-bound\n    RESCALED_KL = enum.auto()  # like KL, but rescale to estimate the full VLB\n\n    def is_vb(self):\n        return self == LossType.KL or self == LossType.RESCALED_KL\n\n\nclass GaussianDiffusion:\n    \"\"\"\n    Utilities for training and sampling diffusion models.\n\n    Ported directly from here, and then adapted over time to further experimentation.\n    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42\n\n    :param betas: a 1-D numpy array of betas for each diffusion timestep,\n                  starting at T and going to 1.\n    :param model_mean_type: a ModelMeanType determining what the model outputs.\n    :param model_var_type: a ModelVarType determining how variance is output.\n    :param loss_type: a LossType determining the loss function to use.\n    :param rescale_timesteps: if True, pass floating point timesteps into the\n                              model so that they are always scaled like in the\n                              original paper (0 to 1000).\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        betas,\n        model_mean_type,\n        model_var_type,\n        loss_type,\n        rescale_timesteps=False,\n    ):\n        self.model_mean_type = model_mean_type\n        self.model_var_type = model_var_type\n        self.loss_type = loss_type\n        self.rescale_timesteps = rescale_timesteps\n\n        # Use float64 for accuracy.\n        betas = np.array(betas, dtype=np.float64)\n        self.betas = betas\n        assert len(betas.shape) == 1, \"betas must be 1-D\"\n        assert (betas > 0).all() and (betas <= 1).all()\n\n        self.num_timesteps = int(betas.shape[0])\n\n        alphas = 1.0 - betas\n        self.alphas_cumprod = np.cumprod(alphas, axis=0)\n        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])\n        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)\n        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)\n        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)\n        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)\n        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)\n        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n        self.posterior_variance = (\n            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)\n        )\n        # log calculation clipped because the posterior variance is 0 at the\n        # beginning of the diffusion chain.\n        self.posterior_log_variance_clipped = np.log(\n            np.append(self.posterior_variance[1], self.posterior_variance[1:])\n        )\n        self.posterior_mean_coef1 = (\n            betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)\n        )\n        self.posterior_mean_coef2 = (\n            (1.0 - self.alphas_cumprod_prev)\n            * np.sqrt(alphas)\n            / (1.0 - self.alphas_cumprod)\n        )\n\n    def q_mean_variance(self, x_start, t):\n        \"\"\"\n        Get the distribution q(x_t | x_0).\n\n        :param x_start: the [N x C x ...] tensor of noiseless inputs.\n        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.\n        :return: A tuple (mean, variance, log_variance), all of x_start's shape.\n        \"\"\"\n        mean = (\n            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n        )\n        variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)\n        log_variance = _extract_into_tensor(\n            self.log_one_minus_alphas_cumprod, t, x_start.shape\n        )\n        return mean, variance, log_variance\n\n    def q_sample(self, x_start, t, noise=None):\n        \"\"\"\n        Diffuse the data for a given number of diffusion steps.\n\n        In other words, sample from q(x_t | x_0).\n\n        :param x_start: the initial data batch.\n        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.\n        :param noise: if specified, the split-out normal noise.\n        :return: A noisy version of x_start.\n        \"\"\"\n        if noise is None:\n            noise = th.randn_like(x_start)\n        assert noise.shape == x_start.shape\n        return (\n            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)\n            * noise\n        )\n\n    def q_posterior_mean_variance(self, x_start, x_t, t):\n        \"\"\"\n        Compute the mean and variance of the diffusion posterior:\n\n            q(x_{t-1} | x_t, x_0)\n\n        \"\"\"\n        assert x_start.shape == x_t.shape\n        posterior_mean = (\n            _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start\n            + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = _extract_into_tensor(\n            self.posterior_log_variance_clipped, t, x_t.shape\n        )\n        assert (\n            posterior_mean.shape[0]\n            == posterior_variance.shape[0]\n            == posterior_log_variance_clipped.shape[0]\n            == x_start.shape[0]\n        )\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n    def p_mean_variance(\n        self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None\n    ):\n        \"\"\"\n        Apply the model to get p(x_{t-1} | x_t), as well as a prediction of\n        the initial x, x_0.\n\n        :param model: the model, which takes a signal and a batch of timesteps\n                      as input.\n        :param x: the [N x C x ...] tensor at time t.\n        :param t: a 1-D Tensor of timesteps.\n        :param clip_denoised: if True, clip the denoised signal into [-1, 1].\n        :param denoised_fn: if not None, a function which applies to the\n            x_start prediction before it is used to sample. Applies before\n            clip_denoised.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :return: a dict with the following keys:\n                 - 'mean': the model mean output.\n                 - 'variance': the model variance output.\n                 - 'log_variance': the log of 'variance'.\n                 - 'pred_xstart': the prediction for x_0.\n        \"\"\"\n        if model_kwargs is None:\n            model_kwargs = {}\n\n        B, C = x.shape[:2]\n        assert t.shape == (B,)\n        model_output = model(x, self._scale_timesteps(t), **model_kwargs)\n\n        if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:\n            assert model_output.shape == (B, C * 2, *x.shape[2:])\n            model_output, model_var_values = th.split(model_output, C, dim=1)\n            if self.model_var_type == ModelVarType.LEARNED:\n                model_log_variance = model_var_values\n                model_variance = th.exp(model_log_variance)\n            else:\n                min_log = _extract_into_tensor(\n                    self.posterior_log_variance_clipped, t, x.shape\n                )\n                max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)\n                # The model_var_values is [-1, 1] for [min_var, max_var].\n                frac = (model_var_values + 1) / 2\n                model_log_variance = frac * max_log + (1 - frac) * min_log\n                model_variance = th.exp(model_log_variance)\n        else:\n            model_variance, model_log_variance = {\n                # for fixedlarge, we set the initial (log-)variance like so\n                # to get a better decoder log likelihood.\n                ModelVarType.FIXED_LARGE: (\n                    np.append(self.posterior_variance[1], self.betas[1:]),\n                    np.log(np.append(self.posterior_variance[1], self.betas[1:])),\n                ),\n                ModelVarType.FIXED_SMALL: (\n                    self.posterior_variance,\n                    self.posterior_log_variance_clipped,\n                ),\n            }[self.model_var_type]\n            model_variance = _extract_into_tensor(model_variance, t, x.shape)\n            model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)\n\n        def process_xstart(x):\n            if denoised_fn is not None:\n                x = denoised_fn(x)\n            if clip_denoised:\n                return x.clamp(-1, 1)\n            return x\n\n        if self.model_mean_type == ModelMeanType.PREVIOUS_X:\n            pred_xstart = process_xstart(\n                self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)\n            )\n            model_mean = model_output\n        elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:\n            if self.model_mean_type == ModelMeanType.START_X:\n                pred_xstart = process_xstart(model_output)\n            else:\n                pred_xstart = process_xstart(\n                    self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)\n                )\n            model_mean, _, _ = self.q_posterior_mean_variance(\n                x_start=pred_xstart, x_t=x, t=t\n            )\n        else:\n            raise NotImplementedError(self.model_mean_type)\n\n        assert (\n            model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape\n        )\n        return {\n            \"mean\": model_mean,\n            \"variance\": model_variance,\n            \"log_variance\": model_log_variance,\n            \"pred_xstart\": pred_xstart,\n        }\n\n    def _predict_xstart_from_eps(self, x_t, t, eps):\n        assert x_t.shape == eps.shape\n        return (\n            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t\n            - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps\n        )\n\n    def _predict_xstart_from_xprev(self, x_t, t, xprev):\n        assert x_t.shape == xprev.shape\n        return (  # (xprev - coef2*x_t) / coef1\n            _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev\n            - _extract_into_tensor(\n                self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape\n            )\n            * x_t\n        )\n\n    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):\n        return (\n            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t\n            - pred_xstart\n        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n\n    def _scale_timesteps(self, t):\n        if self.rescale_timesteps:\n            return t.float() * (1000.0 / self.num_timesteps)\n        return t\n\n    def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):\n        \"\"\"\n        Compute the mean for the previous step, given a function cond_fn that\n        computes the gradient of a conditional log probability with respect to\n        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to\n        condition on y.\n\n        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).\n        \"\"\"\n        gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)\n        new_mean = (\n            p_mean_var[\"mean\"].float() + p_mean_var[\"variance\"] * gradient.float()\n        )\n        return new_mean\n\n    def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):\n        \"\"\"\n        Compute what the p_mean_variance output would have been, should the\n        model's score function be conditioned by cond_fn.\n\n        See condition_mean() for details on cond_fn.\n\n        Unlike condition_mean(), this instead uses the conditioning strategy\n        from Song et al (2020).\n        \"\"\"\n        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)\n\n        eps = self._predict_eps_from_xstart(x, t, p_mean_var[\"pred_xstart\"])\n        eps = eps - (1 - alpha_bar).sqrt() * cond_fn(\n            x, self._scale_timesteps(t), **model_kwargs\n        )\n\n        out = p_mean_var.copy()\n        out[\"pred_xstart\"] = self._predict_xstart_from_eps(x, t, eps)\n        out[\"mean\"], _, _ = self.q_posterior_mean_variance(\n            x_start=out[\"pred_xstart\"], x_t=x, t=t\n        )\n        return out\n\n    def p_sample(\n        self,\n        model,\n        x,\n        t,\n        clip_denoised=True,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n    ):\n        \"\"\"\n        Sample x_{t-1} from the model at the given timestep.\n\n        :param model: the model to sample from.\n        :param x: the current tensor at x_{t-1}.\n        :param t: the value of t, starting at 0 for the first diffusion step.\n        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].\n        :param denoised_fn: if not None, a function which applies to the\n            x_start prediction before it is used to sample.\n        :param cond_fn: if not None, this is a gradient function that acts\n                        similarly to the model.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :return: a dict containing the following keys:\n                 - 'sample': a random sample from the model.\n                 - 'pred_xstart': a prediction of x_0.\n        \"\"\"\n        out = self.p_mean_variance(\n            model,\n            x,\n            t,\n            clip_denoised=clip_denoised,\n            denoised_fn=denoised_fn,\n            model_kwargs=model_kwargs,\n        )\n        noise = th.randn_like(x)\n        nonzero_mask = (\n            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))\n        )  # no noise when t == 0\n        if cond_fn is not None:\n            out[\"mean\"] = self.condition_mean(\n                cond_fn, out, x, t, model_kwargs=model_kwargs\n            )\n        sample = out[\"mean\"] + nonzero_mask * th.exp(0.5 * out[\"log_variance\"]) * noise\n        return {\"sample\": sample, \"pred_xstart\": out[\"pred_xstart\"]}\n\n    def p_sample_loop(\n        self,\n        model,\n        shape,\n        noise=None,\n        clip_denoised=True,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n        device=None,\n        progress=False,\n    ):\n        \"\"\"\n        Generate samples from the model.\n\n        :param model: the model module.\n        :param shape: the shape of the samples, (N, C, H, W).\n        :param noise: if specified, the noise from the encoder to sample.\n                      Should be of the same shape as `shape`.\n        :param clip_denoised: if True, clip x_start predictions to [-1, 1].\n        :param denoised_fn: if not None, a function which applies to the\n            x_start prediction before it is used to sample.\n        :param cond_fn: if not None, this is a gradient function that acts\n                        similarly to the model.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :param device: if specified, the device to create the samples on.\n                       If not specified, use a model parameter's device.\n        :param progress: if True, show a tqdm progress bar.\n        :return: a non-differentiable batch of samples.\n        \"\"\"\n        final = None\n        for sample in self.p_sample_loop_progressive(\n            model,\n            shape,\n            noise=noise,\n            clip_denoised=clip_denoised,\n            denoised_fn=denoised_fn,\n            cond_fn=cond_fn,\n            model_kwargs=model_kwargs,\n            device=device,\n            progress=progress,\n        ):\n            final = sample\n        return final[\"sample\"]\n\n    def p_sample_loop_progressive(\n        self,\n        model,\n        shape,\n        noise=None,\n        clip_denoised=True,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n        device=None,\n        progress=False,\n    ):\n        \"\"\"\n        Generate samples from the model and yield intermediate samples from\n        each timestep of diffusion.\n\n        Arguments are the same as p_sample_loop().\n        Returns a generator over dicts, where each dict is the return value of\n        p_sample().\n        \"\"\"\n        if device is None:\n            device = next(model.parameters()).device\n        assert isinstance(shape, (tuple, list))\n        if noise is not None:\n            img = noise\n        else:\n            img = th.randn(*shape, device=device)\n        indices = list(range(self.num_timesteps))[::-1]\n\n        if progress:\n            # Lazy import so that we don't depend on tqdm.\n            from tqdm.auto import tqdm\n\n            indices = tqdm(indices)\n\n        for i in indices:\n            t = th.tensor([i] * shape[0], device=device)\n            with th.no_grad():\n                out = self.p_sample(\n                    model,\n                    img,\n                    t,\n                    clip_denoised=clip_denoised,\n                    denoised_fn=denoised_fn,\n                    cond_fn=cond_fn,\n                    model_kwargs=model_kwargs,\n                )\n                yield out\n                img = out[\"sample\"]\n\n    def ddim_sample(\n        self,\n        model,\n        x,\n        t,\n        clip_denoised=True,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n        eta=0.0,\n        y0=None,\n        mask=None,\n        is_mask_t0=False,\n    ):\n        \"\"\"\n        Sample x_{t-1} from the model using DDIM.\n\n        Same usage as p_sample().\n        \"\"\"\n        out = self.p_mean_variance(\n            model,\n            x,\n            t,\n            clip_denoised=clip_denoised,\n            denoised_fn=denoised_fn,\n            model_kwargs=model_kwargs,\n        )\n        if cond_fn is not None:\n            out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)\n        # masked generation\n        if y0 is not None and mask is not None:\n            assert y0.shape == x.shape\n            assert mask.shape == x.shape\n            if is_mask_t0:\n                out[\"pred_xstart\"] = mask * y0 + (1 - mask) * out[\"pred_xstart\"]\n            else:\n                nonzero_mask = (\n                    (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))\n                )  # no noise when t == 0\n                out[\"pred_xstart\"] = (\n                    mask * y0 + (1 - mask) * out[\"pred_xstart\"]\n                ) * nonzero_mask + out[\"pred_xstart\"] * (1 - nonzero_mask)\n\n        # Usually our model outputs epsilon, but we re-derive it\n        # in case we used x_start or x_prev prediction.\n        eps = self._predict_eps_from_xstart(x, t, out[\"pred_xstart\"])\n\n        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)\n        alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)\n        sigma = (\n            eta\n            * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))\n            * th.sqrt(1 - alpha_bar / alpha_bar_prev)\n        )\n        # Equation 12.\n        noise = th.randn_like(x)\n        mean_pred = (\n            out[\"pred_xstart\"] * th.sqrt(alpha_bar_prev)\n            + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps\n        )\n        nonzero_mask = (\n            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))\n        )  # no noise when t == 0\n        sample = mean_pred + nonzero_mask * sigma * noise\n        return {\"sample\": sample, \"pred_xstart\": out[\"pred_xstart\"]}\n\n    def ddim_reverse_sample(\n        self,\n        model,\n        x,\n        t,\n        clip_denoised=True,\n        denoised_fn=None,\n        model_kwargs=None,\n        eta=0.0,\n    ):\n        \"\"\"\n        Sample x_{t+1} from the model using DDIM reverse ODE.\n        \"\"\"\n        assert eta == 0.0, \"Reverse ODE only for deterministic path\"\n        out = self.p_mean_variance(\n            model,\n            x,\n            t,\n            clip_denoised=clip_denoised,\n            denoised_fn=denoised_fn,\n            model_kwargs=model_kwargs,\n        )\n        # Usually our model outputs epsilon, but we re-derive it\n        # in case we used x_start or x_prev prediction.\n        eps = (\n            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x\n            - out[\"pred_xstart\"]\n        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)\n        alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)\n\n        # Equation 12. reversed\n        mean_pred = (\n            out[\"pred_xstart\"] * th.sqrt(alpha_bar_next)\n            + th.sqrt(1 - alpha_bar_next) * eps\n        )\n\n        return {\"sample\": mean_pred, \"pred_xstart\": out[\"pred_xstart\"]}\n\n    def ddim_sample_loop(\n        self,\n        model,\n        shape,\n        noise=None,\n        clip_denoised=True,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n        device=None,\n        progress=False,\n        eta=0.0,\n        y0=None,\n        mask=None,\n        is_mask_t0=False,\n    ):\n        \"\"\"\n        Generate samples from the model using DDIM.\n\n        Same usage as p_sample_loop().\n        \"\"\"\n        final = None\n        for sample in self.ddim_sample_loop_progressive(\n            model,\n            shape,\n            noise=noise,\n            clip_denoised=clip_denoised,\n            denoised_fn=denoised_fn,\n            cond_fn=cond_fn,\n            model_kwargs=model_kwargs,\n            device=device,\n            progress=progress,\n            eta=eta,\n            y0=y0,\n            mask=mask,\n            is_mask_t0=is_mask_t0,\n        ):\n            final = sample\n        return final[\"sample\"]\n\n    def ddim_sample_loop_progressive(\n        self,\n        model,\n        shape,\n        noise=None,\n        clip_denoised=True,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n        device=None,\n        progress=False,\n        eta=0.0,\n        y0=None,\n        mask=None,\n        is_mask_t0=False,\n    ):\n        \"\"\"\n        Use DDIM to sample from the model and yield intermediate samples from\n        each timestep of DDIM.\n\n        Same usage as p_sample_loop_progressive().\n        \"\"\"\n        if device is None:\n            device = next(model.parameters()).device\n        assert isinstance(shape, (tuple, list))\n        if noise is not None:\n            img = noise\n        else:\n            img = th.randn(*shape, device=device)\n        indices = list(range(self.num_timesteps))[::-1]\n\n        if progress:\n            # Lazy import so that we don't depend on tqdm.\n            from tqdm.auto import tqdm\n\n            indices = tqdm(indices)\n\n        for i in indices:\n            t = th.tensor([i] * shape[0], device=device)\n            with th.no_grad():\n                out = self.ddim_sample(\n                    model,\n                    img,\n                    t,\n                    clip_denoised=clip_denoised,\n                    denoised_fn=denoised_fn,\n                    cond_fn=cond_fn,\n                    model_kwargs=model_kwargs,\n                    eta=eta,\n                    y0=y0,\n                    mask=mask,\n                    is_mask_t0=is_mask_t0,\n                )\n                yield out\n                img = out[\"sample\"]\n\n    def _vb_terms_bpd(\n        self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None\n    ):\n        \"\"\"\n        Get a term for the variational lower-bound.\n\n        The resulting units are bits (rather than nats, as one might expect).\n        This allows for comparison to other papers.\n\n        :return: a dict with the following keys:\n                 - 'output': a shape [N] tensor of NLLs or KLs.\n                 - 'pred_xstart': the x_0 predictions.\n        \"\"\"\n        true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(\n            x_start=x_start, x_t=x_t, t=t\n        )\n        out = self.p_mean_variance(\n            model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs\n        )\n        kl = normal_kl(\n            true_mean, true_log_variance_clipped, out[\"mean\"], out[\"log_variance\"]\n        )\n        kl = mean_flat(kl) / np.log(2.0)\n\n        decoder_nll = -discretized_gaussian_log_likelihood(\n            x_start, means=out[\"mean\"], log_scales=0.5 * out[\"log_variance\"]\n        )\n        assert decoder_nll.shape == x_start.shape\n        decoder_nll = mean_flat(decoder_nll) / np.log(2.0)\n\n        # At the first timestep return the decoder NLL,\n        # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))\n        output = th.where((t == 0), decoder_nll, kl)\n        return {\"output\": output, \"pred_xstart\": out[\"pred_xstart\"]}\n\n    def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):\n        \"\"\"\n        Compute training losses for a single timestep.\n\n        :param model: the model to evaluate loss on.\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :param t: a batch of timestep indices.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :param noise: if specified, the specific Gaussian noise to try to remove.\n        :return: a dict with the key \"loss\" containing a tensor of shape [N].\n                 Some mean or variance settings may also have other keys.\n        \"\"\"\n        if model_kwargs is None:\n            model_kwargs = {}\n        if noise is None:\n            noise = th.randn_like(x_start)\n        x_t = self.q_sample(x_start, t, noise=noise)  # sample\n\n        terms = {}\n\n        if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:\n            raise NotImplementedError\n            terms[\"loss\"] = self._vb_terms_bpd(\n                model=model,\n                x_start=x_start,\n                x_t=x_t,\n                t=t,\n                clip_denoised=False,\n                model_kwargs=model_kwargs,\n            )[\"output\"]\n            if self.loss_type == LossType.RESCALED_KL:\n                terms[\"loss\"] *= self.num_timesteps\n        elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:\n            model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)\n\n            if self.model_var_type in [\n                ModelVarType.LEARNED,\n                ModelVarType.LEARNED_RANGE,\n            ]:\n                B, C = x_t.shape[:2]\n                assert model_output.shape == (\n                    B,\n                    C * 2,\n                    *x_t.shape[2:],\n                )  # why the output channel is doubled? mean and var?\n                model_output, model_var_values = th.split(model_output, C, dim=1)\n                # Learn the variance using the variational bound, but don't let\n                # it affect our mean prediction.\n                frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)\n                terms[\"vb\"] = self._vb_terms_bpd(\n                    model=lambda *args, r=frozen_out: r,\n                    x_start=x_start,\n                    x_t=x_t,\n                    t=t,\n                    clip_denoised=False,\n                )[\"output\"]\n                if self.loss_type == LossType.RESCALED_MSE:\n                    # Divide by 1000 for equivalence with initial implementation.\n                    # Without a factor of 1/1000, the VB term hurts the MSE term.\n                    terms[\"vb\"] *= self.num_timesteps / 1000.0\n\n            target = {\n                ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(\n                    x_start=x_start, x_t=x_t, t=t\n                )[0],\n                ModelMeanType.START_X: x_start,\n                ModelMeanType.EPSILON: noise,\n            }[self.model_mean_type]\n            assert model_output.shape == target.shape == x_start.shape\n\n            target_tx, target_ty, target_tz = th.split(\n                target, target.shape[-1] // 3, dim=-1\n            )\n            output_tx, output_ty, output_tz = th.split(\n                model_output, model_output.shape[-1] // 3, dim=-1\n            )\n\n            terms[\"mse_tx\"] = mean_flat((target_tx - output_tx) ** 2)\n            terms[\"mse_ty\"] = mean_flat((target_ty - output_ty) ** 2)\n            terms[\"mse_tz\"] = mean_flat((target_tz - output_tz) ** 2)\n            # terms[\"mse\"] = mean_flat((target - model_output) ** 2)\n\n            if \"vb\" in terms:\n                terms[\"loss\"] = (\n                    terms[\"mse_tx\"] + terms[\"mse_ty\"] + terms[\"mse_tz\"] + terms[\"vb\"]\n                )\n                # terms[\"loss\"] = terms[\"mse\"] + terms[\"vb\"]\n            else:\n                terms[\"loss\"] = terms[\"mse_tx\"] + terms[\"mse_ty\"] + terms[\"mse_tz\"]\n                # terms[\"loss\"] = terms[\"mse\"]\n        else:\n            raise NotImplementedError(self.loss_type)\n\n        return terms\n\n    def _prior_bpd(self, x_start):\n        \"\"\"\n        Get the prior KL term for the variational lower-bound, measured in\n        bits-per-dim.\n\n        This term can't be optimized, as it only depends on the encoder.\n\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :return: a batch of [N] KL values (in bits), one per batch element.\n        \"\"\"\n        batch_size = x_start.shape[0]\n        t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)\n        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)\n        kl_prior = normal_kl(\n            mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0\n        )\n        return mean_flat(kl_prior) / np.log(2.0)\n\n    def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):\n        \"\"\"\n        Compute the entire variational lower-bound, measured in bits-per-dim,\n        as well as other related quantities.\n\n        :param model: the model to evaluate loss on.\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :param clip_denoised: if True, clip denoised samples.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n\n        :return: a dict containing the following keys:\n                 - total_bpd: the total variational lower-bound, per batch element.\n                 - prior_bpd: the prior term in the lower-bound.\n                 - vb: an [N x T] tensor of terms in the lower-bound.\n                 - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.\n                 - mse: an [N x T] tensor of epsilon MSEs for each timestep.\n        \"\"\"\n        device = x_start.device\n        batch_size = x_start.shape[0]\n\n        vb = []\n        xstart_mse = []\n        mse = []\n        for t in list(range(self.num_timesteps))[::-1]:\n            t_batch = th.tensor([t] * batch_size, device=device)\n            noise = th.randn_like(x_start)\n            x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)\n            # Calculate VLB term at the current timestep\n            with th.no_grad():\n                out = self._vb_terms_bpd(\n                    model,\n                    x_start=x_start,\n                    x_t=x_t,\n                    t=t_batch,\n                    clip_denoised=clip_denoised,\n                    model_kwargs=model_kwargs,\n                )\n            vb.append(out[\"output\"])\n            xstart_mse.append(mean_flat((out[\"pred_xstart\"] - x_start) ** 2))\n            eps = self._predict_eps_from_xstart(x_t, t_batch, out[\"pred_xstart\"])\n            mse.append(mean_flat((eps - noise) ** 2))\n\n        vb = th.stack(vb, dim=1)\n        xstart_mse = th.stack(xstart_mse, dim=1)\n        mse = th.stack(mse, dim=1)\n\n        prior_bpd = self._prior_bpd(x_start)\n        total_bpd = vb.sum(dim=1) + prior_bpd\n        return {\n            \"total_bpd\": total_bpd,\n            \"prior_bpd\": prior_bpd,\n            \"vb\": vb,\n            \"xstart_mse\": xstart_mse,\n            \"mse\": mse,\n        }\n\n\ndef _extract_into_tensor(arr, timesteps, broadcast_shape):\n    \"\"\"\n    Extract values from a 1-D numpy array for a batch of indices.\n\n    :param arr: the 1-D numpy array.\n    :param timesteps: a tensor of indices into the array to extract.\n    :param broadcast_shape: a larger shape of K dimensions with the batch\n                            dimension equal to the length of timesteps.\n    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.\n    \"\"\"\n    res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()\n    while len(res.shape) < len(broadcast_shape):\n        res = res[..., None]\n    return res.expand(broadcast_shape)\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/losses.py",
    "content": "\"\"\"\nHelpers for various likelihood-based losses. These are ported from the original\nHo et al. diffusion models codebase:\nhttps://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py\n\"\"\"\n\nimport numpy as np\n\nimport torch as th\n\n\ndef normal_kl(mean1, logvar1, mean2, logvar2):\n    \"\"\"\n    Compute the KL divergence between two gaussians.\n\n    Shapes are automatically broadcasted, so batches can be compared to\n    scalars, among other use cases.\n    \"\"\"\n    tensor = None\n    for obj in (mean1, logvar1, mean2, logvar2):\n        if isinstance(obj, th.Tensor):\n            tensor = obj\n            break\n    assert tensor is not None, \"at least one argument must be a Tensor\"\n\n    # Force variances to be Tensors. Broadcasting helps convert scalars to\n    # Tensors, but it does not work for th.exp().\n    logvar1, logvar2 = [\n        x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)\n        for x in (logvar1, logvar2)\n    ]\n\n    return 0.5 * (\n        -1.0\n        + logvar2\n        - logvar1\n        + th.exp(logvar1 - logvar2)\n        + ((mean1 - mean2) ** 2) * th.exp(-logvar2)\n    )\n\n\ndef approx_standard_normal_cdf(x):\n    \"\"\"\n    A fast approximation of the cumulative distribution function of the\n    standard normal.\n    \"\"\"\n    return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))\n\n\ndef discretized_gaussian_log_likelihood(x, *, means, log_scales):\n    \"\"\"\n    Compute the log-likelihood of a Gaussian distribution discretizing to a\n    given image.\n\n    :param x: the target images. It is assumed that this was uint8 values,\n              rescaled to the range [-1, 1].\n    :param means: the Gaussian mean Tensor.\n    :param log_scales: the Gaussian log stddev Tensor.\n    :return: a tensor like x of log probabilities (in nats).\n    \"\"\"\n    assert x.shape == means.shape == log_scales.shape\n    centered_x = x - means\n    inv_stdv = th.exp(-log_scales)\n    plus_in = inv_stdv * (centered_x + 1.0 / 255.0)\n    cdf_plus = approx_standard_normal_cdf(plus_in)\n    min_in = inv_stdv * (centered_x - 1.0 / 255.0)\n    cdf_min = approx_standard_normal_cdf(min_in)\n    log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))\n    log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))\n    cdf_delta = cdf_plus - cdf_min\n    log_probs = th.where(\n        x < -0.999,\n        log_cdf_plus,\n        th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),\n    )\n    assert log_probs.shape == x.shape\n    return log_probs\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/resample.py",
    "content": "\"\"\"\nCode borrowed from https://github.com/Sin3DM/Sin3DM/blob/9c3ac12a655157469c71632346ebf569354ae7f6/src/diffusion/resample.py\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nimport numpy as np\nimport torch as th\nimport torch.distributed as dist\n\n\ndef create_named_schedule_sampler(name, diffusion):\n    \"\"\"\n    Create a ScheduleSampler from a library of pre-defined samplers.\n\n    :param name: the name of the sampler.\n    :param diffusion: the diffusion object to sample for.\n    \"\"\"\n    if name == \"uniform\":\n        return UniformSampler(diffusion)\n    elif name == \"loss-second-moment\":\n        return LossSecondMomentResampler(diffusion)\n    else:\n        raise NotImplementedError(f\"unknown schedule sampler: {name}\")\n\n\nclass ScheduleSampler(ABC):\n    \"\"\"\n    A distribution over timesteps in the diffusion process, intended to reduce\n    variance of the objective.\n\n    By default, samplers perform unbiased importance sampling, in which the\n    objective's mean is unchanged.\n    However, subclasses may override sample() to change how the resampled\n    terms are reweighted, allowing for actual changes in the objective.\n    \"\"\"\n\n    @abstractmethod\n    def weights(self):\n        \"\"\"\n        Get a numpy array of weights, one per diffusion step.\n\n        The weights needn't be normalized, but must be positive.\n        \"\"\"\n\n    def sample(self, batch_size, device):\n        \"\"\"\n        Importance-sample timesteps for a batch.\n\n        :param batch_size: the number of timesteps.\n        :param device: the torch device to save to.\n        :return: a tuple (timesteps, weights):\n                 - timesteps: a tensor of timestep indices.\n                 - weights: a tensor of weights to scale the resulting losses.\n        \"\"\"\n        w = self.weights()\n        p = w / np.sum(w)\n        indices_np = np.random.choice(len(p), size=(batch_size,), p=p)\n        indices = th.from_numpy(indices_np).long().to(device)\n        weights_np = 1 / (len(p) * p[indices_np])\n        weights = th.from_numpy(weights_np).float().to(device)\n        return indices, weights\n\n\nclass UniformSampler(ScheduleSampler):\n    def __init__(self, diffusion):\n        self.diffusion = diffusion\n        self._weights = np.ones([diffusion.num_timesteps])\n\n    def weights(self):\n        return self._weights\n\n\nclass LossAwareSampler(ScheduleSampler):\n    def update_with_local_losses(self, local_ts, local_losses):\n        \"\"\"\n        Update the reweighting using losses from a model.\n\n        Call this method from each rank with a batch of timesteps and the\n        corresponding losses for each of those timesteps.\n        This method will perform synchronization to make sure all of the ranks\n        maintain the exact same reweighting.\n\n        :param local_ts: an integer Tensor of timesteps.\n        :param local_losses: a 1D Tensor of losses.\n        \"\"\"\n        batch_sizes = [\n            th.tensor([0], dtype=th.int32, device=local_ts.device)\n            for _ in range(dist.get_world_size())\n        ]\n        dist.all_gather(\n            batch_sizes,\n            th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),\n        )\n\n        # Pad all_gather batches to be the maximum batch size.\n        batch_sizes = [x.item() for x in batch_sizes]\n        max_bs = max(batch_sizes)\n\n        timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]\n        loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]\n        dist.all_gather(timestep_batches, local_ts)\n        dist.all_gather(loss_batches, local_losses)\n        timesteps = [\n            x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]\n        ]\n        losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]\n        self.update_with_all_losses(timesteps, losses)\n\n    @abstractmethod\n    def update_with_all_losses(self, ts, losses):\n        \"\"\"\n        Update the reweighting using losses from a model.\n\n        Sub-classes should override this method to update the reweighting\n        using losses from the model.\n\n        This method directly updates the reweighting without synchronizing\n        between workers. It is called by update_with_local_losses from all\n        ranks with identical arguments. Thus, it should have deterministic\n        behavior to maintain state across workers.\n\n        :param ts: a list of int timesteps.\n        :param losses: a list of float losses, one per timestep.\n        \"\"\"\n\n\nclass LossSecondMomentResampler(LossAwareSampler):\n    def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):\n        self.diffusion = diffusion\n        self.history_per_term = history_per_term\n        self.uniform_prob = uniform_prob\n        self._loss_history = np.zeros(\n            [diffusion.num_timesteps, history_per_term], dtype=np.float64\n        )\n        self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)\n\n    def weights(self):\n        if not self._warmed_up():\n            return np.ones([self.diffusion.num_timesteps], dtype=np.float64)\n        weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))\n        weights /= np.sum(weights)\n        weights *= 1 - self.uniform_prob\n        weights += self.uniform_prob / len(weights)\n        return weights\n\n    def update_with_all_losses(self, ts, losses):\n        for t, loss in zip(ts, losses):\n            if self._loss_counts[t] == self.history_per_term:\n                # Shift out the oldest loss term.\n                self._loss_history[t, :-1] = self._loss_history[t, 1:]\n                self._loss_history[t, -1] = loss\n            else:\n                self._loss_history[t, self._loss_counts[t]] = loss\n                self._loss_counts[t] += 1\n\n    def _warmed_up(self):\n        return (self._loss_counts == self.history_per_term).all()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/respace.py",
    "content": "import numpy as np\nimport torch as th\n\nfrom .gaussian_diffusion import GaussianDiffusion\n\n\ndef space_timesteps(num_timesteps, section_counts):\n    \"\"\"\n    Create a list of timesteps to use from an original diffusion process,\n    given the number of timesteps we want to take from equally-sized portions\n    of the original process.\n\n    For example, if there's 300 timesteps and the section counts are [10,15,20]\n    then the first 100 timesteps are strided to be 10 timesteps, the second 100\n    are strided to be 15 timesteps, and the final 100 are strided to be 20.\n\n    If the stride is a string starting with \"ddim\", then the fixed striding\n    from the DDIM paper is used, and only one section is allowed.\n\n    :param num_timesteps: the number of diffusion steps in the original\n                          process to divide up.\n    :param section_counts: either a list of numbers, or a string containing\n                           comma-separated numbers, indicating the step count\n                           per section. As a special case, use \"ddimN\" where N\n                           is a number of steps to use the striding from the\n                           DDIM paper.\n    :return: a set of diffusion steps from the original process to use.\n    \"\"\"\n    if isinstance(section_counts, str):\n        if section_counts.startswith(\"ddim\"):\n            desired_count = int(section_counts[len(\"ddim\") :])\n            for i in range(1, num_timesteps):\n                if len(range(0, num_timesteps, i)) == desired_count:\n                    return set(range(0, num_timesteps, i))\n            raise ValueError(\n                f\"cannot create exactly {num_timesteps} steps with an integer stride\"\n            )\n        section_counts = [int(x) for x in section_counts.split(\",\")]\n    size_per = num_timesteps // len(section_counts)\n    extra = num_timesteps % len(section_counts)\n    start_idx = 0\n    all_steps = []\n    for i, section_count in enumerate(section_counts):\n        size = size_per + (1 if i < extra else 0)\n        if size < section_count:\n            raise ValueError(\n                f\"cannot divide section of {size} steps into {section_count}\"\n            )\n        if section_count <= 1:\n            frac_stride = 1\n        else:\n            frac_stride = (size - 1) / (section_count - 1)\n        cur_idx = 0.0\n        taken_steps = []\n        for _ in range(section_count):\n            taken_steps.append(start_idx + round(cur_idx))\n            cur_idx += frac_stride\n        all_steps += taken_steps\n        start_idx += size\n    return set(all_steps)\n\n\nclass SpacedDiffusion(GaussianDiffusion):\n    \"\"\"\n    A diffusion process which can skip steps in a base diffusion process.\n\n    :param use_timesteps: a collection (sequence or set) of timesteps from the\n                          original diffusion process to retain.\n    :param kwargs: the kwargs to create the base diffusion process.\n    \"\"\"\n\n    def __init__(self, use_timesteps, **kwargs):\n        self.use_timesteps = set(use_timesteps)\n        self.timestep_map = []\n        self.original_num_steps = len(kwargs[\"betas\"])\n\n        base_diffusion = GaussianDiffusion(**kwargs)  # pylint: disable=missing-kwoa\n        last_alpha_cumprod = 1.0\n        new_betas = []\n        for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):\n            if i in self.use_timesteps:\n                new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)\n                last_alpha_cumprod = alpha_cumprod\n                self.timestep_map.append(i)\n        kwargs[\"betas\"] = np.array(new_betas)\n        super().__init__(**kwargs)\n\n    def p_mean_variance(\n        self, model, *args, **kwargs\n    ):  # pylint: disable=signature-differs\n        return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)\n\n    def training_losses(\n        self, model, *args, **kwargs\n    ):  # pylint: disable=signature-differs\n        return super().training_losses(self._wrap_model(model), *args, **kwargs)\n\n    def condition_mean(self, cond_fn, *args, **kwargs):\n        return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)\n\n    def condition_score(self, cond_fn, *args, **kwargs):\n        return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)\n\n    def _wrap_model(self, model):\n        if isinstance(model, _WrappedModel):\n            return model\n        return _WrappedModel(\n            model, self.timestep_map, self.rescale_timesteps, self.original_num_steps\n        )\n\n    def _scale_timesteps(self, t):\n        # Scaling is done by the wrapped model.\n        return t\n\n\nclass _WrappedModel:\n    def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):\n        self.model = model\n        self.timestep_map = timestep_map\n        self.rescale_timesteps = rescale_timesteps\n        self.original_num_steps = original_num_steps\n\n    def __call__(self, x, ts, **kwargs):\n        map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)\n        new_ts = map_tensor[ts]\n        if self.rescale_timesteps:\n            new_ts = new_ts.float() * (1000.0 / self.original_num_steps)\n        return self.model(x, new_ts, **kwargs)\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/sigma_sampling.py",
    "content": "import torch\nfrom inspect import isfunction\n\n# import sgm\n\n\ndef exists(x):\n    return x is not None\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\nclass EDMSamplingWithResShift:\n    def __init__(self, p_mean=-1.2, p_std=1.2, scale_shift=320.0 / 576):\n        self.p_mean = p_mean\n        self.p_std = p_std\n        self.scale_shift = scale_shift\n\n    def __call__(self, n_samples, rand=None):\n        log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))\n\n        sigma = log_sigma.exp() * self.scale_shift\n        return sigma\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/sv_diffusion_engine.py",
    "content": "import math\nfrom contextlib import contextmanager\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport pytorch_lightning as pl\nimport torch\nfrom omegaconf import ListConfig, OmegaConf\nfrom safetensors.torch import load_file as load_safetensors\nfrom torch.optim.lr_scheduler import LambdaLR\n\nfrom sgm.modules import UNCONDITIONAL_CONFIG\n\nfrom sgm.modules.autoencoding.temporal_ae import VideoDecoder\nfrom sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER\nfrom sgm.modules.ema import LitEma\nfrom sgm.util import (\n    default,\n    disabled_train,\n    get_obj_from_str,\n    instantiate_from_config,\n    log_txt_as_img,\n)\n\n\nclass SVDiffusionEngine(pl.LightningModule):\n    \"\"\"\n    stable video diffusion engine\n    \"\"\"\n\n    def __init__(\n        self,\n        network_config,\n        denoiser_config,\n        first_stage_config,\n        conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        network_wrapper: Union[None, str] = None,\n        ckpt_path: Union[None, str] = None,\n        use_ema: bool = False,\n        ema_decay_rate: float = 0.9999,\n        scale_factor: float = 1.0,\n        disable_first_stage_autocast=False,\n        input_key: str = \"jpg\",\n        log_keys: Union[List, None] = None,\n        no_cond_log: bool = False,\n        compile_model: bool = False,\n        en_and_decode_n_samples_a_time: Optional[int] = None,\n    ):\n        super().__init__()\n        self.log_keys = log_keys\n        self.input_key = input_key\n        self.optimizer_config = default(\n            optimizer_config, {\"target\": \"torch.optim.AdamW\"}\n        )\n        model = instantiate_from_config(network_config)\n        self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(\n            model, compile_model=compile_model\n        )\n\n        # TODO\n        # add lora to the model if lora input\n        # change forward\n        # print(self.model)\n        for name, child in self.model.named_modules():\n            # print(name, \"named child\")\n            pass\n\n        self.denoiser = instantiate_from_config(denoiser_config)\n        self.sampler = (\n            instantiate_from_config(sampler_config)\n            if sampler_config is not None\n            else None\n        )\n        self.conditioner = instantiate_from_config(\n            default(conditioner_config, UNCONDITIONAL_CONFIG)\n        )\n        self.scheduler_config = scheduler_config\n        self._init_first_stage(first_stage_config)\n\n        self.loss_fn = (\n            instantiate_from_config(loss_fn_config)\n            if loss_fn_config is not None\n            else None\n        )\n\n        self.use_ema = use_ema\n        if self.use_ema:\n            self.model_ema = LitEma(self.model, decay=ema_decay_rate)\n            print(f\"Keeping EMAs of {len(list(self.model_ema.buffers()))}.\")\n\n        self.scale_factor = scale_factor\n        self.disable_first_stage_autocast = disable_first_stage_autocast\n        self.no_cond_log = no_cond_log\n\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path)\n\n        self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time\n\n    def init_from_ckpt(\n        self,\n        path: str,\n    ) -> None:\n        print(\"init svd engine from\", path)\n        if path.endswith(\"ckpt\"):\n            sd = torch.load(path, map_location=\"cpu\")[\"state_dict\"]\n        elif path.endswith(\"safetensors\"):\n            sd = load_safetensors(path)\n        elif path.endswith(\"bin\"):\n            sd = torch.load(path, map_location=\"cpu\")\n        else:\n            raise NotImplementedError\n\n        missing, unexpected = self.load_state_dict(sd, strict=False)\n        print(\n            f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\"\n        )\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n        if len(unexpected) > 0:\n            print(f\"Unexpected Keys: {unexpected}\")\n\n    def _init_first_stage(self, config):\n        model = instantiate_from_config(config).eval()\n        model.train = disabled_train\n        for param in model.parameters():\n            param.requires_grad = False\n        self.first_stage_model = model\n\n    def get_input(self, batch):\n        # assuming unified data format, dataloader returns a dict.\n        # image tensors should be scaled to -1 ... 1 and in bchw format\n        return batch[self.input_key]\n\n    @torch.no_grad()\n    def decode_first_stage(self, z):\n        z = 1.0 / self.scale_factor * z\n        n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])\n\n        n_rounds = math.ceil(z.shape[0] / n_samples)\n        all_out = []\n        with torch.autocast(\"cuda\", enabled=not self.disable_first_stage_autocast):\n            for n in range(n_rounds):\n                if isinstance(self.first_stage_model.decoder, VideoDecoder):\n                    kwargs = {\"timesteps\": len(z[n * n_samples : (n + 1) * n_samples])}\n                else:\n                    kwargs = {}\n                out = self.first_stage_model.decode(\n                    z[n * n_samples : (n + 1) * n_samples], **kwargs\n                )\n                all_out.append(out)\n        out = torch.cat(all_out, dim=0)\n        return out\n\n    @torch.no_grad()\n    def encode_first_stage(self, x):\n        n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])\n        n_rounds = math.ceil(x.shape[0] / n_samples)\n        all_out = []\n        with torch.autocast(\"cuda\", enabled=not self.disable_first_stage_autocast):\n            for n in range(n_rounds):\n                out = self.first_stage_model.encode(\n                    x[n * n_samples : (n + 1) * n_samples]\n                )\n                all_out.append(out)\n        z = torch.cat(all_out, dim=0)\n        z = self.scale_factor * z\n        return z\n\n    def forward(self, batch, training=True):\n        assert training, \"DiffusionEngine forward function is only for training.\"\n\n        x = self.get_input(batch)\n        x = self.encode_first_stage(x)\n        batch[\"global_step\"] = self.global_step\n\n        x.requires_grad = True\n        loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)\n        loss_mean = loss.mean()\n        loss_dict = {\"loss\": loss_mean}\n        return loss_mean, loss_dict\n\n    def shared_step(self, batch: Dict) -> Any:\n        x = self.get_input(batch)\n        x = self.encode_first_stage(x)\n        batch[\"global_step\"] = self.global_step\n        loss, loss_dict = self(x, batch)\n        return loss, loss_dict\n\n    def training_step(self, batch, batch_idx):\n        loss, loss_dict = self.shared_step(batch)\n\n        self.log_dict(\n            loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False\n        )\n\n        self.log(\n            \"global_step\",\n            self.global_step,\n            prog_bar=True,\n            logger=True,\n            on_step=True,\n            on_epoch=False,\n        )\n\n        if self.scheduler_config is not None:\n            lr = self.optimizers().param_groups[0][\"lr\"]\n            self.log(\n                \"lr_abs\", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False\n            )\n\n        return loss\n\n    def on_train_start(self, *args, **kwargs):\n        if self.sampler is None or self.loss_fn is None:\n            raise ValueError(\"Sampler and loss function need to be set for training.\")\n\n    def on_train_batch_end(self, *args, **kwargs):\n        if self.use_ema:\n            self.model_ema(self.model)\n\n    @contextmanager\n    def ema_scope(self, context=None):\n        if self.use_ema:\n            self.model_ema.store(self.model.parameters())\n            self.model_ema.copy_to(self.model)\n            if context is not None:\n                print(f\"{context}: Switched to EMA weights\")\n        try:\n            yield None\n        finally:\n            if self.use_ema:\n                self.model_ema.restore(self.model.parameters())\n                if context is not None:\n                    print(f\"{context}: Restored training weights\")\n\n    def instantiate_optimizer_from_config(self, params, lr, cfg):\n        return get_obj_from_str(cfg[\"target\"])(\n            params, lr=lr, **cfg.get(\"params\", dict())\n        )\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n        params = list(self.model.parameters())\n        for embedder in self.conditioner.embedders:\n            if embedder.is_trainable:\n                params = params + list(embedder.parameters())\n        opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)\n        if self.scheduler_config is not None:\n            scheduler = instantiate_from_config(self.scheduler_config)\n            print(\"Setting up LambdaLR scheduler...\")\n            scheduler = [\n                {\n                    \"scheduler\": LambdaLR(opt, lr_lambda=scheduler.schedule),\n                    \"interval\": \"step\",\n                    \"frequency\": 1,\n                }\n            ]\n            return [opt], scheduler\n        return opt\n\n    def get_trainable_parameters(self):\n        params = list(self.model.parameters())\n        embedder_params = []\n        for embedder in self.conditioner.embedders:\n            if embedder.is_trainable:\n                embedder_params = embedder_params + list(embedder.parameters())\n        print(\n            \"number of trainable parameters: {} - from embeder: {} \".format(\n                len(params), len(embedder_params)\n            )\n        )\n        params = params + embedder_params\n        return params\n\n    @torch.no_grad()\n    def sample(\n        self,\n        cond: Dict,\n        uc: Union[Dict, None] = None,\n        batch_size: int = 16,\n        shape: Union[None, Tuple, List] = None,\n        **kwargs,\n    ):\n        randn = torch.randn(batch_size, *shape).to(self.device)\n\n        denoiser = lambda input, sigma, c: self.denoiser(\n            self.model, input, sigma, c, **kwargs\n        )\n        samples = self.sampler(denoiser, randn, cond, uc=uc)\n        return samples\n\n    @torch.no_grad()\n    def log_conditionings(self, batch: Dict, n: int) -> Dict:\n        \"\"\"\n        Defines heuristics to log different conditionings.\n        These can be lists of strings (text-to-image), tensors, ints, ...\n        \"\"\"\n        image_h, image_w = batch[self.input_key].shape[2:]\n        log = dict()\n\n        for embedder in self.conditioner.embedders:\n            if (\n                (self.log_keys is None) or (embedder.input_key in self.log_keys)\n            ) and not self.no_cond_log:\n                x = batch[embedder.input_key][:n]\n                if isinstance(x, torch.Tensor):\n                    if x.dim() == 1:\n                        # class-conditional, convert integer to string\n                        x = [str(x[i].item()) for i in range(x.shape[0])]\n                        xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)\n                    elif x.dim() == 2:\n                        # size and crop cond and the like\n                        x = [\n                            \"x\".join([str(xx) for xx in x[i].tolist()])\n                            for i in range(x.shape[0])\n                        ]\n                        xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)\n                    else:\n                        raise NotImplementedError()\n                elif isinstance(x, (List, ListConfig)):\n                    if isinstance(x[0], str):\n                        # strings\n                        xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)\n                    else:\n                        raise NotImplementedError()\n                else:\n                    raise NotImplementedError()\n                log[embedder.input_key] = xc\n        return log\n\n    @torch.no_grad()\n    def log_images(\n        self,\n        batch: Dict,\n        N: int = 8,\n        sample: bool = True,\n        ucg_keys: List[str] = None,\n        **kwargs,\n    ) -> Dict:\n        conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]\n        if ucg_keys:\n            assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (\n                \"Each defined ucg key for sampling must be in the provided conditioner input keys,\"\n                f\"but we have {ucg_keys} vs. {conditioner_input_keys}\"\n            )\n        else:\n            ucg_keys = conditioner_input_keys\n        log = dict()\n\n        x = self.get_input(batch)\n\n        c, uc = self.conditioner.get_unconditional_conditioning(\n            batch,\n            force_uc_zero_embeddings=ucg_keys\n            if len(self.conditioner.embedders) > 0\n            else [],\n        )\n\n        sampling_kwargs = {}\n\n        N = min(x.shape[0], N)\n        x = x.to(self.device)[:N]\n        log[\"inputs\"] = x\n        z = self.encode_first_stage(x)\n        log[\"reconstructions\"] = self.decode_first_stage(z)\n        log.update(self.log_conditionings(batch, N))\n\n        for k in c:\n            if isinstance(c[k], torch.Tensor):\n                c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))\n\n        if sample:\n            with self.ema_scope(\"Plotting\"):\n                samples = self.sample(\n                    c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs\n                )\n            samples = self.decode_first_stage(samples)\n            log[\"samples\"] = samples\n        return log\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/svd_conditioner.py",
    "content": "\"\"\"\nModified from https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/encoders/modules.py\n\"\"\"\nimport math\nfrom contextlib import nullcontext\nfrom functools import partial\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange, repeat\nimport numpy as np\nfrom omegaconf import ListConfig\nfrom sgm.util import append_dims, instantiate_from_config\nfrom sgm.modules.encoders.modules import GeneralConditioner\nimport random\n\n\nclass SVDConditioner(GeneralConditioner):\n    OUTPUT_DIM2KEYS = {2: \"vector\", 3: \"crossattn\", 4: \"concat\", 5: \"concat\"}\n    KEY2CATDIM = {\"vector\": 1, \"crossattn\": 2, \"concat\": 1}\n\n    def __init__(self, emb_models: Union[List, ListConfig]):\n        super().__init__(emb_models)\n        \n    def forward(\n        self, batch: Dict, force_zero_embeddings: Optional[List] = None\n    ) -> Dict:\n        output = dict()\n        \n        if force_zero_embeddings is None:\n            force_zero_embeddings = []\n        \n            if self.training:\n                img_ucg_rate = 0\n                for embedder in self.embedders:\n                    if embedder.input_key == \"cond_frames_without_noise\":\n                        img_ucg_rate = embedder.ucg_rate \n                        break\n                if img_ucg_rate > 0:\n                    if random.random() < img_ucg_rate:\n                        force_zero_embeddings.append(\"cond_frames_without_noise\")\n                        force_zero_embeddings.append(\"cond_frames\")\n\n        for embedder in self.embedders:\n            embedding_context = nullcontext if embedder.is_trainable else torch.no_grad\n            with embedding_context():\n                if hasattr(embedder, \"input_key\") and (embedder.input_key is not None):\n                    if embedder.legacy_ucg_val is not None:\n                        batch = self.possibly_get_ucg_val(embedder, batch)\n                    emb_out = embedder(batch[embedder.input_key])\n                elif hasattr(embedder, \"input_keys\"):\n                    emb_out = embedder(*[batch[k] for k in embedder.input_keys])\n            assert isinstance(\n                emb_out, (torch.Tensor, list, tuple)\n            ), f\"encoder outputs must be tensors or a sequence, but got {type(emb_out)}\"\n            if not isinstance(emb_out, (list, tuple)):\n                emb_out = [emb_out]\n            for emb in emb_out:\n                out_key = self.OUTPUT_DIM2KEYS[emb.dim()]\n                \n                if (\n                    hasattr(embedder, \"input_key\")\n                    and embedder.input_key in force_zero_embeddings\n                ):\n                    emb = torch.zeros_like(emb)\n                if out_key in output:\n                    output[out_key] = torch.cat(\n                        (output[out_key], emb), self.KEY2CATDIM[out_key]\n                    )\n                else:\n                    output[out_key] = emb\n        return output\n        "
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/svd_sds_engine.py",
    "content": "import math\nfrom contextlib import contextmanager\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport pytorch_lightning as pl\nimport torch\nfrom omegaconf import ListConfig, OmegaConf\nfrom safetensors.torch import load_file as load_safetensors\n\nfrom sgm.modules import UNCONDITIONAL_CONFIG\n\nfrom sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER\nfrom sgm.util import (\n    default,\n    disabled_train,\n    get_obj_from_str,\n    instantiate_from_config,\n    append_dims,\n)\n\nfrom motionrep.utils.svd_helpper import (\n    get_batch,\n    get_unique_embedder_keys_from_conditioner,\n)\nfrom einops import rearrange, repeat\n\nimport torch.nn.functional as F\n\nimport numpy as np\n\n\nclass SVDSDSEngine(pl.LightningModule):\n    \"\"\"\n    stable video diffusion engine\n    \"\"\"\n\n    def __init__(\n        self,\n        network_config,\n        denoiser_config,\n        first_stage_config,\n        conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        discretization_config: Union[None, Dict, ListConfig, OmegaConf] = None,  # Added\n        network_wrapper: Union[None, str] = None,\n        ckpt_path: Union[None, str] = None,\n        scale_factor: float = 1.0,\n        disable_first_stage_autocast=False,\n        input_key: str = \"jpg\",\n        compile_model: bool = False,\n        en_and_decode_n_samples_a_time: Optional[int] = None,\n    ):\n        super().__init__()\n        self.input_key = input_key\n\n        model = instantiate_from_config(network_config)\n        self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(\n            model, compile_model=compile_model\n        )\n        self.model.eval()\n\n        self.denoiser = instantiate_from_config(denoiser_config)\n        assert self.denoiser is not None, \"need denoiser\"\n\n        self.sampler = (\n            instantiate_from_config(sampler_config)\n            if sampler_config is not None\n            else None\n        )\n        self.conditioner = instantiate_from_config(\n            default(conditioner_config, UNCONDITIONAL_CONFIG)\n        )\n\n        self._init_first_stage(first_stage_config)\n\n        self.loss_fn = (\n            instantiate_from_config(loss_fn_config)\n            if loss_fn_config is not None\n            else None\n        )\n\n        self.scale_factor = scale_factor\n        self.disable_first_stage_autocast = disable_first_stage_autocast\n\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path)\n\n        self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time\n\n        assert discretization_config is not None, \"need discretizer\"\n        self.discretizer = instantiate_from_config(discretization_config)\n\n        # [1000]\n        sigmas_all = self.discretizer.get_sigmas(1000)\n        self.register_buffer(\"sigmas_all\", sigmas_all)\n\n    def init_from_ckpt(\n        self,\n        path: str,\n    ) -> None:\n        print(\"init svd engine from\", path)\n        if path.endswith(\"ckpt\"):\n            sd = torch.load(path, map_location=\"cpu\")[\"state_dict\"]\n        elif path.endswith(\"safetensors\"):\n            sd = load_safetensors(path)\n        elif path.endswith(\"bin\"):\n            sd = torch.load(path, map_location=\"cpu\")\n        else:\n            raise NotImplementedError\n\n        missing, unexpected = self.load_state_dict(sd, strict=False)\n        print(\n            f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\"\n        )\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n        if len(unexpected) > 0:\n            # print(f\"Unexpected Keys: {unexpected}\")\n            pass\n\n    def _init_first_stage(self, config):\n        model = instantiate_from_config(config).eval()\n        model.train = disabled_train\n        for param in model.parameters():\n            param.requires_grad = False\n        self.first_stage_model = model\n\n        del self.first_stage_model.decoder\n        self.first_stage_model.decoder = None\n\n    def get_input(self, batch):\n        # assuming unified data format, dataloader returns a dict.\n        # image tensors should be scaled to -1 ... 1 and in bchw format\n        return batch[self.input_key]\n\n    def encode_first_stage(self, x):\n        n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])\n        n_rounds = math.ceil(x.shape[0] / n_samples)\n        all_out = []\n        with torch.autocast(\"cuda\", enabled=not self.disable_first_stage_autocast):\n            for n in range(n_rounds):\n                out = self.first_stage_model.encode(\n                    x[n * n_samples : (n + 1) * n_samples]\n                )\n                all_out.append(out)\n        z = torch.cat(all_out, dim=0)\n        z = self.scale_factor * z\n        return z\n\n    def forward(self, batch, sample_time_range=[0.02, 0.98]):\n        \"\"\"\n        Args:\n            batch[\"jpg\"]: [BT, 3, H, W]. Videos range in\n                [-1, 1]? TODO Dec 16. Check\n            batch[\"cond_image\"]: [B, 3, H, W]. in [-1, 1]?\n                TODO: check shape\n        \"\"\"\n        x = self.get_input(batch)  # [BT, 3, H, W]\n        T = batch[\"num_video_frames\"]\n        batch_size = x.shape[0] // T\n        z = self.encode_first_stage(x)  # [BT, C, H_latent, W_latent]\n        batch[\"global_step\"] = self.global_step\n\n        with torch.no_grad():\n            sds_grad = self.edm_sds(z, batch, sample_time_range)\n            target = (z - sds_grad).detach()\n\n        loss_sds = 0.5 * F.mse_loss(z, target, reduction=\"sum\") / batch_size\n        log_loss_dict = {\n            \"loss_sds_video\": loss_sds.item(),\n            \"sds_delta_norm\": sds_grad.norm().item(),\n        }\n\n        return loss_sds, log_loss_dict\n\n    def forward_with_encoder_chunk(\n        self, batch, chunk_size=2, sample_time_range=[0.02, 0.98]\n    ):\n        with torch.no_grad():\n            x = self.get_input(batch)  # [BT, 3, H, W]\n            T = batch[\"num_video_frames\"]\n            batch_size = x.shape[0] // T\n            z = self.encode_first_stage(x)  # [BT, C, H_latent, W_latent]\n            batch[\"global_step\"] = self.global_step\n            sds_grad, denoised_latent = self.edm_sds(z, batch, sample_time_range)\n\n        num_chunks = math.ceil(z.shape[0] / chunk_size)\n\n        for n in range(num_chunks):\n            end_ind = min((n + 1) * chunk_size, z.shape[0])\n            x_chunk = x[n * chunk_size : end_ind]\n            z_chunk_recompute = self.encode_first_stage(x_chunk)\n\n            target_chunk = (\n                z_chunk_recompute - sds_grad[n * chunk_size : end_ind]\n            ).detach()\n\n            this_chunk_size = x_chunk.shape[0]\n            assert this_chunk_size > 0\n            # loss_sds_chunk = (\n            #     0.5\n            #     * F.mse_loss(z_chunk_recompute, target_chunk, reduction=\"mean\")\n            #     * this_chunk_size\n            #     / z.shape[0]\n            #     / batch_size\n            # )\n            loss_sds_chunk = 0.5 * F.mse_loss(z_chunk_recompute, target_chunk, reduction=\"sum\") / batch_size\n\n            loss_sds_chunk.backward()\n\n        with torch.no_grad():\n            target = (z - sds_grad).detach()\n            loss_sds = 0.5 * F.mse_loss(z, target, reduction=\"sum\") / batch_size\n            log_loss_dict = {\n                \"latent_loss_sds\": loss_sds.item(),\n                \"latent_sds_norm\": sds_grad.norm().item(),\n                \"latent_sds_max\": sds_grad.max().item(),\n                \"latent_sds_mean\": sds_grad.mean().item(),\n            }\n\n            video_space_sds_grad = x.grad\n\n        return video_space_sds_grad, log_loss_dict, denoised_latent\n\n    @torch.no_grad()\n    def edm_sds(self, input_x, extra_input, sample_time_range=[0.02, 0.98]):\n        \"\"\"\n        Args:\n            input_x: [BT, C, H, W] in latent\n            extra_input: dict\n                \"fps_id\": [B]\n                \"motion_bucket_id\": [B]\n                \"cond_aug\": [B]\n                \"cond_frames_without_noise\": [B, C, H, W]\n                \"cond_frames\": [B, C, H, W]\n            sample_time_range: [t_min, t_max]\n        \"\"\"\n\n        # step-1: prepare inputs\n        num_frames = extra_input[\"num_video_frames\"]\n        batch_size = input_x.shape[0] // num_frames\n        device = input_x.device\n        # video = video.contiguous()\n\n        extra_input[\"num_video_frames\"] = num_frames\n\n        # prepare c and uc\n\n        batch, batch_uc = get_batch(\n            get_unique_embedder_keys_from_conditioner(self.conditioner),\n            extra_input,\n            [1, num_frames],\n            T=num_frames,\n            device=device,\n        )\n\n        # keys would be be ['crossattn', 'vector', 'concat']\n        c, uc = self.conditioner.get_unconditional_conditioning(\n            batch,\n            batch_uc=batch_uc,\n            force_uc_zero_embeddings=[\n                \"cond_frames\",\n                \"cond_frames_without_noise\",\n            ],\n        )\n\n        for k in [\"crossattn\", \"concat\"]:\n            uc[k] = repeat(uc[k], \"b ... -> b t ...\", t=num_frames)\n            uc[k] = rearrange(uc[k], \"b t ... -> (b t) ...\", t=num_frames)\n            c[k] = repeat(c[k], \"b ... -> b t ...\", t=num_frames)\n            c[k] = rearrange(c[k], \"b t ... -> (b t) ...\", t=num_frames)\n\n        # after this should be\n        # crossattn [14, 1, 1024];  vector [14, 768]; concat [14, 4, 72, 128]\n        additional_model_inputs = {}\n        additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n            int(2 * batch_size), num_frames\n        ).to(device)\n        additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n\n        # step-2: sample t and sigmas, then noise\n        sampled_t = np.random.randint(\n            low=int(sample_time_range[0] * self.sigmas_all.shape[0]),\n            high=int(sample_time_range[1] * self.sigmas_all.shape[0]),\n            size=(batch_size),\n        ).tolist() # list of index time t [B]\n        sigmas = self.sigmas_all[sampled_t]\n\n        # sigmas = self.loss_fn.sigma_sampler(batch_size).to(input_x)\n        sigmas = repeat(sigmas, \"b ... -> b t ...\", t=num_frames)\n        sigmas = rearrange(sigmas, \"b t ... -> (b t) ...\", t=num_frames)\n\n        noise = torch.randn_like(input_x)  # [BT, C, H, W]\n\n        sigmas_bc = append_dims(sigmas, input_x.ndim)  # [14, 1, 1, 1]\n\n        noised_input = self.loss_fn.get_noised_input(\n            sigmas_bc, noise, input_x\n        )  # [BT, C, H, W]\n\n        # step-3: prepare conditional and unconditional inputs\n        # [2BT, C, H, W], [2BT]\n        bathced_xt, bathced_sigmas, bathched_c = self.sampler.guider.prepare_inputs(\n            noised_input, sigmas, c, uc\n        )\n        # bathched_c[\"crossattn\"] => [2BT, 1, C] ;   bathched_c[\"concat\"] => [2BT, C, H, W]; bathched_c[\"vector\"] => [2BT, C_feat]\n\n        # output shape [2BT, C, H, W]\n        denoised = self.denoiser(\n            self.model,\n            bathced_xt,\n            bathced_sigmas,\n            bathched_c,\n            **additional_model_inputs,\n        )\n\n        # step-4: cfg guidance and compute sds_grad\n        # [BT, C, H, W]\n        denoised = self.sampler.guider(denoised, bathced_sigmas)\n\n        sds_grad = (input_x - denoised) / sigmas_bc        \n\n        return sds_grad, denoised\n\n    @torch.no_grad()\n    def edm_sds_multistep(self, input_x, extra_input, sample_time_range=[0.02, 0.84], num_step=4, total_steps=25):\n        \"\"\"\n        From t = 20 sample to t = 980. \n        Args:\n            input_x: [BT, C, H, W] in latent\n            extra_input: dict\n                \"fps_id\": [B]\n                \"motion_bucket_id\": [B]\n                \"cond_aug\": [B]\n                \"cond_frames_without_noise\": [B, C, H, W]\n                \"cond_frames\": [B, C, H, W]\n            sample_time_range: [t_min, t_max]\n        \"\"\"\n\n        # step-1: prepare inputs\n        num_frames = extra_input[\"num_video_frames\"]\n        batch_size = input_x.shape[0] // num_frames\n        device = input_x.device\n        # video = video.contiguous()\n\n        extra_input[\"num_video_frames\"] = num_frames\n\n        # prepare c and uc\n\n        batch, batch_uc = get_batch(\n            get_unique_embedder_keys_from_conditioner(self.conditioner),\n            extra_input,\n            [1, num_frames],\n            T=num_frames,\n            device=device,\n        )\n\n        # keys would be be ['crossattn', 'vector', 'concat']\n        c, uc = self.conditioner.get_unconditional_conditioning(\n            batch,\n            batch_uc=batch_uc,\n            force_uc_zero_embeddings=[\n                \"cond_frames\",\n                \"cond_frames_without_noise\",\n            ],\n        )\n\n        for k in [\"crossattn\", \"concat\"]:\n            uc[k] = repeat(uc[k], \"b ... -> b t ...\", t=num_frames)\n            uc[k] = rearrange(uc[k], \"b t ... -> (b t) ...\", t=num_frames)\n            c[k] = repeat(c[k], \"b ... -> b t ...\", t=num_frames)\n            c[k] = rearrange(c[k], \"b t ... -> (b t) ...\", t=num_frames)\n\n        # after this should be\n        # crossattn [14, 1, 1024];  vector [14, 768]; concat [14, 4, 72, 128]\n        additional_model_inputs = {}\n        additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n            int(2 * batch_size), num_frames\n        ).to(device)\n        additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n\n        # step-2: sample t and sigmas, then noise\n        sampled_t = np.random.randint(\n            low=int(sample_time_range[0] * self.sigmas_all.shape[0]),\n            high=int(sample_time_range[1] * self.sigmas_all.shape[0]),\n            size=(batch_size),\n        ) # np.array of index time t [B]\n\n        step_stride = len(self.sigmas_all) // total_steps\n\n        sigma_sum = 0.0\n        for i in range(num_step):\n            sampled_t += step_stride * i \n            sampled_t = np.clip(sampled_t, 0, len(self.sigmas_all) - 2)\n            \n\n            # [B]\n            sigmas = self.sigmas_all[sampled_t]\n\n            # sigmas = self.loss_fn.sigma_sampler(batch_size).to(input_x)\n            sigmas = repeat(sigmas, \"b ... -> b t ...\", t=num_frames)\n            sigmas = rearrange(sigmas, \"b t ... -> (b t) ...\", t=num_frames)\n\n            sigmas_bc = append_dims(sigmas, input_x.ndim)  # [14, 1, 1, 1]\n\n            if i == 0:\n\n                noise = torch.randn_like(input_x)  # [BT, C, H, W]\n\n                noised_input = self.loss_fn.get_noised_input(\n                    sigmas_bc, noise, input_x\n                )  # [BT, C, H, W]\n            else:\n                # dt is negative\n                dt = append_dims(sigmas - prev_sigmas, input_x.ndim)\n                \n                dx = (noised_input - denoised) / append_dims(prev_sigmas, input_x.ndim)\n                noised_input = noised_input + dt * dx\n\n            denoised = self.sampler_step(sigmas, noised_input, c, uc, \n                                         num_frames=num_frames, additional_model_inputs=additional_model_inputs)\n            prev_sigmas = sigmas\n            sigma_sum += sigmas_bc\n            \n        # TODO, so many sigmas, which to use?\n        # sds_grad = (input_x - denoised) / sigmas_bc\n        sds_grad = (input_x - denoised) / sigma_sum\n\n        return sds_grad, denoised\n    \n\n    def sampler_step(self, sigma, noised_input, c, uc=None, num_frames=None, additional_model_inputs=None):\n        \n        # step-3: prepare conditional and unconditional inputs\n        # [2BT, C, H, W], [2BT]\n        bathced_xt, bathced_sigmas, bathched_c = self.sampler.guider.prepare_inputs(\n            noised_input, sigma, c, uc\n        )\n        # bathched_c[\"crossattn\"] => [2BT, 1, C] ;   bathched_c[\"concat\"] => [2BT, C, H, W]; bathched_c[\"vector\"] => [2BT, C_feat]\n\n        # output shape [2BT, C, H, W]\n        denoised = self.denoiser(\n            self.model,\n            bathced_xt,\n            bathced_sigmas,\n            bathched_c,\n            **additional_model_inputs,\n        )\n\n        # step-4: cfg guidance and compute sds_grad\n        # [BT, C, H, W]\n        denoised = self.sampler.guider(denoised, bathced_sigmas)\n\n        return denoised\n\n\n\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/svd_sds_engine_backup.py",
    "content": "import math\nfrom contextlib import contextmanager\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport pytorch_lightning as pl\nimport torch\nfrom omegaconf import ListConfig, OmegaConf\nfrom safetensors.torch import load_file as load_safetensors\n\nfrom sgm.modules import UNCONDITIONAL_CONFIG\n\nfrom sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER\nfrom sgm.util import (\n    default,\n    disabled_train,\n    get_obj_from_str,\n    instantiate_from_config,\n    append_dims,\n)\n\nfrom motionrep.utils.svd_helpper import (\n    get_batch,\n    get_unique_embedder_keys_from_conditioner,\n)\nfrom einops import rearrange, repeat\n\nimport torch.nn.functional as F\n\n\nclass SVDSDSEngine(pl.LightningModule):\n    \"\"\"\n    stable video diffusion engine\n    \"\"\"\n\n    def __init__(\n        self,\n        network_config,\n        denoiser_config,\n        first_stage_config,\n        conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        network_wrapper: Union[None, str] = None,\n        ckpt_path: Union[None, str] = None,\n        scale_factor: float = 1.0,\n        disable_first_stage_autocast=False,\n        input_key: str = \"jpg\",\n        compile_model: bool = False,\n        en_and_decode_n_samples_a_time: Optional[int] = None,\n    ):\n        super().__init__()\n        self.input_key = input_key\n\n        model = instantiate_from_config(network_config)\n        self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(\n            model, compile_model=compile_model\n        )\n        self.model.eval()\n\n        self.denoiser = instantiate_from_config(denoiser_config)\n        self.sampler = (\n            instantiate_from_config(sampler_config)\n            if sampler_config is not None\n            else None\n        )\n        self.conditioner = instantiate_from_config(\n            default(conditioner_config, UNCONDITIONAL_CONFIG)\n        )\n\n        self._init_first_stage(first_stage_config)\n\n        self.loss_fn = (\n            instantiate_from_config(loss_fn_config)\n            if loss_fn_config is not None\n            else None\n        )\n\n        self.scale_factor = scale_factor\n        self.disable_first_stage_autocast = disable_first_stage_autocast\n\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path)\n\n        self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time\n\n    def init_from_ckpt(\n        self,\n        path: str,\n    ) -> None:\n        print(\"init svd engine from\", path)\n        if path.endswith(\"ckpt\"):\n            sd = torch.load(path, map_location=\"cpu\")[\"state_dict\"]\n        elif path.endswith(\"safetensors\"):\n            sd = load_safetensors(path)\n        elif path.endswith(\"bin\"):\n            sd = torch.load(path, map_location=\"cpu\")\n        else:\n            raise NotImplementedError\n\n        missing, unexpected = self.load_state_dict(sd, strict=False)\n        print(\n            f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\"\n        )\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n        if len(unexpected) > 0:\n            # print(f\"Unexpected Keys: {unexpected}\")\n            pass\n\n    def _init_first_stage(self, config):\n        model = instantiate_from_config(config).eval()\n        model.train = disabled_train\n        for param in model.parameters():\n            param.requires_grad = False\n        self.first_stage_model = model\n\n        del self.first_stage_model.decoder\n        self.first_stage_model.decoder = None\n\n    def get_input(self, batch):\n        # assuming unified data format, dataloader returns a dict.\n        # image tensors should be scaled to -1 ... 1 and in bchw format\n        return batch[self.input_key]\n\n    def encode_first_stage(self, x):\n        n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])\n        n_rounds = math.ceil(x.shape[0] / n_samples)\n        all_out = []\n        with torch.autocast(\"cuda\", enabled=not self.disable_first_stage_autocast):\n            for n in range(n_rounds):\n                out = self.first_stage_model.encode(\n                    x[n * n_samples : (n + 1) * n_samples]\n                )\n                all_out.append(out)\n        z = torch.cat(all_out, dim=0)\n        z = self.scale_factor * z\n        return z\n\n    def forward(self, batch, training=True):\n        assert training, \"DiffusionEngine forward function is only for training.\"\n\n        x = self.get_input(batch)  # [BT, 3, H, W]\n        T = batch[\"num_video_frames\"]\n        batch_size = x.shape[0] // T\n        z = self.encode_first_stage(x)  # [BT, C, H_latent, W_latent]\n        batch[\"global_step\"] = self.global_step\n\n        sds_grad = self.emd_sds(z, batch)\n        target = (z - sds_grad).detach()\n\n        loss_sds = 0.5 * F.mse_loss(z, target, reduction=\"mean\") / batch_size\n        log_loss_dict = {\n            \"loss_sds_video\": loss_sds.item(),\n            \"grad_norm\": sds_grad.norm().item(),\n        }\n\n        return loss_sds, log_loss_dict\n\n    def forward_with_encoder_chunk(self, batch, chunk_size=2):\n        with torch.no_grad():\n            x = self.get_input(batch)  # [BT, 3, H, W]\n            T = batch[\"num_video_frames\"]\n            batch_size = x.shape[0] // T\n            z = self.encode_first_stage(x)  # [BT, C, H_latent, W_latent]\n            batch[\"global_step\"] = self.global_step\n            sds_grad = self.emd_sds(z, batch)\n\n        num_chunks = math.ceil(z.shape[0] / chunk_size)\n\n        for n in range(num_chunks):\n            end_ind = min((n + 1) * chunk_size, z.shape[0])\n            x_chunk = x[n * chunk_size : end_ind]\n            z_chunk_recompute = self.encode_first_stage(x_chunk)\n\n            target_chunk = (\n                z_chunk_recompute - sds_grad[n * chunk_size : end_ind]\n            ).detach()\n\n            this_chunk_size = x_chunk.shape[0]\n            loss_sds_chunk = (\n                0.5\n                * F.mse_loss(z_chunk_recompute, target_chunk, reduction=\"mean\")\n                * this_chunk_size\n                / z.shape[0]\n                / batch_size\n            )\n\n            loss_sds_chunk.backward()\n\n        with torch.no_grad():\n            target = (z - sds_grad).detach()\n            loss_sds = 0.5 * F.mse_loss(z, target, reduction=\"mean\") / batch_size\n            log_loss_dict = {\n                \"loss_sds_video\": loss_sds.item(),\n                \"grad_norm\": sds_grad.norm().item(),\n            }\n\n            video_space_sds_grad = x.grad\n\n        return video_space_sds_grad, loss_sds, log_loss_dict\n\n    @torch.no_grad()\n    def emd_sds(self, input_x, extra_input):\n        \"\"\"\n        Args:\n            input_x: [BT, C, H, W] in latent\n            extra_input: dict\n                \"fps_id\": [B]\n                \"motion_bucket_id\": [B]\n                \"cond_aug\": [B]\n                \"cond_frames_without_noise\": [B, C, H, W]\n                \"cond_frames\": [B, C, H, W]\n        \"\"\"\n\n        num_frames = extra_input[\"num_video_frames\"]\n        batch_size = input_x.shape[0] // num_frames\n        device = input_x.device\n        # video = video.contiguous()\n\n        extra_input[\"num_video_frames\"] = num_frames\n\n        # prepare c and uc\n\n        batch, batch_uc = get_batch(\n            get_unique_embedder_keys_from_conditioner(self.conditioner),\n            extra_input,\n            [1, num_frames],\n            T=num_frames,\n            device=device,\n        )\n\n        # keys would be be ['crossattn', 'vector', 'concat']\n        c, uc = self.conditioner.get_unconditional_conditioning(\n            batch,\n            batch_uc=batch_uc,\n            force_uc_zero_embeddings=[\n                \"cond_frames\",\n                \"cond_frames_without_noise\",\n            ],\n        )\n\n        for k in [\"crossattn\", \"concat\"]:\n            uc[k] = repeat(uc[k], \"b ... -> b t ...\", t=num_frames)\n            uc[k] = rearrange(uc[k], \"b t ... -> (b t) ...\", t=num_frames)\n            c[k] = repeat(c[k], \"b ... -> b t ...\", t=num_frames)\n            c[k] = rearrange(c[k], \"b t ... -> (b t) ...\", t=num_frames)\n\n        # after this should be\n        # crossattn torch.Size([14, 1, 1024])\n        # vector torch.Size([14, 768])\n        # concat torch.Size([14, 4, 72, 128])\n\n        additional_model_inputs = {}\n        additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n            int(2 * batch_size), num_frames\n        ).to(device)\n        additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n\n        sigmas = self.loss_fn.sigma_sampler(batch_size).to(input_x)\n        sigmas = repeat(sigmas, \"b ... -> b t ...\", t=num_frames)\n        sigmas = rearrange(sigmas, \"b t ... -> (b t) ...\", t=num_frames)\n\n        noise = torch.randn_like(input_x)  # [BT, C, H, W]\n\n        sigmas_bc = append_dims(sigmas, input_x.ndim)  # [14, 1, 1, 1]\n        noised_input = self.loss_fn.get_noised_input(\n            sigmas_bc, noise, input_x\n        )  # [BT, C, H, W]\n\n        # [2BT, C, H, W], [2BT]\n        bathced_xt, bathced_sigmas, bathched_c = self.sampler.guider.prepare_inputs(\n            noised_input, sigmas, c, uc\n        )\n        # bathched_c[crossattn] => [2BT, 1, C] ;   bathched_c[\"concat\"] => [2BT, C, H, W]; bathched_c[\"vector\"] => [2BT, C_feat]\n\n        # output shape [2BT, C, H, W]\n\n        denoised = self.denoiser(\n            self.model,\n            bathced_xt,\n            bathced_sigmas,\n            bathched_c,\n            **additional_model_inputs,\n        )\n\n        # [BT, C, H, W]\n        denoised = self.sampler.guider(denoised, bathced_sigmas)\n\n        sds_grad = (denoised - input_x) / sigmas_bc\n\n        return sds_grad\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/svd_sds_wdecoder_engine.py",
    "content": "import math\nfrom contextlib import contextmanager\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport pytorch_lightning as pl\nimport torch\nfrom omegaconf import ListConfig, OmegaConf\nfrom safetensors.torch import load_file as load_safetensors\n\nfrom sgm.modules import UNCONDITIONAL_CONFIG\n\nfrom sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER\nfrom sgm.util import (\n    default,\n    disabled_train,\n    get_obj_from_str,\n    instantiate_from_config,\n    append_dims,\n)\n\nfrom motionrep.utils.svd_helpper import (\n    get_batch,\n    get_unique_embedder_keys_from_conditioner,\n)\nfrom einops import rearrange, repeat\n\nimport torch.nn.functional as F\nfrom sgm.modules.autoencoding.temporal_ae import VideoDecoder\n\nimport numpy as np\n\n\nclass SVDWDecSDSEngine(pl.LightningModule):\n    \"\"\"\n    stable video diffusion engine\n    \"\"\"\n\n    def __init__(\n        self,\n        network_config,\n        denoiser_config,\n        first_stage_config,\n        conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,\n        discretization_config: Union[None, Dict, ListConfig, OmegaConf] = None,  # Added\n        network_wrapper: Union[None, str] = None,\n        ckpt_path: Union[None, str] = None,\n        scale_factor: float = 1.0,\n        disable_first_stage_autocast=False,\n        input_key: str = \"jpg\",\n        compile_model: bool = False,\n        en_and_decode_n_samples_a_time: Optional[int] = None,\n    ):\n        super().__init__()\n        self.input_key = input_key\n\n        model = instantiate_from_config(network_config)\n        self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(\n            model, compile_model=compile_model\n        )\n        self.model.eval()\n\n        self.denoiser = instantiate_from_config(denoiser_config)\n        assert self.denoiser is not None, \"need denoiser\"\n\n        self.sampler = (\n            instantiate_from_config(sampler_config)\n            if sampler_config is not None\n            else None\n        )\n        self.conditioner = instantiate_from_config(\n            default(conditioner_config, UNCONDITIONAL_CONFIG)\n        )\n\n        self._init_first_stage(first_stage_config)\n\n        self.loss_fn = (\n            instantiate_from_config(loss_fn_config)\n            if loss_fn_config is not None\n            else None\n        )\n\n        self.scale_factor = scale_factor\n        self.disable_first_stage_autocast = disable_first_stage_autocast\n\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path)\n\n        self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time\n\n        assert discretization_config is not None, \"need discretizer\"\n        self.discretizer = instantiate_from_config(discretization_config)\n\n        # [1000]\n        sigmas_all = self.discretizer.get_sigmas(1000)\n        self.register_buffer(\"sigmas_all\", sigmas_all)\n\n    def init_from_ckpt(\n        self,\n        path: str,\n    ) -> None:\n        print(\"init svd engine from\", path)\n        if path.endswith(\"ckpt\"):\n            sd = torch.load(path, map_location=\"cpu\")[\"state_dict\"]\n        elif path.endswith(\"safetensors\"):\n            sd = load_safetensors(path)\n        elif path.endswith(\"bin\"):\n            sd = torch.load(path, map_location=\"cpu\")\n        else:\n            raise NotImplementedError\n\n        missing, unexpected = self.load_state_dict(sd, strict=False)\n        print(\n            f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\"\n        )\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n        if len(unexpected) > 0:\n            # print(f\"Unexpected Keys: {unexpected}\")\n            pass\n\n    def _init_first_stage(self, config):\n        model = instantiate_from_config(config).eval()\n        model.train = disabled_train\n        for param in model.parameters():\n            param.requires_grad = False\n        self.first_stage_model = model\n\n    def get_input(self, batch):\n        # assuming unified data format, dataloader returns a dict.\n        # image tensors should be scaled to -1 ... 1 and in bchw format\n        return batch[self.input_key]\n\n    def encode_first_stage(self, x):\n        n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])\n        n_rounds = math.ceil(x.shape[0] / n_samples)\n        all_out = []\n        with torch.autocast(\"cuda\", enabled=not self.disable_first_stage_autocast):\n            for n in range(n_rounds):\n                out = self.first_stage_model.encode(\n                    x[n * n_samples : (n + 1) * n_samples]\n                )\n                all_out.append(out)\n        z = torch.cat(all_out, dim=0)\n        z = self.scale_factor * z\n        return z\n\n    @torch.no_grad()\n    def decode_first_stage(self, z):\n        z = 1.0 / self.scale_factor * z\n        n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])\n\n        n_rounds = math.ceil(z.shape[0] / n_samples)\n        all_out = []\n        with torch.autocast(\"cuda\", enabled=not self.disable_first_stage_autocast):\n            for n in range(n_rounds):\n                if isinstance(self.first_stage_model.decoder, VideoDecoder):\n                    kwargs = {\"timesteps\": len(z[n * n_samples : (n + 1) * n_samples])}\n                else:\n                    kwargs = {}\n                out = self.first_stage_model.decode(\n                    z[n * n_samples : (n + 1) * n_samples], **kwargs\n                )\n                all_out.append(out)\n        out = torch.cat(all_out, dim=0)\n        return out\n\n    def forward(self, batch, sample_time_range=[0.02, 0.98]):\n        \"\"\"\n        Args:\n            batch[\"jpg\"]: [BT, 3, H, W]. Videos range in\n                [-1, 1]? TODO Dec 16. Check\n            batch[\"cond_image\"]: [B, 3, H, W]. in [-1, 1]?\n                TODO: check shape\n        \"\"\"\n        x = self.get_input(batch)  # [BT, 3, H, W]\n        T = batch[\"num_video_frames\"]\n        batch_size = x.shape[0] // T\n        z = self.encode_first_stage(x)  # [BT, C, H_latent, W_latent]\n        batch[\"global_step\"] = self.global_step\n\n        with torch.no_grad():\n            sds_grad = self.edm_sds(z, batch, sample_time_range)\n            target = (z - sds_grad).detach()\n\n        loss_sds = 0.5 * F.mse_loss(z, target, reduction=\"mean\") / batch_size\n        log_loss_dict = {\n            \"loss_sds_video\": loss_sds.item(),\n            \"sds_delta_norm\": sds_grad.norm().item(),\n        }\n\n        return loss_sds, log_loss_dict\n\n    def forward_with_encoder_chunk(\n        self, batch, chunk_size=2, sample_time_range=[0.02, 0.98]\n    ):\n        with torch.no_grad():\n            x = self.get_input(batch)  # [BT, 3, H, W]\n            T = batch[\"num_video_frames\"]\n            batch_size = x.shape[0] // T\n            z = self.encode_first_stage(x)  # [BT, C, H_latent, W_latent]\n            batch[\"global_step\"] = self.global_step\n            sds_grad, denoised_latent = self.edm_sds(z, batch, sample_time_range)\n\n        num_chunks = math.ceil(z.shape[0] / chunk_size)\n\n        for n in range(num_chunks):\n            end_ind = min((n + 1) * chunk_size, z.shape[0])\n            x_chunk = x[n * chunk_size : end_ind]\n            z_chunk_recompute = self.encode_first_stage(x_chunk)\n\n            target_chunk = (\n                z_chunk_recompute - sds_grad[n * chunk_size : end_ind]\n            ).detach()\n\n            this_chunk_size = x_chunk.shape[0]\n            assert this_chunk_size > 0\n            # loss_sds_chunk = (\n            #     0.5\n            #     * F.mse_loss(z_chunk_recompute, target_chunk, reduction=\"mean\")\n            #     * this_chunk_size\n            #     / z.shape[0]\n            #     / batch_size\n            # )\n            loss_sds_chunk = 0.5 * F.mse_loss(z_chunk_recompute, target_chunk, reduction=\"sum\") / batch_size\n\n            loss_sds_chunk.backward()\n\n        with torch.no_grad():\n            target = (z - sds_grad).detach()\n            loss_sds = 0.5 * F.mse_loss(z, target, reduction=\"sum\") / batch_size\n            log_loss_dict = {\n                \"latent_loss_sds\": loss_sds.item(),\n                \"latent_sds_norm\": sds_grad.norm().item(),\n                \"latent_sds_max\": sds_grad.abs().max().item(),\n                \"latent_sds_mean\": sds_grad.abs().mean().item(),\n            }\n\n            video_space_sds_grad = x.grad\n\n        return video_space_sds_grad, log_loss_dict, denoised_latent\n\n    @torch.no_grad()\n    def edm_sds(self, input_x, extra_input, sample_time_range=[0.02, 0.98]):\n        \"\"\"\n        Args:\n            input_x: [BT, C, H, W] in latent\n            extra_input: dict\n                \"fps_id\": [B]\n                \"motion_bucket_id\": [B]\n                \"cond_aug\": [B]\n                \"cond_frames_without_noise\": [B, C, H, W]\n                \"cond_frames\": [B, C, H, W]\n            sample_time_range: [t_min, t_max]\n        \"\"\"\n\n        # step-1: prepare inputs\n        num_frames = extra_input[\"num_video_frames\"]\n        batch_size = input_x.shape[0] // num_frames\n        device = input_x.device\n        # video = video.contiguous()\n\n        extra_input[\"num_video_frames\"] = num_frames\n\n        # prepare c and uc\n\n        batch, batch_uc = get_batch(\n            get_unique_embedder_keys_from_conditioner(self.conditioner),\n            extra_input,\n            [1, num_frames],\n            T=num_frames,\n            device=device,\n        )\n\n        # keys would be be ['crossattn', 'vector', 'concat']\n        c, uc = self.conditioner.get_unconditional_conditioning(\n            batch,\n            batch_uc=batch_uc,\n            force_uc_zero_embeddings=[\n                \"cond_frames\",\n                \"cond_frames_without_noise\",\n            ],\n        )\n\n        for k in [\"crossattn\", \"concat\"]:\n            uc[k] = repeat(uc[k], \"b ... -> b t ...\", t=num_frames)\n            uc[k] = rearrange(uc[k], \"b t ... -> (b t) ...\", t=num_frames)\n            c[k] = repeat(c[k], \"b ... -> b t ...\", t=num_frames)\n            c[k] = rearrange(c[k], \"b t ... -> (b t) ...\", t=num_frames)\n\n        # after this should be\n        # crossattn [14, 1, 1024];  vector [14, 768]; concat [14, 4, 72, 128]\n        additional_model_inputs = {}\n        additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n            int(2 * batch_size), num_frames\n        ).to(device)\n        additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n\n        # step-2: sample t and sigmas, then noise\n        sampled_t = np.random.randint(\n            low=int(sample_time_range[0] * self.sigmas_all.shape[0]),\n            high=int(sample_time_range[1] * self.sigmas_all.shape[0]),\n            size=(batch_size),\n        ).tolist() # list of index time t [B]\n        sigmas = self.sigmas_all[sampled_t]\n\n        # sigmas = self.loss_fn.sigma_sampler(batch_size).to(input_x)\n        sigmas = repeat(sigmas, \"b ... -> b t ...\", t=num_frames)\n        sigmas = rearrange(sigmas, \"b t ... -> (b t) ...\", t=num_frames)\n\n        noise = torch.randn_like(input_x)  # [BT, C, H, W]\n\n        sigmas_bc = append_dims(sigmas, input_x.ndim)  # [14, 1, 1, 1]\n\n        noised_input = self.loss_fn.get_noised_input(\n            sigmas_bc, noise, input_x\n        )  # [BT, C, H, W]\n\n        # step-3: prepare conditional and unconditional inputs\n        # [2BT, C, H, W], [2BT]\n        bathced_xt, bathced_sigmas, bathched_c = self.sampler.guider.prepare_inputs(\n            noised_input, sigmas, c, uc\n        )\n        # bathched_c[\"crossattn\"] => [2BT, 1, C] ;   bathched_c[\"concat\"] => [2BT, C, H, W]; bathched_c[\"vector\"] => [2BT, C_feat]\n\n        # output shape [2BT, C, H, W]\n        denoised = self.denoiser(\n            self.model,\n            bathced_xt,\n            bathced_sigmas,\n            bathched_c,\n            **additional_model_inputs,\n        )\n\n        # step-4: cfg guidance and compute sds_grad\n        # [BT, C, H, W]\n        denoised = self.sampler.guider(denoised, bathced_sigmas)\n\n        # sds_grad = (input_x - denoised) / sigmas_bc\n        sds_grad = (input_x - denoised) / torch.norm((input_x - denoised))\n\n        return sds_grad, denoised\n\n    @torch.no_grad()\n    def edm_sds_multistep(self, input_x, extra_input, sample_time_range=[0.02, 0.84], num_step=4, total_steps=25):\n        \"\"\"\n        From t = 20 sample to t = 980.\n        Args:\n            input_x: [BT, C, H, W] in latent\n            extra_input: dict\n                \"fps_id\": [B]\n                \"motion_bucket_id\": [B]\n                \"cond_aug\": [B]\n                \"cond_frames_without_noise\": [B, C, H, W]\n                \"cond_frames\": [B, C, H, W]\n            sample_time_range: [t_min, t_max]\n        \"\"\"\n\n        # step-1: prepare inputs\n        num_frames = extra_input[\"num_video_frames\"]\n        batch_size = input_x.shape[0] // num_frames\n        device = input_x.device\n        # video = video.contiguous()\n\n        extra_input[\"num_video_frames\"] = num_frames\n\n        # prepare c and uc\n\n        batch, batch_uc = get_batch(\n            get_unique_embedder_keys_from_conditioner(self.conditioner),\n            extra_input,\n            [1, num_frames],\n            T=num_frames,\n            device=device,\n        )\n\n        # keys would be be ['crossattn', 'vector', 'concat']\n        c, uc = self.conditioner.get_unconditional_conditioning(\n            batch,\n            batch_uc=batch_uc,\n            force_uc_zero_embeddings=[\n                \"cond_frames\",\n                \"cond_frames_without_noise\",\n            ],\n        )\n\n        for k in [\"crossattn\", \"concat\"]:\n            uc[k] = repeat(uc[k], \"b ... -> b t ...\", t=num_frames)\n            uc[k] = rearrange(uc[k], \"b t ... -> (b t) ...\", t=num_frames)\n            c[k] = repeat(c[k], \"b ... -> b t ...\", t=num_frames)\n            c[k] = rearrange(c[k], \"b t ... -> (b t) ...\", t=num_frames)\n\n        # after this should be\n        # crossattn [14, 1, 1024];  vector [14, 768]; concat [14, 4, 72, 128]\n        additional_model_inputs = {}\n        additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n            int(2 * batch_size), num_frames\n        ).to(device)\n        additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n\n        # step-2: sample t and sigmas, then noise\n        sampled_t = np.random.randint(\n            low=int(sample_time_range[0] * self.sigmas_all.shape[0]),\n            high=int(sample_time_range[1] * self.sigmas_all.shape[0]),\n            size=(batch_size),\n        ) # np.array of index time t [B]\n\n        step_stride = len(self.sigmas_all) // total_steps\n\n        sigma_sum = 0.0\n        for i in range(num_step):\n            sampled_t += step_stride * i\n            sampled_t = np.clip(sampled_t, 0, len(self.sigmas_all) - 2)\n\n\n            # [B]\n            sigmas = self.sigmas_all[sampled_t]\n\n            # sigmas = self.loss_fn.sigma_sampler(batch_size).to(input_x)\n            sigmas = repeat(sigmas, \"b ... -> b t ...\", t=num_frames)\n            sigmas = rearrange(sigmas, \"b t ... -> (b t) ...\", t=num_frames)\n\n            sigmas_bc = append_dims(sigmas, input_x.ndim)  # [14, 1, 1, 1]\n\n            if i == 0:\n\n                noise = torch.randn_like(input_x)  # [BT, C, H, W]\n\n                noised_input = self.loss_fn.get_noised_input(\n                    sigmas_bc, noise, input_x\n                )  # [BT, C, H, W]\n            else:\n                # dt is negative\n                dt = append_dims(sigmas - prev_sigmas, input_x.ndim)\n\n                dx = (noised_input - denoised) / append_dims(prev_sigmas, input_x.ndim)\n                noised_input = noised_input + dt * dx\n\n            denoised = self.sampler_step(sigmas, noised_input, c, uc,\n                                         num_frames=num_frames, additional_model_inputs=additional_model_inputs)\n            prev_sigmas = sigmas\n            sigma_sum += sigmas_bc\n\n        # TODO, so many sigmas, which to use?\n        # sds_grad = (input_x - denoised) / sigmas_bc\n        # sds_grad = (input_x - denoised) / sigma_sum\n        sds_grad = (input_x - denoised) / torch.norm((input_x - denoised))\n\n        return sds_grad, denoised\n\n    @torch.no_grad()\n    def resample_multistep(self, input_x, extra_input, sample_time_range=[0.02, 0.84], num_step=4):\n        \"\"\"\n        From t = 20 sample to t = 980.\n        Args:\n            input_x: [BT, C, H, W] in latent\n            extra_input: dict\n                \"fps_id\": [B]\n                \"motion_bucket_id\": [B]\n                \"cond_aug\": [B]\n                \"cond_frames_without_noise\": [B, C, H, W]\n                \"cond_frames\": [B, C, H, W]\n            sample_time_range: [t_min, t_max]\n        \"\"\"\n\n        # step-1: prepare inputs\n        num_frames = extra_input[\"num_video_frames\"]\n        batch_size = input_x.shape[0] // num_frames\n        device = input_x.device\n        # video = video.contiguous()\n\n        extra_input[\"num_video_frames\"] = num_frames\n\n        # prepare c and uc\n\n        batch, batch_uc = get_batch(\n            get_unique_embedder_keys_from_conditioner(self.conditioner),\n            extra_input,\n            [1, num_frames],\n            T=num_frames,\n            device=device,\n        )\n\n        # keys would be be ['crossattn', 'vector', 'concat']\n        c, uc = self.conditioner.get_unconditional_conditioning(\n            batch,\n            batch_uc=batch_uc,\n            force_uc_zero_embeddings=[\n                \"cond_frames\",\n                \"cond_frames_without_noise\",\n            ],\n        )\n\n        for k in [\"crossattn\", \"concat\"]:\n            uc[k] = repeat(uc[k], \"b ... -> b t ...\", t=num_frames)\n            uc[k] = rearrange(uc[k], \"b t ... -> (b t) ...\", t=num_frames)\n            c[k] = repeat(c[k], \"b ... -> b t ...\", t=num_frames)\n            c[k] = rearrange(c[k], \"b t ... -> (b t) ...\", t=num_frames)\n\n        # after this should be\n        # crossattn [14, 1, 1024];  vector [14, 768]; concat [14, 4, 72, 128]\n        additional_model_inputs = {}\n        additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n            int(2 * batch_size), num_frames\n        ).to(device)\n        additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n\n        # step-2: sample t and sigmas, then noise\n        sampled_t = np.random.randint(\n            low=int(sample_time_range[0] * self.sigmas_all.shape[0]),\n            high=int(sample_time_range[1] * self.sigmas_all.shape[0]),\n            size=(batch_size),\n        ) # np.array of index time t [B]\n\n        sampled_steps = np.linspace(sampled_t[0], len(self.sigmas_all) - 1, num_step, dtype=int)\n        sigma_sum = 0.0\n        for i in range(len(sampled_steps)):\n            sampled_t = sampled_steps[[i]]\n\n            # sampled_t = np.clip(sampled_t, 0, len(self.sigmas_all) - 2)\n\n            # [B]\n            sigmas = self.sigmas_all[sampled_t]\n\n            # sigmas = self.loss_fn.sigma_sampler(batch_size).to(input_x)\n            sigmas = repeat(sigmas, \"b ... -> b t ...\", t=num_frames)\n            sigmas = rearrange(sigmas, \"b t ... -> (b t) ...\", t=num_frames)\n\n            sigmas_bc = append_dims(sigmas, input_x.ndim)  # [14, 1, 1, 1]\n\n            if i == 0:\n\n                noise = torch.randn_like(input_x)  # [BT, C, H, W]\n\n                noised_input = self.loss_fn.get_noised_input(\n                    sigmas_bc, noise, input_x\n                )  # [BT, C, H, W]\n            else:\n                # dt is negative\n                dt = append_dims(sigmas - prev_sigmas, input_x.ndim)\n\n                dx = (noised_input - denoised) / append_dims(prev_sigmas, input_x.ndim)\n                noised_input = noised_input + dt * dx\n\n            denoised = self.sampler_step(sigmas, noised_input, c, uc,\n                                         num_frames=num_frames, additional_model_inputs=additional_model_inputs)\n            prev_sigmas = sigmas\n            sigma_sum += sigmas_bc\n\n        # TODO, so many sigmas, which to use?\n        # sds_grad = (input_x - denoised) / sigmas_bc\n        # sds_grad = (input_x - denoised) / sigma_sum\n        # sds_grad = (input_x - denoised) / torch.norm((input_x - denoised))\n\n        return denoised\n\n    def sampler_step(self, sigma, noised_input, c, uc=None, num_frames=None, additional_model_inputs=None):\n\n        # step-3: prepare conditional and unconditional inputs\n        # [2BT, C, H, W], [2BT]\n        bathced_xt, bathced_sigmas, bathched_c = self.sampler.guider.prepare_inputs(\n            noised_input, sigma, c, uc\n        )\n        # bathched_c[\"crossattn\"] => [2BT, 1, C] ;   bathched_c[\"concat\"] => [2BT, C, H, W]; bathched_c[\"vector\"] => [2BT, C_feat]\n\n        # output shape [2BT, C, H, W]\n        denoised = self.denoiser(\n            self.model,\n            bathced_xt,\n            bathced_sigmas,\n            bathched_c,\n            **additional_model_inputs,\n        )\n\n        # step-4: cfg guidance and compute sds_grad\n        # [BT, C, H, W]\n        denoised = self.sampler.guider(denoised, bathced_sigmas)\n\n        return denoised\n\n\n\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/diffusion/video_diffusion_loss.py",
    "content": "from typing import Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom sgm.modules.autoencoding.lpips.loss.lpips import LPIPS\nfrom sgm.modules.encoders.modules import GeneralConditioner\nfrom sgm.util import append_dims, instantiate_from_config\nfrom sgm.modules.diffusionmodules.denoiser import Denoiser\nfrom einops import rearrange, repeat\n\n\nclass StandardVideoDiffusionLoss(nn.Module):\n    def __init__(\n        self,\n        sigma_sampler_config: dict,\n        loss_weighting_config: dict,\n        loss_type: str = \"l2\",\n        offset_noise_level: float = 0.0,\n        batch2model_keys: Optional[Union[str, List[str]]] = None,\n    ):\n        super().__init__()\n\n        assert loss_type in [\"l2\", \"l1\", \"lpips\"]\n\n        self.sigma_sampler = instantiate_from_config(sigma_sampler_config)\n        self.loss_weighting = instantiate_from_config(loss_weighting_config)\n\n        self.loss_type = loss_type\n        self.offset_noise_level = offset_noise_level\n\n        if loss_type == \"lpips\":\n            self.lpips = LPIPS().eval()\n\n        if not batch2model_keys:\n            batch2model_keys = []\n\n        if isinstance(batch2model_keys, str):\n            batch2model_keys = [batch2model_keys]\n\n        self.batch2model_keys = set(batch2model_keys)\n\n    def get_noised_input(\n        self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor\n    ) -> torch.Tensor:\n        noised_input = input + noise * sigmas_bc\n        return noised_input\n\n    def forward(\n        self,\n        network: nn.Module,\n        denoiser: Denoiser,\n        conditioner: GeneralConditioner,\n        input: torch.Tensor,\n        batch: Dict,\n    ) -> torch.Tensor:\n        cond = conditioner(batch)\n        num_frames = batch[\"num_video_frames\"]\n        for k in [\"crossattn\", \"concat\"]:\n            cond[k] = repeat(cond[k], \"b ... -> b t ...\", t=num_frames)\n            cond[k] = rearrange(cond[k], \"b t ... -> (b t) ...\", t=num_frames)\n\n        return self._forward(network, denoiser, cond, input, batch)\n\n    def _forward(\n        self,\n        network: nn.Module,\n        denoiser: Denoiser,\n        cond: Dict,\n        input: torch.Tensor,\n        batch: Dict,\n    ) -> Tuple[torch.Tensor, Dict]:\n        additional_model_inputs = {\n            key: batch[key] for key in self.batch2model_keys.intersection(batch)\n        }\n        # print(\"pre check additional inputs\", additional_model_inputs.keys())\n        num_frames = batch[\"num_video_frames\"]\n        batch_size = input.shape[0] // num_frames\n        additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n            batch_size, num_frames\n        ).to(input.device)\n        additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n\n        # sigmas = self.sigma_sampler(input.shape[0]).to(input)\n        sigmas = self.sigma_sampler(batch_size).to(input)\n        sigmas = repeat(sigmas, \"b ... -> b t ...\", t=num_frames)\n        sigmas = rearrange(sigmas, \"b t ... -> (b t) ...\", t=num_frames)\n\n        noise = torch.randn_like(input)\n        if self.offset_noise_level > 0.0:\n            offset_shape = (\n                (input.shape[0], 1, input.shape[2])\n                if self.n_frames is not None\n                else (input.shape[0], input.shape[1])\n            )\n            noise = noise + self.offset_noise_level * append_dims(\n                torch.randn(offset_shape, device=input.device),\n                input.ndim,\n            )\n        sigmas_bc = append_dims(sigmas, input.ndim)\n        noised_input = self.get_noised_input(sigmas_bc, noise, input)\n\n        model_output = denoiser(\n            network, noised_input, sigmas, cond, **additional_model_inputs\n        )\n        w = append_dims(self.loss_weighting(sigmas), input.ndim)\n        return self.get_loss(model_output, input, w)\n\n    def get_loss(self, model_output, target, w):\n        if self.loss_type == \"l2\":\n            return torch.mean(\n                (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1\n            )\n        elif self.loss_type == \"l1\":\n            return torch.mean(\n                (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1\n            )\n        elif self.loss_type == \"lpips\":\n            loss = self.lpips(model_output, target).reshape(-1)\n            return loss\n        else:\n            raise NotImplementedError(f\"Unknown loss type {self.loss_type}\")\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/field_components/encoding.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Optional, Sequence, Tuple, List\nfrom motionrep.losses.smoothness_loss import compute_plane_smoothness, compute_plane_tv\n\n\nclass TemporalKplanesEncoding(nn.Module):\n    \"\"\"\n\n    Args:\n        resolutions (Sequence[int]): xyzt resolutions.\n    \"\"\"\n\n    def __init__(\n        self,\n        resolutions: Sequence[int],\n        feat_dim: int = 32,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce=\"sum\",  # Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n    ):\n        super().__init__()\n\n        self.resolutions = resolutions\n\n        if reduce == \"cat\":\n            feat_dim = feat_dim // 3\n        self.feat_dim = feat_dim\n\n        self.reduce = reduce\n\n        self.in_dim = 4\n\n        self.plane_coefs = nn.ParameterList()\n\n        self.coo_combs = [[0, 3], [1, 3], [2, 3]]\n        # [(x, t), (y, t), (z, t)]\n        for coo_comb in self.coo_combs:\n            # [feat_dim, time_resolution, spatial_resolution]\n            new_plane_coef = nn.Parameter(\n                torch.empty(\n                    [\n                        self.feat_dim,\n                        resolutions[coo_comb[1]],\n                        resolutions[coo_comb[0]],  # flip?\n                    ]\n                )\n            )\n\n            # when init to ones?\n\n            nn.init.uniform_(new_plane_coef, a=init_a, b=init_b)\n            self.plane_coefs.append(new_plane_coef)\n\n    def forward(self, inp: Float[Tensor, \"*bs 4\"]):\n        output = 1.0 if self.reduce == \"product\" else 0.0\n        if self.reduce == \"cat\":\n            output = []\n        for ci, coo_comb in enumerate(self.coo_combs):\n            grid = self.plane_coefs[ci].unsqueeze(0)  # [1, feature_dim, reso1, reso2]\n            coords = inp[..., coo_comb].view(1, 1, -1, 2)  # [1, 1, flattened_bs, 2]\n\n            interp = F.grid_sample(\n                grid, coords, align_corners=True, padding_mode=\"border\"\n            )  # [1, output_dim, 1, flattened_bs]\n            interp = interp.view(self.feat_dim, -1).T  # [flattened_bs, output_dim]\n\n            if self.reduce == \"product\":\n                output = output * interp\n            elif self.reduce == \"sum\":\n                output = output + interp\n            elif self.reduce == \"cat\":\n                output.append(interp)\n\n        if self.reduce == \"cat\":\n            # [flattened_bs, output_dim * 3]\n            output = torch.cat(output, dim=-1)\n\n        return output\n\n    def compute_temporal_smoothness(\n        self,\n    ):\n        ret_loss = 0.0\n\n        for plane_coef in self.plane_coefs:\n            ret_loss += compute_plane_smoothness(plane_coef)\n\n        return ret_loss\n\n    def compute_plane_tv(\n        self,\n    ):\n        ret_loss = 0.0\n\n        for plane_coef in self.plane_coefs:\n            ret_loss += compute_plane_tv(plane_coef)\n\n        return ret_loss\n\n    def visualize(\n        self,\n    ) -> Tuple[Float[Tensor, \"3 H W\"]]:\n        \"\"\"Visualize the encoding as a RGB images\n\n        Returns:\n            Tuple[Float[Tensor, \"3 H W\"]]\n        \"\"\"\n        pass\n\n    @staticmethod\n    def functional_forward(\n        plane_coefs: List[Float[Tensor, \"feat_dim H W\"]],\n        inp: Float[Tensor, \"*bs 4\"],\n        reduce: str = \"sum\",\n        coo_combs: Optional[List[List[int]]] = [[0, 3], [1, 3], [2, 3]],\n    ):\n        assert reduce in [\"sum\", \"product\", \"cat\"]\n        output = 1.0 if reduce == \"product\" else 0.0\n\n        if reduce == \"cat\":\n            output = []\n        for ci, coo_comb in enumerate(coo_combs):\n            grid = plane_coefs[ci].unsqueeze(0)  # [1, feature_dim, reso1, reso2]\n            feat_dim = grid.shape[1]\n            coords = inp[..., coo_comb].view(1, 1, -1, 2)  # [1, 1, flattened_bs, 2]\n\n            interp = F.grid_sample(\n                grid, coords, align_corners=True, padding_mode=\"border\"\n            )  # [1, output_dim, 1, flattened_bs]\n            interp = interp.view(feat_dim, -1).T  # [flattened_bs, output_dim]\n\n            if reduce == \"product\":\n                output = output * interp\n            elif reduce == \"sum\":\n                output = output + interp\n            elif reduce == \"cat\":\n                output.append(interp)\n\n        if reduce == \"cat\":\n            # [flattened_bs, output_dim * 3]\n            output = torch.cat(output, dim=-1)\n\n        return output\n\n\nclass TriplanesEncoding(nn.Module):\n    \"\"\"\n\n    Args:\n        resolutions (Sequence[int]): xyz resolutions.\n    \"\"\"\n\n    def __init__(\n        self,\n        resolutions: Sequence[int],\n        feat_dim: int = 32,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce=\"sum\",  # Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n    ):\n        super().__init__()\n\n        self.resolutions = resolutions\n\n        if reduce == \"cat\":\n            feat_dim = feat_dim#  // 3\n        self.feat_dim = feat_dim\n\n        self.reduce = reduce\n\n        self.in_dim = 3\n\n        self.plane_coefs = nn.ParameterList()\n\n        self.coo_combs = [[0, 1], [0, 2], [1, 2]]\n        # [(x, t), (y, t), (z, t)]\n        for coo_comb in self.coo_combs:\n            new_plane_coef = nn.Parameter(\n                torch.empty(\n                    [\n                        self.feat_dim,\n                        resolutions[coo_comb[1]],\n                        resolutions[coo_comb[0]],\n                    ]\n                )\n            )\n\n            # when init to ones?\n\n            nn.init.uniform_(new_plane_coef, a=init_a, b=init_b)\n            self.plane_coefs.append(new_plane_coef)\n\n    def forward(self, inp: Float[Tensor, \"*bs 3\"]):\n        output = 1.0 if self.reduce == \"product\" else 0.0\n        if self.reduce == \"cat\":\n            output = []\n        for ci, coo_comb in enumerate(self.coo_combs):\n            grid = self.plane_coefs[ci].unsqueeze(0)  # [1, feature_dim, reso1, reso2]\n            coords = inp[..., coo_comb].view(1, 1, -1, 2)  # [1, 1, flattened_bs, 2]\n\n            interp = F.grid_sample(\n                grid, coords, align_corners=True, padding_mode=\"border\"\n            )  # [1, output_dim, 1, flattened_bs]\n            interp = interp.view(self.feat_dim, -1).T  # [flattened_bs, output_dim]\n\n            if self.reduce == \"product\":\n                output = output * interp\n            elif self.reduce == \"sum\":\n                output = output + interp\n            elif self.reduce == \"cat\":\n                output.append(interp)\n\n        if self.reduce == \"cat\":\n            # [flattened_bs, output_dim * 3]\n            output = torch.cat(output, dim=-1)\n\n        return output\n\n    def compute_plane_tv(\n        self,\n    ):\n        ret_loss = 0.0\n\n        for plane_coef in self.plane_coefs:\n            ret_loss += compute_plane_tv(plane_coef)\n\n        return ret_loss\n\n\nclass PlaneEncoding(nn.Module):\n    \"\"\"\n\n    Args:\n        resolutions (Sequence[int]): xyz resolutions.\n    \"\"\"\n\n    def __init__(\n        self,\n        resolutions: Sequence[int],  # [y_res, x_res]\n        feat_dim: int = 32,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n    ):\n        super().__init__()\n\n        self.resolutions = resolutions\n\n        self.feat_dim = feat_dim\n        self.in_dim = 2\n\n        self.plane_coefs = nn.ParameterList()\n\n        self.coo_combs = [[0, 1]]\n        for coo_comb in self.coo_combs:\n            new_plane_coef = nn.Parameter(\n                torch.empty(\n                    [\n                        self.feat_dim,\n                        resolutions[coo_comb[1]],\n                        resolutions[coo_comb[0]],\n                    ]\n                )\n            )\n\n            # when init to ones?\n\n            nn.init.uniform_(new_plane_coef, a=init_a, b=init_b)\n            self.plane_coefs.append(new_plane_coef)\n\n    def forward(self, inp: Float[Tensor, \"*bs 2\"]):\n\n        for ci, coo_comb in enumerate(self.coo_combs):\n            grid = self.plane_coefs[ci].unsqueeze(0)  # [1, feature_dim, reso1, reso2]\n            coords = inp[..., coo_comb].view(1, 1, -1, 2)  # [1, 1, flattened_bs, 2]\n\n            interp = F.grid_sample(\n                grid, coords, align_corners=True, padding_mode=\"border\"\n            )  # [1, output_dim, 1, flattened_bs]\n            interp = interp.view(self.feat_dim, -1).T  # [flattened_bs, output_dim]\n\n            output = interp\n\n        return output\n\n    def compute_plane_tv(\n        self,\n    ):\n        ret_loss = 0.0\n\n        for plane_coef in self.plane_coefs:\n            ret_loss += compute_plane_tv(plane_coef)\n\n        return ret_loss\n\n\nclass TemporalNeRFEncoding(nn.Module):\n    def __init__(\n        self,\n        in_dim,  # : int,\n        num_frequencies: int,\n        min_freq_exp: float,\n        max_freq_exp: float,\n        log_scale: bool = False,\n        include_input: bool = False,\n    ) -> None:\n        super().__init__()\n        self.in_dim = in_dim\n        self.num_frequencies = num_frequencies\n        self.min_freq = min_freq_exp\n        self.max_freq = max_freq_exp\n        self.log_scale = log_scale\n        self.include_input = include_input\n\n    def get_out_dim(self) -> int:\n        if self.in_dim is None:\n            raise ValueError(\"Input dimension has not been set\")\n        out_dim = self.in_dim * self.num_frequencies * 2\n        if self.include_input:\n            out_dim += self.in_dim\n        return out_dim\n\n    def forward(\n        self,\n        in_tensor: Float[Tensor, \"*bs input_dim\"],\n    ) -> Float[Tensor, \"*bs output_dim\"]:\n        \"\"\"Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed\n            in mip-NeRF.\n\n        Args:\n            in_tensor: For best performance, the input tensor should be between 0 and 1.\n            covs: Covariances of input points.\n        Returns:\n            Output values will be between -1 and 1\n        \"\"\"\n        scaled_in_tensor = 2 * torch.pi * in_tensor  # scale to [0, 2pi]\n    \n        # freqs = 2 ** torch.linspace(\n        freqs = torch.linspace(\n            self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device\n        )\n        if self.log_scale:\n            freqs = 2 ** freqs\n        scaled_inputs = (\n            scaled_in_tensor[..., None] * freqs\n        )  # [..., \"input_dim\", \"num_scales\"]\n        scaled_inputs = scaled_inputs.view(\n            *scaled_inputs.shape[:-2], -1\n        )  # [..., \"input_dim\" * \"num_scales\"]\n\n        encoded_inputs = torch.sin(\n            torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1)\n        )\n        return encoded_inputs\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/field_components/mlp.py",
    "content": "\"\"\"\nMostly from nerfstudio: https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/field_components/mlp.py\n\"\"\"\nfrom typing import Optional, Set, Tuple, Union\n\nimport torch\nfrom jaxtyping import Float\nfrom torch import Tensor, nn\n\n\nclass MLP(nn.Module):\n    def __init__(\n        self,\n        in_dim: int,\n        num_layers: int,\n        layer_width: int,\n        out_dim: Optional[int] = None,\n        skip_connections: Optional[Tuple[int]] = None,\n        activation: Optional[nn.Module] = nn.ReLU(),\n        out_activation: Optional[nn.Module] = None,\n        zero_init = False,\n    ) -> None:\n        super().__init__()\n        self.in_dim = in_dim\n        assert self.in_dim > 0\n        self.out_dim = out_dim if out_dim is not None else layer_width\n        self.num_layers = num_layers\n        self.layer_width = layer_width\n        self.skip_connections = skip_connections\n        self._skip_connections: Set[int] = (\n            set(skip_connections) if skip_connections else set()\n        )\n        self.activation = activation\n        self.out_activation = out_activation\n        self.net = None\n        self.zero_init = zero_init\n\n        self.build_nn_modules()\n\n    def build_nn_modules(self) -> None:\n        \"\"\"Initialize multi-layer perceptron.\"\"\"\n        layers = []\n        if self.num_layers == 1:\n            layers.append(nn.Linear(self.in_dim, self.out_dim))\n        else:\n            for i in range(self.num_layers - 1):\n                if i == 0:\n                    assert (\n                        i not in self._skip_connections\n                    ), \"Skip connection at layer 0 doesn't make sense.\"\n                    layers.append(nn.Linear(self.in_dim, self.layer_width))\n                elif i in self._skip_connections:\n                    layers.append(\n                        nn.Linear(self.layer_width + self.in_dim, self.layer_width)\n                    )\n                else:\n                    layers.append(nn.Linear(self.layer_width, self.layer_width))\n            layers.append(nn.Linear(self.layer_width, self.out_dim))\n        self.layers = nn.ModuleList(layers)\n\n        if self.zero_init:\n            torch.nn.init.zeros_(self.layers[-1].weight)\n            torch.nn.init.zeros_(self.layers[-1].bias)\n\n    def pytorch_fwd(\n        self, in_tensor: Float[Tensor, \"*bs in_dim\"]\n    ) -> Float[Tensor, \"*bs out_dim\"]:\n        \"\"\"Process input with a multilayer perceptron.\n\n        Args:\n            in_tensor: Network input\n\n        Returns:\n            MLP network output\n        \"\"\"\n        x = in_tensor\n        for i, layer in enumerate(self.layers):\n            # as checked in `build_nn_modules`, 0 should not be in `_skip_connections`\n            if i in self._skip_connections:\n                x = torch.cat([in_tensor, x], -1)\n            x = layer(x)\n            if self.activation is not None and i < len(self.layers) - 1:\n                x = self.activation(x)\n        if self.out_activation is not None:\n            x = self.out_activation(x)\n        return x\n\n    def forward(\n        self, in_tensor: Float[Tensor, \"*bs in_dim\"]\n    ) -> Float[Tensor, \"*bs out_dim\"]:\n        return self.pytorch_fwd(in_tensor)\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/fields/dct_trajectory_field.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom motionrep.utils.dct import dct, idct, dct3d, idct_3d\n\n\nclass DCTTrajctoryField(nn.Module):\n    def __init__(\n        self,\n    ):\n        super().__init__()\n        pass\n\n    def forward(self, x):\n        pass\n\n    def query_points_at_time(self, x, t):\n        pass\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/fields/discrete_field.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Literal, Optional, Sequence, Tuple\nfrom motionrep.field_components.encoding import (\n    TemporalKplanesEncoding,\n    TriplanesEncoding,\n)\nfrom motionrep.field_components.mlp import MLP\nfrom motionrep.operators.rotation import rotation_6d_to_matrix, quaternion_to_matrix\nfrom motionrep.data.scene_box import SceneBox\n\n\nclass PointSetMotionSE3(nn.Module):\n    \"\"\"Temporal Kplanes SE(3) fields.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z ,t].\n\n    \"\"\"\n\n    def __init__(\n        self,\n        inp_x: Float[Tensor, \"*bs 3\"],\n        aabb: Float[Tensor, \"2 3\"],\n        rotation_type: Literal[\"quaternion\", \"6d\"] = \"6d\",\n        num_frames: int = 20,\n        distance_lamba=100.0,\n        topk_nn: int = 20,  # the same neighboor size as dynamic gaussian\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        output_dim_dict = {\"quaternion\": 4 + 3, \"6d\": 6 + 3}\n        self.output_dim = output_dim_dict[rotation_type]\n        self.rotation_type = rotation_type\n\n        self.register_buffer(\"inp_x\", inp_x.detach())\n\n        self.num_frames = num_frames\n\n        # init parameters:\n        translation = nn.Parameter(\n            torch.zeros(num_frames + 1, inp_x.shape[0], 3).requires_grad_(True)\n        )\n        rotation = nn.Parameter(\n            torch.ones(\n                (num_frames + 1, inp_x.shape[0], self.output_dim - 3)\n            ).requires_grad_(True)\n        )\n        self.register_parameter(\"translation\", translation)\n        self.register_parameter(\"rotation\", rotation)\n\n        # [num_points, topk]\n        print(inp_x.shape, \"input shape gaussian\")\n        knn_dist, knn_ind = self.construct_knn(inp_x, topk=topk_nn)\n\n        # [num_points, topk]\n        self.distance_weight = torch.exp(-1.0 * distance_lamba * knn_dist)\n        self.knn_index = knn_ind  # torch.long\n\n        self.precompute_isometry = self.prepare_isometry(inp_x, knn_ind)\n\n        self.inp_time_list = []\n\n    def construct_knn(self, inpx: Float[Tensor, \"*bs 3\"], topk=10, chunk_size=5000):\n        # compute topk nearest neighbors for each point, and the distance\n\n        knn_dist_list, knn_ind_list = [], []\n        num_step = inpx.shape[0] // chunk_size + 1\n\n        with torch.no_grad():\n            for i in range(num_step):\n                end_ind = min((i + 1) * chunk_size, inpx.shape[0])\n\n                src_points = inpx[i * chunk_size : end_ind]\n                # compute the distance matrix\n                cdist = torch.cdist(src_points, inpx)\n\n                print(cdist.shape, \"cdist\")\n                # get the topk nearest neighbors\n                knn_dist, knn_ind = torch.topk(cdist, topk, dim=1, largest=False)\n                knn_dist_list.append(knn_dist)\n                knn_ind_list.append(knn_ind)\n\n            knn_dist = torch.cat(knn_dist_list, dim=0)\n            knn_ind = torch.cat(knn_ind_list, dim=0)\n        return knn_dist, knn_ind\n\n    def prepare_isometry(self, points, knn_ind):\n        # [num_points, topk, 3]\n        p_nn = points[knn_ind]\n\n        dsp = points[:, None, :] - p_nn\n\n        distance = torch.norm(dsp, dim=-1)\n\n        # [num_points, topk]\n        return distance\n\n    def _forward_single_time(self, time_ind: int):\n        if self.rotation_type == \"6d\":\n            rotation_6d, translation = (\n                self.rotation[time_ind],\n                self.translation[time_ind],\n            )\n            R_mat = rotation_6d_to_matrix(rotation_6d)\n\n        elif self.rotation_type == \"quaternion\":\n            quat, translation = self.rotation[time_ind], self.translation[time_ind]\n\n            quat = torch.tanh(quat)\n            R_mat = quaternion_to_matrix(quat)\n\n        return R_mat, translation\n\n    def forward(\n        self,\n        inp: Float[Tensor, \"*bs 4\"],\n        **kwargs,\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"*bs 3\"]]:\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        time_ind = torch.round(inpt * (self.num_frames)).long()[0].item()\n        R_mat, translation = self._forward_single_time(time_ind)\n\n        self.inp_time_list.append(time_ind)\n        if len(self.inp_time_list) > 20:\n            self.inp_time_list.pop(0)\n\n        return R_mat, translation\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        # temporal_smoothness_loss = torch.tensor([0.0]).cuda()\n        temporal_smoothness_loss = self.compute_isometry_loss()\n        smothness_loss = self.compute_arap_loss()\n\n        return temporal_smoothness_loss, smothness_loss\n\n    def compute_arap_loss(\n        self,\n    ):\n        arap_loss = 0.0\n\n        # random sample 16 frames\n        random_frame_ind_list = torch.randint(0, self.num_frames - 1, (16,))\n\n        for i in self.inp_time_list:\n            r1, t1 = self._forward_single_time(i)\n            r2, t2 = self._forward_single_time(i + 1)\n\n            # [num_points, topk, 3, 3], [num_points, topk, 3]\n            r1_nn, t1_nn = r1[self.knn_index], t1[self.knn_index]\n            r2_nn, t2_nn = r2[self.knn_index], t2[self.knn_index]\n\n            # displacement between neighboor points\n            #   shape of [num_points, topk, 3]\n            dsp_t0 = t1_nn - t1[:, None, :]\n            dsp_t1 = t2_nn - t2[:, None, :]\n\n            # rotation matrix from frame-1 to frame-0\n\n            r_mat_1to0 = torch.bmm(r1, r2.transpose(1, 2))  # [N, 3, 3]\n            # [N, 3, 3] => [N, topk, 3, 3]\n            r_mat_1to0 = r_mat_1to0.unsqueeze(1).repeat(\n                1, self.knn_index.shape[1], 1, 1\n            )\n            dsp_t1_to_0 = torch.matmul(r_mat_1to0, dsp_t1[:, :, :, None]).squeeze(-1)\n            # compute the arap loss\n            arap_loss += torch.mean(\n                torch.norm(dsp_t0 - dsp_t1_to_0, dim=-1) * self.distance_weight\n            )\n        return arap_loss\n\n    def compute_isometry_loss(\n        self,\n    ):\n        iso_loss = 0.0\n        # random sample 16 frames\n        random_frame_ind_list = torch.randint(0, self.num_frames - 1, (16,))\n\n        for i in self.inp_time_list:\n            r1, t1 = self._forward_single_time(i)\n            points = self.inp_x + t1\n            distance_mat = self.prepare_isometry(points, self.knn_index)\n\n            iso_loss += torch.mean(\n                torch.abs(distance_mat - self.precompute_isometry)\n                * self.distance_weight\n            )\n        return iso_loss\n\n    def compute_loss(\n        self,\n        inp: Float[Tensor, \"*bs 4\"],\n        trajectory: Float[Tensor, \"*bs 3\"],\n        loss_func,\n    ):\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        R, t = self(inp)\n\n        rec_traj = torch.bmm(R, inpx.unsqueeze(-1)).squeeze(-1) + t\n\n        rec_loss = loss_func(rec_traj, trajectory)\n\n        return rec_loss\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/fields/mul_offset_field.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Literal, Optional, Sequence, Tuple, List\nfrom motionrep.field_components.encoding import (\n    TemporalKplanesEncoding,\n    TriplanesEncoding,\n)\nfrom motionrep.field_components.mlp import MLP\nfrom motionrep.operators.rotation import rotation_6d_to_matrix, quaternion_to_matrix\nfrom motionrep.data.scene_box import SceneBox\n\n\nclass MulTemporalKplanesOffsetfields(nn.Module):\n    \"\"\"Multiple Temporal Kplanes SE(3) fields.\n\n        Decoder is shared, but plane coefs are different.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z ,t].\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions_list: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        add_spatial_triplane: bool = True,\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        self.output_dim = 3\n\n        self.temporal_kplanes_encoding_list = nn.ModuleList(\n            [\n                TemporalKplanesEncoding(resolutions, feat_dim, init_a, init_b, reduce)\n                for resolutions in resolutions_list\n            ]\n        )\n\n        self.add_spatial_triplane = add_spatial_triplane\n        if add_spatial_triplane:\n            self.spatial_kplanes_encoding_list = nn.ModuleList(\n                [\n                    TriplanesEncoding(\n                        resolutions[:-1], feat_dim, init_a, init_b, reduce\n                    )\n                    for resolutions in resolutions_list\n                ]\n            )\n            feat_dim = feat_dim * 2\n\n        self.decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=self.output_dim,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n        )\n\n    def forward(\n        self, inp: Float[Tensor, \"*bs 4\"], dataset_indx: Int[Tensor, \"1\"]\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"*bs 3\"]]:\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inpx, self.aabb) * 2.0 - 1.0\n\n        inpt = inpt * 2.0 - 1.0\n\n        inp = torch.cat([inpx, inpt], dim=-1)\n\n        # for loop in batch dimension\n\n        output = self.temporal_kplanes_encoding_list[dataset_indx](inp)\n\n        if self.add_spatial_triplane:\n            spatial_output = self.spatial_kplanes_encoding_list[dataset_indx](inp)\n            output = torch.cat([output, spatial_output], dim=-1)\n\n        output = self.decoder(output)\n\n        return output\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        temporal_smoothness_loss = 0.0\n        for temporal_kplanes_encoding in self.temporal_kplanes_encoding_list:\n            temporal_smoothness_loss += (\n                temporal_kplanes_encoding.compute_temporal_smoothness()\n            )\n\n        smothness_loss = 0.0\n        for temporal_kplanes_encoding in self.temporal_kplanes_encoding_list:\n            smothness_loss += temporal_kplanes_encoding.compute_plane_tv()\n\n        if self.add_spatial_triplane:\n            for spatial_kplanes_encoding in self.spatial_kplanes_encoding_list:\n                smothness_loss += spatial_kplanes_encoding.compute_plane_tv()\n\n        return smothness_loss, temporal_smoothness_loss\n\n    def compute_loss(\n        self,\n        inp: Float[Tensor, \"*bs 4\"],\n        trajectory: Float[Tensor, \"*bs 3\"],\n        loss_func,\n    ):\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        output = self(inp)\n\n        rec_traj = inpx + output\n\n        rec_loss = loss_func(rec_traj, trajectory)\n\n        return rec_loss\n\n    def arap_loss(self, inp):\n        pass\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/fields/mul_se3_field.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Literal, Optional, Sequence, Tuple\nfrom motionrep.field_components.encoding import (\n    TemporalKplanesEncoding,\n    TriplanesEncoding,\n)\nfrom motionrep.field_components.mlp import MLP\nfrom motionrep.operators.rotation import rotation_6d_to_matrix, quaternion_to_matrix\nfrom motionrep.data.scene_box import SceneBox\n\n\nclass MulTemporalKplanesSE3fields(nn.Module):\n    \"\"\"Multiple Temporal Kplanes SE(3) fields.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z ,t].\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions_list: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        rotation_type: Literal[\"quaternion\", \"6d\"] = \"6d\",\n        add_spatial_triplane: bool = True,\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        output_dim_dict = {\"quaternion\": 4 + 3, \"6d\": 6 + 3}\n        self.output_dim = output_dim_dict[rotation_type]\n        self.rotation_type = rotation_type\n\n        self.temporal_kplanes_encoding_list = nn.ModuleList(\n            [\n                TemporalKplanesEncoding(resolutions, feat_dim, init_a, init_b, reduce)\n                for resolutions in resolutions_list\n            ]\n        )\n\n        self.add_spatial_triplane = add_spatial_triplane\n        if add_spatial_triplane:\n            self.spatial_kplanes_encoding_list = nn.ModuleList(\n                [\n                    TriplanesEncoding(\n                        resolutions[:-1], feat_dim, init_a, init_b, reduce\n                    )\n                    for resolutions in resolutions_list\n                ]\n            )\n            feat_dim = feat_dim * 2\n\n        self.decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=self.output_dim,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n        )\n\n    def forward(\n        self, inp: Float[Tensor, \"*bs 4\"], dataset_indx: Int[Tensor, \"1\"]\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"*bs 3\"]]:\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inpx, self.aabb) * 2.0 - 1.0\n\n        inpt = inpt * 2.0 - 1.0\n\n        inp = torch.cat([inpx, inpt], dim=-1)\n\n        # for loop in batch dimension\n\n        output = self.temporal_kplanes_encoding_list[dataset_indx](inp)\n\n        if self.add_spatial_triplane:\n            spatial_output = self.spatial_kplanes_encoding_list[dataset_indx](inp)\n            output = torch.cat([output, spatial_output], dim=-1)\n\n        output = self.decoder(output)\n\n        if self.rotation_type == \"6d\":\n            rotation_6d, translation = output[:, :6], output[:, 6:]\n            R_mat = rotation_6d_to_matrix(rotation_6d)\n\n        elif self.rotation_type == \"quaternion\":\n            quat, translation = output[:, :4], output[:, 4:]\n\n            # tanh and normalize\n            quat = torch.tanh(quat)\n\n            R_mat = quaternion_to_matrix(quat)\n\n        return R_mat, translation\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        temporal_smoothness_loss = 0.0\n        for temporal_kplanes_encoding in self.temporal_kplanes_encoding_list:\n            temporal_smoothness_loss += (\n                temporal_kplanes_encoding.compute_temporal_smoothness()\n            )\n\n        smothness_loss = 0.0\n        for temporal_kplanes_encoding in self.temporal_kplanes_encoding_list:\n            smothness_loss += temporal_kplanes_encoding.compute_plane_tv()\n\n        if self.add_spatial_triplane:\n            for spatial_kplanes_encoding in self.spatial_kplanes_encoding_list:\n                smothness_loss += spatial_kplanes_encoding.compute_plane_tv()\n\n        return smothness_loss, temporal_smoothness_loss\n\n    def compute_loss(\n        self,\n        inp: Float[Tensor, \"*bs 4\"],\n        trajectory: Float[Tensor, \"*bs 3\"],\n        loss_func,\n    ):\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        R, t = self(inp)\n\n        rec_traj = torch.bmm(R, inpx.unsqueeze(-1)).squeeze(-1) + t\n\n        rec_loss = loss_func(rec_traj, trajectory)\n\n        return rec_loss\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/fields/offset_field.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Literal, Optional, Sequence, Tuple, List\nfrom motionrep.field_components.encoding import (\n    TemporalKplanesEncoding,\n    TriplanesEncoding,\n)\nfrom motionrep.field_components.mlp import MLP\nfrom motionrep.operators.rotation import rotation_6d_to_matrix, quaternion_to_matrix\nfrom motionrep.data.scene_box import SceneBox\n\n\nclass TemporalKplanesOffsetfields(nn.Module):\n    \"\"\"Temporal Offsets fields.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z ,t].\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        add_spatial_triplane: bool = True,\n        zero_init: bool = True,\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        self.output_dim = 3\n\n        self.temporal_kplanes_encoding = TemporalKplanesEncoding(\n            resolutions, feat_dim, init_a, init_b, reduce\n        )\n\n        self.add_spatial_triplane = add_spatial_triplane\n        if add_spatial_triplane:\n            self.spatial_kplanes_encoding = TriplanesEncoding(\n                resolutions[:-1], feat_dim, init_a, init_b, reduce\n            )\n            feat_dim = feat_dim * 2\n\n        self.decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=self.output_dim,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n            zero_init=zero_init,\n        )\n\n    def forward(\n        self, inp: Float[Tensor, \"*bs 4\"]\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"*bs 3\"]]:\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inpx, self.aabb) * 2.0 - 1.0\n\n        inpt = inpt * 2.0 - 1.0\n\n        inp = torch.cat([inpx, inpt], dim=-1)\n        output = self.temporal_kplanes_encoding(inp)\n\n        if self.add_spatial_triplane:\n            spatial_output = self.spatial_kplanes_encoding(inpx)\n            output = torch.cat([output, spatial_output], dim=-1)\n\n        output = self.decoder(output)\n\n        return output\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        smothness_loss = self.temporal_kplanes_encoding.compute_plane_tv()\n        temporal_smoothness_loss = (\n            self.temporal_kplanes_encoding.compute_temporal_smoothness()\n        )\n\n        if self.add_spatial_triplane:\n            smothness_loss += self.spatial_kplanes_encoding.compute_plane_tv()\n\n        return smothness_loss + temporal_smoothness_loss\n\n    def compute_loss(\n        self,\n        inp: Float[Tensor, \"*bs 4\"],\n        trajectory: Float[Tensor, \"*bs 3\"],\n        loss_func,\n    ):\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        output = self(inp)\n\n        rec_traj = inpx + output\n\n        rec_loss = loss_func(rec_traj, trajectory)\n\n        return rec_loss\n\n    def arap_loss(self, inp):\n        pass\n\n    def forward_with_plane_coefs(\n        self,\n        plane_coefs: List[Float[Tensor, \"feat_dim H W\"]],\n        inp: Float[Tensor, \"*bs 4\"],\n    ):\n        \"\"\"\n        Args:\n            pass\n        \"\"\"\n\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inpx, self.aabb) * 2.0 - 1.0\n\n        inpt = inpt * 2.0 - 1.0\n\n        inp = torch.cat([inpx, inpt], dim=-1)\n        output = self.temporal_kplanes_encoding.functional_forward(\n            plane_coefs, inp, reduce=self.temporal_kplanes_encoding.reduce\n        )\n\n        if self.add_spatial_triplane:\n            spatial_output = self.spatial_kplanes_encoding(inpx)\n            output = torch.cat([output, spatial_output], dim=-1)\n\n        output = self.decoder(output)\n\n        return output\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/fields/se3_field.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Literal, Optional, Sequence, Tuple\nfrom motionrep.field_components.encoding import (\n    TemporalKplanesEncoding,\n    TriplanesEncoding,\n)\nfrom motionrep.field_components.mlp import MLP\nfrom motionrep.operators.rotation import rotation_6d_to_matrix, quaternion_to_matrix\nfrom motionrep.data.scene_box import SceneBox\n\n\nclass TemporalKplanesSE3fields(nn.Module):\n    \"\"\"Temporal Kplanes SE(3) fields.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z ,t].\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        rotation_type: Literal[\"quaternion\", \"6d\"] = \"6d\",\n        add_spatial_triplane: bool = True,\n        zero_init: bool = True,\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        output_dim_dict = {\"quaternion\": 4 + 3, \"6d\": 6 + 3}\n        self.output_dim = output_dim_dict[rotation_type]\n        self.rotation_type = rotation_type\n\n        self.temporal_kplanes_encoding = TemporalKplanesEncoding(\n            resolutions, feat_dim, init_a, init_b, reduce\n        )\n\n        self.add_spatial_triplane = add_spatial_triplane\n        if add_spatial_triplane:\n            self.spatial_kplanes_encoding = TriplanesEncoding(\n                resolutions[:-1], feat_dim, init_a, init_b, reduce\n            )\n            feat_dim = feat_dim * 2\n\n        self.decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=self.output_dim,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n            zero_init=zero_init,\n        )\n\n    def forward(\n        self,\n        inp: Float[Tensor, \"*bs 4\"],\n        compute_smoothess_loss: bool = False,\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"*bs 3\"]]:\n        if compute_smoothess_loss:\n            smothness_loss, temporal_smoothness_loss = self.compute_smoothess_loss()\n            return smothness_loss + temporal_smoothness_loss\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inpx, self.aabb) * 2.0 - 1.0\n\n        inpt = inpt * 2.0 - 1.0\n\n        inp = torch.cat([inpx, inpt], dim=-1)\n        output = self.temporal_kplanes_encoding(inp)\n\n        if self.add_spatial_triplane:\n            spatial_output = self.spatial_kplanes_encoding(inpx)\n            output = torch.cat([output, spatial_output], dim=-1)\n\n        output = self.decoder(output)\n\n        if self.rotation_type == \"6d\":\n            rotation_6d, translation = output[:, :6], output[:, 6:]\n            R_mat = rotation_6d_to_matrix(rotation_6d)\n\n        elif self.rotation_type == \"quaternion\":\n            quat, translation = output[:, :4], output[:, 4:]\n\n            # tanh and normalize\n            quat = torch.tanh(quat)\n\n            R_mat = quaternion_to_matrix(quat)\n\n            # --------------- remove below --------------- #\n            # add normalization\n            # r = quat\n            # norm = torch.sqrt(\n            #     r[:, 0] * r[:, 0]\n            #     + r[:, 1] * r[:, 1]\n            #     + r[:, 2] * r[:, 2]\n            #     + r[:, 3] * r[:, 3]\n            # )\n            # q = r / norm[:, None]\n            # R_mat = q\n            # --------------- remove above --------------- #\n\n        return R_mat, translation\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        smothness_loss = self.temporal_kplanes_encoding.compute_plane_tv()\n        temporal_smoothness_loss = (\n            self.temporal_kplanes_encoding.compute_temporal_smoothness()\n        )\n\n        if self.add_spatial_triplane:\n            smothness_loss += self.spatial_kplanes_encoding.compute_plane_tv()\n\n        return smothness_loss, temporal_smoothness_loss\n\n    def compute_loss(\n        self,\n        inp: Float[Tensor, \"*bs 4\"],\n        trajectory: Float[Tensor, \"*bs 3\"],\n        loss_func,\n    ):\n        inpx, inpt = inp[:, :3], inp[:, 3:]\n\n        R, t = self(inp)\n\n        rec_traj = torch.bmm(R, inpx.unsqueeze(-1)).squeeze(-1) + t\n\n        rec_loss = loss_func(rec_traj, trajectory)\n\n        return rec_loss\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/fields/triplane_field.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Optional, Sequence, Tuple, List\nfrom motionrep.field_components.encoding import TriplanesEncoding\nfrom motionrep.field_components.mlp import MLP\nfrom motionrep.data.scene_box import SceneBox\n\n\nclass TriplaneFields(nn.Module):\n    \"\"\"Temporal Kplanes SE(3) fields.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z]\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce=\"sum\",  #: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        output_dim: int = 96,\n        zero_init: bool = False,\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        self.output_dim = output_dim\n\n        self.kplanes_encoding = TriplanesEncoding(\n            resolutions, feat_dim, init_a, init_b, reduce\n        )\n\n        self.decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=self.output_dim,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n            zero_init=zero_init,\n        )\n\n    def forward(\n        self, inp: Float[Tensor, \"*bs 3\"]\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"*bs 3\"]]:\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inp, self.aabb) * 2.0 - 1.0\n\n        output = self.kplanes_encoding(inpx)\n\n        output = self.decoder(output)\n\n        # split_size = output.shape[-1] // 3\n        # output = torch.stack(torch.split(output, split_size, dim=-1), dim=-1)\n\n        return output\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        smothness_loss = self.kplanes_encoding.compute_plane_tv()\n\n        return smothness_loss\n\n\ndef compute_entropy(p):\n    return -torch.sum(p * torch.log(p + 1e-5), dim=1).mean()  # Adding a small constant to prevent log(0)\n\n\nclass TriplaneFieldsWithEntropy(nn.Module):\n    \"\"\"Temporal Kplanes SE(3) fields.\n\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,z) point.\n            aabb[1] is the maximum (x,y,z) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, z]\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce=\"sum\",  #: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        output_dim: int = 96,\n        zero_init: bool = False,\n        num_cls: int = 3,\n    ):\n        super().__init__()\n\n        self.register_buffer(\"aabb\", aabb)\n        self.output_dim = output_dim\n        self.num_cls = num_cls\n\n        self.kplanes_encoding = TriplanesEncoding(\n            resolutions, feat_dim, init_a, init_b, reduce\n        )\n\n        self.decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=self.num_cls,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n            zero_init=zero_init,\n        )\n\n        self.cls_embedding = torch.nn.Embedding(num_cls, output_dim)\n\n    def forward(\n        self, inp: Float[Tensor, \"*bs 3\"]\n    ) -> Tuple[Float[Tensor, \"*bs 3 3\"], Float[Tensor, \"1\"]]:\n        # shift to [-1, 1]\n        inpx = SceneBox.get_normalized_positions(inp, self.aabb) * 2.0 - 1.0\n\n        output = self.kplanes_encoding(inpx)\n\n        output = self.decoder(output)\n\n        prob = F.softmax(output, dim=-1)\n\n        entropy = compute_entropy(prob)\n\n        cls_index = torch.tensor([0, 1, 2]).to(inp.device)\n        cls_emb = self.cls_embedding(cls_index)\n\n        output = torch.matmul(prob, cls_emb)\n\n        \n        return output, entropy\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        smothness_loss = self.kplanes_encoding.compute_plane_tv()\n\n        return smothness_loss\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/fields/video_triplane_disp_field.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom jaxtyping import Float, Int, Shaped\nfrom torch import Tensor, nn\nfrom typing import Optional, Sequence, Tuple, List\nfrom motionrep.field_components.encoding import (\n    TriplanesEncoding,\n    PlaneEncoding,\n    TemporalNeRFEncoding,\n)\nfrom motionrep.field_components.mlp import MLP\nfrom motionrep.data.scene_box import SceneBox\nfrom einops import rearrange, repeat\n\n\nclass TriplaneDispFields(nn.Module):\n    \"\"\"Kplanes Displacement fields.\n        [x, t, t] => [dx, dy]\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,t) point.\n            aabb[1] is the maximum (x,y,t) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, t]\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce=\"cat\",  #: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        output_dim: int = 2,\n        zero_init: bool = False,\n    ):\n        super().__init__()\n\n        if aabb is None:\n            aabb = (\n                torch.tensor([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]], dtype=torch.float32)\n                * 1.1\n            )\n\n        self.register_buffer(\"aabb\", aabb)\n        self.output_dim = output_dim\n\n        self.canonical_encoding = PlaneEncoding(\n            resolutions[:2], feat_dim, init_a, init_b\n        )\n        self.canonical_decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=3,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n        )\n\n        self.kplanes_encoding = TriplanesEncoding(\n            resolutions, feat_dim, init_a, init_b, reduce\n        )\n\n        if reduce == \"cat\":\n            feat_dim = int(feat_dim * 3)\n\n        self.decoder = MLP(\n            feat_dim,\n            int(num_decoder_layers * 3),\n            layer_width=decoder_hidden_size,\n            out_dim=self.output_dim,\n            skip_connections=(2, 4),\n            activation=nn.ReLU(),\n            out_activation=None,\n            zero_init=zero_init,\n        )\n\n    def forward(\n        self, inp: Float[Tensor, \"*bs 3\"]\n    ) -> Tuple[Float[Tensor, \"*bs 2\"], Float[Tensor, \"*bs 3\"]]:\n        # shift to [-1, 1]\n        inp_norm = SceneBox.get_normalized_positions(inp, self.aabb) * 2.0 - 1.0\n\n        output = self.kplanes_encoding(inp_norm)\n\n        # [*bs, 2]\n        output = self.decoder(output)\n\n        inpyx = inp_norm[..., :2].reshape(-1, 2)\n\n        canonical_yx = inpyx + output\n\n        ret_rgb_feat = self.canonical_encoding(canonical_yx)\n        ret_rgb = self.canonical_decoder(ret_rgb_feat)\n\n        return output, ret_rgb\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        smothness_loss = self.kplanes_encoding.compute_plane_tv()\n\n        smothness_canonical = self.canonical_encoding.compute_plane_tv()\n        return smothness_loss + smothness_canonical\n\n    def get_canonical(\n        self, canonical_grid: Float[Tensor, \"*bs 2\"]\n    ) -> Float[Tensor, \"*bs 3\"]:\n        pad_can_grid = torch.cat(\n            [canonical_grid, torch.zeros_like(canonical_grid[..., :1])], dim=-1\n        )\n        pad_can_norm = (\n            SceneBox.get_normalized_positions(pad_can_grid, self.aabb) * 2.0 - 1.0\n        )\n\n        inp_can_grid = pad_can_norm[..., :2]\n\n        ret_rgb_feat = self.canonical_encoding(inp_can_grid)\n        ret_rgb = self.canonical_decoder(ret_rgb_feat)\n\n        return ret_rgb\n\n    def sample_canonical(\n        self,\n        inp: Float[Tensor, \"bs hw 3\"],\n        canonical_frame: Float[Tensor, \"1 H W 3\"],\n        canonical_grid_yx: Float[Tensor, \"bs hw 2\"],\n    ) -> Float[Tensor, \"bs h w 3\"]:\n        #\n        inp_norm = SceneBox.get_normalized_positions(inp, self.aabb) * 2.0 - 1.0\n\n        output = self.kplanes_encoding(inp_norm)\n\n        # [-1, 2]\n        output = self.decoder(output)\n\n        inpyx = inp_norm[..., :2].reshape(-1, 2)\n\n        canonical_yx = inpyx + output\n        canonical_yx = canonical_yx * 1.1\n\n        can_ymin, can_ymax = (\n            canonical_grid_yx[..., 0].min(),\n            canonical_grid_yx[..., 0].max(),\n        )\n        can_xmin, can_xmax = (\n            canonical_grid_yx[..., 1].min(),\n            canonical_grid_yx[..., 1].max(),\n        )\n        canonical_yx[..., 0] = (canonical_yx[..., 0] - can_ymin) / (\n            can_ymax - can_ymin\n        ) * 2.0 - 1.0\n        canonical_yx[..., 1] = (canonical_yx[..., 1] - can_xmin) / (\n            can_xmax - can_xmin\n        ) * 2.0 - 1.0\n\n        canonical_xy = torch.cat(\n            [canonical_yx[..., 1:2], canonical_yx[..., 0:1]], dim=-1\n        )\n        # use grid sample to sample the canonical frame\n\n        # [B, C, H, W]\n        canonical_frame = canonical_frame.permute(0, 3, 1, 2).expand(\n            inp.shape[0], -1, -1, -1\n        )\n        H, W = canonical_frame.shape[-2:]\n        canonical_xy = canonical_xy.reshape(-1, H, W, 2)\n\n        rec = F.grid_sample(canonical_frame, canonical_xy, align_corners=True)\n\n        rec = rearrange(rec, \"b c h w -> b h w c\")\n\n        return rec\n\n\nclass PlaneDynamicDispFields(nn.Module):\n    \"\"\"Plane Displacement fields.\n        [x, t, t] => [dx, dy]\n    Args:\n        aabb: axis-aligned bounding box.\n            aabb[0] is the minimum (x,y,t) point.\n            aabb[1] is the maximum (x,y,t) point.\n        resolutions: resolutions of the kplanes. in an order of [x, y, t]\n\n    \"\"\"\n\n    def __init__(\n        self,\n        aabb: Float[Tensor, \"2 3\"],\n        resolutions: Sequence[int],\n        feat_dim: int = 64,\n        init_a: float = 0.1,\n        init_b: float = 0.5,\n        reduce=\"cat\",  #: Literal[\"sum\", \"product\", \"cat\"] = \"sum\",\n        num_decoder_layers=2,\n        decoder_hidden_size=64,\n        output_dim: int = 2,\n        zero_init: bool = False,\n        num_temporal_freq: int = 20,\n        freq_min: float = 0.0,\n        freq_max: float = 20,\n    ):\n        super().__init__()\n\n        if aabb is None:\n            aabb = (\n                torch.tensor([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]], dtype=torch.float32)\n                * 1.1\n            )\n\n        self.register_buffer(\"aabb\", aabb)\n        self.output_dim = output_dim\n\n        self.canonical_encoding = PlaneEncoding(\n            resolutions[:2], feat_dim, init_a, init_b\n        )\n        self.canonical_decoder = MLP(\n            feat_dim,\n            num_decoder_layers,\n            layer_width=decoder_hidden_size,\n            out_dim=3,\n            skip_connections=None,\n            activation=nn.ReLU(),\n            out_activation=None,\n        )\n\n        self.deform_planes_encoding = PlaneEncoding(\n            resolutions[:2], feat_dim, init_a, init_b\n        )\n\n        self.num_temporal_freq = num_temporal_freq\n        self.temporal_pos_encoding = TemporalNeRFEncoding(\n            1,\n            num_temporal_freq,\n            freq_min,\n            freq_max,\n            log_scale=False,\n        )\n\n        self.decoder = MLP(\n            feat_dim + self.temporal_pos_encoding.get_out_dim(),\n            int(num_decoder_layers * 3),\n            layer_width=decoder_hidden_size,\n            out_dim=self.output_dim,\n            skip_connections=(2, 4),\n            activation=nn.ReLU(),\n            out_activation=None,\n            zero_init=zero_init,\n        )\n\n    def forward(\n        self, inp: Float[Tensor, \"*bs 3\"]\n    ) -> Tuple[Float[Tensor, \"*bs 2\"], Float[Tensor, \"*bs 3\"]]:\n        # shift to [-1, 1]\n        inp_norm = SceneBox.get_normalized_positions(inp, self.aabb) * 2.0 - 1.0\n\n        inp_yx, inp_t = inp_norm[..., 0:2], inp_norm[..., 2:3]\n\n        spatial_feat = self.deform_planes_encoding(inp_yx)\n\n        temporal_enc = self.temporal_pos_encoding(inp_t)\n        # [*bs, 2]\n\n        output = self.decoder(\n            torch.cat(\n                [spatial_feat, temporal_enc.view(-1, temporal_enc.shape[-1])], dim=-1\n            )\n        )\n\n        canonical_yx = inp_yx.reshape(-1, 2) + output\n\n        ret_rgb_feat = self.canonical_encoding(canonical_yx)\n        ret_rgb = self.canonical_decoder(ret_rgb_feat)\n\n        return output, ret_rgb\n\n    def compute_smoothess_loss(\n        self,\n    ):\n        smothness_loss = self.deform_planes_encoding.compute_plane_tv()\n\n        smothness_canonical = self.canonical_encoding.compute_plane_tv()\n        return smothness_loss + smothness_canonical\n\n    def get_canonical(\n        self, canonical_grid: Float[Tensor, \"*bs 2\"]\n    ) -> Float[Tensor, \"*bs 3\"]:\n        pad_can_grid = torch.cat(\n            [canonical_grid, torch.zeros_like(canonical_grid[..., :1])], dim=-1\n        )\n        pad_can_norm = (\n            SceneBox.get_normalized_positions(pad_can_grid, self.aabb) * 2.0 - 1.0\n        )\n\n        inp_can_grid = pad_can_norm[..., :2]\n\n        ret_rgb_feat = self.canonical_encoding(inp_can_grid)\n        ret_rgb = self.canonical_decoder(ret_rgb_feat)\n\n        return ret_rgb\n\n    def sample_canonical(\n        self,\n        inp: Float[Tensor, \"bs hw 3\"],\n        canonical_frame: Float[Tensor, \"1 H W 3\"],\n        canonical_grid_yx: Float[Tensor, \"bs hw 2\"],\n    ) -> Float[Tensor, \"bs h w 3\"]:\n        inp_norm = SceneBox.get_normalized_positions(inp, self.aabb) * 2.0 - 1.0\n\n        inp_yx, inp_t = inp_norm[..., 0:2], inp_norm[..., 2:3]\n        inp_yx = inp_yx.reshape(-1, 2)\n        spatial_feat = self.deform_planes_encoding(inp_yx)\n\n        temporal_enc = self.temporal_pos_encoding(inp_t.view(-1, 1))\n        # [*bs, 2]\n        output = self.decoder(torch.cat([spatial_feat, temporal_enc], dim=-1))\n\n        canonical_yx = inp_yx + output\n        canonical_yx = canonical_yx * 1.1\n\n        can_ymin, can_ymax = (\n            canonical_grid_yx[..., 0].min(),\n            canonical_grid_yx[..., 0].max(),\n        )\n        can_xmin, can_xmax = (\n            canonical_grid_yx[..., 1].min(),\n            canonical_grid_yx[..., 1].max(),\n        )\n        canonical_yx[..., 0] = (canonical_yx[..., 0] - can_ymin) / (\n            can_ymax - can_ymin\n        ) * 2.0 - 1.0\n        canonical_yx[..., 1] = (canonical_yx[..., 1] - can_xmin) / (\n            can_xmax - can_xmin\n        ) * 2.0 - 1.0\n\n        canonical_xy = torch.cat(\n            [canonical_yx[..., 1:2], canonical_yx[..., 0:1]], dim=-1\n        )\n        # use grid sample to sample the canonical frame\n\n        # [B, C, H, W]\n        canonical_frame = canonical_frame.permute(0, 3, 1, 2).expand(\n            inp.shape[0], -1, -1, -1\n        )\n        H, W = canonical_frame.shape[-2:]\n        canonical_xy = canonical_xy.reshape(-1, H, W, 2)\n\n        rec = F.grid_sample(canonical_frame, canonical_xy, align_corners=True)\n\n        rec = rearrange(rec, \"b c h w -> b h w c\")\n\n        return rec\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/arguments/__init__.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom argparse import ArgumentParser, Namespace\nimport sys\nimport os\n\nclass GroupParams:\n    pass\n\nclass ParamGroup:\n    def __init__(self, parser: ArgumentParser, name : str, fill_none = False):\n        group = parser.add_argument_group(name)\n        for key, value in vars(self).items():\n            shorthand = False\n            if key.startswith(\"_\"):\n                shorthand = True\n                key = key[1:]\n            t = type(value)\n            value = value if not fill_none else None \n            if shorthand:\n                if t == bool:\n                    group.add_argument(\"--\" + key, (\"-\" + key[0:1]), default=value, action=\"store_true\")\n                else:\n                    group.add_argument(\"--\" + key, (\"-\" + key[0:1]), default=value, type=t)\n            else:\n                if t == bool:\n                    group.add_argument(\"--\" + key, default=value, action=\"store_true\")\n                else:\n                    group.add_argument(\"--\" + key, default=value, type=t)\n\n    def extract(self, args):\n        group = GroupParams()\n        for arg in vars(args).items():\n            if arg[0] in vars(self) or (\"_\" + arg[0]) in vars(self):\n                setattr(group, arg[0], arg[1])\n        return group\n\nclass ModelParams(ParamGroup): \n    def __init__(self, parser, sentinel=False):\n        self.sh_degree = 3\n        self._source_path = \"\"\n        self._model_path = \"\"\n        self._images = \"images\"\n        self._resolution = -1\n        self._white_background = False\n        self.data_device = \"cuda\"\n        self.eval = False\n        super().__init__(parser, \"Loading Parameters\", sentinel)\n\n    def extract(self, args):\n        g = super().extract(args)\n        g.source_path = os.path.abspath(g.source_path)\n        return g\n\nclass PipelineParams(ParamGroup):\n    def __init__(self, parser):\n        self.convert_SHs_python = False\n        self.compute_cov3D_python = False\n        self.debug = False\n        super().__init__(parser, \"Pipeline Parameters\")\n\nclass OptimizationParams(ParamGroup):\n    def __init__(self, parser):\n        self.iterations = 30_000\n        self.position_lr_init = 0.00016\n        self.position_lr_final = 0.0000016\n        self.position_lr_delay_mult = 0.01\n        self.position_lr_max_steps = 30_000\n        self.feature_lr = 0.0025\n        self.opacity_lr = 0.05\n        self.scaling_lr = 0.005\n        self.rotation_lr = 0.001\n        self.percent_dense = 0.01\n        self.lambda_dssim = 0.2\n        self.densification_interval = 100\n        self.opacity_reset_interval = 3000\n        self.densify_from_iter = 500\n        self.densify_until_iter = 15_000\n        self.densify_grad_threshold = 0.0002\n        super().__init__(parser, \"Optimization Parameters\")\n\ndef get_combined_args(parser : ArgumentParser):\n    cmdlne_string = sys.argv[1:]\n    cfgfile_string = \"Namespace()\"\n    args_cmdline = parser.parse_args(cmdlne_string)\n\n    try:\n        cfgfilepath = os.path.join(args_cmdline.model_path, \"cfg_args\")\n        print(\"Looking for config file in\", cfgfilepath)\n        with open(cfgfilepath) as cfg_file:\n            print(\"Config file found: {}\".format(cfgfilepath))\n            cfgfile_string = cfg_file.read()\n    except TypeError:\n        print(\"Config file not found at\")\n        pass\n    args_cfgfile = eval(cfgfile_string)\n\n    merged_dict = vars(args_cfgfile).copy()\n    for k,v in vars(args_cmdline).items():\n        if v != None:\n            merged_dict[k] = v\n    return Namespace(**merged_dict)\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/gaussian_renderer/__init__.py",
    "content": ""
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/gaussian_renderer/depth_uv_render.py",
    "content": "import torch\nfrom motionrep.gaussian_3d.scene.gaussian_model import GaussianModel\nimport math\n\nfrom diff_gaussian_rasterization import (\n    GaussianRasterizationSettings,\n    GaussianRasterizer,\n)\nfrom typing import Callable\n\n\ndef render_uv_depth_w_gaussian(\n    viewpoint_camera,\n    pc: GaussianModel,\n    pipe,\n    bg_color: torch.Tensor,\n    scaling_modifier=1.0,\n):\n    \"\"\"\n    Render the scene.\n\n    Background tensor (bg_color) must be on GPU!\n\n    Args:\n        point_disp: [N, 3]\n    \"\"\"\n\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = (\n        torch.zeros_like(\n            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\"\n        )\n        + 0\n    )\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug,\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if pipe.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    \n    shs = None\n    colors_precomp = None\n    \n    # project point motion to 2D using camera:\n    w2c = viewpoint_camera.world_view_transform.transpose(0, 1)\n    cam_plane_2_img = viewpoint_camera.cam_plane_2_img  # [2, 2]\n\n    R = w2c[:3, :3].unsqueeze(0)  # [1, 3, 3]\n    t = w2c[:3, 3].unsqueeze(0)  # [1, 3]\n\n    # [N, 3, 1]\n    pts = torch.cat([pc._xyz, torch.ones_like(pc._xyz[:, 0:1])], dim=-1)\n    pts_cam = w2c.unsqueeze(0) @ pts.unsqueeze(-1)  # [N, 4, 1]\n    # pts_cam = R @ (pc._xyz.unsqueeze(-1)) + t[:, None]\n    depth = pts_cam[:, 2, 0]  # [N]\n    # print(\"depth\", depth.shape, depth.max(), depth.mean(), depth.min())\n\n   # [N, 2]\n    pts_cam_xy = pts_cam[:, :2, 0] / depth.unsqueeze(-1)\n    \n    \n    pts_cam_xy_pixel = cam_plane_2_img.unsqueeze(0) @ pts_cam_xy.unsqueeze(-1)  # [N, 2, 1]\n    pts_cam_xy_pixel = pts_cam_xy_pixel.squeeze(-1)  # [N, 2]\n\n    colors_precomp = torch.cat(\n        [pts_cam_xy_pixel, depth.unsqueeze(dim=-1)], dim=-1\n    )  # [N, 3]\n\n    # print(\"converted 2D motion precompute: \", colors_precomp.shape, shs, colors_precomp.max(), colors_precomp.min(), colors_precomp.mean())\n    # Rasterize visible Gaussians to image, obtain their radii (on screen).\n    rendered_image, radii = rasterizer(\n        means3D=means3D,\n        means2D=means2D,\n        shs=shs,\n        colors_precomp=colors_precomp,\n        opacities=opacity,\n        scales=scales,\n        rotations=rotations,\n        cov3D_precomp=cov3D_precomp,\n    )\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n\n    return {\n        \"render\": rendered_image,\n        \"visibility_filter\": radii > 0,\n        \"radii\": radii,\n        \"pts_depth\": depth,\n        \"pts_cam_xy_pixel\": pts_cam_xy_pixel,\n    }\n\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/gaussian_renderer/feat_render.py",
    "content": "import torch\nfrom motionrep.gaussian_3d.scene.gaussian_model import GaussianModel\nimport math\n\nfrom diff_gaussian_rasterization import (\n    GaussianRasterizationSettings,\n    GaussianRasterizer,\n)\nfrom typing import Callable\n\n\ndef render_feat_gaussian(\n    viewpoint_camera,\n    pc: GaussianModel,\n    pipe,\n    bg_color: torch.Tensor,\n    points_feat: torch.Tensor,\n    scaling_modifier=1.0,\n):\n    \"\"\"\n    Render the scene.\n\n    Background tensor (bg_color) must be on GPU!\n\n    Args:\n        point_disp: [N, 3]\n    \"\"\"\n\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = (\n        torch.zeros_like(\n            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\"\n        )\n        + 0\n    )\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug,\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if pipe.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    \n    shs = None\n    colors_precomp = points_feat\n    assert (points_feat.shape[1] == 3) and (points_feat.shape[0] == means3D.shape[0])\n\n    # print(\"converted 2D motion precompute: \", colors_precomp.shape, shs, colors_precomp.max(), colors_precomp.min(), colors_precomp.mean())\n    # Rasterize visible Gaussians to image, obtain their radii (on screen).\n    rendered_image, radii = rasterizer(\n        means3D=means3D,\n        means2D=means2D,\n        shs=shs,\n        colors_precomp=colors_precomp,\n        opacities=opacity,\n        scales=scales,\n        rotations=rotations,\n        cov3D_precomp=cov3D_precomp,\n    )\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n\n    return {\n        \"render\": rendered_image,\n        \"visibility_filter\": radii > 0,\n        \"radii\": radii,\n    }\n\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/gaussian_renderer/flow_depth_render.py",
    "content": "import torch\nfrom motionrep.gaussian_3d.scene.gaussian_model import GaussianModel\nimport math\n\nfrom diff_gaussian_rasterization import (\n    GaussianRasterizationSettings,\n    GaussianRasterizer,\n)\nfrom typing import Callable\n\n\ndef render_flow_depth_w_gaussian(\n    viewpoint_camera,\n    pc: GaussianModel,\n    pipe,\n    point_disp: torch.Tensor,\n    bg_color: torch.Tensor,\n    scaling_modifier=1.0,\n):\n    \"\"\"\n    Render the scene.\n\n    Background tensor (bg_color) must be on GPU!\n\n    Args:\n        point_disp: [N, 3]\n    \"\"\"\n\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = (\n        torch.zeros_like(\n            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\"\n        )\n        + 0\n    )\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug,\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if pipe.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    shs = None\n    colors_precomp = None\n\n    # project point motion to 2D using camera:\n    w2c = viewpoint_camera.world_view_transform.transpose(0, 1)\n    cam_plane_2_img = viewpoint_camera.cam_plane_2_img  # [2, 2]\n\n    R = w2c[:3, :3].unsqueeze(0)  # [1, 3, 3]\n    t = w2c[:3, 3].unsqueeze(0)  # [1, 3]\n\n    # [N, 3, 1]\n    pts = torch.cat([pc._xyz, torch.ones_like(pc._xyz[:, 0:1])], dim=-1)\n    pts_cam = w2c.unsqueeze(0) @ pts.unsqueeze(-1)  # [N, 4, 1]\n    # pts_cam = R @ (pc._xyz.unsqueeze(-1)) + t[:, None]\n    depth = pts_cam[:, 2, 0]  # [N]\n    # print(\"depth\", depth.shape, depth.max(), depth.mean(), depth.min())\n\n    point_disp_pad = torch.cat(\n        [point_disp, torch.zeros_like(point_disp[:, 0:1])], dim=-1\n    )  # [N, 4]\n\n    pts_motion = w2c.unsqueeze(0) @ point_disp_pad.unsqueeze(-1)  # [N, 4, 1]\n\n    # [N, 2]\n    pts_motion_xy = pts_motion[:, :2, 0] / depth.unsqueeze(-1)\n\n    pts_motion_xy_pixel = cam_plane_2_img.unsqueeze(0) @ pts_motion_xy.unsqueeze(\n        -1\n    )  # [N, 2, 1]\n    pts_motion_xy_pixel = pts_motion_xy_pixel.squeeze(-1)  # [N, 2]\n\n    colors_precomp = torch.cat(\n        [pts_motion_xy_pixel, depth.unsqueeze(dim=-1)], dim=-1\n    )  # [N, 3]\n\n    # print(\"converted 2D motion precompute: \", colors_precomp.shape, shs, colors_precomp.max(), colors_precomp.min(), colors_precomp.mean())\n    # Rasterize visible Gaussians to image, obtain their radii (on screen).\n    rendered_image, radii = rasterizer(\n        means3D=means3D,\n        means2D=means2D,\n        shs=shs,\n        colors_precomp=colors_precomp,\n        opacities=opacity,\n        scales=scales,\n        rotations=rotations,\n        cov3D_precomp=cov3D_precomp,\n    )\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n\n    # return {\n    #     \"render\": rendered_image,\n    #     \"viewspace_points\": screenspace_points,\n    #     \"visibility_filter\": radii > 0,\n    #     \"radii\": radii,\n    # }\n\n    return {\"render\": rendered_image}\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/gaussian_renderer/motion_renderer.py",
    "content": "import torch\nfrom motionrep.gaussian_3d.scene.gaussian_model import GaussianModel\nimport math\n\nfrom diff_gaussian_rasterization_wmotion import GaussianRasterizationWMotionSettings as GaussianRasterizationSettings_wmotion\nfrom diff_gaussian_rasterization_wmotion import GaussianRasterizerWMotion as GaussianRasterizer_wmotion\nfrom typing import Callable\n\ndef render_motion_w_gaussian(\n    viewpoint_camera,\n    pc: GaussianModel,\n    motion_fields: Callable,\n    pipe,\n    bg_color: torch.Tensor,\n    scaling_modifier=1.0,\n    point_motion=None,\n):\n    \"\"\"\n    Render the scene.\n\n    Background tensor (bg_color) must be on GPU!\n\n    Args:\n        point_motion: [N, num_feat, 3] or None\n            if None.  motion_fields will be called to sample point motion\n    \"\"\"\n\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = (\n        torch.zeros_like(\n            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\"\n        )\n        + 0\n    )\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings_wmotion = GaussianRasterizationSettings_wmotion(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug,\n    )\n\n    rasterizer_wm = GaussianRasterizer_wmotion(raster_settings=raster_settings_wmotion)\n\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if pipe.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    shs = None\n    colors_precomp = None\n    if point_motion is None:\n        xyz = pc._xyz \n        # [N, num_feat, 3]\n        point_motion = motion_fields(xyz) \n\n    # project point motion to 2D using camera:\n    w2c = viewpoint_camera.world_view_transform.transpose(0, 1)\n    cam_plane_2_img = viewpoint_camera.cam_plane_2_img # [2, 2]\n    \n    R = w2c[:3, :3].unsqueeze(0) # [1, 3, 3]\n    t = w2c[:3, 3].unsqueeze(0) # [1, 3]\n\n    # [N, 3, 1]\n    pts = torch.cat([pc._xyz, torch.ones_like(pc._xyz[:, 0:1])], dim=-1)\n    pts_cam = w2c.unsqueeze(0) @ pts.unsqueeze(-1) # [N, 4, 1]\n    # pts_cam = R @ (pc._xyz.unsqueeze(-1)) + t[:, None]\n    depth = pts_cam[:, 2, 0] # [N]\n    # print(\"depth\", depth.shape, depth.max(), depth.mean(), depth.min())\n    \n    # pts = torch.cat([pc._xyz, torch.ones_like(pc._xyz[:, 0:1])], dim=-1)\n    # cam_pts = pts.unsqueeze(1) @ viewpoint_camera.full_proj_transform.unsqueeze(0) # [N, 1, 4] @ [N, 1, 4]\n    # cam_pts = cam_pts.squeeze(1) # [N, 4]\n    # depth = cam_pts[:, 3] # [N]\n    \n    point_motion_pad = torch.cat([point_motion, torch.zeros_like(point_motion[:, :, 0:1])], dim=-1) # [N, num_feat, 4]\n\n    pts_motion = w2c.unsqueeze(0).unsqueeze(0) @ point_motion_pad.unsqueeze(-1) # [N, num_feat, 4, 1] \n    # pts_motion = R.unsqueeze(1) @ (point_motion.unsqueeze(-1)) # [N, num_feat, 3, 1]\n    # [N, num_feat, 2]\n    pts_motion_xy = pts_motion[:, :, :2, 0] / depth.unsqueeze(-1).unsqueeze(-1) \n    # [N, num_feat, 2]\n\n\n    pts_motion_xy_pixel = cam_plane_2_img.unsqueeze(0).unsqueeze(0) @ pts_motion_xy.unsqueeze(-1) # [N, num_feat, 2, 1]\n    pts_motion_xy_pixel = pts_motion_xy_pixel.squeeze(-1) # [N, num_feat, 2]\n    pts_motion = pts_motion_xy_pixel.flatten(1, 2) # [N, num_feat * 2]\n\n    colors_precomp = pts_motion\n    \n    # print(\"converted 2D motion precompute: \", colors_precomp.shape, shs, colors_precomp.max(), colors_precomp.min(), colors_precomp.mean())\n    # Rasterize visible Gaussians to image, obtain their radii (on screen).\n    rendered_image, radii = rasterizer_wm(\n        means3D=means3D,\n        means2D=means2D,\n        shs=shs,\n        colors_precomp=colors_precomp,\n        opacities=opacity,\n        scales=scales,\n        rotations=rotations,\n        cov3D_precomp=cov3D_precomp,\n    )\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n\n    # return {\n    #     \"render\": rendered_image,\n    #     \"viewspace_points\": screenspace_points,\n    #     \"visibility_filter\": radii > 0,\n    #     \"radii\": radii,\n    # }\n\n    return {\"render\": rendered_image}\n\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/gaussian_renderer/render.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport math\nfrom diff_gaussian_rasterization import (\n    GaussianRasterizationSettings,\n    GaussianRasterizer,\n)\nfrom motionrep.gaussian_3d.scene.gaussian_model import GaussianModel\n\n\ndef render_gaussian(\n    viewpoint_camera,\n    pc: GaussianModel,\n    pipe,\n    bg_color: torch.Tensor,\n    scaling_modifier=1.0,\n    override_color=None,\n    cov3D_precomp=None,\n):\n    \"\"\"\n    Render the scene.\n\n    Background tensor (bg_color) must be on GPU!\n    \"\"\"\n\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = (\n        torch.zeros_like(\n            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\"\n        )\n        + 0\n    )\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug,\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n\n    if pipe.compute_cov3D_python or cov3D_precomp is None:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    elif cov3D_precomp is None:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    shs = None\n    colors_precomp = None\n    if override_color is None:\n        if pipe.convert_SHs_python:\n            shs_view = pc.get_features.transpose(1, 2).view(\n                -1, 3, (pc.max_sh_degree + 1) ** 2\n            )\n            dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(\n                pc.get_features.shape[0], 1\n            )\n            dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)\n            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)\n            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)\n        else:\n            shs = pc.get_features\n    else:\n        colors_precomp = override_color\n\n    # Rasterize visible Gaussians to image, obtain their radii (on screen).\n    rendered_image, radii = rasterizer(\n        means3D=means3D,\n        means2D=means2D,\n        shs=shs,\n        colors_precomp=colors_precomp,\n        opacities=opacity,\n        scales=scales,\n        rotations=rotations,\n        cov3D_precomp=cov3D_precomp,\n    )\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n    return {\n        \"render\": rendered_image,\n        \"viewspace_points\": screenspace_points,\n        \"visibility_filter\": radii > 0,\n        \"radii\": radii,\n    }\n    # return {\"render\": rendered_image}\n\n\ndef gaussian_intrin_scale(x_or_y: torch.Tensor, w_or_h: float):\n\n    ret = ((x_or_y + 1.0) * w_or_h - 1.0) * 0.5\n\n    return ret\n\n\ndef render_arrow_in_screen(viewpoint_camera, points_3d):\n\n    # project point motion to 2D using camera:\n    w2c = viewpoint_camera.world_view_transform.transpose(0, 1)\n    cam_plane_2_img = viewpoint_camera.cam_plane_2_img  # [2, 2]\n    cam_plane_2_img = viewpoint_camera.projection_matrix.transpose(0, 1)  # [4, 4]\n\n    full_proj_mat = viewpoint_camera.full_proj_transform\n\n    # [N, 4]\n    pts = torch.cat([points_3d, torch.ones_like(points_3d[:, 0:1])], dim=-1)\n    # [N, 1, 4] <-  [N, 1, 4] @ [1, 4, 4]\n    pts_cam = pts.unsqueeze(-2) @ full_proj_mat.unsqueeze(0)  # [N, 1, 4]\n\n    # start here\n\n    # pts: [N, 4]\n    # [1, 4, 4] @ [N, 4, 1] -> [N, 4, 1]\n    # from IPython import embed\n\n    # embed()\n    # pts_cam = torch.bmm(\n    #     full_proj_mat.T.unsqueeze(0), pts.unsqueeze(-1)\n    # )  # K*[R,T]*[x,y,z,1]^T to get 2D projection of Gaussians\n    # end here\n    pts_cam = full_proj_mat.T.unsqueeze(0) @ pts.unsqueeze(-1)\n\n    # print(pts_cam.shape)\n\n    pts_cam = pts_cam.squeeze(-1)  # [N, 4]\n    pts_cam = pts_cam[:, :3] / pts_cam[:, 3:]  # [N, 1, 3]\n\n    # print(pts_cam, \"after proj\")\n\n    pts_cam_yx_pixel = pts_cam[:, :2]\n    #  [N, 2] yx => xy\n    # pts_cam_xy_pixel = torch.cat(\n    #     [pts_cam_xy_pixel[:, [1]], pts_cam_xy_pixel[:, [0]]], dim=-1\n    # )\n\n    pts_cam_x, pts_cam_y = pts_cam_yx_pixel[:, 0], pts_cam_yx_pixel[:, 1]\n\n    w, h = viewpoint_camera.image_width, viewpoint_camera.image_height\n\n    pts_cam_x = gaussian_intrin_scale(pts_cam_x, w)\n    pts_cam_y = gaussian_intrin_scale(pts_cam_y, h)\n\n    ret_pts_cam_xy = torch.cat(\n        [pts_cam_x.unsqueeze(-1), pts_cam_y.unsqueeze(-1)], dim=-1\n    )\n\n    # print(ret_pts_cam_xy)\n\n    return ret_pts_cam_xy\n\n\ndef render_arrow_in_screen_back(viewpoint_camera, points_3d):\n\n    # project point motion to 2D using camera:\n    w2c = viewpoint_camera.world_view_transform.transpose(0, 1)\n    cam_plane_2_img = viewpoint_camera.cam_plane_2_img  # [2, 2]\n    cam_plane_2_img = viewpoint_camera.projection_matrix.transpose(0, 1)\n\n    from IPython import embed\n\n    embed()\n\n    R = w2c[:3, :3].unsqueeze(0)  # [1, 3, 3]\n    t = w2c[:3, 3].unsqueeze(0)  # [1, 3]\n\n    # [N, 3, 1]\n    pts = torch.cat([points_3d, torch.ones_like(points_3d[:, 0:1])], dim=-1)\n    pts_cam = w2c.unsqueeze(0) @ pts.unsqueeze(-1)  # [N, 4, 1]\n    # pts_cam = R @ (pc._xyz.unsqueeze(-1)) + t[:, None]\n    depth = pts_cam[:, 2, 0]  # [N]\n    # print(\"depth\", depth.shape, depth.max(), depth.mean(), depth.min())\n\n    # [N, 2]\n    pts_cam_xy = pts_cam[:, :2, 0] / depth.unsqueeze(-1)\n\n    pts_cam_xy_pixel = cam_plane_2_img.unsqueeze(0) @ pts_cam_xy.unsqueeze(\n        -1\n    )  # [N, 2, 1]\n    pts_cam_xy_pixel = pts_cam_xy_pixel.squeeze(-1)  # [N, 2]\n\n    #  [N, 2] yx => xy\n    pts_cam_xy_pixel = torch.cat(\n        [pts_cam_xy_pixel[:, [1]], pts_cam_xy_pixel[:, [0]]], dim=-1\n    )\n\n    return pts_cam_xy_pixel\n\n\n# for spherecal harmonics\n\n\nC0 = 0.28209479177387814\nC1 = 0.4886025119029199\nC2 = [\n    1.0925484305920792,\n    -1.0925484305920792,\n    0.31539156525252005,\n    -1.0925484305920792,\n    0.5462742152960396,\n]\nC3 = [\n    -0.5900435899266435,\n    2.890611442640554,\n    -0.4570457994644658,\n    0.3731763325901154,\n    -0.4570457994644658,\n    1.445305721320277,\n    -0.5900435899266435,\n]\nC4 = [\n    2.5033429417967046,\n    -1.7701307697799304,\n    0.9461746957575601,\n    -0.6690465435572892,\n    0.10578554691520431,\n    -0.6690465435572892,\n    0.47308734787878004,\n    -1.7701307697799304,\n    0.6258357354491761,\n]\n\n\ndef eval_sh(deg, sh, dirs):\n    \"\"\"\n    Evaluate spherical harmonics at unit directions\n    using hardcoded SH polynomials.\n    Works with torch/np/jnp.\n    ... Can be 0 or more batch dimensions.\n    Args:\n        deg: int SH deg. Currently, 0-3 supported\n        sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]\n        dirs: jnp.ndarray unit directions [..., 3]\n    Returns:\n        [..., C]\n    \"\"\"\n    assert deg <= 4 and deg >= 0\n    coeff = (deg + 1) ** 2\n    assert sh.shape[-1] >= coeff\n\n    result = C0 * sh[..., 0]\n    if deg > 0:\n        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]\n        result = (\n            result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3]\n        )\n\n        if deg > 1:\n            xx, yy, zz = x * x, y * y, z * z\n            xy, yz, xz = x * y, y * z, x * z\n            result = (\n                result\n                + C2[0] * xy * sh[..., 4]\n                + C2[1] * yz * sh[..., 5]\n                + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6]\n                + C2[3] * xz * sh[..., 7]\n                + C2[4] * (xx - yy) * sh[..., 8]\n            )\n\n            if deg > 2:\n                result = (\n                    result\n                    + C3[0] * y * (3 * xx - yy) * sh[..., 9]\n                    + C3[1] * xy * z * sh[..., 10]\n                    + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11]\n                    + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12]\n                    + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13]\n                    + C3[5] * z * (xx - yy) * sh[..., 14]\n                    + C3[6] * x * (xx - 3 * yy) * sh[..., 15]\n                )\n\n                if deg > 3:\n                    result = (\n                        result\n                        + C4[0] * xy * (xx - yy) * sh[..., 16]\n                        + C4[1] * yz * (3 * xx - yy) * sh[..., 17]\n                        + C4[2] * xy * (7 * zz - 1) * sh[..., 18]\n                        + C4[3] * yz * (7 * zz - 3) * sh[..., 19]\n                        + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20]\n                        + C4[5] * xz * (7 * zz - 3) * sh[..., 21]\n                        + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22]\n                        + C4[7] * xz * (xx - 3 * yy) * sh[..., 23]\n                        + C4[8]\n                        * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))\n                        * sh[..., 24]\n                    )\n    return result\n\n\ndef RGB2SH(rgb):\n    return (rgb - 0.5) / C0\n\n\ndef SH2RGB(sh):\n    return sh * C0 + 0.5\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/scene/__init__.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport random\nimport numpy as np\nimport json\nfrom motionrep.gaussian_3d.utils.system_utils import searchForMaxIteration\nfrom motionrep.gaussian_3d.scene.dataset_readers import sceneLoadTypeCallbacks\nfrom motionrep.gaussian_3d.scene.gaussian_model import GaussianModel\nfrom motionrep.gaussian_3d.arguments import ModelParams\nfrom motionrep.gaussian_3d.utils.camera_utils import (\n    cameraList_from_camInfos,\n    camera_to_JSON,\n)\n\n\nclass Scene:\n    gaussians: GaussianModel\n\n    def __init__(\n        self,\n        args: ModelParams,\n        gaussians: GaussianModel,\n        load_iteration=None,\n        shuffle=True,\n        resolution_scales=[1.0],\n    ):\n        \"\"\"b\n        :param path: Path to colmap scene main folder.\n        \"\"\"\n        self.model_path = args.model_path\n        self.loaded_iter = None\n        self.gaussians = gaussians\n\n        if load_iteration:\n            if load_iteration == -1:\n                self.loaded_iter = searchForMaxIteration(\n                    os.path.join(self.model_path, \"point_cloud\")\n                )\n            else:\n                self.loaded_iter = load_iteration\n            print(\"Loading trained model at iteration {}\".format(self.loaded_iter))\n\n        self.train_cameras = {}\n        self.test_cameras = {}\n\n        if os.path.exists(os.path.join(args.source_path, \"sparse\")):\n            scene_info = sceneLoadTypeCallbacks[\"Colmap\"](\n                args.source_path, args.images, args.eval\n            )\n        elif os.path.exists(os.path.join(args.source_path, \"transforms_train.json\")):\n            print(\"Found transforms_train.json file, assuming Blender data set!\")\n            scene_info = sceneLoadTypeCallbacks[\"Blender\"](\n                args.source_path, args.white_background, args.eval\n            )\n        else:\n            assert False, \"Could not recognize scene type!\"\n\n        if not self.loaded_iter:\n            with open(scene_info.ply_path, \"rb\") as src_file, open(\n                os.path.join(self.model_path, \"input.ply\"), \"wb\"\n            ) as dest_file:\n                dest_file.write(src_file.read())\n            json_cams = []\n            camlist = []\n            if scene_info.test_cameras:\n                camlist.extend(scene_info.test_cameras)\n            if scene_info.train_cameras:\n                camlist.extend(scene_info.train_cameras)\n            for id, cam in enumerate(camlist):\n                json_cams.append(camera_to_JSON(id, cam))\n            with open(os.path.join(self.model_path, \"cameras.json\"), \"w\") as file:\n                json.dump(json_cams, file)\n\n        if shuffle:\n            random.shuffle(\n                scene_info.train_cameras\n            )  # Multi-res consistent random shuffling\n            random.shuffle(\n                scene_info.test_cameras\n            )  # Multi-res consistent random shuffling\n\n        self.cameras_extent = scene_info.nerf_normalization[\"radius\"]\n\n        for resolution_scale in resolution_scales:\n            print(\"Loading Training Cameras\")\n            self.train_cameras[resolution_scale] = cameraList_from_camInfos(\n                scene_info.train_cameras, resolution_scale, args\n            )\n            print(\"Loading Test Cameras\")\n            self.test_cameras[resolution_scale] = cameraList_from_camInfos(\n                scene_info.test_cameras, resolution_scale, args\n            )\n\n        if self.loaded_iter:\n            self.gaussians.load_ply(\n                os.path.join(\n                    self.model_path,\n                    \"point_cloud\",\n                    \"iteration_\" + str(self.loaded_iter),\n                    \"point_cloud.ply\",\n                )\n            )\n        else:\n            self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)\n\n    def save(self, iteration):\n        point_cloud_path = os.path.join(\n            self.model_path, \"point_cloud/iteration_{}\".format(iteration)\n        )\n        self.gaussians.save_ply(os.path.join(point_cloud_path, \"point_cloud.ply\"))\n\n    def getTrainCameras(self, scale=1.0):\n        return self.train_cameras[scale]\n\n    def getTestCameras(self, scale=1.0):\n        return self.test_cameras[scale]\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/scene/cameras.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nfrom torch import nn\nimport numpy as np\nfrom motionrep.gaussian_3d.utils.graphics_utils import (\n    getWorld2View2,\n    getProjectionMatrix,\n)\n\n\nclass Camera(nn.Module):\n    def __init__(\n        self,\n        colmap_id,\n        R,\n        T,\n        FoVx,\n        FoVy,\n        image,\n        gt_alpha_mask,\n        image_name,\n        uid,\n        trans=np.array([0.0, 0.0, 0.0]),\n        scale=1.0,\n        data_device=\"cuda\",\n    ):\n        super(Camera, self).__init__()\n\n        self.uid = uid\n        self.colmap_id = colmap_id\n        self.R = R\n        self.T = T\n        self.FoVx = FoVx\n        self.FoVy = FoVy\n        self.image_name = image_name\n\n        try:\n            self.data_device = torch.device(data_device)\n        except Exception as e:\n            print(e)\n            print(\n                f\"[Warning] Custom device {data_device} failed, fallback to default cuda device\"\n            )\n            self.data_device = torch.device(\"cuda\")\n\n        self.original_image = image.clamp(0.0, 1.0).to(self.data_device)\n        self.image_width = self.original_image.shape[2]\n        self.image_height = self.original_image.shape[1]\n\n        if gt_alpha_mask is not None:\n            self.original_image *= gt_alpha_mask.to(self.data_device)\n        else:\n            self.original_image *= torch.ones(\n                (1, self.image_height, self.image_width), device=self.data_device\n            )\n\n        self.zfar = 100.0\n        self.znear = 0.01\n\n        self.trans = trans\n        self.scale = scale\n\n        self.world_view_transform = (\n            torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()\n        )\n        self.projection_matrix = (\n            getProjectionMatrix(\n                znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy\n            )\n            .transpose(0, 1)\n            .cuda()\n        )\n        self.full_proj_transform = (\n            self.world_view_transform.unsqueeze(0).bmm(\n                self.projection_matrix.unsqueeze(0)\n            )\n        ).squeeze(0)\n        self.camera_center = self.world_view_transform.inverse()[3, :3]\n\n\nclass MiniCam:\n    def __init__(\n        self,\n        width,\n        height,\n        fovy,\n        fovx,\n        znear,\n        zfar,\n        world_view_transform,\n        full_proj_transform,\n    ):\n        self.image_width = width\n        self.image_height = height\n        self.FoVy = fovy\n        self.FoVx = fovx\n        self.znear = znear\n        self.zfar = zfar\n        self.world_view_transform = world_view_transform\n        self.full_proj_transform = full_proj_transform\n        view_inv = torch.inverse(self.world_view_transform)\n        self.camera_center = view_inv[3][:3]\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/scene/colmap_loader.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport numpy as np\nimport collections\nimport struct\n\nCameraModel = collections.namedtuple(\n    \"CameraModel\", [\"model_id\", \"model_name\", \"num_params\"])\nCamera = collections.namedtuple(\n    \"Camera\", [\"id\", \"model\", \"width\", \"height\", \"params\"])\nBaseImage = collections.namedtuple(\n    \"Image\", [\"id\", \"qvec\", \"tvec\", \"camera_id\", \"name\", \"xys\", \"point3D_ids\"])\nPoint3D = collections.namedtuple(\n    \"Point3D\", [\"id\", \"xyz\", \"rgb\", \"error\", \"image_ids\", \"point2D_idxs\"])\nCAMERA_MODELS = {\n    CameraModel(model_id=0, model_name=\"SIMPLE_PINHOLE\", num_params=3),\n    CameraModel(model_id=1, model_name=\"PINHOLE\", num_params=4),\n    CameraModel(model_id=2, model_name=\"SIMPLE_RADIAL\", num_params=4),\n    CameraModel(model_id=3, model_name=\"RADIAL\", num_params=5),\n    CameraModel(model_id=4, model_name=\"OPENCV\", num_params=8),\n    CameraModel(model_id=5, model_name=\"OPENCV_FISHEYE\", num_params=8),\n    CameraModel(model_id=6, model_name=\"FULL_OPENCV\", num_params=12),\n    CameraModel(model_id=7, model_name=\"FOV\", num_params=5),\n    CameraModel(model_id=8, model_name=\"SIMPLE_RADIAL_FISHEYE\", num_params=4),\n    CameraModel(model_id=9, model_name=\"RADIAL_FISHEYE\", num_params=5),\n    CameraModel(model_id=10, model_name=\"THIN_PRISM_FISHEYE\", num_params=12)\n}\nCAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)\n                         for camera_model in CAMERA_MODELS])\nCAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)\n                           for camera_model in CAMERA_MODELS])\n\n\ndef qvec2rotmat(qvec):\n    return np.array([\n        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,\n         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],\n         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],\n        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],\n         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,\n         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],\n        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],\n         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],\n         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])\n\ndef rotmat2qvec(R):\n    Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat\n    K = np.array([\n        [Rxx - Ryy - Rzz, 0, 0, 0],\n        [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],\n        [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],\n        [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0\n    eigvals, eigvecs = np.linalg.eigh(K)\n    qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]\n    if qvec[0] < 0:\n        qvec *= -1\n    return qvec\n\nclass Image(BaseImage):\n    def qvec2rotmat(self):\n        return qvec2rotmat(self.qvec)\n\ndef read_next_bytes(fid, num_bytes, format_char_sequence, endian_character=\"<\"):\n    \"\"\"Read and unpack the next bytes from a binary file.\n    :param fid:\n    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.\n    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.\n    :param endian_character: Any of {@, =, <, >, !}\n    :return: Tuple of read and unpacked values.\n    \"\"\"\n    data = fid.read(num_bytes)\n    return struct.unpack(endian_character + format_char_sequence, data)\n\ndef read_points3D_text(path):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DText(const std::string& path)\n        void Reconstruction::WritePoints3DText(const std::string& path)\n    \"\"\"\n    xyzs = None\n    rgbs = None\n    errors = None\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                xyz = np.array(tuple(map(float, elems[1:4])))\n                rgb = np.array(tuple(map(int, elems[4:7])))\n                error = np.array(float(elems[7]))\n                if xyzs is None:\n                    xyzs = xyz[None, ...]\n                    rgbs = rgb[None, ...]\n                    errors = error[None, ...]\n                else:\n                    xyzs = np.append(xyzs, xyz[None, ...], axis=0)\n                    rgbs = np.append(rgbs, rgb[None, ...], axis=0)\n                    errors = np.append(errors, error[None, ...], axis=0)\n    return xyzs, rgbs, errors\n\ndef read_points3D_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DBinary(const std::string& path)\n        void Reconstruction::WritePoints3DBinary(const std::string& path)\n    \"\"\"\n\n\n    with open(path_to_model_file, \"rb\") as fid:\n        num_points = read_next_bytes(fid, 8, \"Q\")[0]\n\n        xyzs = np.empty((num_points, 3))\n        rgbs = np.empty((num_points, 3))\n        errors = np.empty((num_points, 1))\n\n        for p_id in range(num_points):\n            binary_point_line_properties = read_next_bytes(\n                fid, num_bytes=43, format_char_sequence=\"QdddBBBd\")\n            xyz = np.array(binary_point_line_properties[1:4])\n            rgb = np.array(binary_point_line_properties[4:7])\n            error = np.array(binary_point_line_properties[7])\n            track_length = read_next_bytes(\n                fid, num_bytes=8, format_char_sequence=\"Q\")[0]\n            track_elems = read_next_bytes(\n                fid, num_bytes=8*track_length,\n                format_char_sequence=\"ii\"*track_length)\n            xyzs[p_id] = xyz\n            rgbs[p_id] = rgb\n            errors[p_id] = error\n    return xyzs, rgbs, errors\n\ndef read_intrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    cameras = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                camera_id = int(elems[0])\n                model = elems[1]\n                assert model == \"PINHOLE\", \"While the loader support other types, the rest of the code assumes PINHOLE\"\n                width = int(elems[2])\n                height = int(elems[3])\n                params = np.array(tuple(map(float, elems[4:])))\n                cameras[camera_id] = Camera(id=camera_id, model=model,\n                                            width=width, height=height,\n                                            params=params)\n    return cameras\n\ndef read_extrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadImagesBinary(const std::string& path)\n        void Reconstruction::WriteImagesBinary(const std::string& path)\n    \"\"\"\n    images = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_reg_images = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_reg_images):\n            binary_image_properties = read_next_bytes(\n                fid, num_bytes=64, format_char_sequence=\"idddddddi\")\n            image_id = binary_image_properties[0]\n            qvec = np.array(binary_image_properties[1:5])\n            tvec = np.array(binary_image_properties[5:8])\n            camera_id = binary_image_properties[8]\n            image_name = \"\"\n            current_char = read_next_bytes(fid, 1, \"c\")[0]\n            while current_char != b\"\\x00\":   # look for the ASCII 0 entry\n                image_name += current_char.decode(\"utf-8\")\n                current_char = read_next_bytes(fid, 1, \"c\")[0]\n            num_points2D = read_next_bytes(fid, num_bytes=8,\n                                           format_char_sequence=\"Q\")[0]\n            x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,\n                                       format_char_sequence=\"ddq\"*num_points2D)\n            xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),\n                                   tuple(map(float, x_y_id_s[1::3]))])\n            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))\n            images[image_id] = Image(\n                id=image_id, qvec=qvec, tvec=tvec,\n                camera_id=camera_id, name=image_name,\n                xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_intrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::WriteCamerasBinary(const std::string& path)\n        void Reconstruction::ReadCamerasBinary(const std::string& path)\n    \"\"\"\n    cameras = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_cameras = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_cameras):\n            camera_properties = read_next_bytes(\n                fid, num_bytes=24, format_char_sequence=\"iiQQ\")\n            camera_id = camera_properties[0]\n            model_id = camera_properties[1]\n            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name\n            width = camera_properties[2]\n            height = camera_properties[3]\n            num_params = CAMERA_MODEL_IDS[model_id].num_params\n            params = read_next_bytes(fid, num_bytes=8*num_params,\n                                     format_char_sequence=\"d\"*num_params)\n            cameras[camera_id] = Camera(id=camera_id,\n                                        model=model_name,\n                                        width=width,\n                                        height=height,\n                                        params=np.array(params))\n        assert len(cameras) == num_cameras\n    return cameras\n\n\ndef read_extrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    images = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                image_id = int(elems[0])\n                qvec = np.array(tuple(map(float, elems[1:5])))\n                tvec = np.array(tuple(map(float, elems[5:8])))\n                camera_id = int(elems[8])\n                image_name = elems[9]\n                elems = fid.readline().split()\n                xys = np.column_stack([tuple(map(float, elems[0::3])),\n                                       tuple(map(float, elems[1::3]))])\n                point3D_ids = np.array(tuple(map(int, elems[2::3])))\n                images[image_id] = Image(\n                    id=image_id, qvec=qvec, tvec=tvec,\n                    camera_id=camera_id, name=image_name,\n                    xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_colmap_bin_array(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py\n\n    :param path: path to the colmap binary file.\n    :return: nd array with the floating point values in the value\n    \"\"\"\n    with open(path, \"rb\") as fid:\n        width, height, channels = np.genfromtxt(fid, delimiter=\"&\", max_rows=1,\n                                                usecols=(0, 1, 2), dtype=int)\n        fid.seek(0)\n        num_delimiter = 0\n        byte = fid.read(1)\n        while True:\n            if byte == b\"&\":\n                num_delimiter += 1\n                if num_delimiter >= 3:\n                    break\n            byte = fid.read(1)\n        array = np.fromfile(fid, np.float32)\n    array = array.reshape((width, height, channels), order=\"F\")\n    return np.transpose(array, (1, 0, 2)).squeeze()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/scene/dataset_readers.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport sys\nfrom PIL import Image\nfrom typing import NamedTuple\nfrom motionrep.gaussian_3d.scene.colmap_loader import (\n    read_extrinsics_text,\n    read_intrinsics_text,\n    qvec2rotmat,\n    read_extrinsics_binary,\n    read_intrinsics_binary,\n    read_points3D_binary,\n    read_points3D_text,\n)\nfrom motionrep.gaussian_3d.utils.graphics_utils import (\n    getWorld2View2,\n    focal2fov,\n    fov2focal,\n)\nimport numpy as np\nimport math\nimport json\nfrom pathlib import Path\nfrom plyfile import PlyData, PlyElement\nfrom motionrep.gaussian_3d.utils.sh_utils import SH2RGB\nfrom motionrep.gaussian_3d.scene.gaussian_model import BasicPointCloud\nimport torch\nimport torch.nn as nn\nfrom motionrep.gaussian_3d.utils.graphics_utils import (\n    getWorld2View2,\n    getProjectionMatrix,\n)\n\n\nclass CameraInfo(NamedTuple):\n    uid: int\n    R: np.array\n    T: np.array\n    FovY: np.array\n    FovX: np.array\n    image: np.array\n    image_path: str\n    image_name: str\n    width: int\n    height: int\n\n\nclass SceneInfo(NamedTuple):\n    point_cloud: BasicPointCloud\n    train_cameras: list\n    test_cameras: list\n    nerf_normalization: dict\n    ply_path: str\n\n\ndef getNerfppNorm(cam_info):\n    def get_center_and_diag(cam_centers):\n        cam_centers = np.hstack(cam_centers)\n        avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)\n        center = avg_cam_center\n        dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)\n        diagonal = np.max(dist)\n        return center.flatten(), diagonal\n\n    cam_centers = []\n\n    for cam in cam_info:\n        W2C = getWorld2View2(cam.R, cam.T)\n        C2W = np.linalg.inv(W2C)\n        cam_centers.append(C2W[:3, 3:4])\n\n    center, diagonal = get_center_and_diag(cam_centers)\n    radius = diagonal * 1.1\n\n    translate = -center\n\n    return {\"translate\": translate, \"radius\": radius}\n\n\ndef readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):\n    cam_infos = []\n    for idx, key in enumerate(cam_extrinsics):\n        sys.stdout.write(\"\\r\")\n        # the exact output you're looking for:\n        sys.stdout.write(\"Reading camera {}/{}\".format(idx + 1, len(cam_extrinsics)))\n        sys.stdout.flush()\n\n        extr = cam_extrinsics[key]\n        intr = cam_intrinsics[extr.camera_id]\n        height = intr.height\n        width = intr.width\n\n        uid = intr.id\n        R = np.transpose(qvec2rotmat(extr.qvec))\n        T = np.array(extr.tvec)\n\n        if intr.model == \"SIMPLE_PINHOLE\":\n            focal_length_x = intr.params[0]\n            FovY = focal2fov(focal_length_x, height)\n            FovX = focal2fov(focal_length_x, width)\n        elif intr.model == \"PINHOLE\":\n            focal_length_x = intr.params[0]\n            focal_length_y = intr.params[1]\n            FovY = focal2fov(focal_length_y, height)\n            FovX = focal2fov(focal_length_x, width)\n        else:\n            assert (\n                False\n            ), \"Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!\"\n\n        image_path = os.path.join(images_folder, os.path.basename(extr.name))\n        image_name = os.path.basename(image_path).split(\".\")[0]\n        image = Image.open(image_path)\n\n        cam_info = CameraInfo(\n            uid=uid,\n            R=R,\n            T=T,\n            FovY=FovY,\n            FovX=FovX,\n            image=image,\n            image_path=image_path,\n            image_name=image_name,\n            width=width,\n            height=height,\n        )\n        cam_infos.append(cam_info)\n    sys.stdout.write(\"\\n\")\n    return cam_infos\n\n\ndef fetchPly(path):\n    plydata = PlyData.read(path)\n    vertices = plydata[\"vertex\"]\n    positions = np.vstack([vertices[\"x\"], vertices[\"y\"], vertices[\"z\"]]).T\n    colors = np.vstack([vertices[\"red\"], vertices[\"green\"], vertices[\"blue\"]]).T / 255.0\n    normals = np.vstack([vertices[\"nx\"], vertices[\"ny\"], vertices[\"nz\"]]).T\n    return BasicPointCloud(points=positions, colors=colors, normals=normals)\n\n\ndef storePly(path, xyz, rgb):\n    # Define the dtype for the structured array\n    dtype = [\n        (\"x\", \"f4\"),\n        (\"y\", \"f4\"),\n        (\"z\", \"f4\"),\n        (\"nx\", \"f4\"),\n        (\"ny\", \"f4\"),\n        (\"nz\", \"f4\"),\n        (\"red\", \"u1\"),\n        (\"green\", \"u1\"),\n        (\"blue\", \"u1\"),\n    ]\n\n    normals = np.zeros_like(xyz)\n\n    elements = np.empty(xyz.shape[0], dtype=dtype)\n    attributes = np.concatenate((xyz, normals, rgb), axis=1)\n    elements[:] = list(map(tuple, attributes))\n\n    # Create the PlyData object and write to file\n    vertex_element = PlyElement.describe(elements, \"vertex\")\n    ply_data = PlyData([vertex_element])\n    ply_data.write(path)\n\n\ndef readColmapSceneInfo(path, images, eval, llffhold=8):\n    try:\n        cameras_extrinsic_file = os.path.join(path, \"sparse/0\", \"images.bin\")\n        cameras_intrinsic_file = os.path.join(path, \"sparse/0\", \"cameras.bin\")\n        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)\n        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)\n    except:\n        cameras_extrinsic_file = os.path.join(path, \"sparse/0\", \"images.txt\")\n        cameras_intrinsic_file = os.path.join(path, \"sparse/0\", \"cameras.txt\")\n        cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)\n        cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)\n\n    reading_dir = \"images\" if images == None else images\n    cam_infos_unsorted = readColmapCameras(\n        cam_extrinsics=cam_extrinsics,\n        cam_intrinsics=cam_intrinsics,\n        images_folder=os.path.join(path, reading_dir),\n    )\n    cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name)\n\n    if eval:\n        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]\n        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]\n    else:\n        train_cam_infos = cam_infos\n        test_cam_infos = []\n\n    nerf_normalization = getNerfppNorm(train_cam_infos)\n\n    ply_path = os.path.join(path, \"sparse/0/points3D.ply\")\n    bin_path = os.path.join(path, \"sparse/0/points3D.bin\")\n    txt_path = os.path.join(path, \"sparse/0/points3D.txt\")\n    if not os.path.exists(ply_path):\n        print(\n            \"Converting point3d.bin to .ply, will happen only the first time you open the scene.\"\n        )\n        try:\n            xyz, rgb, _ = read_points3D_binary(bin_path)\n        except:\n            xyz, rgb, _ = read_points3D_text(txt_path)\n        storePly(ply_path, xyz, rgb)\n    try:\n        pcd = fetchPly(ply_path)\n    except:\n        pcd = None\n\n    scene_info = SceneInfo(\n        point_cloud=pcd,\n        train_cameras=train_cam_infos,\n        test_cameras=test_cam_infos,\n        nerf_normalization=nerf_normalization,\n        ply_path=ply_path,\n    )\n    return scene_info\n\n\ndef readCamerasFromTransforms(path, transformsfile, white_background, extension=\".png\"):\n    cam_infos = []\n\n    with open(os.path.join(path, transformsfile)) as json_file:\n        contents = json.load(json_file)\n\n        # camera_angle_x is the horizontal field of view\n        # frames.file_path is the image name\n        # frame.transform_matrix is the camera-to-world transform\n\n        fovx = contents[\"camera_angle_x\"]\n\n        frames = contents[\"frames\"]\n        for idx, frame in enumerate(frames):\n            cam_name = os.path.join(path, frame[\"file_path\"] + extension)\n\n            # NeRF 'transform_matrix' is a camera-to-world transform\n            c2w = np.array(frame[\"transform_matrix\"])\n            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)\n            c2w[:3, 1:3] *= -1\n\n            # get the world-to-camera transform and set R, T\n            w2c = np.linalg.inv(c2w)\n            R = np.transpose(\n                w2c[:3, :3]\n            )  # R is stored transposed due to 'glm' in CUDA code\n            T = w2c[:3, 3]\n\n            image_path = os.path.join(path, cam_name)\n            image_name = Path(cam_name).stem\n            image = Image.open(image_path)\n\n            im_data = np.array(image.convert(\"RGBA\"))\n\n            bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0])\n\n            norm_data = im_data / 255.0\n            arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * (\n                1 - norm_data[:, :, 3:4]\n            )\n            image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), \"RGB\")\n\n            fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])\n            FovY = fovy\n            FovX = fovx\n\n            cam_infos.append(\n                CameraInfo(\n                    uid=idx,\n                    R=R,\n                    T=T,\n                    FovY=FovY,\n                    FovX=FovX,\n                    image=image,\n                    image_path=image_path,\n                    image_name=image_name,\n                    width=image.size[0],\n                    height=image.size[1],\n                )\n            )\n\n    return cam_infos\n\n\ndef readNerfSyntheticInfo(path, white_background, eval, extension=\".png\"):\n    print(\"Reading Training Transforms\")\n    train_cam_infos = readCamerasFromTransforms(\n        path, \"transforms_train.json\", white_background, extension\n    )\n    print(\"Reading Test Transforms\")\n    test_cam_infos = readCamerasFromTransforms(\n        path, \"transforms_test.json\", white_background, extension\n    )\n\n    if not eval:\n        train_cam_infos.extend(test_cam_infos)\n        test_cam_infos = []\n\n    nerf_normalization = getNerfppNorm(train_cam_infos)\n\n    ply_path = os.path.join(path, \"points3d.ply\")\n    if not os.path.exists(ply_path):\n        # Since this data set has no colmap data, we start with random points\n        num_pts = 100_000\n        print(f\"Generating random point cloud ({num_pts})...\")\n\n        # We create random points inside the bounds of the synthetic Blender scenes\n        xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3\n        shs = np.random.random((num_pts, 3)) / 255.0\n        pcd = BasicPointCloud(\n            points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))\n        )\n\n        storePly(ply_path, xyz, SH2RGB(shs) * 255)\n    try:\n        pcd = fetchPly(ply_path)\n    except:\n        pcd = None\n\n    scene_info = SceneInfo(\n        point_cloud=pcd,\n        train_cameras=train_cam_infos,\n        test_cameras=test_cam_infos,\n        nerf_normalization=nerf_normalization,\n        ply_path=ply_path,\n    )\n    return scene_info\n\n\nsceneLoadTypeCallbacks = {\n    \"Colmap\": readColmapSceneInfo,\n    \"Blender\": readNerfSyntheticInfo,\n}\n\n\n# below used for easy rendering\nclass NoImageCamera(nn.Module):\n    def __init__(\n        self,\n        colmap_id,\n        R,\n        T,\n        FoVx,\n        FoVy,\n        width,\n        height,\n        uid,\n        trans=np.array([0.0, 0.0, 0.0]),\n        scale=1.0,\n        data_device=\"cuda\",\n        img_path=None, # not needed\n    ):\n        super(NoImageCamera, self).__init__()\n\n        self.uid = uid\n        self.colmap_id = colmap_id\n        self.R = R\n        self.T = T\n        self.FoVx = FoVx\n        self.FoVy = FoVy\n        self.img_path = img_path \n        \n        try:\n            self.data_device = torch.device(data_device)\n        except Exception as e:\n            print(e)\n            print(\n                f\"[Warning] Custom device {data_device} failed, fallback to default cuda device\"\n            )\n            self.data_device = torch.device(\"cuda\")\n\n        self.image_width = width\n        self.image_height = height\n\n        self.zfar = 100.0\n        self.znear = 0.01\n\n        self.trans = trans\n        self.scale = scale\n\n        # world to camera, then transpose.  # [4, 4]\n        #  w2c.transpose\n        self.world_view_transform = (\n            torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()\n        )\n\n        # [4, 4]  \n        self.projection_matrix = (\n            getProjectionMatrix(\n                znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy\n            )\n            .transpose(0, 1)\n            .cuda()\n        )\n\n        # # [4, 4].  points @ full_proj_transform => screen space. \n        self.full_proj_transform = (\n            self.world_view_transform.unsqueeze(0).bmm(\n                self.projection_matrix.unsqueeze(0)\n            )\n        ).squeeze(0)\n        self.camera_center = self.world_view_transform.inverse()[3, :3]\n\n        # [2, 2].  \n        #  (w2c @ p) / depth => cam_plane\n        #  (p_in_cam / depth)[:2] @  cam_plane_2_img => [pixel_x, pixel_y]    cam_plane => img_plane \n        self.cam_plane_2_img = torch.tensor(\n            [[ 0.5 * width / math.tan(self.FoVx / 2.0), 0.0], \n             [0.0, 0.5 * height / math.tan(self.FoVy / 2.0)]]\n        ).cuda()\n\n\ndef fast_read_cameras_from_transform_file(file_path, width=1080, height=720):\n    cam_infos = []  \n\n    dir_name = os.path.dirname(file_path)\n\n    with open(file_path) as json_file:\n        contents = json.load(json_file)\n\n        # camera_angle_x is the horizontal field of view\n        # frames.file_path is the image name\n        # frame.transform_matrix is the camera-to-world transform\n\n        fovx = contents[\"camera_angle_x\"]\n\n        frames = contents[\"frames\"]\n        for idx, frame in enumerate(frames):\n            # NeRF 'transform_matrix' is a camera-to-world transform\n            c2w = np.array(frame[\"transform_matrix\"])\n            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)\n            c2w[:3, 1:3] *= -1\n\n            # get the world-to-camera transform and set R, T\n            w2c = np.linalg.inv(c2w)\n            R = np.transpose(\n                w2c[:3, :3]\n            )  # R is stored transposed due to 'glm' in CUDA code\n            T = w2c[:3, 3]\n\n            fovy = focal2fov(fov2focal(fovx, width), height)\n            FovY = fovy\n            FovX = fovx\n\n            img_path = os.path.join(dir_name, frame[\"file_path\"] + \".png\")\n            cam_ = NoImageCamera(\n                colmap_id=idx,\n                R=R,\n                T=T,\n                FoVx=FovX,\n                FoVy=FovY,\n                width=width,\n                height=height,\n                uid=id,\n                data_device=\"cuda\",\n                img_path=img_path,\n            )\n\n            cam_infos.append(cam_)\n\n    return cam_infos\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/scene/gaussian_model.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport numpy as np\nfrom motionrep.gaussian_3d.utils.general_utils import (\n    inverse_sigmoid,\n    get_expon_lr_func,\n    build_rotation,\n)\nfrom torch import nn\nimport os\nfrom motionrep.gaussian_3d.utils.system_utils import mkdir_p\nfrom plyfile import PlyData, PlyElement\nfrom motionrep.gaussian_3d.utils.sh_utils import RGB2SH\nfrom simple_knn._C import distCUDA2\nfrom motionrep.gaussian_3d.utils.graphics_utils import BasicPointCloud\nfrom motionrep.gaussian_3d.utils.general_utils import (\n    strip_symmetric,\n    build_scaling_rotation,\n)\nfrom motionrep.gaussian_3d.utils.rigid_body_utils import (\n    get_rigid_transform,\n    matrix_to_quaternion,\n    quaternion_multiply,\n)\n\n\nclass GaussianModel:\n    def setup_functions(self):\n        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):\n            L = build_scaling_rotation(scaling_modifier * scaling, rotation)\n            actual_covariance = L @ L.transpose(1, 2)\n            symm = strip_symmetric(actual_covariance)\n            return symm\n\n        self.scaling_activation = torch.exp\n        self.scaling_inverse_activation = torch.log\n\n        self.covariance_activation = build_covariance_from_scaling_rotation\n\n        self.opacity_activation = torch.sigmoid\n        self.inverse_opacity_activation = inverse_sigmoid\n\n        self.rotation_activation = torch.nn.functional.normalize\n\n    def __init__(self, sh_degree: int = 3):\n        self.active_sh_degree = 0\n        self.max_sh_degree = sh_degree\n        self._xyz = torch.empty(0)\n        self._features_dc = torch.empty(0)\n        self._features_rest = torch.empty(0)\n        self._scaling = torch.empty(0)\n        self._rotation = torch.empty(0)\n        self._opacity = torch.empty(0)\n        self.max_radii2D = torch.empty(0)\n        self.xyz_gradient_accum = torch.empty(0)\n        self.denom = torch.empty(0)\n        self.optimizer = None\n        self.percent_dense = 0\n        self.spatial_lr_scale = 0\n        self.setup_functions()\n\n        self.matched_inds = None\n\n    def capture(self):\n        if self.optimizer is None:\n            optim_state = None\n        else:\n            optim_state = self.optimizer.state_dict()\n\n        return (\n            self.active_sh_degree,\n            self._xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            self._rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            optim_state,\n            self.spatial_lr_scale,\n        )\n\n    def restore(self, model_args, training_args):\n        (\n            self.active_sh_degree,\n            self._xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            self._rotation,\n            self._opacity,\n            self.max_radii2D,\n            xyz_gradient_accum,\n            denom,\n            opt_dict,\n            self.spatial_lr_scale,\n        ) = model_args\n\n        if training_args is not None:\n            self.training_setup(training_args)\n        self.xyz_gradient_accum = xyz_gradient_accum\n        self.denom = denom\n        if opt_dict is not None:\n            self.optimizer.load_state_dict(opt_dict)\n\n    def capture_training_args(\n        self,\n    ):\n        pass\n\n    @property\n    def get_scaling(self):\n        return self.scaling_activation(self._scaling)\n\n    @property\n    def get_rotation(self):\n        return self.rotation_activation(self._rotation)\n\n    @property\n    def get_xyz(self):\n        return self._xyz\n\n    @property\n    def get_features(self):\n        features_dc = self._features_dc\n        features_rest = self._features_rest\n        return torch.cat((features_dc, features_rest), dim=1)\n\n    @property\n    def get_opacity(self):\n        return self.opacity_activation(self._opacity)\n\n    def get_covariance(self, scaling_modifier=1):\n        return self.covariance_activation(\n            self.get_scaling, scaling_modifier, self._rotation\n        )\n\n    def oneupSHdegree(self):\n        if self.active_sh_degree < self.max_sh_degree:\n            self.active_sh_degree += 1\n\n    def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):\n        self.spatial_lr_scale = spatial_lr_scale\n        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()\n        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())\n        features = (\n            torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2))\n            .float()\n            .cuda()\n        )\n        features[:, :3, 0] = fused_color\n        # typo here?\n        features[:, 3:, 1:] = 0.0\n\n        print(\"Number of points at initialisation : \", fused_point_cloud.shape[0])\n\n        dist2 = torch.clamp_min(\n            distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()),\n            0.0000001,\n        )\n        scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)\n        rots = torch.zeros((fused_point_cloud.shape[0], 4), device=\"cuda\")\n        rots[:, 0] = 1\n\n        opacities = inverse_sigmoid(\n            0.1\n            * torch.ones(\n                (fused_point_cloud.shape[0], 1), dtype=torch.float, device=\"cuda\"\n            )\n        )\n\n        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))\n        self._features_dc = nn.Parameter(\n            features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)\n        )\n        self._features_rest = nn.Parameter(\n            features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)\n        )\n        self._scaling = nn.Parameter(scales.requires_grad_(True))\n        self._rotation = nn.Parameter(rots.requires_grad_(True))\n        self._opacity = nn.Parameter(opacities.requires_grad_(True))\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n\n    def training_setup(self, training_args):\n        self.percent_dense = training_args.percent_dense\n        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n\n        l = [\n            {\n                \"params\": [self._xyz],\n                \"lr\": training_args.position_lr_init * self.spatial_lr_scale,\n                \"name\": \"xyz\",\n            },\n            {\n                \"params\": [self._features_dc],\n                \"lr\": training_args.feature_lr,\n                \"name\": \"f_dc\",\n            },\n            {\n                \"params\": [self._features_rest],\n                \"lr\": training_args.feature_lr / 20.0,\n                \"name\": \"f_rest\",\n            },\n            {\n                \"params\": [self._opacity],\n                \"lr\": training_args.opacity_lr,\n                \"name\": \"opacity\",\n            },\n            {\n                \"params\": [self._scaling],\n                \"lr\": training_args.scaling_lr,\n                \"name\": \"scaling\",\n            },\n            {\n                \"params\": [self._rotation],\n                \"lr\": training_args.rotation_lr,\n                \"name\": \"rotation\",\n            },\n        ]\n\n        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)\n        self.xyz_scheduler_args = get_expon_lr_func(\n            lr_init=training_args.position_lr_init * self.spatial_lr_scale,\n            lr_final=training_args.position_lr_final * self.spatial_lr_scale,\n            lr_delay_mult=training_args.position_lr_delay_mult,\n            max_steps=training_args.position_lr_max_steps,\n        )\n\n    def update_learning_rate(self, iteration):\n        \"\"\"Learning rate scheduling per step\"\"\"\n        for param_group in self.optimizer.param_groups:\n            if param_group[\"name\"] == \"xyz\":\n                lr = self.xyz_scheduler_args(iteration)\n                param_group[\"lr\"] = lr\n                return lr\n\n    def construct_list_of_attributes(self):\n        l = [\"x\", \"y\", \"z\", \"nx\", \"ny\", \"nz\"]\n        # All channels except the 3 DC\n        for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):\n            l.append(\"f_dc_{}\".format(i))\n        for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]):\n            l.append(\"f_rest_{}\".format(i))\n        l.append(\"opacity\")\n        for i in range(self._scaling.shape[1]):\n            l.append(\"scale_{}\".format(i))\n        for i in range(self._rotation.shape[1]):\n            l.append(\"rot_{}\".format(i))\n        return l\n\n    def save_ply(self, path):\n        mkdir_p(os.path.dirname(path))\n\n        xyz = self._xyz.detach().cpu().numpy()\n        normals = np.zeros_like(xyz)\n        f_dc = (\n            self._features_dc.detach()\n            .transpose(1, 2)\n            .flatten(start_dim=1)\n            .contiguous()\n            .cpu()\n            .numpy()\n        )\n        f_rest = (\n            self._features_rest.detach()\n            .transpose(1, 2)\n            .flatten(start_dim=1)\n            .contiguous()\n            .cpu()\n            .numpy()\n        )\n        opacities = self._opacity.detach().cpu().numpy()\n        scale = self._scaling.detach().cpu().numpy()\n        rotation = self._rotation.detach().cpu().numpy()\n\n        dtype_full = [\n            (attribute, \"f4\") for attribute in self.construct_list_of_attributes()\n        ]\n\n        elements = np.empty(xyz.shape[0], dtype=dtype_full)\n        attributes = np.concatenate(\n            (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1\n        )\n        elements[:] = list(map(tuple, attributes))\n        el = PlyElement.describe(elements, \"vertex\")\n        PlyData([el]).write(path)\n\n    def reset_opacity(self):\n        opacities_new = inverse_sigmoid(\n            torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01)\n        )\n        optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, \"opacity\")\n        self._opacity = optimizable_tensors[\"opacity\"]\n\n    def load_ply(self, path):\n        plydata = PlyData.read(path)\n\n        xyz = np.stack(\n            (\n                np.asarray(plydata.elements[0][\"x\"]),\n                np.asarray(plydata.elements[0][\"y\"]),\n                np.asarray(plydata.elements[0][\"z\"]),\n            ),\n            axis=1,\n        )\n        opacities = np.asarray(plydata.elements[0][\"opacity\"])[..., np.newaxis]\n\n        features_dc = np.zeros((xyz.shape[0], 3, 1))\n        features_dc[:, 0, 0] = np.asarray(plydata.elements[0][\"f_dc_0\"])\n        features_dc[:, 1, 0] = np.asarray(plydata.elements[0][\"f_dc_1\"])\n        features_dc[:, 2, 0] = np.asarray(plydata.elements[0][\"f_dc_2\"])\n\n        extra_f_names = [\n            p.name\n            for p in plydata.elements[0].properties\n            if p.name.startswith(\"f_rest_\")\n        ]\n        extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split(\"_\")[-1]))\n        assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3\n        features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))\n        for idx, attr_name in enumerate(extra_f_names):\n            features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])\n        # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)\n        features_extra = features_extra.reshape(\n            (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)\n        )\n\n        scale_names = [\n            p.name\n            for p in plydata.elements[0].properties\n            if p.name.startswith(\"scale_\")\n        ]\n        scale_names = sorted(scale_names, key=lambda x: int(x.split(\"_\")[-1]))\n        scales = np.zeros((xyz.shape[0], len(scale_names)))\n        for idx, attr_name in enumerate(scale_names):\n            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])\n\n        rot_names = [\n            p.name for p in plydata.elements[0].properties if p.name.startswith(\"rot\")\n        ]\n        rot_names = sorted(rot_names, key=lambda x: int(x.split(\"_\")[-1]))\n        rots = np.zeros((xyz.shape[0], len(rot_names)))\n        for idx, attr_name in enumerate(rot_names):\n            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])\n\n        self._xyz = nn.Parameter(\n            torch.tensor(xyz, dtype=torch.float, device=\"cuda\").requires_grad_(True)\n        )\n        self._features_dc = nn.Parameter(\n            torch.tensor(features_dc, dtype=torch.float, device=\"cuda\")\n            .transpose(1, 2)\n            .contiguous()\n            .requires_grad_(True)\n        )\n        self._features_rest = nn.Parameter(\n            torch.tensor(features_extra, dtype=torch.float, device=\"cuda\")\n            .transpose(1, 2)\n            .contiguous()\n            .requires_grad_(True)\n        )\n        self._opacity = nn.Parameter(\n            torch.tensor(opacities, dtype=torch.float, device=\"cuda\").requires_grad_(\n                True\n            )\n        )\n        self._scaling = nn.Parameter(\n            torch.tensor(scales, dtype=torch.float, device=\"cuda\").requires_grad_(True)\n        )\n        self._rotation = nn.Parameter(\n            torch.tensor(rots, dtype=torch.float, device=\"cuda\").requires_grad_(True)\n        )\n\n        self.active_sh_degree = self.max_sh_degree\n\n    def replace_tensor_to_optimizer(self, tensor, name):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            if group[\"name\"] == name:\n                stored_state = self.optimizer.state.get(group[\"params\"][0], None)\n                stored_state[\"exp_avg\"] = torch.zeros_like(tensor)\n                stored_state[\"exp_avg_sq\"] = torch.zeros_like(tensor)\n\n                del self.optimizer.state[group[\"params\"][0]]\n                group[\"params\"][0] = nn.Parameter(tensor.requires_grad_(True))\n                self.optimizer.state[group[\"params\"][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n        return optimizable_tensors\n\n    def _prune_optimizer(self, mask):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            stored_state = self.optimizer.state.get(group[\"params\"][0], None)\n            if stored_state is not None:\n                stored_state[\"exp_avg\"] = stored_state[\"exp_avg\"][mask]\n                stored_state[\"exp_avg_sq\"] = stored_state[\"exp_avg_sq\"][mask]\n\n                del self.optimizer.state[group[\"params\"][0]]\n                group[\"params\"][0] = nn.Parameter(\n                    (group[\"params\"][0][mask].requires_grad_(True))\n                )\n                self.optimizer.state[group[\"params\"][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n            else:\n                group[\"params\"][0] = nn.Parameter(\n                    group[\"params\"][0][mask].requires_grad_(True)\n                )\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n        return optimizable_tensors\n\n    def prune_points(self, mask):\n        valid_points_mask = ~mask\n        optimizable_tensors = self._prune_optimizer(valid_points_mask)\n\n        self._xyz = optimizable_tensors[\"xyz\"]\n        self._features_dc = optimizable_tensors[\"f_dc\"]\n        self._features_rest = optimizable_tensors[\"f_rest\"]\n        self._opacity = optimizable_tensors[\"opacity\"]\n        self._scaling = optimizable_tensors[\"scaling\"]\n        self._rotation = optimizable_tensors[\"rotation\"]\n\n        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]\n\n        self.denom = self.denom[valid_points_mask]\n        self.max_radii2D = self.max_radii2D[valid_points_mask]\n\n    def cat_tensors_to_optimizer(self, tensors_dict):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            assert len(group[\"params\"]) == 1\n            extension_tensor = tensors_dict[group[\"name\"]]\n            stored_state = self.optimizer.state.get(group[\"params\"][0], None)\n            if stored_state is not None:\n                stored_state[\"exp_avg\"] = torch.cat(\n                    (stored_state[\"exp_avg\"], torch.zeros_like(extension_tensor)), dim=0\n                )\n                stored_state[\"exp_avg_sq\"] = torch.cat(\n                    (stored_state[\"exp_avg_sq\"], torch.zeros_like(extension_tensor)),\n                    dim=0,\n                )\n\n                del self.optimizer.state[group[\"params\"][0]]\n                group[\"params\"][0] = nn.Parameter(\n                    torch.cat(\n                        (group[\"params\"][0], extension_tensor), dim=0\n                    ).requires_grad_(True)\n                )\n                self.optimizer.state[group[\"params\"][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n            else:\n                group[\"params\"][0] = nn.Parameter(\n                    torch.cat(\n                        (group[\"params\"][0], extension_tensor), dim=0\n                    ).requires_grad_(True)\n                )\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n\n        return optimizable_tensors\n\n    def densification_postfix(\n        self,\n        new_xyz,\n        new_features_dc,\n        new_features_rest,\n        new_opacities,\n        new_scaling,\n        new_rotation,\n    ):\n        d = {\n            \"xyz\": new_xyz,\n            \"f_dc\": new_features_dc,\n            \"f_rest\": new_features_rest,\n            \"opacity\": new_opacities,\n            \"scaling\": new_scaling,\n            \"rotation\": new_rotation,\n        }\n\n        optimizable_tensors = self.cat_tensors_to_optimizer(d)\n        self._xyz = optimizable_tensors[\"xyz\"]\n        self._features_dc = optimizable_tensors[\"f_dc\"]\n        self._features_rest = optimizable_tensors[\"f_rest\"]\n        self._opacity = optimizable_tensors[\"opacity\"]\n        self._scaling = optimizable_tensors[\"scaling\"]\n        self._rotation = optimizable_tensors[\"rotation\"]\n\n        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n\n    def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):\n        n_init_points = self.get_xyz.shape[0]\n        # Extract points that satisfy the gradient condition\n        padded_grad = torch.zeros((n_init_points), device=\"cuda\")\n        padded_grad[: grads.shape[0]] = grads.squeeze()\n        selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)\n        selected_pts_mask = torch.logical_and(\n            selected_pts_mask,\n            torch.max(self.get_scaling, dim=1).values\n            > self.percent_dense * scene_extent,\n        )\n\n        stds = self.get_scaling[selected_pts_mask].repeat(N, 1)\n        means = torch.zeros((stds.size(0), 3), device=\"cuda\")\n        samples = torch.normal(mean=means, std=stds)\n        rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1)\n        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[\n            selected_pts_mask\n        ].repeat(N, 1)\n        new_scaling = self.scaling_inverse_activation(\n            self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N)\n        )\n        new_rotation = self._rotation[selected_pts_mask].repeat(N, 1)\n        new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1)\n        new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1)\n        new_opacity = self._opacity[selected_pts_mask].repeat(N, 1)\n\n        self.densification_postfix(\n            new_xyz,\n            new_features_dc,\n            new_features_rest,\n            new_opacity,\n            new_scaling,\n            new_rotation,\n        )\n\n        prune_filter = torch.cat(\n            (\n                selected_pts_mask,\n                torch.zeros(N * selected_pts_mask.sum(), device=\"cuda\", dtype=bool),\n            )\n        )\n        self.prune_points(prune_filter)\n\n    def densify_and_clone(self, grads, grad_threshold, scene_extent):\n        # Extract points that satisfy the gradient condition\n        selected_pts_mask = torch.where(\n            torch.norm(grads, dim=-1) >= grad_threshold, True, False\n        )\n        selected_pts_mask = torch.logical_and(\n            selected_pts_mask,\n            torch.max(self.get_scaling, dim=1).values\n            <= self.percent_dense * scene_extent,\n        )\n\n        new_xyz = self._xyz[selected_pts_mask]\n        new_features_dc = self._features_dc[selected_pts_mask]\n        new_features_rest = self._features_rest[selected_pts_mask]\n        new_opacities = self._opacity[selected_pts_mask]\n        new_scaling = self._scaling[selected_pts_mask]\n        new_rotation = self._rotation[selected_pts_mask]\n\n        self.densification_postfix(\n            new_xyz,\n            new_features_dc,\n            new_features_rest,\n            new_opacities,\n            new_scaling,\n            new_rotation,\n        )\n\n    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):\n        grads = self.xyz_gradient_accum / self.denom\n        grads[grads.isnan()] = 0.0\n\n        self.densify_and_clone(grads, max_grad, extent)\n        self.densify_and_split(grads, max_grad, extent)\n\n        prune_mask = (self.get_opacity < min_opacity).squeeze()\n        if max_screen_size:\n            big_points_vs = self.max_radii2D > max_screen_size\n            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent\n            prune_mask = torch.logical_or(\n                torch.logical_or(prune_mask, big_points_vs), big_points_ws\n            )\n        self.prune_points(prune_mask)\n\n        torch.cuda.empty_cache()\n\n    def add_densification_stats(self, viewspace_point_tensor, update_filter):\n        self.xyz_gradient_accum[update_filter] += torch.norm(\n            viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True\n        )\n        self.denom[update_filter] += 1\n\n    def apply_discrete_offset_filds(self, origin_points, offsets):\n        \"\"\"\n        Args:\n            origin_points: (N_r, 3)\n            offsets: (N_r, 3)\n        \"\"\"\n\n        # since origin points and self._xyz might not be matched, we need to first\n        #   compute the distance between origin points and self._xyz\n        #   then find the nearest point in self._xyz for each origin point\n\n        # compute the distance between origin points and self._xyz\n        # [N_r, num_points]\n        dist = torch.cdist(origin_points, self._xyz)\n        # find the nearest point in self._xyz for each origin point\n        _, idx = torch.min(dist, dim=0)\n\n        # apply offsets\n\n        new_xyz = self._xyz + offsets[idx]\n\n        if self.optimizer is None:\n            optim_state = None\n        else:\n            optim_state = self.optimizer.state_dict()\n\n        new_model_args = (\n            self.active_sh_degree,\n            new_xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            self._rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            optim_state,\n            self.spatial_lr_scale,\n        )\n\n        ret_gaussian = GaussianModel(self.max_sh_degree)\n        ret_gaussian.restore(new_model_args, None)\n\n        return ret_gaussian\n\n    def apply_discrete_offset_filds_with_R(self, origin_points, offsets, topk=6):\n        \"\"\"\n        Args:\n            origin_points: (N_r, 3)\n            offsets: (N_r, 3)\n        \"\"\"\n\n        # since origin points and self._xyz might not be matched, we need to first\n        #   compute the distance between origin points and self._xyz\n        #   then find the nearest point in self._xyz for each origin point\n\n        if self.matched_inds is None:\n            # compute the distance between origin points and self._xyz\n            # [N_r, num_points]\n            dist = torch.cdist(origin_points, self._xyz) * -1.0\n            # find the nearest point in self._xyz for each origin point\n\n            # idxs: [topk, num_points]\n            print(dist.shape, topk, dist[0])\n            _, idxs = torch.topk(dist, topk, dim=0)\n\n            self.matched_inds = idxs\n        else:\n            idxs = self.matched_inds\n\n        # [topk, num_points, 3] => [num_points, topk, 3]\n        matched_topk_offsets = offsets[idxs].transpose(0, 1)\n        source_points = origin_points[idxs].transpose(0, 1)\n\n        # [num_points, 3, 3/1]\n        R, t = get_rigid_transform(source_points, source_points + matched_topk_offsets)\n\n        # new_xyz = R @ self._xyz.unsqueeze(dim=-1) + t\n        # new_xyz = new_xyz.squeeze(dim=-1)\n\n        avg_offsets = matched_topk_offsets.mean(dim=1)\n        new_xyz = self._xyz + avg_offsets  # offset directly\n\n        new_rotation = quaternion_multiply(matrix_to_quaternion(R), self._rotation)\n\n        if self.optimizer is None:\n            optim_state = None\n        else:\n            optim_state = self.optimizer.state_dict()\n\n        new_model_args = (\n            self.active_sh_degree,\n            new_xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            new_rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            optim_state,\n            self.spatial_lr_scale,\n        )\n\n        ret_gaussian = GaussianModel(self.max_sh_degree)\n        ret_gaussian.restore(new_model_args, None)\n\n        return ret_gaussian\n\n    def apply_se3_fields(\n        self,\n        se3_model,\n        timestamp: float,\n        freeze_mask=None,\n    ):\n        \"\"\"\n        Args:\n            se3_model: SE3Model\n            timestamp: float.  in range [0, 1]\n            freeze_mask: [N]\n        \"\"\"\n\n        inp_time = torch.ones_like(self._xyz[:, 0:1]) * timestamp\n        inp = torch.cat([self._xyz, inp_time], dim=-1)\n\n        if freeze_mask is not None:\n            moving_mask = torch.logical_not(freeze_mask)\n            inp = inp[moving_mask, ...]\n        # [bs, 3, 3]. [bs, 3]\n        R, t = se3_model(inp)\n\n        # print(\"abs t mean\", torch.abs(t).mean(dim=0))\n        # new_xyz = (R @ self._xyz.unsqueeze(dim=-1)).squeeze(dim=-1) + t\n\n        if freeze_mask is None:\n            new_xyz = self._xyz + t\n            new_rotation = quaternion_multiply(matrix_to_quaternion(R), self._rotation)\n        else:\n            new_xyz = self._xyz.clone()\n            new_xyz[moving_mask, ...] += t\n            new_rotation = self._rotation.clone()\n            new_rotation[moving_mask, ...] = quaternion_multiply(\n                matrix_to_quaternion(R), self._rotation[moving_mask, ...]\n            )\n\n        if self.optimizer is None:\n            optim_state = None\n        else:\n            optim_state = self.optimizer.state_dict()\n\n        new_model_args = (\n            self.active_sh_degree,\n            new_xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            new_rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            optim_state,\n            self.spatial_lr_scale,\n        )\n\n        ret_gaussian = GaussianModel(self.max_sh_degree)\n        ret_gaussian.restore(new_model_args, None)\n\n        return ret_gaussian\n\n    def apply_offset_fields(self, offset_field, timestamp: float):\n        \"\"\"\n        Args:\n            se3_model: SE3Model\n            timestamp: float.  in range [0, 1]\n        \"\"\"\n\n        inp_time = torch.ones_like(self._xyz[:, 0:1]) * timestamp\n        inp = torch.cat([self._xyz, inp_time], dim=-1)\n        # [bs, 3, 3]. [bs, 3]\n        offsets = offset_field(inp)\n\n        # print(\"abs t mean\", torch.abs(t).mean(dim=0))\n        new_xyz = self._xyz + offsets\n\n        if self.optimizer is None:\n            optim_state = None\n        else:\n            optim_state = self.optimizer.state_dict()\n\n        new_model_args = (\n            self.active_sh_degree,\n            new_xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            self._rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            optim_state,\n            self.spatial_lr_scale,\n        )\n\n        ret_gaussian = GaussianModel(self.max_sh_degree)\n        ret_gaussian.restore(new_model_args, None)\n\n        return ret_gaussian\n\n    def apply_offset_fields_with_R(self, offset_field, timestamp: float, eps=1e-2):\n        \"\"\"\n        Args:\n            se3_model: SE3Model\n            timestamp: float.  in range [0, 1]\n        \"\"\"\n\n        # [4, 3]\n        inp_perterb = (\n            torch.tensor(\n                [\n                    [0.0, 0.0, 0.0],  # add this will coplanar?\n                    [+eps, -eps, -eps],\n                    [-eps, -eps, +eps],\n                    [-eps, +eps, -eps],\n                    [+eps, +eps, +eps],\n                ],\n            )\n            .to(self._xyz.device)\n            .float()\n        )\n        #  => [N, 4, 3]\n        source_points = self._xyz.unsqueeze(dim=1) + inp_perterb.unsqueeze(dim=0)\n        num_points = source_points.shape[0]\n\n        inpx = source_points.flatten(end_dim=1)\n        inp_time = torch.ones_like(inpx[:, 0:1]) * timestamp\n\n        inp = torch.cat([inpx, inp_time], dim=-1)\n\n        sampled_offsets = offset_field(inp).reshape((num_points, -1, 3))\n\n        R, t = get_rigid_transform(source_points, source_points + sampled_offsets)\n\n        # new_xyz = R @ self._xyz.unsqueeze(dim=-1) + t\n        # new_xyz = new_xyz.squeeze(dim=-1)\n\n        avg_offsets = sampled_offsets.mean(dim=1)\n        new_xyz = self._xyz + avg_offsets  # offset directly\n\n        new_rotation = quaternion_multiply(matrix_to_quaternion(R), self._rotation)\n\n        if self.optimizer is None:\n            optim_state = None\n        else:\n            optim_state = self.optimizer.state_dict()\n\n        new_model_args = (\n            self.active_sh_degree,\n            new_xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            new_rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            optim_state,\n            self.spatial_lr_scale,\n        )\n\n        ret_gaussian = GaussianModel(self.max_sh_degree)\n        ret_gaussian.restore(new_model_args, None)\n\n        return ret_gaussian\n\n    def init_from_mesh(\n        self,\n        mesh_path: str,\n        num_gaussians: int = 10000,\n    ):\n        import point_cloud_utils as pcu\n\n        mesh = pcu.load_triangle_mesh(mesh_path)\n\n        v, f = mesh.v, mesh.f\n\n        v_n = pcu.estimate_mesh_normals(v, f)\n        vert_colors = mesh.vertex_data.colors\n\n        fid, bc = pcu.sample_mesh_random(v, f, num_gaussians)\n\n        # Interpolate the vertex positions and normals using the returned barycentric coordinates\n        # to get sample positions and normals\n        rand_positions = pcu.interpolate_barycentric_coords(f, fid, bc, v)\n        rand_normals = pcu.interpolate_barycentric_coords(f, fid, bc, v_n)\n        rand_colors = pcu.interpolate_barycentric_coords(f, fid, bc, vert_colors)[:, :3]\n\n        # copy original pointcloud init functions\n\n        fused_point_cloud = torch.tensor(np.asarray(rand_positions)).float().cuda()\n        fused_color = RGB2SH(torch.tensor(np.asarray(rand_colors)).float().cuda())\n        features = (\n            torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2))\n            .float()\n            .cuda()\n        )\n        features[:, :3, 0] = fused_color\n        # typo here?\n        features[:, 3:, 1:] = 0.0\n\n        print(\"Number of points at initialisation : \", fused_point_cloud.shape[0])\n\n        dist2 = torch.clamp_min(\n            distCUDA2(torch.from_numpy(np.asarray(rand_positions)).float().cuda()),\n            0.0000001,\n        )\n        scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)\n        rots = torch.zeros((fused_point_cloud.shape[0], 4), device=\"cuda\")\n        rots[:, 0] = 1\n\n        opacities = inverse_sigmoid(\n            0.1\n            * torch.ones(\n                (fused_point_cloud.shape[0], 1), dtype=torch.float, device=\"cuda\"\n            )\n        )\n\n        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))\n        self._features_dc = nn.Parameter(\n            features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)\n        )\n        self._features_rest = nn.Parameter(\n            features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)\n        )\n        self._scaling = nn.Parameter(scales.requires_grad_(True))\n        self._rotation = nn.Parameter(rots.requires_grad_(True))\n        self._opacity = nn.Parameter(opacities.requires_grad_(True))\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n\n    def detach_grad(\n        self,\n    ):\n        self._xyz.requires_grad = False\n        self._features_dc.requires_grad = False\n        self._features_rest.requires_grad = False\n        self._scaling.requires_grad = False\n        self._rotation.requires_grad = False\n        self._opacity.requires_grad = False\n\n    def apply_mask(self, mask):\n        new_xyz = self._xyz[mask]\n        if self.xyz_gradient_accum.shape == self._xyz.shape:\n            new_xyz_gradient_accum = self.xyz_gradient_accum[mask]\n            new_denom = self.denom[mask]\n        else:\n            new_xyz_gradient_accum = self.xyz_gradient_accum\n            new_denom = self.denom\n        new_model_args = (\n            self.active_sh_degree,\n            new_xyz,\n            self._features_dc[mask],\n            self._features_rest[mask],\n            self._scaling[mask],\n            self._rotation[mask],\n            self._opacity[mask],\n            self.max_radii2D,\n            new_xyz_gradient_accum,\n            new_denom,\n            None,\n            self.spatial_lr_scale,\n        )\n\n        ret_gaussian = GaussianModel(self.max_sh_degree)\n        ret_gaussian.restore(new_model_args, None)\n\n        return ret_gaussian\n\n    @torch.no_grad()\n    def extract_fields(self, resolution=128, num_blocks=16, relax_ratio=1.5):\n        # resolution: resolution of field\n\n        block_size = 2 / num_blocks\n\n        assert resolution % block_size == 0\n        split_size = resolution // num_blocks\n\n        opacities = self.get_opacity\n\n        # pre-filter low opacity gaussians to save computation\n        mask = (opacities > 0.005).squeeze(1)\n\n        opacities = opacities[mask]\n        xyzs = self.get_xyz[mask]\n        stds = self.get_scaling[mask]\n\n        # normalize to ~ [-1, 1]\n        mn, mx = xyzs.amin(0), xyzs.amax(0)\n        self.center = (mn + mx) / 2\n        self.scale = 1.0 / (mx - mn).amax().item()\n\n        print(\"gaussian center, scale\", self.center, self.scale)\n        xyzs = (xyzs - self.center) * self.scale\n        stds = stds * self.scale\n\n        covs = self.covariance_activation(stds, 1, self._rotation[mask])\n\n        # tile\n        device = opacities.device\n        occ = torch.zeros([resolution] * 3, dtype=torch.float32, device=device)\n\n        X = torch.linspace(-1, 1, resolution).split(split_size)\n        Y = torch.linspace(-1, 1, resolution).split(split_size)\n        Z = torch.linspace(-1, 1, resolution).split(split_size)\n\n        # loop blocks (assume max size of gaussian is small than relax_ratio * block_size !!!)\n        for xi, xs in enumerate(X):\n            for yi, ys in enumerate(Y):\n                for zi, zs in enumerate(Z):\n                    xx, yy, zz = torch.meshgrid(xs, ys, zs)\n                    # sample points [M, 3]\n                    pts = torch.cat(\n                        [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],\n                        dim=-1,\n                    ).to(device)\n                    # in-tile gaussians mask\n                    vmin, vmax = pts.amin(0), pts.amax(0)\n                    vmin -= block_size * relax_ratio\n                    vmax += block_size * relax_ratio\n                    mask = (xyzs < vmax).all(-1) & (xyzs > vmin).all(-1)\n                    # if hit no gaussian, continue to next block\n                    if not mask.any():\n                        continue\n                    mask_xyzs = xyzs[mask]  # [L, 3]\n                    mask_covs = covs[mask]  # [L, 6]\n                    mask_opas = opacities[mask].view(1, -1)  # [L, 1] --> [1, L]\n\n                    # query per point-gaussian pair.\n                    g_pts = pts.unsqueeze(1).repeat(\n                        1, mask_covs.shape[0], 1\n                    ) - mask_xyzs.unsqueeze(\n                        0\n                    )  # [M, L, 3]\n                    g_covs = mask_covs.unsqueeze(0).repeat(\n                        pts.shape[0], 1, 1\n                    )  # [M, L, 6]\n\n                    # batch on gaussian to avoid OOM\n                    batch_g = 1024\n                    val = 0\n                    for start in range(0, g_covs.shape[1], batch_g):\n                        end = min(start + batch_g, g_covs.shape[1])\n                        w = gaussian_3d_coeff(\n                            g_pts[:, start:end].reshape(-1, 3),\n                            g_covs[:, start:end].reshape(-1, 6),\n                        ).reshape(\n                            pts.shape[0], -1\n                        )  # [M, l]\n                        val += (mask_opas[:, start:end] * w).sum(-1)\n\n                    # kiui.lo(val, mask_opas, w)\n\n                    occ[\n                        xi * split_size : xi * split_size + len(xs),\n                        yi * split_size : yi * split_size + len(ys),\n                        zi * split_size : zi * split_size + len(zs),\n                    ] = val.reshape(len(xs), len(ys), len(zs))\n\n        return occ\n\n    def extract_mesh(self, path, density_thresh=1, resolution=128, decimate_target=1e5):\n        os.makedirs(os.path.dirname(path), exist_ok=True)\n\n        from motionrep.gaussian_3d.scene.mesh import Mesh\n        from motionrep.gaussian_3d.scene.mesh_utils import decimate_mesh, clean_mesh\n\n        occ = self.extract_fields(resolution).detach().cpu().numpy()\n\n        print(occ.shape, occ.min(), occ.max(), occ.mean(), \"occ stats\")\n        print(np.percentile(occ, [0, 1, 5, 10, 50, 90, 95, 99, 100]), \"occ percentiles\")\n        import mcubes\n\n        vertices, triangles = mcubes.marching_cubes(occ, density_thresh)\n        vertices = vertices / (resolution - 1.0) * 2 - 1\n\n        # transform back to the original space\n        vertices = vertices / self.scale + self.center.detach().cpu().numpy()\n\n        vertices, triangles = clean_mesh(\n            vertices, triangles, remesh=True, remesh_size=0.015\n        )\n        if decimate_target > 0 and triangles.shape[0] > decimate_target:\n            vertices, triangles = decimate_mesh(vertices, triangles, decimate_target)\n\n        v = torch.from_numpy(vertices.astype(np.float32)).contiguous().cuda()\n        f = torch.from_numpy(triangles.astype(np.int32)).contiguous().cuda()\n\n        print(\n            f\"[INFO] marching cubes result: {v.shape} ({v.min().item()}-{v.max().item()}), {f.shape}\"\n        )\n\n        mesh = Mesh(v=v, f=f, device=\"cuda\")\n\n        return mesh\n\n\ndef gaussian_3d_coeff(xyzs, covs):\n    # xyzs: [N, 3]\n    # covs: [N, 6]\n    x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]\n    a, b, c, d, e, f = (\n        covs[:, 0],\n        covs[:, 1],\n        covs[:, 2],\n        covs[:, 3],\n        covs[:, 4],\n        covs[:, 5],\n    )\n\n    # eps must be small enough !!!\n    inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)\n    inv_a = (d * f - e**2) * inv_det\n    inv_b = (e * c - b * f) * inv_det\n    inv_c = (e * b - c * d) * inv_det\n    inv_d = (a * f - c**2) * inv_det\n    inv_e = (b * c - e * a) * inv_det\n    inv_f = (a * d - b**2) * inv_det\n\n    power = (\n        -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f)\n        - x * y * inv_b\n        - x * z * inv_c\n        - y * z * inv_e\n    )\n\n    power[power > 0] = -1e10  # abnormal values... make weights 0\n\n    return torch.exp(power)\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/scene/mesh.py",
    "content": "import os\nimport cv2\nimport torch\nimport trimesh\nimport numpy as np\n\n\ndef dot(x, y):\n    return torch.sum(x * y, -1, keepdim=True)\n\n\ndef length(x, eps=1e-20):\n    return torch.sqrt(torch.clamp(dot(x, x), min=eps))\n\n\ndef safe_normalize(x, eps=1e-20):\n    return x / length(x, eps)\n\n\nclass Mesh:\n    def __init__(\n        self,\n        v=None,\n        f=None,\n        vn=None,\n        fn=None,\n        vt=None,\n        ft=None,\n        albedo=None,\n        vc=None,  # vertex color\n        device=None,\n    ):\n        self.device = device\n        self.v = v\n        self.vn = vn\n        self.vt = vt\n        self.f = f\n        self.fn = fn\n        self.ft = ft\n        # only support a single albedo\n        self.albedo = albedo\n        # support vertex color is no albedo\n        self.vc = vc\n\n        self.ori_center = 0\n        self.ori_scale = 1\n\n    @classmethod\n    def load(\n        cls,\n        path=None,\n        resize=True,\n        renormal=True,\n        retex=False,\n        front_dir=\"+z\",\n        **kwargs,\n    ):\n        # assume init with kwargs\n        if path is None:\n            mesh = cls(**kwargs)\n        # obj supports face uv\n        elif path.endswith(\".obj\"):\n            mesh = cls.load_obj(path, **kwargs)\n        # trimesh only supports vertex uv, but can load more formats\n        else:\n            mesh = cls.load_trimesh(path, **kwargs)\n\n        print(f\"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}\")\n        # auto-normalize\n        if resize:\n            mesh.auto_size()\n        # auto-fix normal\n        if renormal or mesh.vn is None:\n            mesh.auto_normal()\n            print(f\"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}\")\n        # auto-fix texcoords\n        if retex or (mesh.albedo is not None and mesh.vt is None):\n            mesh.auto_uv(cache_path=path)\n            print(f\"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}\")\n\n        # rotate front dir to +z\n        if front_dir != \"+z\":\n            # axis switch\n            if \"-z\" in front_dir:\n                T = torch.tensor(\n                    [[1, 0, 0], [0, 1, 0], [0, 0, -1]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            elif \"+x\" in front_dir:\n                T = torch.tensor(\n                    [[0, 0, 1], [0, 1, 0], [1, 0, 0]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            elif \"-x\" in front_dir:\n                T = torch.tensor(\n                    [[0, 0, -1], [0, 1, 0], [1, 0, 0]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            elif \"+y\" in front_dir:\n                T = torch.tensor(\n                    [[1, 0, 0], [0, 0, 1], [0, 1, 0]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            elif \"-y\" in front_dir:\n                T = torch.tensor(\n                    [[1, 0, 0], [0, 0, -1], [0, 1, 0]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            else:\n                T = torch.tensor(\n                    [[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            # rotation (how many 90 degrees)\n            if \"1\" in front_dir:\n                T @= torch.tensor(\n                    [[0, -1, 0], [1, 0, 0], [0, 0, 1]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            elif \"2\" in front_dir:\n                T @= torch.tensor(\n                    [[1, 0, 0], [0, -1, 0], [0, 0, 1]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            elif \"3\" in front_dir:\n                T @= torch.tensor(\n                    [[0, 1, 0], [-1, 0, 0], [0, 0, 1]],\n                    device=mesh.device,\n                    dtype=torch.float32,\n                )\n            mesh.v @= T\n            mesh.vn @= T\n\n        return mesh\n\n    # load from obj file\n    @classmethod\n    def load_obj(cls, path, albedo_path=None, device=None):\n        assert os.path.splitext(path)[-1] == \".obj\"\n\n        mesh = cls()\n\n        # device\n        if device is None:\n            device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n        mesh.device = device\n\n        # load obj\n        with open(path, \"r\") as f:\n            lines = f.readlines()\n\n        def parse_f_v(fv):\n            # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)\n            # supported forms:\n            # f v1 v2 v3\n            # f v1/vt1 v2/vt2 v3/vt3\n            # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3\n            # f v1//vn1 v2//vn2 v3//vn3\n            xs = [int(x) - 1 if x != \"\" else -1 for x in fv.split(\"/\")]\n            xs.extend([-1] * (3 - len(xs)))\n            return xs[0], xs[1], xs[2]\n\n        # NOTE: we ignore usemtl, and assume the mesh ONLY uses one material (first in mtl)\n        vertices, texcoords, normals = [], [], []\n        faces, tfaces, nfaces = [], [], []\n        mtl_path = None\n\n        for line in lines:\n            split_line = line.split()\n            # empty line\n            if len(split_line) == 0:\n                continue\n            prefix = split_line[0].lower()\n            # mtllib\n            if prefix == \"mtllib\":\n                mtl_path = split_line[1]\n            # usemtl\n            elif prefix == \"usemtl\":\n                pass  # ignored\n            # v/vn/vt\n            elif prefix == \"v\":\n                vertices.append([float(v) for v in split_line[1:]])\n            elif prefix == \"vn\":\n                normals.append([float(v) for v in split_line[1:]])\n            elif prefix == \"vt\":\n                val = [float(v) for v in split_line[1:]]\n                texcoords.append([val[0], 1.0 - val[1]])\n            elif prefix == \"f\":\n                vs = split_line[1:]\n                nv = len(vs)\n                v0, t0, n0 = parse_f_v(vs[0])\n                for i in range(nv - 2):  # triangulate (assume vertices are ordered)\n                    v1, t1, n1 = parse_f_v(vs[i + 1])\n                    v2, t2, n2 = parse_f_v(vs[i + 2])\n                    faces.append([v0, v1, v2])\n                    tfaces.append([t0, t1, t2])\n                    nfaces.append([n0, n1, n2])\n\n        mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)\n        mesh.vt = (\n            torch.tensor(texcoords, dtype=torch.float32, device=device)\n            if len(texcoords) > 0\n            else None\n        )\n        mesh.vn = (\n            torch.tensor(normals, dtype=torch.float32, device=device)\n            if len(normals) > 0\n            else None\n        )\n\n        mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)\n        mesh.ft = (\n            torch.tensor(tfaces, dtype=torch.int32, device=device)\n            if len(texcoords) > 0\n            else None\n        )\n        mesh.fn = (\n            torch.tensor(nfaces, dtype=torch.int32, device=device)\n            if len(normals) > 0\n            else None\n        )\n\n        # see if there is vertex color\n        use_vertex_color = False\n        if mesh.v.shape[1] == 6:\n            use_vertex_color = True\n            mesh.vc = mesh.v[:, 3:]\n            mesh.v = mesh.v[:, :3]\n            print(f\"[load_obj] use vertex color: {mesh.vc.shape}\")\n\n        # try to load texture image\n        if not use_vertex_color:\n            # try to retrieve mtl file\n            mtl_path_candidates = []\n            if mtl_path is not None:\n                mtl_path_candidates.append(mtl_path)\n                mtl_path_candidates.append(\n                    os.path.join(os.path.dirname(path), mtl_path)\n                )\n            mtl_path_candidates.append(path.replace(\".obj\", \".mtl\"))\n\n            mtl_path = None\n            for candidate in mtl_path_candidates:\n                if os.path.exists(candidate):\n                    mtl_path = candidate\n                    break\n\n            # if albedo_path is not provided, try retrieve it from mtl\n            if mtl_path is not None and albedo_path is None:\n                with open(mtl_path, \"r\") as f:\n                    lines = f.readlines()\n                for line in lines:\n                    split_line = line.split()\n                    # empty line\n                    if len(split_line) == 0:\n                        continue\n                    prefix = split_line[0]\n                    # NOTE: simply use the first map_Kd as albedo!\n                    if \"map_Kd\" in prefix:\n                        albedo_path = os.path.join(os.path.dirname(path), split_line[1])\n                        print(f\"[load_obj] use texture from: {albedo_path}\")\n                        break\n\n            # still not found albedo_path, or the path doesn't exist\n            if albedo_path is None or not os.path.exists(albedo_path):\n                # init an empty texture\n                print(f\"[load_obj] init empty albedo!\")\n                # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)\n                albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array(\n                    [0.5, 0.5, 0.5]\n                )  # default color\n            else:\n                albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)\n                albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)\n                albedo = albedo.astype(np.float32) / 255\n                print(f\"[load_obj] load texture: {albedo.shape}\")\n\n                # import matplotlib.pyplot as plt\n                # plt.imshow(albedo)\n                # plt.show()\n\n            mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)\n\n        return mesh\n\n    @classmethod\n    def load_trimesh(cls, path, device=None):\n        mesh = cls()\n\n        # device\n        if device is None:\n            device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n        mesh.device = device\n\n        # use trimesh to load ply/glb, assume only has one single RootMesh...\n        _data = trimesh.load(path)\n        if isinstance(_data, trimesh.Scene):\n            if len(_data.geometry) == 1:\n                _mesh = list(_data.geometry.values())[0]\n            else:\n                # manual concat, will lose texture\n                _concat = []\n                for g in _data.geometry.values():\n                    if isinstance(g, trimesh.Trimesh):\n                        _concat.append(g)\n                _mesh = trimesh.util.concatenate(_concat)\n        else:\n            _mesh = _data\n\n        if _mesh.visual.kind == \"vertex\":\n            vertex_colors = _mesh.visual.vertex_colors\n            vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255\n            mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)\n            print(f\"[load_trimesh] use vertex color: {mesh.vc.shape}\")\n        elif _mesh.visual.kind == \"texture\":\n            _material = _mesh.visual.material\n            if isinstance(_material, trimesh.visual.material.PBRMaterial):\n                texture = np.array(_material.baseColorTexture).astype(np.float32) / 255\n            elif isinstance(_material, trimesh.visual.material.SimpleMaterial):\n                texture = (\n                    np.array(_material.to_pbr().baseColorTexture).astype(np.float32)\n                    / 255\n                )\n            else:\n                raise NotImplementedError(\n                    f\"material type {type(_material)} not supported!\"\n                )\n            mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)\n            print(f\"[load_trimesh] load texture: {texture.shape}\")\n        else:\n            texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array(\n                [0.5, 0.5, 0.5]\n            )\n            mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)\n            print(f\"[load_trimesh] failed to load texture.\")\n\n        vertices = _mesh.vertices\n\n        try:\n            texcoords = _mesh.visual.uv\n            texcoords[:, 1] = 1 - texcoords[:, 1]\n        except Exception as e:\n            texcoords = None\n\n        try:\n            normals = _mesh.vertex_normals\n        except Exception as e:\n            normals = None\n\n        # trimesh only support vertex uv...\n        faces = tfaces = nfaces = _mesh.faces\n\n        mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)\n        mesh.vt = (\n            torch.tensor(texcoords, dtype=torch.float32, device=device)\n            if texcoords is not None\n            else None\n        )\n        mesh.vn = (\n            torch.tensor(normals, dtype=torch.float32, device=device)\n            if normals is not None\n            else None\n        )\n\n        mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)\n        mesh.ft = (\n            torch.tensor(tfaces, dtype=torch.int32, device=device)\n            if texcoords is not None\n            else None\n        )\n        mesh.fn = (\n            torch.tensor(nfaces, dtype=torch.int32, device=device)\n            if normals is not None\n            else None\n        )\n\n        return mesh\n\n    # aabb\n    def aabb(self):\n        return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values\n\n    # unit size\n    @torch.no_grad()\n    def auto_size(self):\n        vmin, vmax = self.aabb()\n        self.ori_center = (vmax + vmin) / 2\n        self.ori_scale = 1.2 / torch.max(vmax - vmin).item()\n        self.v = (self.v - self.ori_center) * self.ori_scale\n\n    def auto_normal(self):\n        i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()\n        v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]\n\n        face_normals = torch.cross(v1 - v0, v2 - v0)\n\n        # Splat face normals to vertices\n        vn = torch.zeros_like(self.v)\n        vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)\n        vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)\n        vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)\n\n        # Normalize, replace zero (degenerated) normals with some default value\n        vn = torch.where(\n            dot(vn, vn) > 1e-20,\n            vn,\n            torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),\n        )\n        vn = safe_normalize(vn)\n\n        self.vn = vn\n        self.fn = self.f\n\n    def auto_uv(self, cache_path=None, vmap=True):\n        # try to load cache\n        if cache_path is not None:\n            cache_path = os.path.splitext(cache_path)[0] + \"_uv.npz\"\n        if cache_path is not None and os.path.exists(cache_path):\n            data = np.load(cache_path)\n            vt_np, ft_np, vmapping = data[\"vt\"], data[\"ft\"], data[\"vmapping\"]\n        else:\n            import xatlas\n\n            v_np = self.v.detach().cpu().numpy()\n            f_np = self.f.detach().int().cpu().numpy()\n            atlas = xatlas.Atlas()\n            atlas.add_mesh(v_np, f_np)\n            chart_options = xatlas.ChartOptions()\n            # chart_options.max_iterations = 4\n            atlas.generate(chart_options=chart_options)\n            vmapping, ft_np, vt_np = atlas[0]  # [N], [M, 3], [N, 2]\n\n            # save to cache\n            if cache_path is not None:\n                np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)\n\n        vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)\n        ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)\n        self.vt = vt\n        self.ft = ft\n\n        if vmap:\n            # remap v/f to vt/ft, so each v correspond to a unique vt. (necessary for gltf)\n            vmapping = (\n                torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)\n            )\n            self.align_v_to_vt(vmapping)\n\n    def align_v_to_vt(self, vmapping=None):\n        # remap v/f and vn/vn to vt/ft.\n        if vmapping is None:\n            ft = self.ft.view(-1).long()\n            f = self.f.view(-1).long()\n            vmapping = torch.zeros(\n                self.vt.shape[0], dtype=torch.long, device=self.device\n            )\n            vmapping[ft] = f  # scatter, randomly choose one if index is not unique\n\n        self.v = self.v[vmapping]\n        self.f = self.ft\n        # assume fn == f\n        if self.vn is not None:\n            self.vn = self.vn[vmapping]\n            self.fn = self.ft\n\n    def to(self, device):\n        self.device = device\n        for name in [\"v\", \"f\", \"vn\", \"fn\", \"vt\", \"ft\", \"albedo\"]:\n            tensor = getattr(self, name)\n            if tensor is not None:\n                setattr(self, name, tensor.to(device))\n        return self\n\n    def write(self, path):\n        if path.endswith(\".ply\"):\n            self.write_ply(path)\n        elif path.endswith(\".obj\"):\n            self.write_obj(path)\n        elif path.endswith(\".glb\") or path.endswith(\".gltf\"):\n            self.write_glb(path)\n        else:\n            raise NotImplementedError(f\"format {path} not supported!\")\n\n    # write to ply file (only geom)\n    def write_ply(self, path):\n        v_np = self.v.detach().cpu().numpy()\n        f_np = self.f.detach().cpu().numpy()\n\n        _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)\n        _mesh.export(path)\n\n    # write to gltf/glb file (geom + texture)\n    def write_glb(self, path):\n        assert (\n            self.vn is not None and self.vt is not None\n        )  # should be improved to support export without texture...\n\n        # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]\n        if self.v.shape[0] != self.vt.shape[0]:\n            self.align_v_to_vt()\n\n        # assume f == fn == ft\n\n        import pygltflib\n\n        f_np = self.f.detach().cpu().numpy().astype(np.uint32)\n        v_np = self.v.detach().cpu().numpy().astype(np.float32)\n        # vn_np = self.vn.detach().cpu().numpy().astype(np.float32)\n        vt_np = self.vt.detach().cpu().numpy().astype(np.float32)\n\n        albedo = self.albedo.detach().cpu().numpy()\n        albedo = (albedo * 255).astype(np.uint8)\n        albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)\n\n        f_np_blob = f_np.flatten().tobytes()\n        v_np_blob = v_np.tobytes()\n        # vn_np_blob = vn_np.tobytes()\n        vt_np_blob = vt_np.tobytes()\n        albedo_blob = cv2.imencode(\".png\", albedo)[1].tobytes()\n\n        gltf = pygltflib.GLTF2(\n            scene=0,\n            scenes=[pygltflib.Scene(nodes=[0])],\n            nodes=[pygltflib.Node(mesh=0)],\n            meshes=[\n                pygltflib.Mesh(\n                    primitives=[\n                        pygltflib.Primitive(\n                            # indices to accessors (0 is triangles)\n                            attributes=pygltflib.Attributes(\n                                POSITION=1,\n                                TEXCOORD_0=2,\n                            ),\n                            indices=0,\n                            material=0,\n                        )\n                    ]\n                )\n            ],\n            materials=[\n                pygltflib.Material(\n                    pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(\n                        baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),\n                        metallicFactor=0.0,\n                        roughnessFactor=1.0,\n                    ),\n                    alphaCutoff=0,\n                    doubleSided=True,\n                )\n            ],\n            textures=[\n                pygltflib.Texture(sampler=0, source=0),\n            ],\n            samplers=[\n                pygltflib.Sampler(\n                    magFilter=pygltflib.LINEAR,\n                    minFilter=pygltflib.LINEAR_MIPMAP_LINEAR,\n                    wrapS=pygltflib.REPEAT,\n                    wrapT=pygltflib.REPEAT,\n                ),\n            ],\n            images=[\n                # use embedded (buffer) image\n                pygltflib.Image(bufferView=3, mimeType=\"image/png\"),\n            ],\n            buffers=[\n                pygltflib.Buffer(\n                    byteLength=len(f_np_blob)\n                    + len(v_np_blob)\n                    + len(vt_np_blob)\n                    + len(albedo_blob)\n                )\n            ],\n            # buffer view (based on dtype)\n            bufferViews=[\n                # triangles; as flatten (element) array\n                pygltflib.BufferView(\n                    buffer=0,\n                    byteLength=len(f_np_blob),\n                    target=pygltflib.ELEMENT_ARRAY_BUFFER,  # GL_ELEMENT_ARRAY_BUFFER (34963)\n                ),\n                # positions; as vec3 array\n                pygltflib.BufferView(\n                    buffer=0,\n                    byteOffset=len(f_np_blob),\n                    byteLength=len(v_np_blob),\n                    byteStride=12,  # vec3\n                    target=pygltflib.ARRAY_BUFFER,  # GL_ARRAY_BUFFER (34962)\n                ),\n                # texcoords; as vec2 array\n                pygltflib.BufferView(\n                    buffer=0,\n                    byteOffset=len(f_np_blob) + len(v_np_blob),\n                    byteLength=len(vt_np_blob),\n                    byteStride=8,  # vec2\n                    target=pygltflib.ARRAY_BUFFER,\n                ),\n                # texture; as none target\n                pygltflib.BufferView(\n                    buffer=0,\n                    byteOffset=len(f_np_blob) + len(v_np_blob) + len(vt_np_blob),\n                    byteLength=len(albedo_blob),\n                ),\n            ],\n            accessors=[\n                # 0 = triangles\n                pygltflib.Accessor(\n                    bufferView=0,\n                    componentType=pygltflib.UNSIGNED_INT,  # GL_UNSIGNED_INT (5125)\n                    count=f_np.size,\n                    type=pygltflib.SCALAR,\n                    max=[int(f_np.max())],\n                    min=[int(f_np.min())],\n                ),\n                # 1 = positions\n                pygltflib.Accessor(\n                    bufferView=1,\n                    componentType=pygltflib.FLOAT,  # GL_FLOAT (5126)\n                    count=len(v_np),\n                    type=pygltflib.VEC3,\n                    max=v_np.max(axis=0).tolist(),\n                    min=v_np.min(axis=0).tolist(),\n                ),\n                # 2 = texcoords\n                pygltflib.Accessor(\n                    bufferView=2,\n                    componentType=pygltflib.FLOAT,\n                    count=len(vt_np),\n                    type=pygltflib.VEC2,\n                    max=vt_np.max(axis=0).tolist(),\n                    min=vt_np.min(axis=0).tolist(),\n                ),\n            ],\n        )\n\n        # set actual data\n        gltf.set_binary_blob(f_np_blob + v_np_blob + vt_np_blob + albedo_blob)\n\n        # glb = b\"\".join(gltf.save_to_bytes())\n        gltf.save(path)\n\n    # write to obj file (geom + texture)\n    def write_obj(self, path):\n        mtl_path = path.replace(\".obj\", \".mtl\")\n        albedo_path = path.replace(\".obj\", \"_albedo.png\")\n\n        v_np = self.v.detach().cpu().numpy()\n        vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None\n        vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None\n        f_np = self.f.detach().cpu().numpy()\n        ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None\n        fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None\n\n        with open(path, \"w\") as fp:\n            fp.write(f\"mtllib {os.path.basename(mtl_path)} \\n\")\n\n            for v in v_np:\n                fp.write(f\"v {v[0]} {v[1]} {v[2]} \\n\")\n\n            if vt_np is not None:\n                for v in vt_np:\n                    fp.write(f\"vt {v[0]} {1 - v[1]} \\n\")\n\n            if vn_np is not None:\n                for v in vn_np:\n                    fp.write(f\"vn {v[0]} {v[1]} {v[2]} \\n\")\n\n            fp.write(f\"usemtl defaultMat \\n\")\n            for i in range(len(f_np)):\n                fp.write(\n                    f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else \"\"}/{fn_np[i, 0] + 1 if fn_np is not None else \"\"} \\\n                             {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else \"\"}/{fn_np[i, 1] + 1 if fn_np is not None else \"\"} \\\n                             {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else \"\"}/{fn_np[i, 2] + 1 if fn_np is not None else \"\"} \\n'\n                )\n\n        with open(mtl_path, \"w\") as fp:\n            fp.write(f\"newmtl defaultMat \\n\")\n            fp.write(f\"Ka 1 1 1 \\n\")\n            fp.write(f\"Kd 1 1 1 \\n\")\n            fp.write(f\"Ks 0 0 0 \\n\")\n            fp.write(f\"Tr 1 \\n\")\n            fp.write(f\"illum 1 \\n\")\n            fp.write(f\"Ns 0 \\n\")\n            fp.write(f\"map_Kd {os.path.basename(albedo_path)} \\n\")\n\n        if not (False or self.albedo is None):\n            albedo = self.albedo.detach().cpu().numpy()\n            albedo = (albedo * 255).astype(np.uint8)\n            cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/scene/mesh_utils.py",
    "content": "import numpy as np\nimport pymeshlab as pml\n\n\ndef poisson_mesh_reconstruction(points, normals=None):\n    # points/normals: [N, 3] np.ndarray\n\n    import open3d as o3d\n\n    pcd = o3d.geometry.PointCloud()\n    pcd.points = o3d.utility.Vector3dVector(points)\n\n    # outlier removal\n    pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)\n\n    # normals\n    if normals is None:\n        pcd.estimate_normals()\n    else:\n        pcd.normals = o3d.utility.Vector3dVector(normals[ind])\n\n    # visualize\n    o3d.visualization.draw_geometries([pcd], point_show_normal=False)\n\n    mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(\n        pcd, depth=9\n    )\n    vertices_to_remove = densities < np.quantile(densities, 0.1)\n    mesh.remove_vertices_by_mask(vertices_to_remove)\n\n    # visualize\n    o3d.visualization.draw_geometries([mesh])\n\n    vertices = np.asarray(mesh.vertices)\n    triangles = np.asarray(mesh.triangles)\n\n    print(\n        f\"[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}\"\n    )\n\n    return vertices, triangles\n\n\ndef decimate_mesh(\n    verts, faces, target, backend=\"pymeshlab\", remesh=False, optimalplacement=True\n):\n    # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.\n\n    _ori_vert_shape = verts.shape\n    _ori_face_shape = faces.shape\n\n    if backend == \"pyfqmr\":\n        import pyfqmr\n\n        solver = pyfqmr.Simplify()\n        solver.setMesh(verts, faces)\n        solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)\n        verts, faces, normals = solver.getMesh()\n    else:\n        m = pml.Mesh(verts, faces)\n        ms = pml.MeshSet()\n        ms.add_mesh(m, \"mesh\")  # will copy!\n\n        # filters\n        # ms.meshing_decimation_clustering(threshold=pml.PercentageValue(1))\n        ms.meshing_decimation_quadric_edge_collapse(\n            targetfacenum=int(target), optimalplacement=optimalplacement\n        )\n\n        if remesh:\n            # ms.apply_coord_taubin_smoothing()\n            ms.meshing_isotropic_explicit_remeshing(\n                iterations=3, targetlen=pml.PercentageValue(1)\n            )\n\n        # extract mesh\n        m = ms.current_mesh()\n        verts = m.vertex_matrix()\n        faces = m.face_matrix()\n\n    print(\n        f\"[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}\"\n    )\n\n    return verts, faces\n\n\ndef clean_mesh(\n    verts,\n    faces,\n    v_pct=1,\n    min_f=64,\n    min_d=20,\n    repair=True,\n    remesh=True,\n    remesh_size=0.01,\n):\n    # verts: [N, 3]\n    # faces: [N, 3]\n\n    _ori_vert_shape = verts.shape\n    _ori_face_shape = faces.shape\n\n    m = pml.Mesh(verts, faces)\n    ms = pml.MeshSet()\n    ms.add_mesh(m, \"mesh\")  # will copy!\n\n    # filters\n    ms.meshing_remove_unreferenced_vertices()  # verts not refed by any faces\n\n    if v_pct > 0:\n        ms.meshing_merge_close_vertices(\n            threshold=pml.PercentageValue(v_pct)\n        )  # 1/10000 of bounding box diagonal\n\n    ms.meshing_remove_duplicate_faces()  # faces defined by the same verts\n    ms.meshing_remove_null_faces()  # faces with area == 0\n\n    if min_d > 0:\n        ms.meshing_remove_connected_component_by_diameter(\n            mincomponentdiag=pml.PercentageValue(min_d)\n        )\n\n    if min_f > 0:\n        ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)\n\n    if repair:\n        # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)\n        ms.meshing_repair_non_manifold_edges(method=0)\n        ms.meshing_repair_non_manifold_vertices(vertdispratio=0)\n\n    if remesh:\n        # ms.apply_coord_taubin_smoothing()\n        ms.meshing_isotropic_explicit_remeshing(\n            iterations=3, targetlen=pml.PureValue(remesh_size)\n        )\n\n    # extract mesh\n    m = ms.current_mesh()\n    verts = m.vertex_matrix()\n    faces = m.face_matrix()\n\n    print(\n        f\"[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}\"\n    )\n\n    return verts, faces\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/utils/camera_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom motionrep.gaussian_3d.scene.cameras import Camera\nimport numpy as np\nfrom motionrep.gaussian_3d.utils.general_utils import PILtoTorch\nfrom motionrep.gaussian_3d.utils.graphics_utils import fov2focal\nimport torch\n\nWARNED = False\n\n\ndef loadCam(args, id, cam_info, resolution_scale):\n    orig_w, orig_h = cam_info.image.size\n\n    if args.resolution in [1, 2, 4, 8]:\n        resolution = round(orig_w / (resolution_scale * args.resolution)), round(\n            orig_h / (resolution_scale * args.resolution)\n        )\n    else:  # should be a type that converts to float\n        if args.resolution == -1:\n            if orig_w > 1600:\n                global WARNED\n                if not WARNED:\n                    print(\n                        \"[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\\n \"\n                        \"If this is not desired, please explicitly specify '--resolution/-r' as 1\"\n                    )\n                    WARNED = True\n                global_down = orig_w / 1600\n            else:\n                global_down = 1\n        else:\n            global_down = orig_w / args.resolution\n\n        scale = float(global_down) * float(resolution_scale)\n        resolution = (int(orig_w / scale), int(orig_h / scale))\n\n    resized_image_rgb = PILtoTorch(cam_info.image, resolution)\n\n    gt_image = resized_image_rgb[:3, ...]\n    loaded_mask = None\n\n    if resized_image_rgb.shape[1] == 4:\n        loaded_mask = resized_image_rgb[3:4, ...]\n\n    return Camera(\n        colmap_id=cam_info.uid,\n        R=cam_info.R,\n        T=cam_info.T,\n        FoVx=cam_info.FovX,\n        FoVy=cam_info.FovY,\n        image=gt_image,\n        gt_alpha_mask=loaded_mask,\n        image_name=cam_info.image_name,\n        uid=id,\n        data_device=args.data_device,\n    )\n\n\ndef cameraList_from_camInfos(cam_infos, resolution_scale, args):\n    camera_list = []\n\n    for id, c in enumerate(cam_infos):\n        camera_list.append(loadCam(args, id, c, resolution_scale))\n\n    return camera_list\n\n\ndef camera_to_JSON(id, camera: Camera):\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = camera.R.transpose()\n    Rt[:3, 3] = camera.T\n    Rt[3, 3] = 1.0\n\n    W2C = np.linalg.inv(Rt)\n    pos = W2C[:3, 3]\n    rot = W2C[:3, :3]\n    serializable_array_2d = [x.tolist() for x in rot]\n    camera_entry = {\n        \"id\": id,\n        \"img_name\": camera.image_name,\n        \"width\": camera.width,\n        \"height\": camera.height,\n        \"position\": pos.tolist(),\n        \"rotation\": serializable_array_2d,\n        \"fy\": fov2focal(camera.FovY, camera.height),\n        \"fx\": fov2focal(camera.FovX, camera.width),\n    }\n    return camera_entry\n\n\ndef look_at(from_point, to_point, up_vector=(0, 1, 0)):\n    \"\"\"\n    Compute the look-at matrix for a camera.\n\n    :param from_point: The position of the camera.\n    :param to_point: The point the camera is looking at.\n    :param up_vector: The up direction of the camera.\n    :return: The 4x4 look-at matrix.\n    \"\"\"\n\n    # minus z for opengl. z for colmap\n    forward = np.array(to_point) - np.array(from_point)\n    forward = forward / (np.linalg.norm(forward) + 1e-5)\n\n    # x-axis\n    # Right direction is the cross product of the forward vector and the up vector\n    right = np.cross(up_vector, forward)\n    right = right / (np.linalg.norm(right) + 1e-5)\n\n    # y axis\n    # True up direction is the cross product of the right vector and the forward vector\n    true_up = np.cross(forward, right)\n    true_up = true_up / (np.linalg.norm(true_up) + 1e-5)\n\n    # camera to world\n    rotation = np.array(\n        [\n            [right[0], true_up[0], forward[0]],\n            [right[1], true_up[1], forward[1]],\n            [right[2], true_up[2], forward[2]],\n        ]\n    )\n\n    # Construct the translation matrix\n    translation = np.array(\n        [\n            [-from_point[0]],\n            [-from_point[1]],\n            [-from_point[2]],\n        ]\n    )\n\n    # Combine the rotation and translation to get the look-at matrix\n    T = 1.0 * rotation.transpose() @ translation\n\n    return rotation.transpose(), T\n\n\ndef create_cameras_around_sphere(\n    radius=6,\n    elevation=0,\n    fovx=35,\n    resolutions=(720, 1080),\n    num_cams=60,\n    center=(0, 0, 0),\n):\n    \"\"\"\n    Create cameras around a sphere.\n\n    :param radius: The radius of the circle on which cameras are placed.\n    :param elevation: The elevation angle in degrees.\n    :param fovx: The horizontal field of view of the cameras.\n    :param resolutions: The resolution of the cameras.\n    :param num_cams: The number of cameras.\n    :param center: The center of the sphere.\n    :return: A list of camera extrinsics (world2camera transformations).\n    \"\"\"\n    extrinsics = []\n\n    # Convert elevation to radians\n    elevation_rad = np.radians(elevation)\n\n    # Compute the y-coordinate of the cameras based on the elevation\n    z = radius * np.sin(elevation_rad)\n\n    # Compute the radius of the circle at the given elevation\n    circle_radius = radius * np.cos(elevation_rad)\n\n    for i in range(num_cams):\n        # Compute the angle for the current camera\n        angle = 2 * np.pi * i / num_cams\n\n        # Compute the x and z coordinates of the camera\n        x = circle_radius * np.cos(angle) + center[0]\n        y = circle_radius * np.sin(angle) + center[1]\n\n        # Create the look-at matrix for the camera\n        R, T = look_at((x, y, z + center[2]), center)\n        extrinsics.append([R, T.squeeze(axis=-1)])\n\n    cam_list = []\n    dummy_image = torch.tensor(\n        np.zeros((3, resolutions[0], resolutions[1]), dtype=np.uint8)\n    )\n    for i in range(num_cams):\n        R, T = extrinsics[i]\n\n        # R is stored transposed due to 'glm' in CUDA code\n        R = R.transpose()\n        cam = Camera(\n            colmap_id=i,\n            R=R,\n            T=T,\n            FoVx=fovx,\n            FoVy=fovx * resolutions[1] / resolutions[0],\n            image_name=\"\",\n            uid=i,\n            data_device=\"cuda\",\n            image=dummy_image,\n            gt_alpha_mask=None,\n        )\n\n        cam_list.append(cam)\n\n    return cam_list\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/utils/general_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport sys\nfrom datetime import datetime\nimport numpy as np\nimport random\n\ndef inverse_sigmoid(x):\n    return torch.log(x/(1-x))\n\ndef PILtoTorch(pil_image, resolution):\n    resized_image_PIL = pil_image.resize(resolution)\n    resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0\n    if len(resized_image.shape) == 3:\n        return resized_image.permute(2, 0, 1)\n    else:\n        return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)\n\ndef get_expon_lr_func(\n    lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000\n):\n    \"\"\"\n    Copied from Plenoxels\n\n    Continuous learning rate decay function. Adapted from JaxNeRF\n    The returned rate is lr_init when step=0 and lr_final when step=max_steps, and\n    is log-linearly interpolated elsewhere (equivalent to exponential decay).\n    If lr_delay_steps>0 then the learning rate will be scaled by some smooth\n    function of lr_delay_mult, such that the initial learning rate is\n    lr_init*lr_delay_mult at the beginning of optimization but will be eased back\n    to the normal learning rate when steps>lr_delay_steps.\n    :param conf: config subtree 'lr' or similar\n    :param max_steps: int, the number of steps during optimization.\n    :return HoF which takes step as input\n    \"\"\"\n\n    def helper(step):\n        if step < 0 or (lr_init == 0.0 and lr_final == 0.0):\n            # Disable this parameter\n            return 0.0\n        if lr_delay_steps > 0:\n            # A kind of reverse cosine decay.\n            delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(\n                0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)\n            )\n        else:\n            delay_rate = 1.0\n        t = np.clip(step / max_steps, 0, 1)\n        log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)\n        return delay_rate * log_lerp\n\n    return helper\n\ndef strip_lowerdiag(L):\n    uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=\"cuda\")\n\n    uncertainty[:, 0] = L[:, 0, 0]\n    uncertainty[:, 1] = L[:, 0, 1]\n    uncertainty[:, 2] = L[:, 0, 2]\n    uncertainty[:, 3] = L[:, 1, 1]\n    uncertainty[:, 4] = L[:, 1, 2]\n    uncertainty[:, 5] = L[:, 2, 2]\n    return uncertainty\n\ndef strip_symmetric(sym):\n    return strip_lowerdiag(sym)\n\ndef build_rotation(r):\n    norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])\n\n    q = r / norm[:, None]\n\n    R = torch.zeros((q.size(0), 3, 3), device='cuda')\n\n    r = q[:, 0]\n    x = q[:, 1]\n    y = q[:, 2]\n    z = q[:, 3]\n\n    R[:, 0, 0] = 1 - 2 * (y*y + z*z)\n    R[:, 0, 1] = 2 * (x*y - r*z)\n    R[:, 0, 2] = 2 * (x*z + r*y)\n    R[:, 1, 0] = 2 * (x*y + r*z)\n    R[:, 1, 1] = 1 - 2 * (x*x + z*z)\n    R[:, 1, 2] = 2 * (y*z - r*x)\n    R[:, 2, 0] = 2 * (x*z - r*y)\n    R[:, 2, 1] = 2 * (y*z + r*x)\n    R[:, 2, 2] = 1 - 2 * (x*x + y*y)\n    return R\n\ndef build_scaling_rotation(s, r):\n    L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=\"cuda\")\n    R = build_rotation(r)\n\n    L[:,0,0] = s[:,0]\n    L[:,1,1] = s[:,1]\n    L[:,2,2] = s[:,2]\n\n    L = R @ L\n    return L\n\ndef safe_state(silent):\n    old_f = sys.stdout\n    class F:\n        def __init__(self, silent):\n            self.silent = silent\n\n        def write(self, x):\n            if not self.silent:\n                if x.endswith(\"\\n\"):\n                    old_f.write(x.replace(\"\\n\", \" [{}]\\n\".format(str(datetime.now().strftime(\"%d/%m %H:%M:%S\")))))\n                else:\n                    old_f.write(x)\n\n        def flush(self):\n            old_f.flush()\n\n    sys.stdout = F(silent)\n\n    random.seed(0)\n    np.random.seed(0)\n    torch.manual_seed(0)\n    torch.cuda.set_device(torch.device(\"cuda:0\"))\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/utils/graphics_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport math\nimport numpy as np\nfrom typing import NamedTuple\n\nclass BasicPointCloud(NamedTuple):\n    points : np.array\n    colors : np.array\n    normals : np.array\n\ndef geom_transform_points(points, transf_matrix):\n    P, _ = points.shape\n    ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)\n    points_hom = torch.cat([points, ones], dim=1)\n    points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))\n\n    denom = points_out[..., 3:] + 0.0000001\n    return (points_out[..., :3] / denom).squeeze(dim=0)\n\ndef getWorld2View(R, t):\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = R.transpose()\n    Rt[:3, 3] = t\n    Rt[3, 3] = 1.0\n    return np.float32(Rt)\n\ndef getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = R.transpose()\n    Rt[:3, 3] = t\n    Rt[3, 3] = 1.0\n\n    C2W = np.linalg.inv(Rt)\n    cam_center = C2W[:3, 3]\n    cam_center = (cam_center + translate) * scale\n    C2W[:3, 3] = cam_center\n    Rt = np.linalg.inv(C2W)\n    return np.float32(Rt)\n\ndef getProjectionMatrix(znear, zfar, fovX, fovY):\n    tanHalfFovY = math.tan((fovY / 2))\n    tanHalfFovX = math.tan((fovX / 2))\n\n    top = tanHalfFovY * znear\n    bottom = -top\n    right = tanHalfFovX * znear\n    left = -right\n\n    P = torch.zeros(4, 4)\n\n    z_sign = 1.0\n\n    P[0, 0] = 2.0 * znear / (right - left)\n    P[1, 1] = 2.0 * znear / (top - bottom)\n    P[0, 2] = (right + left) / (right - left)\n    P[1, 2] = (top + bottom) / (top - bottom)\n    P[3, 2] = z_sign\n    P[2, 2] = z_sign * zfar / (zfar - znear)\n    P[2, 3] = -(zfar * znear) / (zfar - znear)\n    return P\n\ndef fov2focal(fov, pixels):\n    return pixels / (2 * math.tan(fov / 2))\n\ndef focal2fov(focal, pixels):\n    return 2*math.atan(pixels/(2*focal))"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/utils/image_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\n\ndef mse(img1, img2):\n    return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)\n\ndef psnr(img1, img2):\n    mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)\n    return 20 * torch.log10(1.0 / torch.sqrt(mse))\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/utils/loss_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nfrom math import exp\n\ndef l1_loss(network_output, gt):\n    return torch.abs((network_output - gt)).mean()\n\ndef l2_loss(network_output, gt):\n    return ((network_output - gt) ** 2).mean()\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])\n    return gauss / gauss.sum()\n\ndef create_window(window_size, channel):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())\n    return window\n\ndef ssim(img1, img2, window_size=11, size_average=True):\n    channel = img1.size(-3)\n    window = create_window(window_size, channel)\n\n    if img1.is_cuda:\n        window = window.cuda(img1.get_device())\n    window = window.type_as(img1)\n\n    return _ssim(img1, img2, window, window_size, channel, size_average)\n\ndef _ssim(img1, img2, window, window_size, channel, size_average=True):\n    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)\n    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq\n    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq\n    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2\n\n    C1 = 0.01 ** 2\n    C2 = 0.03 ** 2\n\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))\n\n    if size_average:\n        return ssim_map.mean()\n    else:\n        return ssim_map.mean(1).mean(1).mean(1)\n\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/utils/rigid_body_utils.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n\ndef get_rigid_transform(A, B):\n    \"\"\"\n    Estimate the rigid body transformation between two sets of 3D points.\n    A and B are Nx3 matrices where each row is a 3D point.\n    Returns a rotation matrix R and translation vector t.\n    Args:\n        A, B: [batch, N, 3] matrix of 3D points\n    Outputs:\n        R, t: [batch, 3, 3/1]\n        target = R @ source (source shape [3, 1]) + t\n    \"\"\"\n    assert A.shape == B.shape, \"Input matrices must have the same shape\"\n    assert A.shape[-1] == 3, \"Input matrices must have 3 columns (x, y, z coordinates)\"\n\n    # Compute centroids. [..., 1, 3]\n    centroid_A = torch.mean(A, dim=-2, keepdim=True)\n    centroid_B = torch.mean(B, dim=-2, keepdim=True)\n\n    # Center the point sets\n    A_centered = A - centroid_A\n    B_centered = B - centroid_B\n\n    # Compute the cross-covariance matrix. [..., 3, 3]\n    H = A_centered.transpose(-2, -1) @ B_centered\n\n    # Compute the Singular Value Decomposition. Along last two dimensions\n    U, S, Vt = torch.linalg.svd(H)\n\n    # Compute the rotation matrix\n    R = Vt.transpose(-2, -1) @ U.transpose(-2, -1)\n\n    # Ensure a right-handed coordinate system\n    flip_mask = (torch.det(R) < 0) * -2.0 + 1.0\n    # Vt[:, 2, :] *= flip_mask[..., None]\n\n    # [N] => [N, 3]\n    pad_flip_mask = torch.stack(\n        [torch.ones_like(flip_mask), torch.ones_like(flip_mask), flip_mask], dim=-1\n    )\n    Vt = Vt * pad_flip_mask[..., None]\n\n    # Compute the rotation matrix\n    R = Vt.transpose(-2, -1) @ U.transpose(-2, -1)\n\n    # print(R.shape, centroid_A.shape, centroid_B.shape, flip_mask.shape)\n    # Compute the translation\n    t = centroid_B - (R @ centroid_A.transpose(-2, -1)).transpose(-2, -1)\n    t = t.transpose(-2, -1)\n    return R, t\n\n\ndef _test_rigid_transform():\n    # Example usage:\n    A = torch.tensor([[1, 2, 3], [4, 5, 6], [9, 8, 10], [10, -5, 1]]) * 1.0\n\n    R_synthesized = torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) * 1.0\n    # init a random rotation matrix:\n\n    B = (R_synthesized @ A.T).T + 2.0  # Just an example offset\n\n    R, t = get_rigid_transform(A[None, ...], B[None, ...])\n    print(\"Rotation matrix R:\")\n    print(R)\n    print(\"\\nTranslation vector t:\")\n    print(t)\n\n\ndef _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Returns torch.sqrt(torch.max(0, x))\n    but with a zero subgradient where x is 0.\n    \"\"\"\n    ret = torch.zeros_like(x)\n    positive_mask = x > 0\n    ret[positive_mask] = torch.sqrt(x[positive_mask])\n    return ret\n\n\ndef matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    from pytorch3d. Based on trace_method like: https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L205\n    Convert rotations given as rotation matrices to quaternions.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n    \"\"\"\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix shape {matrix.shape}.\")\n\n    batch_dim = matrix.shape[:-2]\n    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(\n        matrix.reshape(batch_dim + (9,)), dim=-1\n    )\n\n    q_abs = _sqrt_positive_part(\n        torch.stack(\n            [\n                1.0 + m00 + m11 + m22,\n                1.0 + m00 - m11 - m22,\n                1.0 - m00 + m11 - m22,\n                1.0 - m00 - m11 + m22,\n            ],\n            dim=-1,\n        )\n    )\n\n    # we produce the desired quaternion multiplied by each of r, i, j, k\n    quat_by_rijk = torch.stack(\n        [\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),\n        ],\n        dim=-2,\n    )\n\n    # We floor here at 0.1 but the exact level is not important; if q_abs is small,\n    # the candidate won't be picked.\n    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)\n    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))\n\n    # if not for numerical problems, quat_candidates[i] should be same (up to a sign),\n    # forall i; we pick the best-conditioned one (with the largest denominator)\n\n    return quat_candidates[\n        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :\n    ].reshape(batch_dim + (4,))\n\n\ndef quternion_to_matrix(r):\n    norm = torch.sqrt(\n        r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]\n    )\n\n    q = r / norm[:, None]\n\n    R = torch.zeros((q.size(0), 3, 3), device=\"cuda\")\n\n    r = q[:, 0]\n    x = q[:, 1]\n    y = q[:, 2]\n    z = q[:, 3]\n\n    R[:, 0, 0] = 1 - 2 * (y * y + z * z)\n    R[:, 0, 1] = 2 * (x * y - r * z)\n    R[:, 0, 2] = 2 * (x * z + r * y)\n    R[:, 1, 0] = 2 * (x * y + r * z)\n    R[:, 1, 1] = 1 - 2 * (x * x + z * z)\n    R[:, 1, 2] = 2 * (y * z - r * x)\n    R[:, 2, 0] = 2 * (x * z - r * y)\n    R[:, 2, 1] = 2 * (y * z + r * x)\n    R[:, 2, 2] = 1 - 2 * (x * x + y * y)\n    return R\n\n\ndef standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    from Pytorch3d\n    Convert a unit quaternion to a standard form: one in which the real\n    part is non negative.\n\n    Args:\n        quaternions: Quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Standardized quaternions as tensor of shape (..., 4).\n    \"\"\"\n    return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)\n\n\ndef quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    From pytorch3d\n    Multiply two quaternions.\n    Usual torch rules for broadcasting apply.\n\n    Args:\n        a: Quaternions as tensor of shape (..., 4), real part first.\n        b: Quaternions as tensor of shape (..., 4), real part first.\n\n    Returns:\n        The product of a and b, a tensor of quaternions shape (..., 4).\n    \"\"\"\n    aw, ax, ay, az = torch.unbind(a, -1)\n    bw, bx, by, bz = torch.unbind(b, -1)\n    ow = aw * bw - ax * bx - ay * by - az * bz\n    ox = aw * bx + ax * bw + ay * bz - az * by\n    oy = aw * by - ax * bz + ay * bw + az * bx\n    oz = aw * bz + ax * by - ay * bx + az * bw\n    ret = torch.stack((ow, ox, oy, oz), -1)\n    ret = standardize_quaternion(ret)\n    return ret\n\n\ndef _test_matrix_to_quaternion():\n    # init a random batch of quaternion\n    r = torch.randn((10, 4)).cuda()\n\n    norm = torch.sqrt(\n        r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]\n    )\n\n    q = r / norm[:, None]\n\n    q = standardize_quaternion(q)\n\n    R = quternion_to_matrix(q)\n\n    I_rec = R @ R.transpose(-2, -1)\n    I_rec_error = torch.abs(I_rec - torch.eye(3, device=\"cuda\")[None, ...]).max()\n\n    q_recovered = matrix_to_quaternion(R)\n    norm_ = torch.linalg.norm(q_recovered, dim=-1)\n    q_recovered = q_recovered / norm_[..., None]\n    q_recovered = standardize_quaternion(q_recovered)\n\n    print(q_recovered.shape, q.shape, R.shape)\n\n    rec = (q - q_recovered).abs().max()\n\n    print(\"rotation to I error:\", I_rec_error, \"quant rec error: \", rec)\n\n\ndef _test_matrix_to_quaternion_2():\n    R = (\n        torch.tensor(\n            [[[1, 0, 0], [0, -1, 0], [0, 0, -1]], [[1, 0, 0], [0, 0, 1], [0, -1, 0]]]\n        )\n        * 1.0\n    )\n\n    q_rec = matrix_to_quaternion(R.transpose(-2, -1))\n\n    R_rec = quternion_to_matrix(q_rec)\n\n    print(R_rec)\n\n\nif __name__ == \"__main__\":\n    # _test_rigid_transform()\n    _test_matrix_to_quaternion()\n\n    _test_matrix_to_quaternion_2()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/utils/sh_utils.py",
    "content": "#  Copyright 2021 The PlenOctree Authors.\n#  Redistribution and use in source and binary forms, with or without\n#  modification, are permitted provided that the following conditions are met:\n#\n#  1. Redistributions of source code must retain the above copyright notice,\n#  this list of conditions and the following disclaimer.\n#\n#  2. Redistributions in binary form must reproduce the above copyright notice,\n#  this list of conditions and the following disclaimer in the documentation\n#  and/or other materials provided with the distribution.\n#\n#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE\n#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n#  POSSIBILITY OF SUCH DAMAGE.\n\nimport torch\n\nC0 = 0.28209479177387814\nC1 = 0.4886025119029199\nC2 = [\n    1.0925484305920792,\n    -1.0925484305920792,\n    0.31539156525252005,\n    -1.0925484305920792,\n    0.5462742152960396\n]\nC3 = [\n    -0.5900435899266435,\n    2.890611442640554,\n    -0.4570457994644658,\n    0.3731763325901154,\n    -0.4570457994644658,\n    1.445305721320277,\n    -0.5900435899266435\n]\nC4 = [\n    2.5033429417967046,\n    -1.7701307697799304,\n    0.9461746957575601,\n    -0.6690465435572892,\n    0.10578554691520431,\n    -0.6690465435572892,\n    0.47308734787878004,\n    -1.7701307697799304,\n    0.6258357354491761,\n]   \n\n\ndef eval_sh(deg, sh, dirs):\n    \"\"\"\n    Evaluate spherical harmonics at unit directions\n    using hardcoded SH polynomials.\n    Works with torch/np/jnp.\n    ... Can be 0 or more batch dimensions.\n    Args:\n        deg: int SH deg. Currently, 0-3 supported\n        sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]\n        dirs: jnp.ndarray unit directions [..., 3]\n    Returns:\n        [..., C]\n    \"\"\"\n    assert deg <= 4 and deg >= 0\n    coeff = (deg + 1) ** 2\n    assert sh.shape[-1] >= coeff\n\n    result = C0 * sh[..., 0]\n    if deg > 0:\n        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]\n        result = (result -\n                C1 * y * sh[..., 1] +\n                C1 * z * sh[..., 2] -\n                C1 * x * sh[..., 3])\n\n        if deg > 1:\n            xx, yy, zz = x * x, y * y, z * z\n            xy, yz, xz = x * y, y * z, x * z\n            result = (result +\n                    C2[0] * xy * sh[..., 4] +\n                    C2[1] * yz * sh[..., 5] +\n                    C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +\n                    C2[3] * xz * sh[..., 7] +\n                    C2[4] * (xx - yy) * sh[..., 8])\n\n            if deg > 2:\n                result = (result +\n                C3[0] * y * (3 * xx - yy) * sh[..., 9] +\n                C3[1] * xy * z * sh[..., 10] +\n                C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +\n                C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +\n                C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +\n                C3[5] * z * (xx - yy) * sh[..., 14] +\n                C3[6] * x * (xx - 3 * yy) * sh[..., 15])\n\n                if deg > 3:\n                    result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +\n                            C4[1] * yz * (3 * xx - yy) * sh[..., 17] +\n                            C4[2] * xy * (7 * zz - 1) * sh[..., 18] +\n                            C4[3] * yz * (7 * zz - 3) * sh[..., 19] +\n                            C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +\n                            C4[5] * xz * (7 * zz - 3) * sh[..., 21] +\n                            C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +\n                            C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +\n                            C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])\n    return result\n\ndef RGB2SH(rgb):\n    return (rgb - 0.5) / C0\n\ndef SH2RGB(sh):\n    return sh * C0 + 0.5"
  },
  {
    "path": "projects/uncleaned_train/motionrep/gaussian_3d/utils/system_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom errno import EEXIST\nfrom os import makedirs, path\nimport os\n\ndef mkdir_p(folder_path):\n    # Creates a directory. equivalent to using mkdir -p on the command line\n    try:\n        makedirs(folder_path)\n    except OSError as exc: # Python >2.5\n        if exc.errno == EEXIST and path.isdir(folder_path):\n            pass\n        else:\n            raise\n\ndef searchForMaxIteration(folder):\n    saved_iters = [int(fname.split(\"_\")[-1]) for fname in os.listdir(folder)]\n    return max(saved_iters)\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/losses/se3_loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/losses/smoothness_loss.py",
    "content": "import torch\nfrom typing import Tuple\n\n\ndef compute_plane_tv(t: torch.Tensor, only_w: bool = False) -> float:\n    \"\"\"Computes total variance across a plane.\n    From nerf-studio\n\n    Args:\n        t: Plane tensor\n        only_w: Whether to only compute total variance across w dimension\n\n    Returns:\n        Total variance\n    \"\"\"\n    _, h, w = t.shape\n    w_tv = torch.square(t[..., :, 1:] - t[..., :, : w - 1]).mean()\n\n    if only_w:\n        return w_tv\n\n    h_tv = torch.square(t[..., 1:, :] - t[..., : h - 1, :]).mean()\n    return h_tv + w_tv\n\n\ndef compute_plane_smoothness(t: torch.Tensor) -> float:\n    \"\"\"Computes smoothness across the temporal axis of a plane\n    From nerf-studio\n    Args:\n        t: Plane tensor\n\n    Returns:\n        Time smoothness\n    \"\"\"\n    _, h, _ = t.shape\n    # Convolve with a second derivative filter, in the time dimension which is dimension 2\n    first_difference = t[..., 1:, :] - t[..., : h - 1, :]  # [c, h-1, w]\n    second_difference = (\n        first_difference[..., 1:, :] - first_difference[..., : h - 2, :]\n    )  # [c, h-2, w]\n    # Take the L2 norm of the result\n    return torch.square(second_difference).mean()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/operators/dct.py",
    "content": "\"\"\"\nCode from https://github.com/zh217/torch-dct/blob/master/torch_dct/_dct.py\n\"\"\"\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\n\nimport torch.fft\n\n\ndef dct1_rfft_impl(x):\n    return torch.view_as_real(torch.fft.rfft(x, dim=1))\n\n\ndef dct_fft_impl(v):\n    return torch.view_as_real(torch.fft.fft(v, dim=1))\n\n\ndef idct_irfft_impl(V):\n    return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)\n\n\ndef dct(x, norm=None):\n    \"\"\"\n    Discrete Cosine Transform, Type II (a.k.a. the DCT)\n\n    For the meaning of the parameter `norm`, see:\n    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html\n\n    if norm is None:\n              N-1\n    y[k] = 2* sum x[n]*cos(pi*k*(2n+1)/(2*N)), 0 <= k < N.\n              n=0\n\n    :param x: the input signal\n    :param norm: the normalization, None or 'ortho'\n    :return: the DCT-II of the signal over the last dimension\n    \"\"\"\n    x_shape = x.shape\n    N = x_shape[-1]\n    x = x.contiguous().view(-1, N)\n\n    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)\n\n    Vc = dct_fft_impl(v)\n\n    k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)\n    W_r = torch.cos(k)\n    W_i = torch.sin(k)\n\n    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i\n\n    if norm == \"ortho\":\n        V[:, 0] /= np.sqrt(N) * 2\n        V[:, 1:] /= np.sqrt(N / 2) * 2\n\n    V = 2 * V.view(*x_shape)\n\n    return V\n\n\ndef idct(X, norm=None):\n    \"\"\"\n    The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III\n\n    Our definition of idct is that idct(dct(x)) == x\n\n    For the meaning of the parameter `norm`, see:\n    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html\n\n    :param X: the input signal\n    :param norm: the normalization, None or 'ortho'\n    :return: the inverse DCT-II of the signal over the last dimension\n    \"\"\"\n\n    x_shape = X.shape\n    N = x_shape[-1]\n\n    X_v = X.contiguous().view(-1, x_shape[-1]) / 2\n\n    if norm == \"ortho\":\n        X_v[:, 0] *= np.sqrt(N) * 2\n        X_v[:, 1:] *= np.sqrt(N / 2) * 2\n\n    k = (\n        torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :]\n        * np.pi\n        / (2 * N)\n    )\n    W_r = torch.cos(k)\n    W_i = torch.sin(k)\n\n    V_t_r = X_v\n    V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)\n\n    V_r = V_t_r * W_r - V_t_i * W_i\n    V_i = V_t_r * W_i + V_t_i * W_r\n\n    V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)\n\n    v = idct_irfft_impl(V)\n    x = v.new_zeros(v.shape)\n    x[:, ::2] += v[:, : N - (N // 2)]\n    x[:, 1::2] += v.flip([1])[:, : N // 2]\n\n    return x.view(*x_shape)\n\n\ndef dct_3d(x, norm=None):\n    \"\"\"\n    3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)\n\n    For the meaning of the parameter `norm`, see:\n    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html\n\n    :param x: the input signal\n    :param norm: the normalization, None or 'ortho'\n    :return: the DCT-II of the signal over the last 3 dimensions\n    \"\"\"\n    X1 = dct(x, norm=norm)\n    X2 = dct(X1.transpose(-1, -2), norm=norm)\n    X3 = dct(X2.transpose(-1, -3), norm=norm)\n    return X3.transpose(-1, -3).transpose(-1, -2)\n\n\ndef idct_3d(X, norm=None):\n    \"\"\"\n    The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III\n\n    Our definition of idct is that idct_3d(dct_3d(x)) == x\n\n    For the meaning of the parameter `norm`, see:\n    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html\n\n    :param X: the input signal\n    :param norm: the normalization, None or 'ortho'\n    :return: the DCT-II of the signal over the last 3 dimensions\n    \"\"\"\n    x1 = idct(X, norm=norm)\n    x2 = idct(x1.transpose(-1, -2), norm=norm)\n    x3 = idct(x2.transpose(-1, -3), norm=norm)\n    return x3.transpose(-1, -3).transpose(-1, -2)\n\n\ndef code_test_dct3d():\n    # init a tensor of shape [100, 20, 3]\n    x = torch.rand(100, 20, 3)\n\n    dct_coef = dct_3d(x, norm=\"ortho\")\n    print(\"inp signal shape: \", x.shape, \"  dct coef shape: \", dct_coef.shape)\n\n    x_recon = idct_3d(dct_coef, norm=\"ortho\")\n    print(\"inp signal shape: \", x.shape, \"  recon signal shape: \", x_recon.shape)\n\n    print(\"max error: \", torch.max(torch.abs(x - x_recon)))\n\n    dct_coef[:, 0, :] = 0\n\n    x_recon = idct_3d(dct_coef, norm=\"ortho\")\n    print(\"max error after removing first order: \", torch.max(torch.abs(x - x_recon)))\n\n\nif __name__ == \"__main__\":\n    code_test_dct3d()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/operators/np_operators.py",
    "content": "import torch\nimport numpy as np\nfrom sklearn.decomposition import PCA\nimport matplotlib.pyplot as plt\n\n\ndef feature_map_to_rgb_pca(feature_map):\n    \"\"\"\n    Args:\n        feature_map: (C, H, W) feature map.\n    Outputs:\n        rgb_image: (H, W, 3) image.\n    \"\"\"\n    # Move feature map to CPU and convert to numpy\n    if isinstance(feature_map, torch.Tensor):\n        feature_map = feature_map.detach().cpu().numpy()\n\n    H, W = feature_map.shape[1:]\n    # Flatten spatial dimensions  # [N, C]\n    flattened_map = feature_map.reshape(feature_map.shape[0], -1).T\n\n    # Apply PCA and reduce channel dimension to 3\n    pca = PCA(n_components=3)\n    pca_result = pca.fit_transform(flattened_map)\n\n    # Reshape back to (H, W, 3)\n    rgb_image = pca_result.reshape(H, W, 3)\n\n    # Normalize to [0, 1]\n    rgb_image = (rgb_image - rgb_image.min()) / (\n        rgb_image.max() - rgb_image.min() + 1e-3\n    )\n\n    return rgb_image\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/operators/rotation.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\n\n\ndef rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Converts 6D rotation representation by Zhou et al. [1] to rotation matrix\n    using Gram--Schmidt orthogonalization per Section B of [1].\n    Args:\n        d6: 6D rotation representation, of size (*, 6)\n\n    Returns:\n        batch of rotation matrices of size (*, 3, 3)\n\n    [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.\n    On the Continuity of Rotation Representations in Neural Networks.\n    IEEE Conference on Computer Vision and Pattern Recognition, 2019.\n    Retrieved from http://arxiv.org/abs/1812.07035\n    \"\"\"\n\n    a1, a2 = d6[..., :3], d6[..., 3:]\n    b1 = F.normalize(a1, dim=-1)\n    b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1\n    b2 = F.normalize(b2, dim=-1)\n    b3 = torch.cross(b1, b2, dim=-1)\n    return torch.stack((b1, b2, b3), dim=-2)\n\n\ndef matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Converts rotation matrices to 6D rotation representation by Zhou et al. [1]\n    by dropping the last row. Note that 6D representation is not unique.\n    Args:\n        matrix: batch of rotation matrices of size (*, 3, 3)\n\n    Returns:\n        6D rotation representation, of size (*, 6)\n\n    [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.\n    On the Continuity of Rotation Representations in Neural Networks.\n    IEEE Conference on Computer Vision and Pattern Recognition, 2019.\n    Retrieved from http://arxiv.org/abs/1812.07035\n    \"\"\"\n    batch_dim = matrix.size()[:-2]\n    return matrix[..., :2, :].clone().reshape(batch_dim + (6,))\n\n\ndef quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert rotations given as quaternions to rotation matrices.\n\n    Args:\n        quaternions: quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n    \"\"\"\n    r, i, j, k = torch.unbind(quaternions, -1)\n    # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.\n    two_s = 2.0 / (quaternions * quaternions).sum(-1)\n\n    o = torch.stack(\n        (\n            1 - two_s * (j * j + k * k),\n            two_s * (i * j - k * r),\n            two_s * (i * k + j * r),\n            two_s * (i * j + k * r),\n            1 - two_s * (i * i + k * k),\n            two_s * (j * k - i * r),\n            two_s * (i * k - j * r),\n            two_s * (j * k + i * r),\n            1 - two_s * (i * i + j * j),\n        ),\n        -1,\n    )\n    return o.reshape(quaternions.shape[:-1] + (3, 3))\n\n\ndef _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Returns torch.sqrt(torch.max(0, x))\n    but with a zero subgradient where x is 0.\n    \"\"\"\n    ret = torch.zeros_like(x)\n    positive_mask = x > 0\n    ret[positive_mask] = torch.sqrt(x[positive_mask])\n    return ret\n\n\ndef matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert rotations given as rotation matrices to quaternions.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n    \"\"\"\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix shape {matrix.shape}.\")\n\n    batch_dim = matrix.shape[:-2]\n    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(\n        matrix.reshape(batch_dim + (9,)), dim=-1\n    )\n\n    q_abs = _sqrt_positive_part(\n        torch.stack(\n            [\n                1.0 + m00 + m11 + m22,\n                1.0 + m00 - m11 - m22,\n                1.0 - m00 + m11 - m22,\n                1.0 - m00 - m11 + m22,\n            ],\n            dim=-1,\n        )\n    )\n\n    # we produce the desired quaternion multiplied by each of r, i, j, k\n    quat_by_rijk = torch.stack(\n        [\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),\n        ],\n        dim=-2,\n    )\n\n    # We floor here at 0.1 but the exact level is not important; if q_abs is small,\n    # the candidate won't be picked.\n    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)\n    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))\n\n    # if not for numerical problems, quat_candidates[i] should be same (up to a sign),\n    # forall i; we pick the best-conditioned one (with the largest denominator)\n\n    return quat_candidates[\n        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :\n    ].reshape(batch_dim + (4,))\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/camera_utils.py",
    "content": "import numpy as np\n\n\ndef normalize(x: np.ndarray) -> np.ndarray:\n    \"\"\"Normalization helper function.\"\"\"\n    return x / np.linalg.norm(x)\n\n\ndef viewmatrix(lookdir: np.ndarray, up: np.ndarray, position: np.ndarray) -> np.ndarray:\n    \"\"\"Construct lookat view matrix.\"\"\"\n    vec2 = normalize(lookdir)\n    vec0 = normalize(np.cross(up, vec2))\n    vec1 = normalize(np.cross(vec2, vec0))\n    m = np.stack([vec0, vec1, vec2, position], axis=1)\n    return m\n\n\ndef generate_spiral_path(\n    pose: np.ndarray,\n    radius: float,\n    lookat_pt: np.ndarray = np.array([0, 0, 0]),\n    up: np.ndarray = np.array([0, 0, 1]),\n    n_frames: int = 60,\n    n_rots: int = 1,\n    y_scale: float = 1.0,\n) -> np.ndarray:\n    \"\"\"Calculates a forward facing spiral path for rendering.\"\"\"\n    x_axis = pose[:3, 0]\n    y_axis = pose[:3, 1]\n    campos = pose[:3, 3]\n\n    render_poses = []\n    for theta in np.linspace(0.0, 2 * np.pi * n_rots, n_frames, endpoint=False):\n        t = (np.cos(theta) * x_axis + y_scale * np.sin(theta) * y_axis) * radius\n        position = campos + t\n        z_axis = position - lookat_pt\n        new_pose = np.eye(4)\n        new_pose[:3] = viewmatrix(z_axis, up, position)\n        render_poses.append(new_pose)\n    render_poses = np.stack(render_poses, axis=0)\n    return render_poses\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/colmap_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport numpy as np\nimport collections\nimport struct\n\nCameraModel = collections.namedtuple(\n    \"CameraModel\", [\"model_id\", \"model_name\", \"num_params\"])\nCamera = collections.namedtuple(\n    \"Camera\", [\"id\", \"model\", \"width\", \"height\", \"params\"])\nBaseImage = collections.namedtuple(\n    \"Image\", [\"id\", \"qvec\", \"tvec\", \"camera_id\", \"name\", \"xys\", \"point3D_ids\"])\nPoint3D = collections.namedtuple(\n    \"Point3D\", [\"id\", \"xyz\", \"rgb\", \"error\", \"image_ids\", \"point2D_idxs\"])\nCAMERA_MODELS = {\n    CameraModel(model_id=0, model_name=\"SIMPLE_PINHOLE\", num_params=3),\n    CameraModel(model_id=1, model_name=\"PINHOLE\", num_params=4),\n    CameraModel(model_id=2, model_name=\"SIMPLE_RADIAL\", num_params=4),\n    CameraModel(model_id=3, model_name=\"RADIAL\", num_params=5),\n    CameraModel(model_id=4, model_name=\"OPENCV\", num_params=8),\n    CameraModel(model_id=5, model_name=\"OPENCV_FISHEYE\", num_params=8),\n    CameraModel(model_id=6, model_name=\"FULL_OPENCV\", num_params=12),\n    CameraModel(model_id=7, model_name=\"FOV\", num_params=5),\n    CameraModel(model_id=8, model_name=\"SIMPLE_RADIAL_FISHEYE\", num_params=4),\n    CameraModel(model_id=9, model_name=\"RADIAL_FISHEYE\", num_params=5),\n    CameraModel(model_id=10, model_name=\"THIN_PRISM_FISHEYE\", num_params=12)\n}\nCAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)\n                         for camera_model in CAMERA_MODELS])\nCAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)\n                           for camera_model in CAMERA_MODELS])\n\n\ndef qvec2rotmat(qvec):\n    return np.array([\n        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,\n         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],\n         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],\n        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],\n         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,\n         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],\n        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],\n         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],\n         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])\n\ndef rotmat2qvec(R):\n    Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat\n    K = np.array([\n        [Rxx - Ryy - Rzz, 0, 0, 0],\n        [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],\n        [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],\n        [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0\n    eigvals, eigvecs = np.linalg.eigh(K)\n    qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]\n    if qvec[0] < 0:\n        qvec *= -1\n    return qvec\n\nclass Image(BaseImage):\n    def qvec2rotmat(self):\n        return qvec2rotmat(self.qvec)\n\ndef read_next_bytes(fid, num_bytes, format_char_sequence, endian_character=\"<\"):\n    \"\"\"Read and unpack the next bytes from a binary file.\n    :param fid:\n    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.\n    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.\n    :param endian_character: Any of {@, =, <, >, !}\n    :return: Tuple of read and unpacked values.\n    \"\"\"\n    data = fid.read(num_bytes)\n    return struct.unpack(endian_character + format_char_sequence, data)\n\ndef read_points3D_text(path):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DText(const std::string& path)\n        void Reconstruction::WritePoints3DText(const std::string& path)\n    \"\"\"\n    xyzs = None\n    rgbs = None\n    errors = None\n    num_points = 0\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                num_points += 1\n\n\n    xyzs = np.empty((num_points, 3))\n    rgbs = np.empty((num_points, 3))\n    errors = np.empty((num_points, 1))\n    count = 0\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                xyz = np.array(tuple(map(float, elems[1:4])))\n                rgb = np.array(tuple(map(int, elems[4:7])))\n                error = np.array(float(elems[7]))\n                xyzs[count] = xyz\n                rgbs[count] = rgb\n                errors[count] = error\n                count += 1\n\n    return xyzs, rgbs, errors\n\ndef read_points3D_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DBinary(const std::string& path)\n        void Reconstruction::WritePoints3DBinary(const std::string& path)\n    \"\"\"\n\n\n    with open(path_to_model_file, \"rb\") as fid:\n        num_points = read_next_bytes(fid, 8, \"Q\")[0]\n\n        xyzs = np.empty((num_points, 3))\n        rgbs = np.empty((num_points, 3))\n        errors = np.empty((num_points, 1))\n\n        for p_id in range(num_points):\n            binary_point_line_properties = read_next_bytes(\n                fid, num_bytes=43, format_char_sequence=\"QdddBBBd\")\n            xyz = np.array(binary_point_line_properties[1:4])\n            rgb = np.array(binary_point_line_properties[4:7])\n            error = np.array(binary_point_line_properties[7])\n            track_length = read_next_bytes(\n                fid, num_bytes=8, format_char_sequence=\"Q\")[0]\n            track_elems = read_next_bytes(\n                fid, num_bytes=8*track_length,\n                format_char_sequence=\"ii\"*track_length)\n            xyzs[p_id] = xyz\n            rgbs[p_id] = rgb\n            errors[p_id] = error\n    return xyzs, rgbs, errors\n\ndef read_intrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    cameras = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                camera_id = int(elems[0])\n                model = elems[1]\n                assert model == \"PINHOLE\", \"While the loader support other types, the rest of the code assumes PINHOLE\"\n                width = int(elems[2])\n                height = int(elems[3])\n                params = np.array(tuple(map(float, elems[4:])))\n                cameras[camera_id] = Camera(id=camera_id, model=model,\n                                            width=width, height=height,\n                                            params=params)\n    return cameras\n\ndef read_extrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadImagesBinary(const std::string& path)\n        void Reconstruction::WriteImagesBinary(const std::string& path)\n    \"\"\"\n    images = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_reg_images = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_reg_images):\n            binary_image_properties = read_next_bytes(\n                fid, num_bytes=64, format_char_sequence=\"idddddddi\")\n            image_id = binary_image_properties[0]\n            qvec = np.array(binary_image_properties[1:5])\n            tvec = np.array(binary_image_properties[5:8])\n            camera_id = binary_image_properties[8]\n            image_name = \"\"\n            current_char = read_next_bytes(fid, 1, \"c\")[0]\n            while current_char != b\"\\x00\":   # look for the ASCII 0 entry\n                image_name += current_char.decode(\"utf-8\")\n                current_char = read_next_bytes(fid, 1, \"c\")[0]\n            num_points2D = read_next_bytes(fid, num_bytes=8,\n                                           format_char_sequence=\"Q\")[0]\n            x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,\n                                       format_char_sequence=\"ddq\"*num_points2D)\n            xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),\n                                   tuple(map(float, x_y_id_s[1::3]))])\n            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))\n            images[image_id] = Image(\n                id=image_id, qvec=qvec, tvec=tvec,\n                camera_id=camera_id, name=image_name,\n                xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_intrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::WriteCamerasBinary(const std::string& path)\n        void Reconstruction::ReadCamerasBinary(const std::string& path)\n    \"\"\"\n    cameras = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_cameras = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_cameras):\n            camera_properties = read_next_bytes(\n                fid, num_bytes=24, format_char_sequence=\"iiQQ\")\n            camera_id = camera_properties[0]\n            model_id = camera_properties[1]\n            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name\n            width = camera_properties[2]\n            height = camera_properties[3]\n            num_params = CAMERA_MODEL_IDS[model_id].num_params\n            params = read_next_bytes(fid, num_bytes=8*num_params,\n                                     format_char_sequence=\"d\"*num_params)\n            cameras[camera_id] = Camera(id=camera_id,\n                                        model=model_name,\n                                        width=width,\n                                        height=height,\n                                        params=np.array(params))\n        assert len(cameras) == num_cameras\n    return cameras\n\n\ndef read_extrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    images = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                image_id = int(elems[0])\n                qvec = np.array(tuple(map(float, elems[1:5])))\n                tvec = np.array(tuple(map(float, elems[5:8])))\n                camera_id = int(elems[8])\n                image_name = elems[9]\n                elems = fid.readline().split()\n                xys = np.column_stack([tuple(map(float, elems[0::3])),\n                                       tuple(map(float, elems[1::3]))])\n                point3D_ids = np.array(tuple(map(int, elems[2::3])))\n                images[image_id] = Image(\n                    id=image_id, qvec=qvec, tvec=tvec,\n                    camera_id=camera_id, name=image_name,\n                    xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_colmap_bin_array(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py\n\n    :param path: path to the colmap binary file.\n    :return: nd array with the floating point values in the value\n    \"\"\"\n    with open(path, \"rb\") as fid:\n        width, height, channels = np.genfromtxt(fid, delimiter=\"&\", max_rows=1,\n                                                usecols=(0, 1, 2), dtype=int)\n        fid.seek(0)\n        num_delimiter = 0\n        byte = fid.read(1)\n        while True:\n            if byte == b\"&\":\n                num_delimiter += 1\n                if num_delimiter >= 3:\n                    break\n            byte = fid.read(1)\n        array = np.fromfile(fid, np.float32)\n    array = array.reshape((width, height, channels), order=\"F\")\n    return np.transpose(array, (1, 0, 2)).squeeze()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/config.py",
    "content": "from omegaconf import OmegaConf\n\n\ndef load_config_with_merge(config_path: str):\n    cfg = OmegaConf.load(config_path)\n\n    path_ = cfg.get(\"_base\", None)\n\n    if path_ is not None:\n        print(f\"Merging base config from {path_}\")\n        cfg = OmegaConf.merge(load_config_with_merge(path_), cfg)\n    else:\n        return cfg\n    return cfg\n\n\ndef merge_without_none(base_cfg, override_cfg):\n    for key, value in override_cfg.items():\n        if value is not None:\n            base_cfg[key] = value\n        elif not (key in base_cfg):\n            base_cfg[key] = None\n    return base_cfg\n\n\ndef create_config(config_path, args, cli_args: list = []):\n    \"\"\"\n    Args:\n        config_path: path to config file\n        args: argparse object with known variables\n        cli_args: list of cli args in the format of\n            [\"lr=0.1\", \"model.name=alexnet\"]\n    \"\"\"\n    # recursively merge base config\n    cfg = load_config_with_merge(config_path)\n\n    # parse cli args, and merge them into cfg\n    cli_conf = OmegaConf.from_cli(cli_args)\n    arg_cfg = OmegaConf.create(vars(args))\n\n    # drop None in arg_cfg\n\n    arg_cfg = OmegaConf.merge(arg_cfg, cli_conf)\n\n    # cfg = OmegaConf.merge(cfg, arg_cfg, cli_conf)\n    cfg = merge_without_none(cfg, arg_cfg)\n\n    return cfg\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/dct.py",
    "content": "\"\"\"\nCode from https://github.com/zh217/torch-dct/blob/master/torch_dct/_dct.py\n\"\"\"\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\n\nimport torch.fft\n\n\ndef dct1_rfft_impl(x):\n    return torch.view_as_real(torch.fft.rfft(x, dim=1))\n\n\ndef dct_fft_impl(v):\n    return torch.view_as_real(torch.fft.fft(v, dim=1))\n\n\ndef idct_irfft_impl(V):\n    return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)\n\n\ndef dct(x, norm=None):\n    \"\"\"\n    Discrete Cosine Transform, Type II (a.k.a. the DCT)\n\n    For the meaning of the parameter `norm`, see:\n    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html\n\n    if norm is None:\n              N-1\n    y[k] = 2* sum x[n]*cos(pi*k*(2n+1)/(2*N)), 0 <= k < N.\n              n=0\n\n    :param x: the input signal\n    :param norm: the normalization, None or 'ortho'\n    :return: the DCT-II of the signal over the last dimension\n    \"\"\"\n    x_shape = x.shape\n    N = x_shape[-1]\n    x = x.contiguous().view(-1, N)\n\n    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)\n\n    Vc = dct_fft_impl(v)\n\n    k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)\n    W_r = torch.cos(k)\n    W_i = torch.sin(k)\n\n    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i\n\n    if norm == \"ortho\":\n        V[:, 0] /= np.sqrt(N) * 2\n        V[:, 1:] /= np.sqrt(N / 2) * 2\n\n    V = 2 * V.view(*x_shape)\n\n    return V\n\n\ndef idct(X, norm=None):\n    \"\"\"\n    The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III\n\n    Our definition of idct is that idct(dct(x)) == x\n\n    For the meaning of the parameter `norm`, see:\n    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html\n\n    :param X: the input signal\n    :param norm: the normalization, None or 'ortho'\n    :return: the inverse DCT-II of the signal over the last dimension\n    \"\"\"\n\n    x_shape = X.shape\n    N = x_shape[-1]\n\n    X_v = X.contiguous().view(-1, x_shape[-1]) / 2\n\n    if norm == \"ortho\":\n        X_v[:, 0] *= np.sqrt(N) * 2\n        X_v[:, 1:] *= np.sqrt(N / 2) * 2\n\n    k = (\n        torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :]\n        * np.pi\n        / (2 * N)\n    )\n    W_r = torch.cos(k)\n    W_i = torch.sin(k)\n\n    V_t_r = X_v\n    V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)\n\n    V_r = V_t_r * W_r - V_t_i * W_i\n    V_i = V_t_r * W_i + V_t_i * W_r\n\n    V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)\n\n    v = idct_irfft_impl(V)\n    x = v.new_zeros(v.shape)\n    x[:, ::2] += v[:, : N - (N // 2)]\n    x[:, 1::2] += v.flip([1])[:, : N // 2]\n\n    return x.view(*x_shape)\n\n\ndef dct_3d(x, norm=None):\n    \"\"\"\n    3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)\n\n    For the meaning of the parameter `norm`, see:\n    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html\n\n    :param x: the input signal\n    :param norm: the normalization, None or 'ortho'\n    :return: the DCT-II of the signal over the last 3 dimensions\n    \"\"\"\n    X1 = dct(x, norm=norm)\n    X2 = dct(X1.transpose(-1, -2), norm=norm)\n    X3 = dct(X2.transpose(-1, -3), norm=norm)\n    return X3.transpose(-1, -3).transpose(-1, -2)\n\n\ndef idct_3d(X, norm=None):\n    \"\"\"\n    The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III\n\n    Our definition of idct is that idct_3d(dct_3d(x)) == x\n\n    For the meaning of the parameter `norm`, see:\n    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html\n\n    :param X: the input signal\n    :param norm: the normalization, None or 'ortho'\n    :return: the DCT-II of the signal over the last 3 dimensions\n    \"\"\"\n    x1 = idct(X, norm=norm)\n    x2 = idct(x1.transpose(-1, -2), norm=norm)\n    x3 = idct(x2.transpose(-1, -3), norm=norm)\n    return x3.transpose(-1, -3).transpose(-1, -2)\n\n\ndef code_test_dct3d():\n    # init a tensor of shape [100, 20, 3]\n    x = torch.rand(100, 20, 3)\n\n    dct_coef = dct_3d(x, norm=\"ortho\")\n    print(\"inp signal shape: \", x.shape, \"  dct coef shape: \", dct_coef.shape)\n\n    x_recon = idct_3d(dct_coef, norm=\"ortho\")\n    print(\"inp signal shape: \", x.shape, \"  recon signal shape: \", x_recon.shape)\n\n    print(\"max error: \", torch.max(torch.abs(x - x_recon)))\n\n    dct_coef[:, 0, :] = 0\n\n    x_recon = idct_3d(dct_coef, norm=\"ortho\")\n    print(\"max error after removing first order: \", torch.max(torch.abs(x - x_recon)))\n\n\ndef unwarp_phase(phase, frequency_array):\n    phase_lambda = torch.pi / frequency_array\n\n    phase = phase + phase_lambda\n\n    num_unwarp = phase // (2.0 * phase_lambda)\n    phase = phase - num_unwarp * phase_lambda * 2.0\n\n    phase = phase - phase_lambda\n\n    return phase\n\n\ndef get_mag_phase(fft_weights, s=3.0 / 16.0):\n    \"\"\"\n    Args:\n        fft_weights: [*bs, numK * 2, 3/2] # [B**, numK, 2]\n    Returns:\n        mag_phase: [*bs, numK * 2, 3/2]\n    \"\"\"\n\n    num_K = fft_weights.shape[-2] // 2\n\n    # [num_k, 1]\n    k_list = torch.arange(1, num_K + 1, device=fft_weights.device).unsqueeze(-1)\n    # k_list = torch.ones_like(k_list) # need to fix this\n    k_list = torch.pi * 2 * k_list * s\n\n    _t_shape = fft_weights.shape[:-2] + (num_K, 1)\n    k_list.expand(_t_shape)  # [B**, numK, 1]\n\n    # [*bs, numK, 3/2]\n    a, b = torch.split(fft_weights, num_K, dim=-2)\n\n    # [B**, numK, 3/2]\n    mag = torch.sqrt(a**2 + b**2 + 1e-10)\n\n    sin_k_theta = -1.0 * b / (mag.detach())  # Do I need to detach?\n    cos_k_theta = a / (mag.detach())  # Do I need to detach here?\n\n    # [-pi, pi]\n    k_theta = torch.atan2(sin_k_theta, cos_k_theta)\n    theta = k_theta / k_list\n\n    # [B**, numK * 2, 3/2]\n    mag_phase = torch.cat([mag, theta], dim=-2)\n\n    return mag_phase\n\n\ndef get_fft_from_mag_phase(mag_phase, s=3.0 / 16.0):\n    \"\"\"\n    Args:\n        mag_phase: [*bs, numK * 2, 3/2] # [B**, numK, 2]\n    Returns:\n        fft_weights: [*bs, numK * 2, 3/2]\n    \"\"\"\n\n    num_K = mag_phase.shape[-2] // 2\n\n    k_list = torch.arange(1, num_K + 1, device=mag_phase.device).unsqueeze(-1)\n    # k_list = torch.ones_like(k_list) # need to fix this\n    k_list = torch.pi * 2 * k_list * s  # scale to get frequency\n\n    _t_shape = mag_phase.shape[:-2] + (num_K, 1)\n    k_list.expand(_t_shape)  # [B**, numK, 1]\n\n    # [*bs, numK, 3/2]\n    mag, phase = torch.split(mag_phase, num_K, dim=-2)\n\n    theta = phase * k_list\n\n    a = mag * torch.cos(theta)\n    b = -1.0 * mag * torch.sin(theta)\n\n    # [B**, numK * 2, 3/2]\n    fft_weights = torch.cat([a, b], dim=-2)\n\n    return fft_weights\n\n\ndef get_displacements_from_fft_coeffs(fft_coe, t, s=3.0 / 16.0):\n    \"\"\"\n    Args:\n        fft_coe: [*bs, numK * 2, 3/2]\n        t: [*bs, 1]\n\n    Returns:\n        disp = a * cos(freq * t) - b * sin(freq * t).\n            Note that some formulation use\n            disp = a * cos(freq * t) + b * sin(freq * t)\n        shape of disp: [*bs, 3/2]\n    \"\"\"\n    num_K = fft_coe.shape[-2] // 2\n    k_list = torch.arange(1, num_K + 1, device=fft_coe.device)\n    # [num_K, 1]\n    freq_array = (torch.pi * 2 * k_list * s).unsqueeze(-1)\n\n    # expand front dims to match t\n    _tmp_shape = t.shape[:-1] + freq_array.shape\n    freq_array.expand(_tmp_shape)  # [*bs, num_K, 1]\n\n    cos_ = torch.cos(freq_array * t.unsqueeze(-2))\n    sin_ = -1.0 * torch.sin(freq_array * t.unsqueeze(-2))\n\n    # [*bs, num_K * 2] => [*bs, num_K]\n    basis = torch.cat([cos_, sin_], dim=-2).squeeze(dim=-1)  #\n\n    # [*bs, num_K * 2, 3/2] => [*bs, 3/2]\n    disp = (basis.unsqueeze(-1) * fft_coe).sum(dim=-2)\n\n    return disp\n\n\ndef bandpass_filter(signal: torch.Tensor, low_cutoff, high_cutoff, fs: int):\n    \"\"\"\n    Args:\n        signal: [T, ...]\n        low_cutoff: float\n        high_cutoff: float\n        fs: int\n    \"\"\"\n    # Apply FFT\n    fft_signal = torch.fft.fft(signal, dim=0)\n    freq = torch.fft.fftfreq(signal.size(0), d=1 / fs)\n\n    # Bandpass filter\n    mask = (freq <= low_cutoff) | (freq >= high_cutoff)\n    fft_signal[mask] = 0\n\n    # Apply inverse FFT\n    filtered_signal = torch.fft.ifft(fft_signal, dim=0)\n    return filtered_signal.real\n\n\ndef bandpass_filter_numpy(signal: np.ndarray, low_cutoff, high_cutoff, fs):\n    # Apply FFT\n    fft_signal = np.fft.fft(signal, axis=0)\n    freq = np.fft.fftfreq(signal.shape[0], d=1 / fs)\n\n    # Bandpass filter\n    fft_signal[(freq <= low_cutoff) | (freq >= high_cutoff)] = 0\n\n    # Apply inverse FFT\n    filtered_signal = np.fft.ifft(fft_signal, axis=0)\n    return filtered_signal.real\n\n\nif __name__ == \"__main__\":\n    code_test_dct3d()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/flow_utils.py",
    "content": "\nimport numpy as np\n\ndef flow_to_image(flow, display=False):\n    \"\"\"\n    Convert flow into middlebury color code image\n    :param flow: optical flow map\n    :return: optical flow image in middlebury color\n    \"\"\"\n    UNKNOWN_FLOW_THRESH = 100\n    u = flow[:, :, 0]\n    v = flow[:, :, 1]\n\n    maxu = -999.\n    maxv = -999.\n    minu = 999.\n    minv = 999.\n\n    idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)\n    u[idxUnknow] = 0\n    v[idxUnknow] = 0\n\n    maxu = max(maxu, np.max(u))\n    minu = min(minu, np.min(u))\n\n    maxv = max(maxv, np.max(v))\n    minv = min(minv, np.min(v))\n\n    # sqrt_rad = u**2 + v**2\n    rad = np.sqrt(u**2 + v**2)\n\n    maxrad = max(-1, np.max(rad))\n\n    if display:\n        print(\"max flow: %.4f\\nflow range:\\nu = %.3f .. %.3f\\nv = %.3f .. %.3f\" % (maxrad, minu,maxu, minv, maxv))\n\n    u = u/(maxrad + np.finfo(float).eps)\n    v = v/(maxrad + np.finfo(float).eps)\n\n    img = compute_color(u, v)\n\n    idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)\n    img[idx] = 0\n\n    return np.uint8(img)\n\n\ndef make_color_wheel():\n    \"\"\"\n    Generate color wheel according Middlebury color code\n    :return: Color wheel\n    \"\"\"\n    RY = 15\n    YG = 6\n    GC = 4\n    CB = 11\n    BM = 13\n    MR = 6\n\n    ncols = RY + YG + GC + CB + BM + MR\n\n    colorwheel = np.zeros([ncols, 3])\n\n    col = 0\n\n    # RY\n    colorwheel[0:RY, 0] = 255\n    colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))\n    col += RY\n\n    # YG\n    colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))\n    colorwheel[col:col+YG, 1] = 255\n    col += YG\n\n    # GC\n    colorwheel[col:col+GC, 1] = 255\n    colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))\n    col += GC\n\n    # CB\n    colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))\n    colorwheel[col:col+CB, 2] = 255\n    col += CB\n\n    # BM\n    colorwheel[col:col+BM, 2] = 255\n    colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))\n    col += + BM\n\n    # MR\n    colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))\n    colorwheel[col:col+MR, 0] = 255\n\n    return colorwheel\n\n\ndef compute_color(u, v):\n    \"\"\"\n    compute optical flow color map\n    :param u: optical flow horizontal map\n    :param v: optical flow vertical map\n    :return: optical flow in color code\n    \"\"\"\n    [h, w] = u.shape\n    img = np.zeros([h, w, 3])\n    nanIdx = np.isnan(u) | np.isnan(v)\n    u[nanIdx] = 0\n    v[nanIdx] = 0\n\n    colorwheel = make_color_wheel()\n    ncols = np.size(colorwheel, 0)\n\n    rad = np.sqrt(u**2+v**2)\n\n    a = np.arctan2(-v, -u) / np.pi\n\n    fk = (a+1) / 2 * (ncols - 1) + 1\n\n    k0 = np.floor(fk).astype(int)\n\n    k1 = k0 + 1\n    k1[k1 == ncols+1] = 1\n    f = fk - k0\n\n    for i in range(0, np.size(colorwheel,1)):\n        tmp = colorwheel[:, i]\n        col0 = tmp[k0-1] / 255\n        col1 = tmp[k1-1] / 255\n        col = (1-f) * col0 + f * col1\n\n        idx = rad <= 1\n        col[idx] = 1-rad[idx]*(1-col[idx])\n        notidx = np.logical_not(idx)\n\n        col[notidx] *= 0.75\n        img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))\n\n    return img"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/img_utils.py",
    "content": "import torch\nimport torchvision\nimport cv2\nimport numpy as np\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nfrom math import exp\n\n\ndef make_grid(imgs: torch.Tensor, scale=0.5):\n    \"\"\"\n    Args:\n        imgs: [B, C, H, W] in [0, 1]\n    Output:\n        x row of images, and 3 x column of images\n        which means 3 x ^ 2 <= B\n\n        img_grid: np.ndarray, [H', W', C]\n    \"\"\"\n\n    B, C, H, W = imgs.shape\n\n    num_row = int(np.sqrt(B / 3))\n    if num_row < 1:\n        num_row = 1\n    num_col = int(np.ceil(B / num_row))\n\n    img_grid = torchvision.utils.make_grid(imgs, nrow=num_col, padding=0)\n\n    img_grid = img_grid.permute(1, 2, 0).cpu().numpy()\n\n    # resize by scale\n    img_grid = cv2.resize(img_grid, None, fx=scale, fy=scale)\n    return img_grid\n\n\ndef compute_psnr(img1, img2, mask=None):\n    \"\"\"\n    Args:\n        img1: [B, C, H, W]\n        img2: [B, C, H, W]\n        mask: [B, 1, H, W] or [1, 1, H, W] or None\n    Outs:\n        psnr: [B]\n    \"\"\"\n    # batch dim is preserved\n    if mask is None:\n        mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)\n    else:\n        if mask.shape[0] != img1.shape[0]:\n            mask = mask.repeat(img1.shape[0], 1, 1, 1)\n        if mask.shape[1] != img1.shape[1]:\n            mask = mask.repeat(1, img1.shape[1], 1, 1)\n\n        diff = ((img1 - img2)) ** 2\n        diff = diff * mask\n        mse = diff.view(img1.shape[0], -1).sum(1, keepdim=True) / (\n            mask.view(img1.shape[0], -1).sum(1, keepdim=True) + 1e-8\n        )\n\n    return 20 * torch.log10(1.0 / torch.sqrt(mse))\n\n\ndef torch_rgb_to_gray(image):\n    # image is [B, C, H, W]\n    gray_image = (\n        0.299 * image[:, 0, :, :]\n        + 0.587 * image[:, 1, :, :]\n        + 0.114 * image[:, 2, :, :]\n    )\n    gray_image = gray_image.unsqueeze(1)\n\n    return gray_image\n\n\ndef compute_gradient_loss(pred, gt, mask=None):\n    \"\"\"\n    Args:\n        pred: [B, C, H, W]\n        gt: [B, C, H, W]\n        mask: [B, 1, H, W] or None\n    \"\"\"\n    assert pred.shape == gt.shape, \"a and b must have the same shape\"\n\n    pred = torch_rgb_to_gray(pred)\n    gt = torch_rgb_to_gray(gt)\n\n    sobel_kernel_x = torch.tensor(\n        [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=pred.dtype, device=pred.device\n    )\n    sobel_kernel_y = torch.tensor(\n        [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=pred.dtype, device=pred.device\n    )\n\n    gradient_a_x = (\n        torch.nn.functional.conv2d(\n            pred.repeat(1, 3, 1, 1),\n            sobel_kernel_x.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1),\n            padding=1,\n        )\n        / 3\n    )\n    gradient_a_y = (\n        torch.nn.functional.conv2d(\n            pred.repeat(1, 3, 1, 1),\n            sobel_kernel_y.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1),\n            padding=1,\n        )\n        / 3\n    )\n    # gradient_a_magnitude = torch.sqrt(gradient_a_x ** 2 + gradient_a_y ** 2)\n\n    gradient_b_x = (\n        torch.nn.functional.conv2d(\n            gt.repeat(1, 3, 1, 1),\n            sobel_kernel_x.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1),\n            padding=1,\n        )\n        / 3\n    )\n    gradient_b_y = (\n        torch.nn.functional.conv2d(\n            gt.repeat(1, 3, 1, 1),\n            sobel_kernel_y.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1),\n            padding=1,\n        )\n        / 3\n    )\n    # gradient_b_magnitude = torch.sqrt(gradient_b_x ** 2 + gradient_b_y ** 2)\n\n    pred_grad = torch.cat([gradient_a_x, gradient_a_y], dim=1)\n    gt_grad = torch.cat([gradient_b_x, gradient_b_y], dim=1)\n\n    if mask is None:\n        gradient_difference = torch.abs(pred_grad - gt_grad).mean()\n    else:\n        gradient_difference = torch.abs(pred_grad - gt_grad).mean(dim=1, keepdim=True)[\n            mask\n        ].sum() / (mask.sum() + 1e-8)\n\n    return gradient_difference\n\n\ndef mark_image_with_red_squares(img):\n    # img, torch.Tensor of shape [B, H, W, C]\n\n    mark_color = torch.tensor([1.0, 0, 0], dtype=torch.float32)\n\n    for x_offset in range(4):\n        for y_offset in range(4):\n            img[:, x_offset::16, y_offset::16, :] = mark_color\n\n    return img\n\n\n# below for compute batched SSIM\ndef gaussian(window_size, sigma):\n\n    gauss = torch.Tensor(\n        [\n            exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))\n            for x in range(window_size)\n        ]\n    )\n    return gauss / gauss.sum()\n\n\ndef create_window(window_size, channel):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = Variable(\n        _2D_window.expand(channel, 1, window_size, window_size).contiguous()\n    )\n    return window\n\n\ndef compute_ssim(img1, img2, window_size=11, size_average=True):\n    channel = img1.size(-3)\n    window = create_window(window_size, channel)\n\n    if img1.is_cuda:\n        window = window.cuda(img1.get_device())\n    window = window.type_as(img1)\n\n    return _ssim(img1, img2, window, window_size, channel, size_average)\n\n\ndef _ssim(img1, img2, window, window_size, channel, size_average=True):\n    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)\n    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = (\n        F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq\n    )\n    sigma2_sq = (\n        F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq\n    )\n    sigma12 = (\n        F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)\n        - mu1_mu2\n    )\n\n    C1 = 0.01**2\n    C2 = 0.03**2\n\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (\n        (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)\n    )\n\n    if size_average:\n        return ssim_map.mean()\n    else:\n        return ssim_map.mean(1).mean(1).mean(1)\n\n\n# above for compute batched SSIM\n\n\ndef compute_low_res_psnr(img1, img2, scale_factor):\n    \"\"\"\n    Args:\n        img1: [B, C, H, W]\n        img2: [B, C, H, W]\n        scale_factor: int\n    \"\"\"\n    img1 = F.interpolate(\n        img1, scale_factor=1 / scale_factor, mode=\"bilinear\", align_corners=False\n    )\n    img2 = F.interpolate(\n        img2, scale_factor=1 / scale_factor, mode=\"bilinear\", align_corners=False\n    )\n    return compute_psnr(img1, img2)\n\n\ndef compute_low_res_mse(img1, img2, scale_factor):\n    \"\"\"\n    Args:\n        img1: [B, C, H, W]\n        img2: [B, C, H, W]\n        scale_factor: int\n    \"\"\"\n    img1 = F.interpolate(\n        img1, scale_factor=1 / scale_factor, mode=\"bilinear\", align_corners=False\n    )\n    img2 = F.interpolate(\n        img2, scale_factor=1 / scale_factor, mode=\"bilinear\", align_corners=False\n    )\n    loss_mse = F.mse_loss(img1, img2, reduction=\"mean\")\n    return loss_mse\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/io_utils.py",
    "content": "import cv2\nimport imageio\nimport numpy as np\nimport mediapy\nimport os\nimport PIL\n\n\ndef read_video_cv2(video_path, rgb=True):\n    \"\"\"Read video using cv2, return [T, 3, H, W] array, fps\"\"\"\n\n    # BGR\n    cap = cv2.VideoCapture(video_path)\n    fps = cap.get(cv2.CAP_PROP_FPS)\n    num_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n    ret_list = []\n    for i in range(num_frame):\n        ret, frame = cap.read()\n        if ret:\n            if rgb:\n                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n            frame = np.transpose(frame, [2, 0, 1])  # [3, H, W]\n            ret_list.append(frame[np.newaxis, ...])\n        else:\n            break\n    cap.release()\n    ret_array = np.concatenate(ret_list, axis=0)  # [T, 3, H, W]\n    return ret_array, fps\n\n\ndef save_video_cv2(video_path, img_list, fps):\n    # BGR\n\n    if len(img_list) == 0:\n        return\n    h, w = img_list[0].shape[:2]\n    fourcc = cv2.VideoWriter_fourcc(\n        *\"mp4v\"\n    )  # cv2.VideoWriter_fourcc('m', 'p', '4', 'v')\n    writer = cv2.VideoWriter(video_path, fourcc, fps, (w, h))\n\n    for frame in img_list:\n        writer.write(frame)\n    writer.release()\n\n\ndef save_video_imageio(video_path, img_list, fps):\n    \"\"\"\n    Img_list: [[H, W, 3]]\n    \"\"\"\n    if len(img_list) == 0:\n        return\n    writer = imageio.get_writer(video_path, fps=fps)\n    for frame in img_list:\n        writer.append_data(frame)\n\n    writer.close()\n\n\ndef save_gif_imageio(video_path, img_list, fps):\n    \"\"\"\n    Img_list: [[H, W, 3]]\n    \"\"\"\n    if len(img_list) == 0:\n        return\n    assert video_path.endswith(\".gif\")\n\n    imageio.mimsave(video_path, img_list, format=\"GIF\", fps=fps)\n\n\ndef save_video_mediapy(video_frames, output_video_path: str = None, fps: int = 14):\n    # video_frames: [N, H, W, 3]\n    if isinstance(video_frames[0], PIL.Image.Image):\n        video_frames = [np.array(frame) for frame in video_frames]\n    os.makedirs(os.path.dirname(output_video_path), exist_ok=True)\n    mediapy.write_video(output_video_path, video_frames, fps=fps, qp=18)\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/optimizer.py",
    "content": "import torch\nfrom torch.optim.lr_scheduler import LambdaLR\n\n\ndef get_linear_schedule_with_warmup(\n    optimizer, num_warmup_steps, num_training_steps, last_epoch=-1\n):\n    \"\"\"\n    From diffusers.optimization\n    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after\n    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.\n\n    Args:\n        optimizer ([`~torch.optim.Optimizer`]):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (`int`):\n            The total number of training steps.\n        last_epoch (`int`, *optional*, defaults to -1):\n            The index of the last epoch when resuming training.\n\n    Return:\n        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n\n    def lr_lambda(current_step: int):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        return max(\n            0.0,\n            float(num_training_steps - current_step)\n            / float(max(1, num_training_steps - num_warmup_steps)),\n        )\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/peft_utils.py",
    "content": "import peft\n\nfrom peft.utils.save_and_load import get_peft_model_state_dict\nfrom peft import PeftModel\n\n\ndef save_peft_adaptor(model: peft.PeftModel, dir, save_base_model=False):\n    # save the adaptor only\n    model.save_pretrained(dir)\n\n    if save_base_model:\n        raise NotImplementedError\n\n\ndef load_peft_adaptor_and_merge(adaptor_path, base_model):\n    model = PeftModel.from_pretrained(base_model, adaptor_path)\n    model = model.merge_and_unload()\n\n    return model\n\n\ndef _code_test_peft_load_save():\n    import torch.nn as nn\n    import torch\n    import copy\n\n    class MLP(nn.Module):\n        def __init__(self, num_units_hidden=10):\n            super().__init__()\n            self.seq = nn.Sequential(\n                nn.Linear(20, num_units_hidden),\n                nn.ReLU(),\n                nn.Linear(num_units_hidden, num_units_hidden),\n                nn.ReLU(),\n                nn.Linear(num_units_hidden, 2),\n                nn.LogSoftmax(dim=-1),\n            )\n\n        def forward(self, X):\n            return self.seq(X)\n\n    module = MLP()\n    print(\"=> Name of original model parameters:\")\n    for name, param in module.named_parameters():\n        print(name, param.shape)\n    module_copy = copy.deepcopy(module)\n    config = peft.LoraConfig(\n        r=8,\n        target_modules=[\"seq.0\", \"seq.2\"],\n        modules_to_save=[\"seq.4\"],\n    )\n    peft_model = peft.get_peft_model(module, config)\n\n    peft_model.print_trainable_parameters()\n\n    print(\"\\n=> Name of PeftModel's parameters:\")\n    for name, param in peft_model.named_parameters():\n        print(name, param.shape)\n\n    save_path = \"./tmp\"\n\n    save_peft_adaptor(peft_model, save_path)\n\n    loaded_merged_model = load_peft_adaptor_and_merge(save_path, module_copy)\n\n    print(\"\\n=> Name of Loaded and Merged model's parameters:\")\n    for name, param in loaded_merged_model.named_parameters():\n        print(name, param.shape)\n\n\nif __name__ == \"__main__\":\n    _code_test_peft_load_save()\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/print_utils.py",
    "content": "import torch.distributed as dist\n\n\ndef print_if_zero_rank(s):\n    if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0):\n        print(\"### \" + s)\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/pytorch_mssim.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom math import exp\nimport numpy as np\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])\n    return gauss/gauss.sum()\n\n\ndef create_window(window_size, channel=1):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)\n    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()\n    return window\n\ndef create_window_3d(window_size, channel=1):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t())\n    _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())\n    window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)\n    return window\n\n\ndef ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):\n    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).\n    if val_range is None:\n        if torch.max(img1) > 128:\n            max_val = 255\n        else:\n            max_val = 1\n\n        if torch.min(img1) < -0.5:\n            min_val = -1\n        else:\n            min_val = 0\n        L = max_val - min_val\n    else:\n        L = val_range\n\n    padd = 0\n    (_, channel, height, width) = img1.size()\n    if window is None:\n        real_size = min(window_size, height, width)\n        window = create_window(real_size, channel=channel).to(img1.device)\n    \n    # mu1 = F.conv2d(img1, window, padding=padd, groups=channel)\n    # mu2 = F.conv2d(img2, window, padding=padd, groups=channel)\n    mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)\n    mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq\n    sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq\n    sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2\n\n    C1 = (0.01 * L) ** 2\n    C2 = (0.03 * L) ** 2\n\n    v1 = 2.0 * sigma12 + C2\n    v2 = sigma1_sq + sigma2_sq + C2\n    cs = torch.mean(v1 / v2)  # contrast sensitivity\n\n    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)\n\n    if size_average:\n        ret = ssim_map.mean()\n    else:\n        ret = ssim_map.mean(1).mean(1).mean(1)\n\n    if full:\n        return ret, cs\n    return ret\n\n\ndef ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):\n    \"\"\"\n    Args:\n        img1, img2: (N, C, H, W)\n    \"\"\"\n    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).\n    if val_range is None:\n        if torch.max(img1) > 128:\n            max_val = 255\n        else:\n            max_val = 1\n\n        if torch.min(img1) < -0.5:\n            min_val = -1\n        else:\n            min_val = 0\n        L = max_val - min_val\n    else:\n        L = val_range\n\n    padd = 0\n    (_, _, height, width) = img1.size()\n    if window is None:\n        real_size = min(window_size, height, width)\n        window = create_window_3d(real_size, channel=1).to(img1.device)\n        # Channel is set to 1 since we consider color images as volumetric images\n\n    img1 = img1.unsqueeze(1)\n    img2 = img2.unsqueeze(1)\n\n    mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)\n    mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq\n    sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq\n    sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2\n\n    C1 = (0.01 * L) ** 2\n    C2 = (0.03 * L) ** 2\n\n    v1 = 2.0 * sigma12 + C2\n    v2 = sigma1_sq + sigma2_sq + C2\n    cs = torch.mean(v1 / v2)  # contrast sensitivity\n\n    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)\n\n    if size_average:\n        ret = ssim_map.mean()\n    else:\n        ret = ssim_map.mean(1).mean(1).mean(1)\n\n    if full:\n        return ret, cs\n    return ret\n\n\ndef msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):\n    device = img1.device\n    weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)\n    levels = weights.size()[0]\n    mssim = []\n    mcs = []\n    for _ in range(levels):\n        sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)\n        mssim.append(sim)\n        mcs.append(cs)\n\n        img1 = F.avg_pool2d(img1, (2, 2))\n        img2 = F.avg_pool2d(img2, (2, 2))\n\n    mssim = torch.stack(mssim)\n    mcs = torch.stack(mcs)\n\n    # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)\n    if normalize:\n        mssim = (mssim + 1) / 2\n        mcs = (mcs + 1) / 2\n\n    pow1 = mcs ** weights\n    pow2 = mssim ** weights\n    # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/\n    output = torch.prod(pow1[:-1] * pow2[-1])\n    return output\n\n\n# Classes to re-use window\nclass SSIM(torch.nn.Module):\n    def __init__(self, window_size=11, size_average=True, val_range=None):\n        super(SSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.val_range = val_range\n\n        # Assume 3 channel for SSIM\n        self.channel = 3\n        self.window = create_window(window_size, channel=self.channel)\n\n    def forward(self, img1, img2):\n        (_, channel, _, _) = img1.size()\n\n        if channel == self.channel and self.window.dtype == img1.dtype:\n            window = self.window\n        else:\n            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)\n            self.window = window\n            self.channel = channel\n\n        _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)\n        dssim = (1 - _ssim) / 2\n        return dssim\n\nclass MSSSIM(torch.nn.Module):\n    def __init__(self, window_size=11, size_average=True, channel=3):\n        super(MSSSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.channel = channel\n\n    def forward(self, img1, img2):\n        return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/svd_helpper.py",
    "content": "from glob import glob\nfrom sys import version\nfrom typing import Dict, List, Optional, Tuple, Union\nimport numpy as np\nimport torch\nimport os\n\nfrom omegaconf import ListConfig, OmegaConf\nfrom safetensors.torch import load_file as load_safetensors\n\nfrom sgm.inference.helpers import embed_watermark\nfrom sgm.modules.diffusionmodules.guiders import LinearPredictionGuider, VanillaCFG\nfrom sgm.util import append_dims, default, instantiate_from_config\nimport math\nfrom einops import repeat\n\n\ndef init_st(version_dict, load_ckpt=True, load_filter=True):\n    state = dict()\n    if not \"model\" in state:\n        config = version_dict[\"config\"]\n        ckpt = version_dict[\"ckpt\"]\n\n        config = OmegaConf.load(config)\n        model, msg = load_model_from_config(config, ckpt if load_ckpt else None)\n\n        state[\"msg\"] = msg\n        state[\"model\"] = model\n        state[\"ckpt\"] = ckpt if load_ckpt else None\n        state[\"config\"] = config\n        if load_filter:\n            return state\n            # from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering\n            state[\"filter\"] = DeepFloydDataFiltering(verbose=False)\n    return state\n\n\ndef load_model_from_config(config, ckpt=None, verbose=True):\n    model = instantiate_from_config(config.model)\n\n    if ckpt is not None:\n        print(f\"Loading model from {ckpt}\")\n        if ckpt.endswith(\"ckpt\"):\n            pl_sd = torch.load(ckpt, map_location=\"cpu\")\n            if \"global_step\" in pl_sd:\n                global_step = pl_sd[\"global_step\"]\n                print(f\"Global Step: {pl_sd['global_step']}\")\n            sd = pl_sd[\"state_dict\"]\n        elif ckpt.endswith(\"safetensors\"):\n            sd = load_safetensors(ckpt)\n        else:\n            raise NotImplementedError\n\n        msg = None\n\n        m, u = model.load_state_dict(sd, strict=False)\n\n        if len(m) > 0 and verbose:\n            print(\"missing keys:\")\n            print(m)\n        if len(u) > 0 and verbose:\n            print(\"unexpected keys:\")\n            print(u)\n    else:\n        msg = None\n\n    model = initial_model_load(model)\n    # model.eval()  # ?\n    return model, msg\n\n\ndef load_model(model):\n    model.cuda()\n\n\nlowvram_mode = False\n\n\ndef set_lowvram_mode(mode):\n    global lowvram_mode\n    lowvram_mode = mode\n\n\ndef initial_model_load(model):\n    global lowvram_mode\n    if lowvram_mode:\n        model.model.half()\n    else:\n        model.cuda()\n    return model\n\n\ndef unload_model(model):\n    global lowvram_mode\n    if lowvram_mode:\n        model.cpu()\n        torch.cuda.empty_cache()\n\n\ndef get_unique_embedder_keys_from_conditioner(conditioner):\n    return list(set([x.input_key for x in conditioner.embedders]))\n\n\ndef get_batch(keys, value_dict, N, T, device):\n    batch = {}\n    batch_uc = {}\n\n    for key in keys:\n        if key == \"fps_id\":\n            batch[key] = (\n                torch.tensor([value_dict[\"fps_id\"]])\n                .to(device)\n                .repeat(int(math.prod(N)))\n            )\n        elif key == \"motion_bucket_id\":\n            batch[key] = (\n                torch.tensor([value_dict[\"motion_bucket_id\"]])\n                .to(device)\n                .repeat(int(math.prod(N)))\n            )\n        elif key == \"cond_aug\":\n            batch[key] = repeat(\n                torch.tensor([value_dict[\"cond_aug\"]]).to(device),\n                \"1 -> b\",\n                b=math.prod(N),\n            )\n        elif key == \"cond_frames\":\n            batch[key] = repeat(value_dict[\"cond_frames\"], \"1 ... -> b ...\", b=N[0])\n        elif key == \"cond_frames_without_noise\":\n            batch[key] = repeat(\n                value_dict[\"cond_frames_without_noise\"], \"1 ... -> b ...\", b=N[0]\n            )\n        else:\n            batch[key] = value_dict[key]\n\n    if T is not None:\n        batch[\"num_video_frames\"] = T\n\n    for key in batch.keys():\n        if key not in batch_uc and isinstance(batch[key], torch.Tensor):\n            batch_uc[key] = torch.clone(batch[key])\n    return batch, batch_uc\n\n\nif __name__ == \"__main__\":\n    pass\n"
  },
  {
    "path": "projects/uncleaned_train/motionrep/utils/torch_utils.py",
    "content": "import torch\nimport time\n\n\ndef get_sync_time():\n    if torch.cuda.is_available():\n        torch.cuda.synchronize()\n    return time.time()\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/convert_gaussian_to_mesh.py",
    "content": "import os\nfrom random import gauss\nfrom fire import Fire\nfrom motionrep.gaussian_3d.scene import GaussianModel\nimport numpy as np\nimport torch\n\n\ndef convert_gaussian_to_mesh(gaussian_path, save_path=None):\n    if save_path is None:\n        dir_path = os.path.dirname(gaussian_path)\n        save_path = os.path.join(dir_path, \"gaussian_to_mesh.obj\")\n\n    gaussian_path = os.path.join(gaussian_path)\n\n    gaussians = GaussianModel(3)\n\n    gaussians.load_ply(gaussian_path)\n    gaussians.detach_grad()\n    print(\n        \"load gaussians from: {}\".format(gaussian_path),\n        \"... num gaussians: \",\n        gaussians._xyz.shape[0],\n    )\n\n    mesh = gaussians.extract_mesh(\n        save_path, density_thresh=1, resolution=128, decimate_target=1e5\n    )\n\n    mesh.write(save_path)\n\n\ndef internal_filling(gaussian_path, save_path=None, resolution=64):\n    if save_path is None:\n        dir_path = os.path.dirname(gaussian_path)\n        save_path = os.path.join(dir_path, \"gaussian_internal_fill.ply\")\n\n    gaussians = GaussianModel(3)\n\n    gaussians.load_ply(gaussian_path)\n    gaussians.detach_grad()\n\n    print(\n        \"load gaussians from: {}\".format(gaussian_path),\n        \"... num gaussians: \",\n        gaussians._xyz.shape[0],\n    )\n\n    # [res, res, res]\n    occ = (\n        gaussians.extract_fields(resolution=resolution, num_blocks=16, relax_ratio=1.5)\n        .detach()\n        .cpu()\n        .numpy()\n    )\n\n    xyzs = gaussians._xyz.detach().cpu().numpy()\n\n    center = gaussians.center.detach().cpu().numpy()\n    scale = gaussians.scale # float\n    xyzs = (xyzs - center) * scale # [-1.5, 1.5]?\n\n    percentile = [82, 84, 86][1]\n\n    # from IPython import embed\n    # embed()\n\n    thres = np.percentile(occ, percentile)\n    print(\"density threshold: {:.2f} -- in percentile: {:.1f} \".format(thres, percentile))\n    occ_large_thres = occ > thres\n    # get the xyz of the occupied voxels\n    # xyz = np.argwhere(occ)\n    # normalize to [-1, 1]\n    # xyz = xyz / (resolution - 1) * 2 - 1\n\n    voxel_counts = np.zeros((resolution, resolution, resolution))\n\n    points_xyzindex = ((xyzs + 1) / 2 * (resolution - 1)).astype(np.uint32)\n\n    for x, y, z in points_xyzindex:\n        voxel_counts[x, y, z] += 1\n    \n    add_points = np.logical_and(occ_large_thres, voxel_counts <= 1)\n\n    add_xyz = np.argwhere(add_points).astype(np.float32)\n    add_xyz = add_xyz / (resolution - 1) * 2 - 1\n\n    all_xyz = np.concatenate([xyzs, add_xyz], axis=0)\n\n    print(\"added points: \", add_xyz.shape[0])\n    \n    # save to ply\n    import point_cloud_utils as pcu\n\n    pcu.save_mesh_vf(save_path, all_xyz, np.zeros((0, 3), dtype=np.int32))\n\n    add_path = os.path.join(os.path.dirname(save_path), \"extra_filled_points.ply\")\n    pcu.save_mesh_vf(add_path, add_xyz, np.zeros((0, 3), dtype=np.int32))\n\n\n\nif __name__ == \"__main__\":\n    # Fire(convert_gaussian_to_mesh)\n\n    Fire(internal_filling)\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/diff_warp_utils.py",
    "content": "import warp as wp\nimport warp.torch\nimport torch\nfrom typing import Optional, Union, Sequence, Any\nfrom torch import Tensor\nfrom warp_rewrite import from_torch_safe\n\n\n@wp.struct\nclass MPMStateStruct(object):\n    ###### essential #####\n    # particle\n    particle_x: wp.array(dtype=wp.vec3)  # current position\n    particle_v: wp.array(dtype=wp.vec3)  # particle velocity\n    particle_F: wp.array(dtype=wp.mat33)  # particle elastic deformation gradient\n    particle_init_cov: wp.array(dtype=float)  # initial covariance matrix\n    particle_cov: wp.array(dtype=float)  # current covariance matrix\n    particle_F_trial: wp.array(\n        dtype=wp.mat33\n    )  # apply return mapping on this to obtain elastic def grad\n    particle_R: wp.array(dtype=wp.mat33)  # rotation matrix\n    particle_stress: wp.array(dtype=wp.mat33)  # Kirchoff stress, elastic stress\n    particle_C: wp.array(dtype=wp.mat33)\n    particle_vol: wp.array(dtype=float)  # current volume\n    particle_mass: wp.array(dtype=float)  # mass\n    particle_density: wp.array(dtype=float)  # density\n    particle_Jp: wp.array(dtype=float)\n\n    particle_selection: wp.array(\n        dtype=int\n    )  # only particle_selection[p] = 0 will be simulated\n\n    # grid\n    grid_m: wp.array(dtype=float, ndim=3)\n    grid_v_in: wp.array(dtype=wp.vec3, ndim=3)  # grid node momentum/velocity\n    grid_v_out: wp.array(\n        dtype=wp.vec3, ndim=3\n    )  # grid node momentum/velocity, after grid update\n\n    def init(\n        self,\n        shape: Union[Sequence[int], int],\n        device: wp.context.Devicelike = None,\n        requires_grad=False,\n    ) -> None:\n        # shape default is int. number of particles\n        self.particle_x = wp.empty(\n            shape, dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.particle_v = wp.zeros(\n            shape, dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.particle_F = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_init_cov = wp.zeros(\n            shape * 6, dtype=float, device=device, requires_grad=False\n        )\n        self.particle_cov = wp.zeros(\n            shape * 6, dtype=float, device=device, requires_grad=False\n        )\n\n        self.particle_F_trial = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_R = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_stress = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_C = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_vol = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=False\n        )\n        self.particle_mass = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=False\n        )\n        self.particle_density = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=False\n        )\n        self.particle_Jp = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_selection = wp.zeros(\n            shape, dtype=int, device=device, requires_grad=requires_grad\n        )\n\n        # grid: will init later\n        self.grid_m = wp.empty(\n            (10, 10, 10), dtype=float, device=device, requires_grad=requires_grad\n        )\n        self.grid_v_in = wp.zeros(\n            (10, 10, 10), dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.grid_v_out = wp.zeros(\n            (10, 10, 10), dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n\n    def init_grid(\n        self, grid_res: int, device: wp.context.Devicelike = None, requires_grad=False\n    ):\n        self.grid_m = wp.zeros(\n            (grid_res, grid_res, grid_res),\n            dtype=float,\n            device=device,\n            requires_grad=False,\n        )\n        self.grid_v_in = wp.zeros(\n            (grid_res, grid_res, grid_res),\n            dtype=wp.vec3,\n            device=device,\n            requires_grad=requires_grad,\n        )\n        self.grid_v_out = wp.zeros(\n            (grid_res, grid_res, grid_res),\n            dtype=wp.vec3,\n            device=device,\n            requires_grad=requires_grad,\n        )\n\n    def from_torch(\n        self,\n        tensor_x: Tensor,\n        tensor_volume: Tensor,\n        tensor_cov: Optional[Tensor] = None,\n        tensor_velocity: Optional[Tensor] = None,\n        n_grid: int = 100,\n        grid_lim=1.0,\n        device=\"cuda:0\",\n        requires_grad=True,\n    ):\n        num_dim, n_particles = tensor_x.shape[1], tensor_x.shape[0]\n        assert tensor_x.shape[0] == tensor_volume.shape[0]\n        # assert tensor_x.shape[0] == tensor_cov.reshape(-1, 6).shape[0]\n        self.init_grid(grid_res=n_grid, device=device, requires_grad=requires_grad)\n\n        if tensor_x is not None:\n            self.particle_x = from_torch_safe(\n                tensor_x.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_volume is not None:\n            print(self.particle_vol.shape, tensor_volume.shape)\n            volume_numpy = tensor_volume.detach().cpu().numpy()\n            self.particle_vol = wp.from_numpy(\n                volume_numpy, dtype=float, device=device, requires_grad=False\n            )\n\n        if tensor_cov is not None:\n            cov_numpy = tensor_cov.reshape(-1).detach().clone().cpu().numpy()\n            self.particle_cov = wp.from_numpy(\n                cov_numpy, dtype=float, device=device, requires_grad=False\n            )\n            self.particle_init_cov = self.particle_cov\n\n        if tensor_velocity is not None:\n            self.particle_v = from_torch_safe(\n                tensor_velocity.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        # initial deformation gradient is set to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F_trial],\n            device=device,\n        )\n        # initial trial deformation gradient is set to identity\n\n        print(\"Particles initialized from torch data.\")\n        print(\"Total particles: \", n_particles)\n\n    def reset_state(\n        self,\n        tensor_x: Tensor,\n        tensor_cov: Optional[Tensor] = None,\n        tensor_velocity: Optional[Tensor] = None,\n        device=\"cuda:0\",\n        requires_grad=True,\n    ):\n        # reset p_c, p_v, p_C, p_F_trial\n        num_dim, n_particles = tensor_x.shape[1], tensor_x.shape[0]\n\n        # assert tensor_x.shape[0] == tensor_cov.reshape(-1, 6).shape[0]\n\n        if tensor_x is not None:\n            self.particle_x = from_torch_safe(\n                tensor_x.contiguous().detach(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_cov is not None:\n            cov_numpy = tensor_cov.reshape(-1).detach().clone().cpu().numpy()\n            self.particle_cov = wp.from_numpy(\n                cov_numpy, dtype=float, device=device, requires_grad=False\n            )\n            self.particle_cov = self.particle_init_cov\n\n        if tensor_velocity is not None:\n            self.particle_v = from_torch_safe(\n                tensor_velocity.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        # initial deformation gradient is set to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F_trial],\n            device=device,\n        )\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=set_mat33_to_zero,\n            dim=n_particles,\n            inputs=[self.particle_C],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=set_mat33_to_zero,\n            dim=n_particles,\n            inputs=[self.particle_stress],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=set_mat33_to_zero,\n            dim=n_particles,\n            inputs=[self.particle_R],\n            device=device,\n        )\n    def set_require_grad(self, requires_grad=True):\n        self.particle_x.requires_grad = requires_grad\n        self.particle_v.requires_grad = requires_grad\n        self.particle_F.requires_grad = requires_grad\n        self.particle_F_trial.requires_grad = requires_grad\n        self.particle_stress.requires_grad = requires_grad\n\n        self.grid_v_out.requires_grad = requires_grad\n        self.grid_v_in.requires_grad = requires_grad\n\n\n@wp.struct\nclass ParticleStateStruct(object):\n    ###### essential #####\n    # particle\n    particle_x: wp.array(dtype=wp.vec3)  # current position\n    particle_v: wp.array(dtype=wp.vec3)  # particle velocity\n    particle_F: wp.array(dtype=wp.mat33)  # particle elastic deformation gradient\n    particle_init_cov: wp.array(dtype=float)  # initial covariance matrix\n    particle_cov: wp.array(dtype=float)  # current covariance matrix\n    particle_F_trial: wp.array(\n        dtype=wp.mat33\n    )  # apply return mapping on this to obtain elastic def grad\n    particle_C: wp.array(dtype=wp.mat33)\n    particle_vol: wp.array(dtype=float)  # current volume\n\n    particle_selection: wp.array(\n        dtype=int\n    )  # only particle_selection[p] = 0 will be simulated\n\n    def init(\n        self,\n        shape: Union[Sequence[int], int],\n        device: wp.context.Devicelike = None,\n        requires_grad=False,\n    ) -> None:\n        # shape default is int. number of particles\n        self.particle_x = wp.empty(\n            shape, dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.particle_v = wp.zeros(\n            shape, dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.particle_F = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_init_cov = wp.zeros(\n            shape * 6, dtype=float, device=device, requires_grad=requires_grad\n        )\n        self.particle_cov = wp.zeros(\n            shape * 6, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_F_trial = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_stress = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_C = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_vol = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_selection = wp.zeros(\n            shape, dtype=int, device=device, requires_grad=requires_grad\n        )\n\n    def from_torch(\n        self,\n        tensor_x: Tensor,\n        tensor_volume: Tensor,\n        tensor_cov: Optional[Tensor] = None,\n        tensor_velocity: Optional[Tensor] = None,\n        n_grid: int = 100,\n        grid_lim=1.0,\n        device=\"cuda:0\",\n        requires_grad=True,\n    ):\n        num_dim, n_particles = tensor_x.shape[1], tensor_x.shape[0]\n        assert tensor_x.shape[0] == tensor_volume.shape[0]\n        # assert tensor_x.shape[0] == tensor_cov.reshape(-1, 6).shape[0]\n\n        if tensor_x is not None:\n            # print(self.particle_x.shape, tensor_x.shape)\n            # print(tensor_x.grad)\n            if tensor_x.requires_grad:\n                # tensor_x.grad = torch.zeros_like(tensor_x, requires_grad=False)\n                raise RuntimeError(\"tensor_x requires grad\")\n\n            # x_numpy = tensor_x.detach().clone().cpu().numpy()\n            # self.particle_x = wp.from_numpy(x_numpy, dtype=wp.vec3, requires_grad=True, device=device)\n            self.particle_x = from_torch_safe(\n                tensor_x.contiguous().detach(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_volume is not None:\n            print(self.particle_vol.shape, tensor_volume.shape)\n            volume_numpy = tensor_volume.detach().cpu().numpy()\n            # self.particle_vol = wp.from_torch(tensor_volume.contiguous(), dtype=float, device=device, requires_grad=requires_grad)\n            # self.particle_vol = wp.from_torch(tensor_volume.contiguous(), dtype=float, requires_grad=False)\n            self.particle_vol = wp.from_numpy(\n                volume_numpy, dtype=float, device=device, requires_grad=False\n            )\n\n        if tensor_cov is not None:\n            cov_numpy = tensor_cov.reshape(-1).detach().clone().cpu().numpy()\n            self.particle_cov = wp.from_numpy(\n                cov_numpy, dtype=float, device=device, requires_grad=False\n            )\n            self.particle_cov = self.particle_init_cov\n\n        if tensor_velocity is not None:\n            if tensor_velocity.requires_grad:\n                tensor_velocity.grad = torch.zeros_like(\n                    tensor_velocity, requires_grad=False\n                )\n                raise RuntimeError(\"tensor_x requires grad\")\n            self.particle_v = from_torch_safe(\n                tensor_velocity.contiguous().detach(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        # initial deformation gradient is set to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F_trial],\n            device=device,\n        )\n        # initial trial deformation gradient is set to identity\n\n        print(\"Particles initialized from torch data.\")\n        print(\"Total particles: \", n_particles)\n\n    def set_require_grad(self, requires_grad=True):\n        self.particle_x.requires_grad = requires_grad\n        self.particle_v.requires_grad = requires_grad\n        self.particle_F.requires_grad = requires_grad\n        self.particle_F_trial.requires_grad = requires_grad\n        self.particle_stress.requires_grad = requires_grad\n\n\n@wp.struct\nclass MPMModelStruct(object):\n    ####### essential #######\n    grid_lim: float\n    n_particles: int\n    n_grid: int\n    dx: float\n    inv_dx: float\n    grid_dim_x: int\n    grid_dim_y: int\n    grid_dim_z: int\n    mu: wp.array(dtype=float)\n    lam: wp.array(dtype=float)\n    E: wp.array(dtype=float)\n    nu: wp.array(dtype=float)\n    material: int\n\n    ######## for plasticity ####\n    yield_stress: wp.array(dtype=float)\n    friction_angle: float\n    alpha: float\n    gravitational_accelaration: wp.vec3\n    hardening: float\n    xi: float\n    plastic_viscosity: float\n    softening: float\n\n    ####### for damping\n    rpic_damping: float\n    grid_v_damping_scale: float\n\n    ####### for PhysGaussian: covariance\n    update_cov_with_F: int\n\n    def init(\n        self,\n        shape: Union[Sequence[int], int],\n        device: wp.context.Devicelike = None,\n        requires_grad=False,\n    ) -> None:\n        self.E = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )  # young's modulus\n        self.nu = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )  # poisson's ratio\n\n        self.mu = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n        self.lam = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n        self.yield_stress = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n    def finalize_mu_lam(self, n_particles, device=\"cuda:0\"):\n        wp.launch(\n            kernel=compute_mu_lam_from_E_nu_clean,\n            dim=n_particles,\n            inputs=[self.mu, self.lam, self.E, self.nu],\n            device=device,\n        )\n\n    def init_other_params(self, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.grid_lim = grid_lim\n        self.n_grid = n_grid\n        self.grid_dim_x = n_grid\n        self.grid_dim_y = n_grid\n        self.grid_dim_z = n_grid\n        (\n            self.dx,\n            self.inv_dx,\n        ) = self.grid_lim / self.n_grid, float(\n            n_grid / grid_lim\n        )  # [0-1]?\n\n        self.update_cov_with_F = False\n\n        # material is used to switch between different elastoplastic models. 0 is jelly\n        self.material = 0\n\n        self.plastic_viscosity = 0.0\n        self.softening = 0.1\n        self.friction_angle = 25.0\n        sin_phi = wp.sin(self.friction_angle / 180.0 * 3.14159265)\n        self.alpha = wp.sqrt(2.0 / 3.0) * 2.0 * sin_phi / (3.0 - sin_phi)\n\n        self.gravitational_accelaration = wp.vec3(0.0, 0.0, 0.0)\n\n        self.rpic_damping = 0.0  # 0.0 if no damping (apic). -1 if pic\n\n        self.grid_v_damping_scale = 1.1  # globally applied\n\n    def from_torch(\n        self, tensor_E: Tensor, tensor_nu: Tensor, device=\"cuda:0\", requires_grad=False\n    ):\n        self.E = wp.from_torch(tensor_E.contiguous(), requires_grad=requires_grad)\n        self.nu = wp.from_torch(tensor_nu.contiguous(), requires_grad=requires_grad)\n        n_particles = tensor_E.shape[0]\n        self.finalize_mu_lam(n_particles=n_particles, device=device)\n\n    def set_require_grad(self, requires_grad=True):\n        self.E.requires_grad = requires_grad\n        self.nu.requires_grad = requires_grad\n        self.mu.requires_grad = requires_grad\n        self.lam.requires_grad = requires_grad\n\n\n# for various boundary conditions\n@wp.struct\nclass Dirichlet_collider:\n    point: wp.vec3\n    normal: wp.vec3\n    direction: wp.vec3\n\n    start_time: float\n    end_time: float\n\n    friction: float\n    surface_type: int\n\n    velocity: wp.vec3\n\n    threshold: float\n    reset: int\n    index: int\n\n    x_unit: wp.vec3\n    y_unit: wp.vec3\n    radius: float\n    v_scale: float\n    width: float\n    height: float\n    length: float\n    R: float\n\n    size: wp.vec3\n\n    horizontal_axis_1: wp.vec3\n    horizontal_axis_2: wp.vec3\n    half_height_and_radius: wp.vec2\n\n\n@wp.struct\nclass Impulse_modifier:\n    # this needs to be changed for each different BC!\n    point: wp.vec3\n    normal: wp.vec3\n    start_time: float\n    end_time: float\n    force: wp.vec3\n    forceTimesDt: wp.vec3\n    numsteps: int\n\n    point: wp.vec3\n    size: wp.vec3\n    mask: wp.array(dtype=int)\n\n\n@wp.struct\nclass MPMtailoredStruct:\n    # this needs to be changed for each different BC!\n    point: wp.vec3\n    normal: wp.vec3\n    start_time: float\n    end_time: float\n    friction: float\n    surface_type: int\n    velocity: wp.vec3\n    threshold: float\n    reset: int\n\n    point_rotate: wp.vec3\n    normal_rotate: wp.vec3\n    x_unit: wp.vec3\n    y_unit: wp.vec3\n    radius: float\n    v_scale: float\n    width: float\n    point_plane: wp.vec3\n    normal_plane: wp.vec3\n    velocity_plane: wp.vec3\n    threshold_plane: float\n\n\n@wp.struct\nclass MaterialParamsModifier:\n    point: wp.vec3\n    size: wp.vec3\n    E: float\n    nu: float\n    density: float\n\n\n@wp.struct\nclass ParticleVelocityModifier:\n    point: wp.vec3\n    normal: wp.vec3\n    half_height_and_radius: wp.vec2\n    rotation_scale: float\n    translation_scale: float\n\n    size: wp.vec3\n\n    horizontal_axis_1: wp.vec3\n    horizontal_axis_2: wp.vec3\n\n    start_time: float\n\n    end_time: float\n\n    velocity: wp.vec3\n\n    mask: wp.array(dtype=int)\n\n\n@wp.kernel\ndef compute_mu_lam_from_E_nu_clean(\n    mu: wp.array(dtype=float),\n    lam: wp.array(dtype=float),\n    E: wp.array(dtype=float),\n    nu: wp.array(dtype=float),\n):\n    p = wp.tid()\n    mu[p] = E[p] / (2.0 * (1.0 + nu[p]))\n    lam[p] = E[p] * nu[p] / ((1.0 + nu[p]) * (1.0 - 2.0 * nu[p]))\n\n\n@wp.kernel\ndef set_vec3_to_zero(target_array: wp.array(dtype=wp.vec3)):\n    tid = wp.tid()\n    target_array[tid] = wp.vec3(0.0, 0.0, 0.0)\n\n\n@wp.kernel\ndef set_mat33_to_identity(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n\n\n@wp.kernel\ndef set_mat33_to_zero(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n\n@wp.kernel\ndef add_identity_to_mat33(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.add(\n        target_array[tid], wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    )\n\n\n@wp.kernel\ndef subtract_identity_to_mat33(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.sub(\n        target_array[tid], wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    )\n\n\n@wp.kernel\ndef add_vec3_to_vec3(\n    first_array: wp.array(dtype=wp.vec3), second_array: wp.array(dtype=wp.vec3)\n):\n    tid = wp.tid()\n    first_array[tid] = wp.add(first_array[tid], second_array[tid])\n\n\n@wp.kernel\ndef set_value_to_float_array(target_array: wp.array(dtype=float), value: float):\n    tid = wp.tid()\n    target_array[tid] = value\n\n\n@wp.kernel\ndef set_warpvalue_to_float_array(\n    target_array: wp.array(dtype=float), value: warp.types.float32\n):\n    tid = wp.tid()\n    target_array[tid] = value\n\n\n@wp.kernel\ndef get_float_array_product(\n    arrayA: wp.array(dtype=float),\n    arrayB: wp.array(dtype=float),\n    arrayC: wp.array(dtype=float),\n):\n    tid = wp.tid()\n    arrayC[tid] = arrayA[tid] * arrayB[tid]\n\n\ndef torch2warp_quat(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 4\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.quat,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_float(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=warp.types.float32,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_vec3(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 3\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.vec3,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_mat33(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 3\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.mat33,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/engine_utils.py",
    "content": "import numpy as np\nimport h5py\nimport os\nimport sys\nimport warp as wp\nimport torch\n\ndef save_data_at_frame(mpm_solver, dir_name, frame, save_to_ply = True, save_to_h5 = False):\n    os.umask(0)\n    os.makedirs(dir_name, 0o777, exist_ok=True)\n    \n    fullfilename = dir_name + '/sim_' + str(frame).zfill(10) + '.h5'\n\n    if save_to_ply:\n        particle_position_to_ply(mpm_solver, fullfilename[:-2]+'ply')\n    \n    if save_to_h5:\n\n        if os.path.exists(fullfilename): os.remove(fullfilename)\n        newFile = h5py.File(fullfilename, \"w\")\n\n        x_np = mpm_solver.mpm_state.particle_x.numpy().transpose() # x_np has shape (3, n_particles)\n        newFile.create_dataset(\"x\", data=x_np) # position\n\n        currentTime = np.array([mpm_solver.time]).reshape(1,1)\n        newFile.create_dataset(\"time\", data=currentTime) # current time\n\n        f_tensor_np = mpm_solver.mpm_state.particle_F.numpy().reshape(-1,9).transpose() # shape = (9, n_particles)\n        newFile.create_dataset(\"f_tensor\", data=f_tensor_np) # deformation grad\n\n        v_np = mpm_solver.mpm_state.particle_v.numpy().transpose() # v_np has shape (3, n_particles)\n        newFile.create_dataset(\"v\", data=v_np) # particle velocity\n\n        C_np = mpm_solver.mpm_state.particle_C.numpy().reshape(-1,9).transpose() # shape = (9, n_particles)\n        newFile.create_dataset(\"C\", data=C_np) # particle C\n        print(\"save siumlation data at frame \", frame, \" to \", fullfilename)\n\ndef particle_position_to_ply(mpm_solver, filename):\n    # position is (n,3)\n    if os.path.exists(filename):\n        os.remove(filename)\n    position = mpm_solver.mpm_state.particle_x.numpy()\n    num_particles = (position).shape[0]\n    position = position.astype(np.float32)\n    with open(filename, 'wb') as f: # write binary\n        header = f\"\"\"ply\nformat binary_little_endian 1.0\nelement vertex {num_particles}\nproperty float x\nproperty float y\nproperty float z\nend_header\n\"\"\"\n        f.write(str.encode(header))\n        f.write(position.tobytes())\n        print(\"write\", filename)\n\ndef particle_position_tensor_to_ply(position_tensor, filename):\n    # position is (n,3)\n    if os.path.exists(filename):\n        os.remove(filename)\n    position = position_tensor.clone().detach().cpu().numpy()\n    num_particles = (position).shape[0]\n    position = position.astype(np.float32)\n    with open(filename, 'wb') as f: # write binary\n        header = f\"\"\"ply\nformat binary_little_endian 1.0\nelement vertex {num_particles}\nproperty float x\nproperty float y\nproperty float z\nend_header\n\"\"\"\n        f.write(str.encode(header))\n        f.write(position.tobytes())\n        print(\"write\", filename)\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/grad_test.py",
    "content": "import warp as wp\nimport numpy as np\nimport torch\nimport os\nfrom mpm_solver_warp_diff import MPM_Simulator_WARPDiff\nfrom run_gaussian_static import load_gaussians, get_volume\nfrom tqdm import tqdm\nfrom fire import Fire\n\nfrom diff_warp_utils import MPMStateStruct, MPMModelStruct\nfrom warp_rewrite import MyTape\n\nfrom mpm_utils import *\n\n\ndef test(input_dir, output_dir=None, fps=6, device=0):\n    wp.init()\n    wp.config.verify_cuda = True\n\n    device = \"cuda:{}\".format(device)\n\n    gaussian_dict, scale, shift = load_gaussians(input_dir)\n\n    velocity_scaling = 0.5\n    init_velocity = velocity_scaling * gaussian_dict[\"velocity\"]\n    init_position = gaussian_dict[\"position\"]\n    init_cov = gaussian_dict[\"cov\"]\n\n    volume_array_path = os.path.join(input_dir, \"volume_array.npy\")\n    if os.path.exists(volume_array_path):\n        volume_array = np.load(volume_array_path)\n        volume_tensor = torch.from_numpy(volume_array).float().to(device)\n    else:\n        volume_array = get_volume(init_position)\n        np.save(volume_array_path, volume_array)\n        volume_tensor = torch.from_numpy(volume_array).float().to(device)\n\n    tensor_init_pos = torch.from_numpy(init_position).float().to(device)\n    tensor_init_cov = torch.from_numpy(init_cov).float().to(device)\n    tensor_init_velocity = torch.from_numpy(init_velocity).float().to(device)\n\n    # set boundary conditions\n    static_center_point = (\n        torch.from_numpy(gaussian_dict[\"satic_center_point\"]).float().to(device)\n    )\n    max_static_offset = (\n        torch.from_numpy(gaussian_dict[\"max_static_offset\"]).float().to(device)\n    )\n    velocity = torch.zeros_like(static_center_point)\n    # mpm_solver.enforce_particle_velocity_translation(static_center_point, max_static_offset, velocity,\n    #                                                  start_time=0, end_time=1000, device=device)\n\n    material_params = {\n        \"E\": 2.0,  # 0.1-200 MPa\n        \"nu\": 0.1,  # > 0.35\n        \"material\": \"jelly\",\n        # \"material\": \"metal\",\n        # \"friction_angle\": 25,\n        \"g\": [0.0, 0.0, 0],\n        \"density\": 0.02,  # kg / m^3\n    }\n\n    n_particles = tensor_init_pos.shape[0]\n    mpm_state = MPMStateStruct()\n\n    mpm_state.init(init_position.shape[0], device=device, requires_grad=True)\n    mpm_state.from_torch(\n        tensor_init_pos,\n        volume_tensor,\n        tensor_init_cov,\n        tensor_init_velocity,\n        device=device,\n        requires_grad=True,\n        n_grid=100,\n        grid_lim=1.0,\n    )\n    mpm_state.set_require_grad(True)\n\n    next_mpm_state = MPMStateStruct()\n    next_mpm_state.init(init_position.shape[0], device=device, requires_grad=True)\n    next_mpm_state.from_torch(\n        tensor_init_pos.clone(),\n        volume_tensor.clone(),\n        tensor_init_cov.clone(),\n        tensor_init_velocity.clone(),\n        device=device,\n        requires_grad=True,\n        n_grid=100,\n        grid_lim=1.0,\n    )\n    next_mpm_state.set_require_grad(True)\n    # mpm_state.grid_v_out = wp.from_numpy(\n    #     np.ones((100, 100, 100, 3)), dtype=wp.vec3, requires_grad=True, device=device\n    # )\n\n    # tensor_init_pos.requires_grad = True\n    # tensor_init_cov.requires_grad = False\n    # tensor_init_velocity.requires_grad = True\n\n    # mpm_state.particle_x = wp.from_torch(tensor_init_pos, requires_grad=True)\n    # mpm_state.particle_x = wp.from_numpy(init_position, dtype=wp.vec3, requires_grad=True, device=device)\n    # mpm_state.particle_v = wp.from_numpy(init_velocity, dtype=wp.vec3, requires_grad=True, device=device)\n    # mpm_state.particle_vol = wp.from_numpy(volume_array, dtype=float, requires_grad=False, device=device)\n\n    mpm_model = MPMModelStruct()\n    mpm_model.init(n_particles, device=device, requires_grad=True)\n    mpm_model.init_other_params(n_grid=100, grid_lim=1.0, device=device)\n\n    E_tensor = (torch.ones(velocity.shape[0]) * 2.0).contiguous().to(device)\n    nu_tensor = (torch.ones(velocity.shape[0]) * 0.1).contiguous().to(device)\n    # E_warp = wp.from_torch(E_tensor, requires_grad=True)\n    # nu_warp = wp.from_torch(nu_tensor, requires_grad=True)\n\n    mpm_model.from_torch(E_tensor, nu_tensor, device=device, requires_grad=True)\n\n    total_time = 0.1\n    time_step = 0.01\n    total_iters = int(total_time / time_step)\n    total_iters = 3\n    loss = torch.zeros(1, device=device)\n    loss = wp.from_torch(loss, requires_grad=True)\n\n    dt = time_step\n    tape = MyTape()  # wp.Tape()\n\n    with tape:\n        # for k in tqdm(range(1, total_iters)):\n        k = 1\n        # mpm_solver.p2g2p(k, time_step, device=device)\n        for i in range(3):\n            wp.launch(\n                kernel=compute_stress_from_F_trial,\n                dim=n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )\n\n            wp.launch(\n                kernel=p2g_apic_with_stress,\n                dim=n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # apply p2g'\n\n            wp.launch(\n                kernel=grid_normalization_and_gravity,\n                dim=(100),\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )\n\n            wp.launch(\n                kernel=g2p_test,\n                dim=n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # x, v, C, F_trial are updated\n\n        wp.launch(\n            position_loss_kernel,\n            dim=n_particles,\n            inputs=[mpm_state, loss],\n            device=device,\n        )\n\n    print(loss, \"pre backward\")\n\n    tape.backward(loss)  # 75120.86\n\n    print(loss)\n\n    v_grad = mpm_state.particle_v.grad\n    x_grad = mpm_state.particle_x.grad\n    grid_v_grad = mpm_state.grid_v_out.grad\n    grid_v_in_grad = mpm_state.grid_v_in.grad\n    print(x_grad)\n    from IPython import embed\n\n    embed()\n\n\n@wp.kernel\ndef position_loss_kernel(mpm_state: MPMStateStruct, loss: wp.array(dtype=float)):\n    tid = wp.tid()\n\n    pos = mpm_state.particle_x[tid]\n    wp.atomic_add(loss, 0, pos[0] + pos[1] + pos[2])\n    # wp.atomic_add(loss, 0, mpm_state.particle_x[tid][0])\n\n\n@wp.kernel\ndef g2p_test(state: MPMStateStruct, model: MPMModelStruct, dt: float):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n\n        # new_v = wp.vec3(0.0, 0.0, 0.0)\n        # new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        new_v = wp.vec3(0.0)\n        new_C = wp.mat33(new_v, new_v, new_v)\n        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    grid_v = state.grid_v_out[ix, iy, iz]\n                    new_v = new_v + grid_v * weight\n                    new_C = new_C + wp.outer(grid_v, dpos) * (\n                        weight * model.inv_dx * 4.0\n                    )\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n                    new_F = new_F + wp.outer(grid_v, dweight)\n\n        state.particle_v[p] = new_v\n        # wp.atomic_add(state.particle_x, p, dt * state.particle_v[p])\n        wp.atomic_add(state.particle_x, p, dt * new_v)\n\n        # might add clip here https://github.com/PingchuanMa/NCLaw/blob/main/nclaw/sim/mpm.py\n        state.particle_C[p] = new_C\n        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n        F_tmp = (I33 + new_F * dt) * state.particle_F[p]\n        state.particle_F_trial[p] = F_tmp\n\n        # next_state.particle_v[p] = new_v\n        # next_state.particle_C[p] = new_C\n        # next_state.particle_F_trial[p] = F_tmp\n        # wp.atomic_add(next_state.particle_x, p, dt * new_v)\n\n        if model.update_cov_with_F:\n            pass\n            # update_cov(next_state, p, new_F, dt)\n\n\nif __name__ == \"__main__\":\n    Fire(test)\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/mpm_solver_warp.py",
    "content": "import sys\nimport os\n\nimport warp as wp\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)))\nfrom engine_utils import *\nfrom warp_utils import *\nfrom mpm_utils import *\n\n\nclass MPM_Simulator_WARP:\n    def __init__(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.initialize(n_particles, n_grid, grid_lim, device=device)\n        self.time_profile = {}\n\n    def initialize(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.n_particles = n_particles\n\n        self.mpm_model = MPMModelStruct()\n        # domain will be [0,grid_lim]*[0,grid_lim]*[0,grid_lim] !!!\n        # domain will be [0,grid_lim]*[0,grid_lim]*[0,grid_lim] !!!\n        # domain will be [0,grid_lim]*[0,grid_lim]*[0,grid_lim] !!!\n        self.mpm_model.grid_lim = grid_lim\n        self.mpm_model.n_grid = n_grid\n        self.mpm_model.grid_dim_x = self.mpm_model.n_grid\n        self.mpm_model.grid_dim_y = self.mpm_model.n_grid\n        self.mpm_model.grid_dim_z = self.mpm_model.n_grid\n        (\n            self.mpm_model.dx,\n            self.mpm_model.inv_dx,\n        ) = self.mpm_model.grid_lim / self.mpm_model.n_grid, float(\n            self.mpm_model.n_grid / self.mpm_model.grid_lim\n        )\n\n        self.mpm_model.E = wp.zeros(shape=n_particles, dtype=float, device=device)\n        self.mpm_model.nu = wp.zeros(shape=n_particles, dtype=float, device=device)\n        self.mpm_model.mu = wp.zeros(shape=n_particles, dtype=float, device=device)\n        self.mpm_model.lam = wp.zeros(shape=n_particles, dtype=float, device=device)\n\n        self.mpm_model.update_cov_with_F = False\n\n        # material is used to switch between different elastoplastic models. 0 is jelly\n        self.mpm_model.material = 0\n\n        self.mpm_model.plastic_viscosity = 0.0\n        self.mpm_model.softening = 0.1\n        self.mpm_model.yield_stress = wp.zeros(\n            shape=n_particles, dtype=float, device=device\n        )\n        self.mpm_model.friction_angle = 25.0\n        sin_phi = wp.sin(self.mpm_model.friction_angle / 180.0 * 3.14159265)\n        self.mpm_model.alpha = wp.sqrt(2.0 / 3.0) * 2.0 * sin_phi / (3.0 - sin_phi)\n\n        self.mpm_model.gravitational_accelaration = wp.vec3(0.0, 0.0, 0.0)\n\n        self.mpm_model.rpic_damping = 0.0  # 0.0 if no damping (apic). -1 if pic\n\n        self.mpm_model.grid_v_damping_scale = 1.1  # globally applied\n\n        self.mpm_state = MPMStateStruct()\n\n        self.mpm_state.particle_x = wp.empty(\n            shape=n_particles, dtype=wp.vec3, device=device\n        )  # current position\n\n        self.mpm_state.particle_v = wp.zeros(\n            shape=n_particles, dtype=wp.vec3, device=device\n        )  # particle velocity\n\n        self.mpm_state.particle_F = wp.zeros(\n            shape=n_particles, dtype=wp.mat33, device=device\n        )  # particle F elastic\n\n        self.mpm_state.particle_R = wp.zeros(\n            shape=n_particles, dtype=wp.mat33, device=device\n        )  # particle R rotation\n\n        self.mpm_state.particle_init_cov = wp.zeros(\n            shape=n_particles * 6, dtype=float, device=device\n        )  # initial covariance matrix\n\n        self.mpm_state.particle_cov = wp.zeros(\n            shape=n_particles * 6, dtype=float, device=device\n        )  # current covariance matrix\n\n        self.mpm_state.particle_F_trial = wp.zeros(\n            shape=n_particles, dtype=wp.mat33, device=device\n        )  # apply return mapping will yield\n\n        self.mpm_state.particle_stress = wp.zeros(\n            shape=n_particles, dtype=wp.mat33, device=device\n        )\n\n        self.mpm_state.particle_vol = wp.zeros(\n            shape=n_particles, dtype=float, device=device\n        )  # particle volume\n        self.mpm_state.particle_mass = wp.zeros(\n            shape=n_particles, dtype=float, device=device\n        )  # particle mass\n        self.mpm_state.particle_density = wp.zeros(\n            shape=n_particles, dtype=float, device=device\n        )\n        self.mpm_state.particle_C = wp.zeros(\n            shape=n_particles, dtype=wp.mat33, device=device\n        )\n        self.mpm_state.particle_Jp = wp.zeros(\n            shape=n_particles, dtype=float, device=device\n        )\n\n        self.mpm_state.particle_selection = wp.zeros(\n            shape=n_particles, dtype=int, device=device\n        )\n\n        self.mpm_state.grid_m = wp.zeros(\n            shape=(self.mpm_model.n_grid, self.mpm_model.n_grid, self.mpm_model.n_grid),\n            dtype=float,\n            device=device,\n        )\n        self.mpm_state.grid_v_in = wp.zeros(\n            shape=(self.mpm_model.n_grid, self.mpm_model.n_grid, self.mpm_model.n_grid),\n            dtype=wp.vec3,\n            device=device,\n        )\n        self.mpm_state.grid_v_out = wp.zeros(\n            shape=(self.mpm_model.n_grid, self.mpm_model.n_grid, self.mpm_model.n_grid),\n            dtype=wp.vec3,\n            device=device,\n        )\n\n        self.time = 0.0\n\n        self.grid_postprocess = []\n        self.collider_params = []\n        self.modify_bc = []\n\n        self.tailored_struct_for_bc = MPMtailoredStruct()\n        self.pre_p2g_operations = []\n        self.impulse_params = []\n\n        self.particle_velocity_modifiers = []\n        self.particle_velocity_modifier_params = []\n\n    # the h5 file should store particle initial position and volume.\n    def load_from_sampling(\n        self, sampling_h5, n_grid=100, grid_lim=1.0, device=\"cuda:0\"\n    ):\n        if not os.path.exists(sampling_h5):\n            print(\"h5 file cannot be found at \", os.getcwd() + sampling_h5)\n            exit()\n\n        h5file = h5py.File(sampling_h5, \"r\")\n        x, particle_volume = h5file[\"x\"], h5file[\"particle_volume\"]\n\n        x = x[()].transpose()  # np vector of x # shape now is (n_particles, dim)\n\n        self.dim, self.n_particles = x.shape[1], x.shape[0]\n\n        self.initialize(self.n_particles, n_grid, grid_lim, device=device)\n\n        print(\n            \"Sampling particles are loaded from h5 file. Simulator is re-initialized for the correct n_particles\"\n        )\n        particle_volume = np.squeeze(particle_volume, 0)\n\n        self.mpm_state.particle_x = wp.from_numpy(\n            x, dtype=wp.vec3, device=device\n        )  # initialize warp array from np\n\n        # initial velocity is default to zero\n        wp.launch(\n            kernel=set_vec3_to_zero,\n            dim=self.n_particles,\n            inputs=[self.mpm_state.particle_v],\n            device=device,\n        )\n        # initial velocity is default to zero\n\n        # initial deformation gradient is set to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=self.n_particles,\n            inputs=[self.mpm_state.particle_F_trial],\n            device=device,\n        )\n        # initial deformation gradient is set to identity\n\n        self.mpm_state.particle_vol = wp.from_numpy(\n            particle_volume, dtype=float, device=device\n        )\n\n        print(\"Particles initialized from sampling file.\")\n        print(\"Total particles: \", self.n_particles)\n\n    # shape of tensor_x is (n, 3); shape of tensor_volume is (n,)\n    def load_initial_data_from_torch(\n        self,\n        tensor_x,\n        tensor_volume,\n        tensor_cov=None,\n        tensor_velocity=None,\n        n_grid=100,\n        grid_lim=1.0,\n        device=\"cuda:0\",\n    ):\n        self.dim, self.n_particles = tensor_x.shape[1], tensor_x.shape[0]\n        assert tensor_x.shape[0] == tensor_volume.shape[0]\n        # assert tensor_x.shape[0] == tensor_cov.reshape(-1, 6).shape[0]\n        self.initialize(self.n_particles, n_grid, grid_lim, device=device)\n\n        self.import_particle_x_from_torch(tensor_x, device)\n        self.mpm_state.particle_vol = wp.from_numpy(\n            tensor_volume.detach().clone().cpu().numpy(), dtype=float, device=device\n        )\n        if tensor_cov is not None:\n            self.mpm_state.particle_init_cov = wp.from_numpy(\n                tensor_cov.reshape(-1).detach().clone().cpu().numpy(),\n                dtype=float,\n                device=device,\n            )\n\n            if self.mpm_model.update_cov_with_F:\n                self.mpm_state.particle_cov = self.mpm_state.particle_init_cov\n\n        # initial velocity is default to zero\n        wp.launch(\n            kernel=set_vec3_to_zero,\n            dim=self.n_particles,\n            inputs=[self.mpm_state.particle_v],\n            device=device,\n        )\n        if tensor_velocity is not None:\n            warp_velocity = torch2warp_vec3(\n                tensor_velocity.detach().clone(), dvc=device\n            )\n            self.mpm_state.particle_v = warp_velocity\n\n        # initial deformation gradient is set to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=self.n_particles,\n            inputs=[self.mpm_state.particle_F_trial],\n            device=device,\n        )\n        # initial trial deformation gradient is set to identity\n\n        print(\"Particles initialized from torch data.\")\n        print(\"Total particles: \", self.n_particles)\n\n    # must give density. mass will be updated as density * volume\n    def set_parameters(self, device=\"cuda:0\", **kwargs):\n        self.set_parameters_dict(device, kwargs)\n\n    def set_parameters_dict(self, kwargs={}, device=\"cuda:0\"):\n        if \"material\" in kwargs:\n            if kwargs[\"material\"] == \"jelly\":\n                self.mpm_model.material = 0\n            elif kwargs[\"material\"] == \"metal\":\n                self.mpm_model.material = 1\n            elif kwargs[\"material\"] == \"sand\":\n                self.mpm_model.material = 2\n            elif kwargs[\"material\"] == \"foam\":\n                self.mpm_model.material = 3\n            elif kwargs[\"material\"] == \"snow\":\n                self.mpm_model.material = 4\n            elif kwargs[\"material\"] == \"plasticine\":\n                self.mpm_model.material = 5\n            else:\n                raise TypeError(\"Undefined material type\")\n\n        if \"grid_lim\" in kwargs:\n            self.mpm_model.grid_lim = kwargs[\"grid_lim\"]\n        if \"n_grid\" in kwargs:\n            self.mpm_model.n_grid = kwargs[\"n_grid\"]\n        self.mpm_model.grid_dim_x = self.mpm_model.n_grid\n        self.mpm_model.grid_dim_y = self.mpm_model.n_grid\n        self.mpm_model.grid_dim_z = self.mpm_model.n_grid\n        (\n            self.mpm_model.dx,\n            self.mpm_model.inv_dx,\n        ) = self.mpm_model.grid_lim / self.mpm_model.n_grid, float(\n            self.mpm_model.n_grid / self.mpm_model.grid_lim\n        )\n        self.mpm_state.grid_m = wp.zeros(\n            shape=(self.mpm_model.n_grid, self.mpm_model.n_grid, self.mpm_model.n_grid),\n            dtype=float,\n            device=device,\n        )\n        self.mpm_state.grid_v_in = wp.zeros(\n            shape=(self.mpm_model.n_grid, self.mpm_model.n_grid, self.mpm_model.n_grid),\n            dtype=wp.vec3,\n            device=device,\n        )\n        self.mpm_state.grid_v_out = wp.zeros(\n            shape=(self.mpm_model.n_grid, self.mpm_model.n_grid, self.mpm_model.n_grid),\n            dtype=wp.vec3,\n            device=device,\n        )\n\n        if \"E\" in kwargs:\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[self.mpm_model.E, kwargs[\"E\"]],\n                device=device,\n            )\n        if \"nu\" in kwargs:\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[self.mpm_model.nu, kwargs[\"nu\"]],\n                device=device,\n            )\n        if \"yield_stress\" in kwargs:\n            val = kwargs[\"yield_stress\"]\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[self.mpm_model.yield_stress, val],\n                device=device,\n            )\n        if \"hardening\" in kwargs:\n            self.mpm_model.hardening = kwargs[\"hardening\"]\n        if \"xi\" in kwargs:\n            self.mpm_model.xi = kwargs[\"xi\"]\n        if \"friction_angle\" in kwargs:\n            self.mpm_model.friction_angle = kwargs[\"friction_angle\"]\n            sin_phi = wp.sin(self.mpm_model.friction_angle / 180.0 * 3.14159265)\n            self.mpm_model.alpha = wp.sqrt(2.0 / 3.0) * 2.0 * sin_phi / (3.0 - sin_phi)\n\n        if \"g\" in kwargs:\n            self.mpm_model.gravitational_accelaration = wp.vec3(\n                kwargs[\"g\"][0], kwargs[\"g\"][1], kwargs[\"g\"][2]\n            )\n\n        if \"density\" in kwargs:\n            density_value = kwargs[\"density\"]\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[self.mpm_state.particle_density, density_value],\n                device=device,\n            )\n            wp.launch(\n                kernel=get_float_array_product,\n                dim=self.n_particles,\n                inputs=[\n                    self.mpm_state.particle_density,\n                    self.mpm_state.particle_vol,\n                    self.mpm_state.particle_mass,\n                ],\n                device=device,\n            )\n        if \"rpic_damping\" in kwargs:\n            self.mpm_model.rpic_damping = kwargs[\"rpic_damping\"]\n        if \"plastic_viscosity\" in kwargs:\n            self.mpm_model.plastic_viscosity = kwargs[\"plastic_viscosity\"]\n        if \"softening\" in kwargs:\n            self.mpm_model.softening = kwargs[\"softening\"]\n        if \"grid_v_damping_scale\" in kwargs:\n            self.mpm_model.grid_v_damping_scale = kwargs[\"grid_v_damping_scale\"]\n\n        if \"additional_material_params\" in kwargs:\n            for params in kwargs[\"additional_material_params\"]:\n                param_modifier = MaterialParamsModifier()\n                param_modifier.point = wp.vec3(params[\"point\"])\n                param_modifier.size = wp.vec3(params[\"size\"])\n                param_modifier.density = params[\"density\"]\n                param_modifier.E = params[\"E\"]\n                param_modifier.nu = params[\"nu\"]\n                wp.launch(\n                    kernel=apply_additional_params,\n                    dim=self.n_particles,\n                    inputs=[self.mpm_state, self.mpm_model, param_modifier],\n                    device=device,\n                )\n\n            wp.launch(\n                kernel=get_float_array_product,\n                dim=self.n_particles,\n                inputs=[\n                    self.mpm_state.particle_density,\n                    self.mpm_state.particle_vol,\n                    self.mpm_state.particle_mass,\n                ],\n                device=device,\n            )\n\n    def finalize_mu_lam(self, device=\"cuda:0\"):\n        wp.launch(\n            kernel=compute_mu_lam_from_E_nu,\n            dim=self.n_particles,\n            inputs=[self.mpm_state, self.mpm_model],\n            device=device,\n        )\n\n    def p2g2p(self, step, dt, device=\"cuda:0\"):\n        grid_size = (\n            self.mpm_model.grid_dim_x,\n            self.mpm_model.grid_dim_y,\n            self.mpm_model.grid_dim_z,\n        )\n        wp.launch(\n            kernel=zero_grid,\n            dim=(grid_size),\n            inputs=[self.mpm_state, self.mpm_model],\n            device=device,\n        )\n\n        # apply pre-p2g operations on particles\n        for k in range(len(self.pre_p2g_operations)):\n            wp.launch(\n                kernel=self.pre_p2g_operations[k],\n                dim=self.n_particles,\n                inputs=[self.time, dt, self.mpm_state, self.impulse_params[k]],\n                device=device,\n            )\n        # apply dirichlet particle v modifier\n        for k in range(len(self.particle_velocity_modifiers)):\n            wp.launch(\n                kernel=self.particle_velocity_modifiers[k],\n                dim=self.n_particles,\n                inputs=[\n                    self.time,\n                    self.mpm_state,\n                    self.particle_velocity_modifier_params[k],\n                ],\n                device=device,\n            )\n\n        # compute stress = stress(returnMap(F_trial))\n        with wp.ScopedTimer(\n            \"compute_stress_from_F_trial\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=compute_stress_from_F_trial,\n                dim=self.n_particles,\n                inputs=[self.mpm_state, self.mpm_model, dt],\n                device=device,\n            )  # F and stress are updated\n\n        # p2g\n        with wp.ScopedTimer(\n            \"p2g\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=p2g_apic_with_stress,\n                dim=self.n_particles,\n                inputs=[self.mpm_state, self.mpm_model, dt],\n                device=device,\n            )  # apply p2g'\n\n        # grid update\n        with wp.ScopedTimer(\n            \"grid_update\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=grid_normalization_and_gravity,\n                dim=(grid_size),\n                inputs=[self.mpm_state, self.mpm_model, dt],\n                device=device,\n            )\n\n        if self.mpm_model.grid_v_damping_scale < 1.0:\n            wp.launch(\n                kernel=add_damping_via_grid,\n                dim=(grid_size),\n                inputs=[self.mpm_state, self.mpm_model.grid_v_damping_scale],\n                device=device,\n            )\n\n        # apply BC on grid\n        with wp.ScopedTimer(\n            \"apply_BC_on_grid\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            for k in range(len(self.grid_postprocess)):\n                wp.launch(\n                    kernel=self.grid_postprocess[k],\n                    dim=grid_size,\n                    inputs=[\n                        self.time,\n                        dt,\n                        self.mpm_state,\n                        self.mpm_model,\n                        self.collider_params[k],\n                    ],\n                    device=device,\n                )\n                if self.modify_bc[k] is not None:\n                    self.modify_bc[k](self.time, dt, self.collider_params[k])\n\n        # g2p\n        with wp.ScopedTimer(\n            \"g2p\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=g2p,\n                dim=self.n_particles,\n                inputs=[self.mpm_state, self.mpm_model, dt],\n                device=device,\n            )  # x, v, C, F_trial are updated\n\n        #### CFL check ####\n        # particle_v = self.mpm_state.particle_v.numpy()\n        # if np.max(np.abs(particle_v)) > self.mpm_model.dx / dt:\n        #     print(\"max particle v: \", np.max(np.abs(particle_v)))\n        #     print(\"max allowed  v: \", self.mpm_model.dx / dt)\n        #     print(\"does not allow v*dt>dx\")\n        #     input()\n        #### CFL check ####\n        self.time = self.time + dt\n\n    # set particle densities to all_particle_densities,\n    def reset_densities_and_update_masses(\n        self, all_particle_densities, device=\"cuda:0\"\n    ):\n        all_particle_densities = all_particle_densities.clone().detach()\n        self.mpm_state.particle_density = torch2warp_float(\n            all_particle_densities, dvc=device\n        )\n        wp.launch(\n            kernel=get_float_array_product,\n            dim=self.n_particles,\n            inputs=[\n                self.mpm_state.particle_density,\n                self.mpm_state.particle_vol,\n                self.mpm_state.particle_mass,\n            ],\n            device=device,\n        )\n\n    # clone = True makes a copy, not necessarily needed\n    def import_particle_x_from_torch(self, tensor_x, clone=True, device=\"cuda:0\"):\n        if tensor_x is not None:\n            if clone:\n                tensor_x = tensor_x.clone().detach()\n            self.mpm_state.particle_x = torch2warp_vec3(tensor_x, dvc=device)\n\n    # clone = True makes a copy, not necessarily needed\n    def import_particle_v_from_torch(self, tensor_v, clone=True, device=\"cuda:0\"):\n        if tensor_v is not None:\n            if clone:\n                tensor_v = tensor_v.clone().detach()\n            self.mpm_state.particle_v = torch2warp_vec3(tensor_v, dvc=device)\n\n    # clone = True makes a copy, not necessarily needed\n    def import_particle_F_from_torch(self, tensor_F, clone=True, device=\"cuda:0\"):\n        if tensor_F is not None:\n            if clone:\n                tensor_F = tensor_F.clone().detach()\n            tensor_F = torch.reshape(tensor_F, (-1, 3, 3))  # arranged by rowmajor\n            self.mpm_state.particle_F = torch2warp_mat33(tensor_F, dvc=device)\n\n    # clone = True makes a copy, not necessarily needed\n    def import_particle_C_from_torch(self, tensor_C, clone=True, device=\"cuda:0\"):\n        if tensor_C is not None:\n            if clone:\n                tensor_C = tensor_C.clone().detach()\n            tensor_C = torch.reshape(tensor_C, (-1, 3, 3))  # arranged by rowmajor\n            self.mpm_state.particle_C = torch2warp_mat33(tensor_C, dvc=device)\n\n    def export_particle_x_to_torch(self):\n        return wp.to_torch(self.mpm_state.particle_x)\n\n    def export_particle_v_to_torch(self):\n        return wp.to_torch(self.mpm_state.particle_v)\n\n    def export_particle_F_to_torch(self):\n        F_tensor = wp.to_torch(self.mpm_state.particle_F)\n        F_tensor = F_tensor.reshape(-1, 9)\n        return F_tensor\n\n    def export_particle_R_to_torch(self, device=\"cuda:0\"):\n        with wp.ScopedTimer(\n            \"compute_R_from_F\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=compute_R_from_F,\n                dim=self.n_particles,\n                inputs=[self.mpm_state, self.mpm_model],\n                device=device,\n            )\n\n        R_tensor = wp.to_torch(self.mpm_state.particle_R)\n        R_tensor = R_tensor.reshape(-1, 9)\n        return R_tensor\n\n    def export_particle_C_to_torch(self):\n        C_tensor = wp.to_torch(self.mpm_state.particle_C)\n        C_tensor = C_tensor.reshape(-1, 9)\n        return C_tensor\n\n    def export_particle_cov_to_torch(self, device=\"cuda:0\"):\n        if not self.mpm_model.update_cov_with_F:\n            with wp.ScopedTimer(\n                \"compute_cov_from_F\",\n                synchronize=True,\n                print=False,\n                dict=self.time_profile,\n            ):\n                wp.launch(\n                    kernel=compute_cov_from_F,\n                    dim=self.n_particles,\n                    inputs=[self.mpm_state, self.mpm_model],\n                    device=device,\n                )\n\n        cov = wp.to_torch(self.mpm_state.particle_cov)\n        return cov\n\n    def print_time_profile(self):\n        print(\"MPM Time profile:\")\n        for key, value in self.time_profile.items():\n            print(key, sum(value))\n\n    # a surface specified by a point and the normal vector\n    def add_surface_collider(\n        self,\n        point,\n        normal,\n        surface=\"sticky\",\n        friction=0.0,\n        start_time=0.0,\n        end_time=999.0,\n    ):\n        point = list(point)\n        # Normalize normal\n        normal_scale = 1.0 / wp.sqrt(float(sum(x**2 for x in normal)))\n        normal = list(normal_scale * x for x in normal)\n\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n\n        collider_param.point = wp.vec3(point[0], point[1], point[2])\n        collider_param.normal = wp.vec3(normal[0], normal[1], normal[2])\n\n        if surface == \"sticky\" and friction != 0:\n            raise ValueError(\"friction must be 0 on sticky surfaces.\")\n        if surface == \"sticky\":\n            collider_param.surface_type = 0\n        elif surface == \"slip\":\n            collider_param.surface_type = 1\n        elif surface == \"cut\":\n            collider_param.surface_type = 11\n        else:\n            collider_param.surface_type = 2\n        # frictional\n        collider_param.friction = friction\n\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                offset = wp.vec3(\n                    float(grid_x) * model.dx - param.point[0],\n                    float(grid_y) * model.dx - param.point[1],\n                    float(grid_z) * model.dx - param.point[2],\n                )\n                n = wp.vec3(param.normal[0], param.normal[1], param.normal[2])\n                dotproduct = wp.dot(offset, n)\n\n                if dotproduct < 0.0:\n                    if param.surface_type == 0:\n                        state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                            0.0, 0.0, 0.0\n                        )\n                    elif param.surface_type == 11:\n                        if (\n                            float(grid_z) * model.dx < 0.4\n                            or float(grid_z) * model.dx > 0.53\n                        ):\n                            state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                                0.0, 0.0, 0.0\n                            )\n                        else:\n                            v_in = state.grid_v_out[grid_x, grid_y, grid_z]\n                            state.grid_v_out[grid_x, grid_y, grid_z] = (\n                                wp.vec3(v_in[0], 0.0, v_in[2]) * 0.3\n                            )\n                    else:\n                        v = state.grid_v_out[grid_x, grid_y, grid_z]\n                        normal_component = wp.dot(v, n)\n                        if param.surface_type == 1:\n                            v = (\n                                v - normal_component * n\n                            )  # Project out all normal component\n                        else:\n                            v = (\n                                v - wp.min(normal_component, 0.0) * n\n                            )  # Project out only inward normal component\n                        if normal_component < 0.0 and wp.length(v) > 1e-20:\n                            v = wp.max(\n                                0.0, wp.length(v) + normal_component * param.friction\n                            ) * wp.normalize(\n                                v\n                            )  # apply friction here\n                        state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                            0.0, 0.0, 0.0\n                        )\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(None)\n\n    # a cubiod is a rectangular cube'\n    # centered at `point`\n    # dimension is x: point[0]±size[0]\n    #              y: point[1]±size[1]\n    #              z: point[2]±size[2]\n    # all grid nodes lie within the cubiod will have their speed set to velocity\n    # the cuboid itself is also moving with const speed = velocity\n    # set the speed to zero to fix BC\n    def set_velocity_on_cuboid(\n        self,\n        point,\n        size,\n        velocity,\n        start_time=0.0,\n        end_time=999.0,\n        reset=0,\n    ):\n        point = list(point)\n\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n        collider_param.point = wp.vec3(point[0], point[1], point[2])\n        collider_param.size = size\n        collider_param.velocity = wp.vec3(velocity[0], velocity[1], velocity[2])\n        # collider_param.threshold = threshold\n        collider_param.reset = reset\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                offset = wp.vec3(\n                    float(grid_x) * model.dx - param.point[0],\n                    float(grid_y) * model.dx - param.point[1],\n                    float(grid_z) * model.dx - param.point[2],\n                )\n                if (\n                    wp.abs(offset[0]) < param.size[0]\n                    and wp.abs(offset[1]) < param.size[1]\n                    and wp.abs(offset[2]) < param.size[2]\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = param.velocity\n            elif param.reset == 1:\n                if time < param.end_time + 15.0 * dt:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n\n        def modify(time, dt, param: Dirichlet_collider):\n            if time >= param.start_time and time < param.end_time:\n                param.point = wp.vec3(\n                    param.point[0] + dt * param.velocity[0],\n                    param.point[1] + dt * param.velocity[1],\n                    param.point[2] + dt * param.velocity[2],\n                )  # param.point + dt * param.velocity\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(modify)\n\n    def add_bounding_box(self, start_time=0.0, end_time=999.0):\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            padding = 3\n            if time >= param.start_time and time < param.end_time:\n                if grid_x < padding and state.grid_v_out[grid_x, grid_y, grid_z][0] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n                if (\n                    grid_x >= model.grid_dim_x - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][0] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n\n                if grid_y < padding and state.grid_v_out[grid_x, grid_y, grid_z][1] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n                if (\n                    grid_y >= model.grid_dim_y - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][1] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n\n                if grid_z < padding and state.grid_v_out[grid_x, grid_y, grid_z][2] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        0.0,\n                    )\n                if (\n                    grid_z >= model.grid_dim_z - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][2] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        0.0,\n                    )\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(None)\n\n    # particle_v += force/particle_mass * dt\n    # this is applied from start_dt, ends after num_dt p2g2p's\n    # particle velocity is changed before p2g at each timestep\n    def add_impulse_on_particles(\n        self,\n        force,\n        dt,\n        point=[1, 1, 1],\n        size=[1, 1, 1],\n        num_dt=1,\n        start_time=0.0,\n        device=\"cuda:0\",\n    ):\n        impulse_param = Impulse_modifier()\n        impulse_param.start_time = start_time\n        impulse_param.end_time = start_time + dt * num_dt\n\n        impulse_param.point = wp.vec3(point[0], point[1], point[2])\n        impulse_param.size = wp.vec3(size[0], size[1], size[2])\n        impulse_param.mask = wp.zeros(shape=self.n_particles, dtype=int, device=device)\n\n        impulse_param.force = wp.vec3(\n            force[0],\n            force[1],\n            force[2],\n        )\n\n        wp.launch(\n            kernel=selection_add_impulse_on_particles,\n            dim=self.n_particles,\n            inputs=[self.mpm_state, impulse_param],\n            device=device,\n        )\n\n        self.impulse_params.append(impulse_param)\n\n        @wp.kernel\n        def apply_force(\n            time: float, dt: float, state: MPMStateStruct, param: Impulse_modifier\n        ):\n            p = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                if param.mask[p] == 1:\n                    impulse = wp.vec3(\n                        param.force[0] / state.particle_mass[p],\n                        param.force[1] / state.particle_mass[p],\n                        param.force[2] / state.particle_mass[p],\n                    )\n                    state.particle_v[p] = state.particle_v[p] + impulse * dt\n\n        self.pre_p2g_operations.append(apply_force)\n\n    def enforce_particle_velocity_translation(\n        self, point, size, velocity, start_time, end_time, device=\"cuda:0\"\n    ):\n        # first select certain particles based on position\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.point = wp.vec3(point[0], point[1], point[2])\n        velocity_modifier_params.size = wp.vec3(size[0], size[1], size[2])\n\n        velocity_modifier_params.velocity = wp.vec3(\n            velocity[0], velocity[1], velocity[2]\n        )\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.zeros(\n            shape=self.n_particles, dtype=int, device=device\n        )\n\n        wp.launch(\n            kernel=selection_enforce_particle_velocity_translation,\n            dim=self.n_particles,\n            inputs=[self.mpm_state, velocity_modifier_params],\n            device=device,\n        )\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    state.particle_v[p] = velocity_modifier_params.velocity\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    # define a cylinder with center point, half_height, radius, normal\n    # particles within the cylinder are rotating along the normal direction\n    # may also have a translational velocity along the normal direction\n    def enforce_particle_velocity_rotation(\n        self,\n        point,\n        normal,\n        half_height_and_radius,\n        rotation_scale,\n        translation_scale,\n        start_time,\n        end_time,\n        device=\"cuda:0\",\n    ):\n        normal_scale = 1.0 / wp.sqrt(\n            float(normal[0] ** 2 + normal[1] ** 2 + normal[2] ** 2)\n        )\n        normal = list(normal_scale * x for x in normal)\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.point = wp.vec3(point[0], point[1], point[2])\n        velocity_modifier_params.half_height_and_radius = wp.vec2(\n            half_height_and_radius[0], half_height_and_radius[1]\n        )\n        velocity_modifier_params.normal = wp.vec3(normal[0], normal[1], normal[2])\n\n        horizontal_1 = wp.vec3(1.0, 1.0, 1.0)\n        if wp.abs(wp.dot(velocity_modifier_params.normal, horizontal_1)) < 0.01:\n            horizontal_1 = wp.vec3(0.72, 0.37, -0.67)\n        horizontal_1 = (\n            horizontal_1\n            - wp.dot(horizontal_1, velocity_modifier_params.normal)\n            * velocity_modifier_params.normal\n        )\n        horizontal_1 = horizontal_1 * (1.0 / wp.length(horizontal_1))\n        horizontal_2 = wp.cross(horizontal_1, velocity_modifier_params.normal)\n\n        velocity_modifier_params.horizontal_axis_1 = horizontal_1\n        velocity_modifier_params.horizontal_axis_2 = horizontal_2\n\n        velocity_modifier_params.rotation_scale = rotation_scale\n        velocity_modifier_params.translation_scale = translation_scale\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.zeros(\n            shape=self.n_particles, dtype=int, device=device\n        )\n\n        wp.launch(\n            kernel=selection_enforce_particle_velocity_cylinder,\n            dim=self.n_particles,\n            inputs=[self.mpm_state, velocity_modifier_params],\n            device=device,\n        )\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    offset = state.particle_x[p] - velocity_modifier_params.point\n                    horizontal_distance = wp.length(\n                        offset\n                        - wp.dot(offset, velocity_modifier_params.normal)\n                        * velocity_modifier_params.normal\n                    )\n                    cosine = (\n                        wp.dot(offset, velocity_modifier_params.horizontal_axis_1)\n                        / horizontal_distance\n                    )\n                    theta = wp.acos(cosine)\n                    if wp.dot(offset, velocity_modifier_params.horizontal_axis_2) > 0:\n                        theta = theta\n                    else:\n                        theta = -theta\n                    axis1_scale = (\n                        -horizontal_distance\n                        * wp.sin(theta)\n                        * velocity_modifier_params.rotation_scale\n                    )\n                    axis2_scale = (\n                        horizontal_distance\n                        * wp.cos(theta)\n                        * velocity_modifier_params.rotation_scale\n                    )\n                    axis_vertical_scale = translation_scale\n                    state.particle_v[p] = (\n                        axis1_scale * velocity_modifier_params.horizontal_axis_1\n                        + axis2_scale * velocity_modifier_params.horizontal_axis_2\n                        + axis_vertical_scale * velocity_modifier_params.normal\n                    )\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    # given normal direction, say [0,0,1]\n    # gradually release grid velocities from start position to end position\n    def release_particles_sequentially(\n        self, normal, start_position, end_position, num_layers, start_time, end_time\n    ):\n        num_layers = 50\n        point = [0, 0, 0]\n        size = [0, 0, 0]\n        axis = -1\n        for i in range(3):\n            if normal[i] == 0:\n                point[i] = 1\n                size[i] = 1\n            else:\n                axis = i\n                point[i] = end_position\n\n        half_length_portion = wp.abs(start_position - end_position) / num_layers\n        end_time_portion = end_time / num_layers\n        for i in range(num_layers):\n            size[axis] = half_length_portion * (num_layers - i)\n            self.enforce_particle_velocity_translation(\n                point=point,\n                size=size,\n                velocity=[0, 0, 0],\n                start_time=start_time,\n                end_time=end_time_portion * (i + 1),\n            )\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/mpm_solver_warp_diff.py",
    "content": "import sys\nimport os\n\nimport warp as wp\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)))\nfrom engine_utils import *\nfrom diff_warp_utils import *\nfrom mpm_utils import *\nfrom typing import Optional, Union, Sequence, Any\n\n\nclass MPM_Simulator_WARPDiff(object):\n    # def __init__(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n    #     self.initialize(n_particles, n_grid, grid_lim, device=device)\n    #     self.time_profile = {}\n\n    def __init__(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.initialize(n_particles, n_grid, grid_lim, device=device)\n        self.time_profile = {}\n\n    def initialize(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.n_particles = n_particles\n\n        self.time = 0.0\n\n        self.grid_postprocess = []\n        self.collider_params = []\n        self.modify_bc = []\n\n        self.tailored_struct_for_bc = MPMtailoredStruct()\n        self.pre_p2g_operations = []\n        self.impulse_params = []\n\n        self.particle_velocity_modifiers = []\n        self.particle_velocity_modifier_params = []\n\n    # must give density. mass will be updated as density * volume\n    def set_parameters(self, device=\"cuda:0\", **kwargs):\n        self.set_parameters_dict(device, kwargs)\n\n    def set_parameters_dict(self, mpm_model, mpm_state, kwargs={}, device=\"cuda:0\"):\n        if \"material\" in kwargs:\n            if kwargs[\"material\"] == \"jelly\":\n                mpm_model.material = 0\n            elif kwargs[\"material\"] == \"metal\":\n                mpm_model.material = 1\n            elif kwargs[\"material\"] == \"sand\":\n                mpm_model.material = 2\n            elif kwargs[\"material\"] == \"foam\":\n                mpm_model.material = 3\n            elif kwargs[\"material\"] == \"snow\":\n                mpm_model.material = 4\n            elif kwargs[\"material\"] == \"plasticine\":\n                mpm_model.material = 5\n            else:\n                raise TypeError(\"Undefined material type\")\n\n        if \"yield_stress\" in kwargs:\n            val = kwargs[\"yield_stress\"]\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_model.yield_stress, val],\n                device=device,\n            )\n        if \"hardening\" in kwargs:\n            mpm_model.hardening = kwargs[\"hardening\"]\n        if \"xi\" in kwargs:\n            mpm_model.xi = kwargs[\"xi\"]\n        if \"friction_angle\" in kwargs:\n            mpm_model.friction_angle = kwargs[\"friction_angle\"]\n            sin_phi = wp.sin(mpm_model.friction_angle / 180.0 * 3.14159265)\n            mpm_model.alpha = wp.sqrt(2.0 / 3.0) * 2.0 * sin_phi / (3.0 - sin_phi)\n\n        if \"g\" in kwargs:\n            mpm_model.gravitational_accelaration = wp.vec3(\n                kwargs[\"g\"][0], kwargs[\"g\"][1], kwargs[\"g\"][2]\n            )\n\n        if \"density\" in kwargs:\n            density_value = kwargs[\"density\"]\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_state.particle_density, density_value],\n                device=device,\n            )\n            wp.launch(\n                kernel=get_float_array_product,\n                dim=self.n_particles,\n                inputs=[\n                    mpm_state.particle_density,\n                    mpm_state.particle_vol,\n                    mpm_state.particle_mass,\n                ],\n                device=device,\n            )\n        if \"rpic_damping\" in kwargs:\n            mpm_model.rpic_damping = kwargs[\"rpic_damping\"]\n        if \"plastic_viscosity\" in kwargs:\n            mpm_model.plastic_viscosity = kwargs[\"plastic_viscosity\"]\n        if \"softening\" in kwargs:\n            mpm_model.softening = kwargs[\"softening\"]\n        if \"grid_v_damping_scale\" in kwargs:\n            mpm_model.grid_v_damping_scale = kwargs[\"grid_v_damping_scale\"]\n\n    def set_E_nu(self, mpm_model, E: float, nu: float, device=\"cuda:0\"):\n        \n        wp.launch(\n            kernel=set_value_to_float_array,\n            dim=self.n_particles,\n            inputs=[mpm_model.E, E],\n            device=device,\n        )\n        wp.launch(\n            kernel=set_value_to_float_array,\n            dim=self.n_particles,\n            inputs=[mpm_model.nu, nu],\n            device=device,\n        )\n\n    def p2g2p(self, mpm_model, mpm_state, step, dt, device=\"cuda:0\"):\n        grid_size = (\n            mpm_model.grid_dim_x,\n            mpm_model.grid_dim_y,\n            mpm_model.grid_dim_z,\n        )\n\n        wp.launch(\n            kernel=compute_mu_lam_from_E_nu,\n            dim=self.n_particles,\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n        wp.launch(\n            kernel=zero_grid,  # gradient might gone\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        # apply pre-p2g operations on particles\n        # apply impulse force on particles..\n        for k in range(len(self.pre_p2g_operations)):\n            wp.launch(\n                kernel=self.pre_p2g_operations[k],\n                dim=self.n_particles,\n                inputs=[self.time, dt, mpm_state, self.impulse_params[k]],\n                device=device,\n            )\n\n        # apply dirichlet particle v modifier\n        for k in range(len(self.particle_velocity_modifiers)):\n            wp.launch(\n                kernel=self.particle_velocity_modifiers[k],\n                dim=self.n_particles,\n                inputs=[\n                    self.time,\n                    mpm_state,\n                    self.particle_velocity_modifier_params[k],\n                ],\n                device=device,\n            )\n\n        # compute stress = stress(returnMap(F_trial))\n        with wp.ScopedTimer(\n            \"compute_stress_from_F_trial\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=compute_stress_from_F_trial,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # F and stress are updated\n\n        # p2g\n        with wp.ScopedTimer(\n            \"p2g\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=p2g_apic_with_stress,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # apply p2g'\n\n        # grid update\n        with wp.ScopedTimer(\n            \"grid_update\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=grid_normalization_and_gravity,\n                dim=(grid_size),\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )\n\n        if mpm_model.grid_v_damping_scale < 1.0:\n            wp.launch(\n                kernel=add_damping_via_grid,\n                dim=(grid_size),\n                inputs=[mpm_state, mpm_model.grid_v_damping_scale],\n                device=device,\n            )\n\n        # apply BC on grid, collide\n        with wp.ScopedTimer(\n            \"apply_BC_on_grid\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            for k in range(len(self.grid_postprocess)):\n                wp.launch(\n                    kernel=self.grid_postprocess[k],\n                    dim=grid_size,\n                    inputs=[\n                        self.time,\n                        dt,\n                        mpm_state,\n                        mpm_model,\n                        self.collider_params[k],\n                    ],\n                    device=device,\n                )\n                if self.modify_bc[k] is not None:\n                    self.modify_bc[k](self.time, dt, self.collider_params[k])\n\n        # g2p\n        with wp.ScopedTimer(\n            \"g2p\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=g2p,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # x, v, C, F_trial are updated\n\n        #### CFL check ####\n        # particle_v = self.mpm_state.particle_v.numpy()\n        # if np.max(np.abs(particle_v)) > self.mpm_model.dx / dt:\n        #     print(\"max particle v: \", np.max(np.abs(particle_v)))\n        #     print(\"max allowed  v: \", self.mpm_model.dx / dt)\n        #     print(\"does not allow v*dt>dx\")\n        #     input()\n        #### CFL check ####\n        self.time = self.time + dt\n\n    def print_time_profile(self):\n        print(\"MPM Time profile:\")\n        for key, value in self.time_profile.items():\n            print(key, sum(value))\n\n    # a surface specified by a point and the normal vector\n    def add_surface_collider(\n        self,\n        point,\n        normal,\n        surface=\"sticky\",\n        friction=0.0,\n        start_time=0.0,\n        end_time=999.0,\n    ):\n        point = list(point)\n        # Normalize normal\n        normal_scale = 1.0 / wp.sqrt(float(sum(x**2 for x in normal)))\n        normal = list(normal_scale * x for x in normal)\n\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n\n        collider_param.point = wp.vec3(point[0], point[1], point[2])\n        collider_param.normal = wp.vec3(normal[0], normal[1], normal[2])\n\n        if surface == \"sticky\" and friction != 0:\n            raise ValueError(\"friction must be 0 on sticky surfaces.\")\n        if surface == \"sticky\":\n            collider_param.surface_type = 0\n        elif surface == \"slip\":\n            collider_param.surface_type = 1\n        elif surface == \"cut\":\n            collider_param.surface_type = 11\n        else:\n            collider_param.surface_type = 2\n        # frictional\n        collider_param.friction = friction\n\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                offset = wp.vec3(\n                    float(grid_x) * model.dx - param.point[0],\n                    float(grid_y) * model.dx - param.point[1],\n                    float(grid_z) * model.dx - param.point[2],\n                )\n                n = wp.vec3(param.normal[0], param.normal[1], param.normal[2])\n                dotproduct = wp.dot(offset, n)\n\n                if dotproduct < 0.0:\n                    if param.surface_type == 0:\n                        state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                            0.0, 0.0, 0.0\n                        )\n                    elif param.surface_type == 11:\n                        if (\n                            float(grid_z) * model.dx < 0.4\n                            or float(grid_z) * model.dx > 0.53\n                        ):\n                            state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                                0.0, 0.0, 0.0\n                            )\n                        else:\n                            v_in = state.grid_v_out[grid_x, grid_y, grid_z]\n                            state.grid_v_out[grid_x, grid_y, grid_z] = (\n                                wp.vec3(v_in[0], 0.0, v_in[2]) * 0.3\n                            )\n                    else:\n                        v = state.grid_v_out[grid_x, grid_y, grid_z]\n                        normal_component = wp.dot(v, n)\n                        if param.surface_type == 1:\n                            v = (\n                                v - normal_component * n\n                            )  # Project out all normal component\n                        else:\n                            v = (\n                                v - wp.min(normal_component, 0.0) * n\n                            )  # Project out only inward normal component\n                        if normal_component < 0.0 and wp.length(v) > 1e-20:\n                            v = wp.max(\n                                0.0, wp.length(v) + normal_component * param.friction\n                            ) * wp.normalize(\n                                v\n                            )  # apply friction here\n                        state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                            0.0, 0.0, 0.0\n                        )\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(None)\n\n    # a cubiod is a rectangular cube'\n    # centered at `point`\n    # dimension is x: point[0]±size[0]\n    #              y: point[1]±size[1]\n    #              z: point[2]±size[2]\n    # all grid nodes lie within the cubiod will have their speed set to velocity\n    # the cuboid itself is also moving with const speed = velocity\n    # set the speed to zero to fix BC\n    def set_velocity_on_cuboid(\n        self,\n        point,\n        size,\n        velocity,\n        start_time=0.0,\n        end_time=999.0,\n        reset=0,\n    ):\n        point = list(point)\n\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n        collider_param.point = wp.vec3(point[0], point[1], point[2])\n        collider_param.size = size\n        collider_param.velocity = wp.vec3(velocity[0], velocity[1], velocity[2])\n        # collider_param.threshold = threshold\n        collider_param.reset = reset\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                offset = wp.vec3(\n                    float(grid_x) * model.dx - param.point[0],\n                    float(grid_y) * model.dx - param.point[1],\n                    float(grid_z) * model.dx - param.point[2],\n                )\n                if (\n                    wp.abs(offset[0]) < param.size[0]\n                    and wp.abs(offset[1]) < param.size[1]\n                    and wp.abs(offset[2]) < param.size[2]\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = param.velocity\n            elif param.reset == 1:\n                if time < param.end_time + 15.0 * dt:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n\n        def modify(time, dt, param: Dirichlet_collider):\n            if time >= param.start_time and time < param.end_time:\n                param.point = wp.vec3(\n                    param.point[0] + dt * param.velocity[0],\n                    param.point[1] + dt * param.velocity[1],\n                    param.point[2] + dt * param.velocity[2],\n                )  # param.point + dt * param.velocity\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(modify)\n\n    def add_bounding_box(self, start_time=0.0, end_time=999.0):\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            padding = 3\n            if time >= param.start_time and time < param.end_time:\n                if grid_x < padding and state.grid_v_out[grid_x, grid_y, grid_z][0] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n                if (\n                    grid_x >= model.grid_dim_x - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][0] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n\n                if grid_y < padding and state.grid_v_out[grid_x, grid_y, grid_z][1] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n                if (\n                    grid_y >= model.grid_dim_y - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][1] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n\n                if grid_z < padding and state.grid_v_out[grid_x, grid_y, grid_z][2] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        0.0,\n                    )\n                if (\n                    grid_z >= model.grid_dim_z - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][2] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        0.0,\n                    )\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(None)\n\n    # particle_v += force/particle_mass * dt\n    # this is applied from start_dt, ends after num_dt p2g2p's\n    # particle velocity is changed before p2g at each timestep\n    def add_impulse_on_particles(\n        self,\n        mpm_state,\n        force,\n        dt,\n        point=[1, 1, 1],\n        size=[1, 1, 1],\n        num_dt=1,\n        start_time=0.0,\n        device=\"cuda:0\",\n    ):\n        impulse_param = Impulse_modifier()\n        impulse_param.start_time = start_time\n        impulse_param.end_time = start_time + dt * num_dt\n\n        impulse_param.point = wp.vec3(point[0], point[1], point[2])\n        impulse_param.size = wp.vec3(size[0], size[1], size[2])\n        impulse_param.mask = wp.zeros(shape=self.n_particles, dtype=int, device=device)\n\n        impulse_param.force = wp.vec3(\n            force[0],\n            force[1],\n            force[2],\n        )\n\n        wp.launch(\n            kernel=selection_add_impulse_on_particles,\n            dim=self.n_particles,\n            inputs=[mpm_state, impulse_param],\n            device=device,\n        )\n\n        self.impulse_params.append(impulse_param)\n\n        @wp.kernel\n        def apply_force(\n            time: float, dt: float, state: MPMStateStruct, param: Impulse_modifier\n        ):\n            p = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                if param.mask[p] == 1:\n                    impulse = wp.vec3(\n                        param.force[0] / state.particle_mass[p],\n                        param.force[1] / state.particle_mass[p],\n                        param.force[2] / state.particle_mass[p],\n                    )\n                    state.particle_v[p] = state.particle_v[p] + impulse * dt\n\n        self.pre_p2g_operations.append(apply_force)\n\n    def enforce_particle_velocity_translation(\n        self, mpm_state, point, size, velocity, start_time, end_time, device=\"cuda:0\"\n    ):\n        # first select certain particles based on position\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.point = wp.vec3(point[0], point[1], point[2])\n        velocity_modifier_params.size = wp.vec3(size[0], size[1], size[2])\n\n        velocity_modifier_params.velocity = wp.vec3(\n            velocity[0], velocity[1], velocity[2]\n        )\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.zeros(\n            shape=self.n_particles, dtype=int, device=device\n        )\n\n        wp.launch(\n            kernel=selection_enforce_particle_velocity_translation,\n            dim=self.n_particles,\n            inputs=[mpm_state, velocity_modifier_params],\n            device=device,\n        )\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    state.particle_v[p] = velocity_modifier_params.velocity\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    # define a cylinder with center point, half_height, radius, normal\n    # particles within the cylinder are rotating along the normal direction\n    # may also have a translational velocity along the normal direction\n    def enforce_particle_velocity_rotation(\n        self,\n        mpm_state,\n        point,\n        normal,\n        half_height_and_radius,\n        rotation_scale,\n        translation_scale,\n        start_time,\n        end_time,\n        device=\"cuda:0\",\n    ):\n        normal_scale = 1.0 / wp.sqrt(\n            float(normal[0] ** 2 + normal[1] ** 2 + normal[2] ** 2)\n        )\n        normal = list(normal_scale * x for x in normal)\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.point = wp.vec3(point[0], point[1], point[2])\n        velocity_modifier_params.half_height_and_radius = wp.vec2(\n            half_height_and_radius[0], half_height_and_radius[1]\n        )\n        velocity_modifier_params.normal = wp.vec3(normal[0], normal[1], normal[2])\n\n        horizontal_1 = wp.vec3(1.0, 1.0, 1.0)\n        if wp.abs(wp.dot(velocity_modifier_params.normal, horizontal_1)) < 0.01:\n            horizontal_1 = wp.vec3(0.72, 0.37, -0.67)\n        horizontal_1 = (\n            horizontal_1\n            - wp.dot(horizontal_1, velocity_modifier_params.normal)\n            * velocity_modifier_params.normal\n        )\n        horizontal_1 = horizontal_1 * (1.0 / wp.length(horizontal_1))\n        horizontal_2 = wp.cross(horizontal_1, velocity_modifier_params.normal)\n\n        velocity_modifier_params.horizontal_axis_1 = horizontal_1\n        velocity_modifier_params.horizontal_axis_2 = horizontal_2\n\n        velocity_modifier_params.rotation_scale = rotation_scale\n        velocity_modifier_params.translation_scale = translation_scale\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.zeros(\n            shape=self.n_particles, dtype=int, device=device\n        )\n\n        wp.launch(\n            kernel=selection_enforce_particle_velocity_cylinder,\n            dim=self.n_particles,\n            inputs=[mpm_state, velocity_modifier_params],\n            device=device,\n        )\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    offset = state.particle_x[p] - velocity_modifier_params.point\n                    horizontal_distance = wp.length(\n                        offset\n                        - wp.dot(offset, velocity_modifier_params.normal)\n                        * velocity_modifier_params.normal\n                    )\n                    cosine = (\n                        wp.dot(offset, velocity_modifier_params.horizontal_axis_1)\n                        / horizontal_distance\n                    )\n                    theta = wp.acos(cosine)\n                    if wp.dot(offset, velocity_modifier_params.horizontal_axis_2) > 0:\n                        theta = theta\n                    else:\n                        theta = -theta\n                    axis1_scale = (\n                        -horizontal_distance\n                        * wp.sin(theta)\n                        * velocity_modifier_params.rotation_scale\n                    )\n                    axis2_scale = (\n                        horizontal_distance\n                        * wp.cos(theta)\n                        * velocity_modifier_params.rotation_scale\n                    )\n                    axis_vertical_scale = translation_scale\n                    state.particle_v[p] = (\n                        axis1_scale * velocity_modifier_params.horizontal_axis_1\n                        + axis2_scale * velocity_modifier_params.horizontal_axis_2\n                        + axis_vertical_scale * velocity_modifier_params.normal\n                    )\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    # given normal direction, say [0,0,1]\n    # gradually release grid velocities from start position to end position\n    def release_particles_sequentially(\n        self, normal, start_position, end_position, num_layers, start_time, end_time\n    ):\n        num_layers = 50\n        point = [0, 0, 0]\n        size = [0, 0, 0]\n        axis = -1\n        for i in range(3):\n            if normal[i] == 0:\n                point[i] = 1\n                size[i] = 1\n            else:\n                axis = i\n                point[i] = end_position\n\n        half_length_portion = wp.abs(start_position - end_position) / num_layers\n        end_time_portion = end_time / num_layers\n        for i in range(num_layers):\n            size[axis] = half_length_portion * (num_layers - i)\n            self.enforce_particle_velocity_translation(\n                point=point,\n                size=size,\n                velocity=[0, 0, 0],\n                start_time=start_time,\n                end_time=end_time_portion * (i + 1),\n            )\n\n    def enforce_particle_velocity_by_mask(\n        self, mpm_state, selection_mask:torch.Tensor, velocity, start_time, end_time, device=\"cuda:0\"\n    ):\n        # first select certain particles based on position\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        \n        velocity_modifier_params.velocity = wp.vec3(\n            velocity[0], velocity[1], velocity[2]\n        )\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.from_torch(selection_mask, device=device)\n\n        wp.launch(\n            kernel=selection_enforce_particle_velocity_translation,\n            dim=self.n_particles,\n            inputs=[mpm_state, velocity_modifier_params],\n            device=device,\n        )\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    state.particle_v[p] = velocity_modifier_params.velocity\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/mpm_utils.py",
    "content": "import warp as wp\nfrom diff_warp_utils import *\nimport numpy as np\nimport math\n\n\n# compute stress from F\n@wp.func\ndef kirchoff_stress_FCR(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, mu: float, lam: float\n):\n    # compute kirchoff stress for FCR model (remember tau = P F^T)\n    R = U * wp.transpose(V)\n    id = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    return 2.0 * mu * (F - R) * wp.transpose(F) + id * lam * J * (J - 1.0)\n\n\n@wp.func\ndef kirchoff_stress_neoHookean(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, sig: wp.vec3, mu: float, lam: float\n):\n    # compute kirchoff stress for FCR model (remember tau = P F^T)\n    b = wp.vec3(sig[0] * sig[0], sig[1] * sig[1], sig[2] * sig[2])\n    b_hat = b - wp.vec3(\n        (b[0] + b[1] + b[2]) / 3.0,\n        (b[0] + b[1] + b[2]) / 3.0,\n        (b[0] + b[1] + b[2]) / 3.0,\n    )\n    tau = mu * J ** (-2.0 / 3.0) * b_hat + lam / 2.0 * (J * J - 1.0) * wp.vec3(\n        1.0, 1.0, 1.0\n    )\n    return (\n        U\n        * wp.mat33(tau[0], 0.0, 0.0, 0.0, tau[1], 0.0, 0.0, 0.0, tau[2])\n        * wp.transpose(V)\n        * wp.transpose(F)\n    )\n\n\n@wp.func\ndef kirchoff_stress_StVK(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, sig: wp.vec3, mu: float, lam: float\n):\n    sig = wp.vec3(\n        wp.max(sig[0], 0.01), wp.max(sig[1], 0.01), wp.max(sig[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    log_sig_sum = wp.log(sig[0]) + wp.log(sig[1]) + wp.log(sig[2])\n    ONE = wp.vec3(1.0, 1.0, 1.0)\n    tau = 2.0 * mu * epsilon + lam * log_sig_sum * ONE\n    return (\n        U\n        * wp.mat33(tau[0], 0.0, 0.0, 0.0, tau[1], 0.0, 0.0, 0.0, tau[2])\n        * wp.transpose(V)\n        * wp.transpose(F)\n    )\n\n\n@wp.func\ndef kirchoff_stress_drucker_prager(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, sig: wp.vec3, mu: float, lam: float\n):\n    log_sig_sum = wp.log(sig[0]) + wp.log(sig[1]) + wp.log(sig[2])\n    center00 = 2.0 * mu * wp.log(sig[0]) * (1.0 / sig[0]) + lam * log_sig_sum * (\n        1.0 / sig[0]\n    )\n    center11 = 2.0 * mu * wp.log(sig[1]) * (1.0 / sig[1]) + lam * log_sig_sum * (\n        1.0 / sig[1]\n    )\n    center22 = 2.0 * mu * wp.log(sig[2]) * (1.0 / sig[2]) + lam * log_sig_sum * (\n        1.0 / sig[2]\n    )\n    center = wp.mat33(center00, 0.0, 0.0, 0.0, center11, 0.0, 0.0, 0.0, center22)\n    return U * center * wp.transpose(V) * wp.transpose(F)\n\n\n@wp.func\ndef von_mises_return_mapping(F_trial: wp.mat33, model: MPMModelStruct, p: int):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig_old = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig_old, V)\n\n    sig = wp.vec3(\n        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    temp = (epsilon[0] + epsilon[1] + epsilon[2]) / 3.0\n\n    tau = 2.0 * model.mu[p] * epsilon + model.lam[p] * (\n        epsilon[0] + epsilon[1] + epsilon[2]\n    ) * wp.vec3(1.0, 1.0, 1.0)\n    sum_tau = tau[0] + tau[1] + tau[2]\n    cond = wp.vec3(\n        tau[0] - sum_tau / 3.0, tau[1] - sum_tau / 3.0, tau[2] - sum_tau / 3.0\n    )\n    if wp.length(cond) > model.yield_stress[p]:\n        epsilon_hat = epsilon - wp.vec3(temp, temp, temp)\n        epsilon_hat_norm = wp.length(epsilon_hat) + 1e-6\n        delta_gamma = epsilon_hat_norm - model.yield_stress[p] / (2.0 * model.mu[p])\n        epsilon = epsilon - (delta_gamma / epsilon_hat_norm) * epsilon_hat\n        sig_elastic = wp.mat33(\n            wp.exp(epsilon[0]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[1]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[2]),\n        )\n        F_elastic = U * sig_elastic * wp.transpose(V)\n        if model.hardening == 1:\n            model.yield_stress[p] = (\n                model.yield_stress[p] + 2.0 * model.mu[p] * model.xi * delta_gamma\n            )\n        return F_elastic\n    else:\n        return F_trial\n\n\n@wp.func\ndef von_mises_return_mapping_with_damage(\n    F_trial: wp.mat33, model: MPMModelStruct, p: int\n):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig_old = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig_old, V)\n\n    sig = wp.vec3(\n        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    temp = (epsilon[0] + epsilon[1] + epsilon[2]) / 3.0\n\n    tau = 2.0 * model.mu[p] * epsilon + model.lam[p] * (\n        epsilon[0] + epsilon[1] + epsilon[2]\n    ) * wp.vec3(1.0, 1.0, 1.0)\n    sum_tau = tau[0] + tau[1] + tau[2]\n    cond = wp.vec3(\n        tau[0] - sum_tau / 3.0, tau[1] - sum_tau / 3.0, tau[2] - sum_tau / 3.0\n    )\n    if wp.length(cond) > model.yield_stress[p]:\n        if model.yield_stress[p] <= 0:\n            return F_trial\n        epsilon_hat = epsilon - wp.vec3(temp, temp, temp)\n        epsilon_hat_norm = wp.length(epsilon_hat) + 1e-6\n        delta_gamma = epsilon_hat_norm - model.yield_stress[p] / (2.0 * model.mu[p])\n        epsilon = epsilon - (delta_gamma / epsilon_hat_norm) * epsilon_hat\n        model.yield_stress[p] = model.yield_stress[p] - model.softening * wp.length(\n            (delta_gamma / epsilon_hat_norm) * epsilon_hat\n        )\n        if model.yield_stress[p] <= 0:\n            model.mu[p] = 0.0\n            model.lam[p] = 0.0\n        sig_elastic = wp.mat33(\n            wp.exp(epsilon[0]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[1]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[2]),\n        )\n        F_elastic = U * sig_elastic * wp.transpose(V)\n        if model.hardening == 1:\n            model.yield_stress[p] = (\n                model.yield_stress[p] + 2.0 * model.mu[p] * model.xi * delta_gamma\n            )\n        return F_elastic\n    else:\n        return F_trial\n\n\n# for toothpaste\n@wp.func\ndef viscoplasticity_return_mapping_with_StVK(\n    F_trial: wp.mat33, model: MPMModelStruct, p: int, dt: float\n):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig_old = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig_old, V)\n\n    sig = wp.vec3(\n        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    b_trial = wp.vec3(sig[0] * sig[0], sig[1] * sig[1], sig[2] * sig[2])\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    trace_epsilon = epsilon[0] + epsilon[1] + epsilon[2]\n    epsilon_hat = epsilon - wp.vec3(\n        trace_epsilon / 3.0, trace_epsilon / 3.0, trace_epsilon / 3.0\n    )\n    s_trial = 2.0 * model.mu[p] * epsilon_hat\n    s_trial_norm = wp.length(s_trial)\n    y = s_trial_norm - wp.sqrt(2.0 / 3.0) * model.yield_stress[p]\n    if y > 0:\n        mu_hat = model.mu[p] * (b_trial[0] + b_trial[1] + b_trial[2]) / 3.0\n        s_new_norm = s_trial_norm - y / (\n            1.0 + model.plastic_viscosity / (2.0 * mu_hat * dt)\n        )\n        s_new = (s_new_norm / s_trial_norm) * s_trial\n        epsilon_new = 1.0 / (2.0 * model.mu[p]) * s_new + wp.vec3(\n            trace_epsilon / 3.0, trace_epsilon / 3.0, trace_epsilon / 3.0\n        )\n        sig_elastic = wp.mat33(\n            wp.exp(epsilon_new[0]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon_new[1]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon_new[2]),\n        )\n        F_elastic = U * sig_elastic * wp.transpose(V)\n        return F_elastic\n    else:\n        return F_trial\n\n\n@wp.func\ndef sand_return_mapping(\n    F_trial: wp.mat33, state: MPMStateStruct, model: MPMModelStruct, p: int\n):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig, V)\n\n    epsilon = wp.vec3(\n        wp.log(wp.max(wp.abs(sig[0]), 1e-14)),\n        wp.log(wp.max(wp.abs(sig[1]), 1e-14)),\n        wp.log(wp.max(wp.abs(sig[2]), 1e-14)),\n    )\n    sigma_out = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    tr = epsilon[0] + epsilon[1] + epsilon[2]  # + state.particle_Jp[p]\n    epsilon_hat = epsilon - wp.vec3(tr / 3.0, tr / 3.0, tr / 3.0)\n    epsilon_hat_norm = wp.length(epsilon_hat)\n    delta_gamma = (\n        epsilon_hat_norm\n        + (3.0 * model.lam[p] + 2.0 * model.mu[p])\n        / (2.0 * model.mu[p])\n        * tr\n        * model.alpha\n    )\n\n    if delta_gamma <= 0:\n        F_elastic = F_trial\n\n    if delta_gamma > 0 and tr > 0:\n        F_elastic = U * wp.transpose(V)\n\n    if delta_gamma > 0 and tr <= 0:\n        H = epsilon - epsilon_hat * (delta_gamma / epsilon_hat_norm)\n        s_new = wp.vec3(wp.exp(H[0]), wp.exp(H[1]), wp.exp(H[2]))\n\n        F_elastic = U * wp.diag(s_new) * wp.transpose(V)\n    return F_elastic\n\n\n@wp.kernel\ndef compute_mu_lam_from_E_nu(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n    model.mu[p] = model.E[p] / (2.0 * (1.0 + model.nu[p]))\n    model.lam[p] = (\n        model.E[p] * model.nu[p] / ((1.0 + model.nu[p]) * (1.0 - 2.0 * model.nu[p]))\n    )\n\n\n@wp.kernel\ndef zero_grid(state: MPMStateStruct, model: MPMModelStruct):\n    grid_x, grid_y, grid_z = wp.tid()\n    state.grid_m[grid_x, grid_y, grid_z] = 0.0\n    state.grid_v_in[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n\n\n@wp.func\ndef compute_dweight(\n    model: MPMModelStruct, w: wp.mat33, dw: wp.mat33, i: int, j: int, k: int\n):\n    dweight = wp.vec3(\n        dw[0, i] * w[1, j] * w[2, k],\n        w[0, i] * dw[1, j] * w[2, k],\n        w[0, i] * w[1, j] * dw[2, k],\n    )\n    return dweight * model.inv_dx\n\n\n@wp.func\ndef update_cov(state: MPMStateStruct, p: int, grad_v: wp.mat33, dt: float):\n    cov_n = wp.mat33(0.0)\n    cov_n[0, 0] = state.particle_cov[p * 6]\n    cov_n[0, 1] = state.particle_cov[p * 6 + 1]\n    cov_n[0, 2] = state.particle_cov[p * 6 + 2]\n    cov_n[1, 0] = state.particle_cov[p * 6 + 1]\n    cov_n[1, 1] = state.particle_cov[p * 6 + 3]\n    cov_n[1, 2] = state.particle_cov[p * 6 + 4]\n    cov_n[2, 0] = state.particle_cov[p * 6 + 2]\n    cov_n[2, 1] = state.particle_cov[p * 6 + 4]\n    cov_n[2, 2] = state.particle_cov[p * 6 + 5]\n\n    cov_np1 = cov_n + dt * (grad_v * cov_n + cov_n * wp.transpose(grad_v))\n\n    state.particle_cov[p * 6] = cov_np1[0, 0]\n    state.particle_cov[p * 6 + 1] = cov_np1[0, 1]\n    state.particle_cov[p * 6 + 2] = cov_np1[0, 2]\n    state.particle_cov[p * 6 + 3] = cov_np1[1, 1]\n    state.particle_cov[p * 6 + 4] = cov_np1[1, 2]\n    state.particle_cov[p * 6 + 5] = cov_np1[2, 2]\n\n\n@wp.kernel\ndef p2g_apic_with_stress(state: MPMStateStruct, model: MPMModelStruct, dt: float):\n    # input given to p2g:   particle_stress\n    #                       particle_x\n    #                       particle_v\n    #                       particle_C\n    # output:               grid_v_in, grid_m\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        stress = state.particle_stress[p]\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    dpos = (\n                        wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    ) * model.dx\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n\n                    C = state.particle_C[p]\n                    # if model.rpic = 0, standard apic\n                    C = (1.0 - model.rpic_damping) * C + model.rpic_damping / 2.0 * (\n                        C - wp.transpose(C)\n                    )\n\n                    # C = (1.0 - model.rpic_damping) * state.particle_C[\n                    #     p\n                    # ] + model.rpic_damping / 2.0 * (\n                    #     state.particle_C[p] - wp.transpose(state.particle_C[p])\n                    # )\n\n                    if model.rpic_damping < -0.001:\n                        # standard pic\n                        C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n                    elastic_force = -state.particle_vol[p] * stress * dweight\n                    v_in_add = (\n                        weight\n                        * state.particle_mass[p]\n                        * (state.particle_v[p] + C * dpos)\n                        + dt * elastic_force\n                    )\n                    wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)\n                    wp.atomic_add(\n                        state.grid_m, ix, iy, iz, weight * state.particle_mass[p]\n                    )\n\n\n# add gravity\n@wp.kernel\ndef grid_normalization_and_gravity(\n    state: MPMStateStruct, model: MPMModelStruct, dt: float\n):\n    grid_x, grid_y, grid_z = wp.tid()\n    if state.grid_m[grid_x, grid_y, grid_z] > 1e-15:\n        v_out = state.grid_v_in[grid_x, grid_y, grid_z] * (\n            1.0 / state.grid_m[grid_x, grid_y, grid_z]\n        )\n        # add gravity\n        v_out = v_out + dt * model.gravitational_accelaration\n        state.grid_v_out[grid_x, grid_y, grid_z] = v_out\n\n\n@wp.kernel\ndef g2p(state: MPMStateStruct, model: MPMModelStruct, dt: float):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n        new_v = wp.vec3(0.0, 0.0, 0.0)\n        new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    grid_v = state.grid_v_out[ix, iy, iz]\n                    new_v = new_v + grid_v * weight\n                    new_C = new_C + wp.outer(grid_v, dpos) * (\n                        weight * model.inv_dx * 4.0\n                    )\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n                    new_F = new_F + wp.outer(grid_v, dweight)\n\n        state.particle_v[p] = new_v\n        # state.particle_x[p] = state.particle_x[p] + dt * new_v\n        # state.particle_x[p] = state.particle_x[p] + dt * state.particle_v[p]\n        wp.atomic_add(state.particle_x, p, dt * state.particle_v[p])\n        state.particle_C[p] = new_C\n        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n        F_tmp = (I33 + new_F * dt) * state.particle_F[p]\n        state.particle_F_trial[p] = F_tmp\n\n        if model.update_cov_with_F:\n            update_cov(state, p, new_F, dt)\n\n\n# compute (Kirchhoff) stress = stress(returnMap(F_trial))\n@wp.kernel\ndef compute_stress_from_F_trial(\n    state: MPMStateStruct, model: MPMModelStruct, dt: float\n):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        # apply return mapping\n        if model.material == 1:  # metal\n            state.particle_F[p] = von_mises_return_mapping(\n                state.particle_F_trial[p], model, p\n            )\n        elif model.material == 2:  # sand\n            state.particle_F[p] = sand_return_mapping(\n                state.particle_F_trial[p], state, model, p\n            )\n        elif model.material == 3:  # visplas, with StVk+VM, no thickening\n            state.particle_F[p] = viscoplasticity_return_mapping_with_StVK(\n                state.particle_F_trial[p], model, p, dt\n            )\n        elif model.material == 5:\n            state.particle_F[p] = von_mises_return_mapping_with_damage(\n                state.particle_F_trial[p], model, p\n            )\n        else:  # elastic, jelly\n            state.particle_F[p] = state.particle_F_trial[p]\n\n        # also compute stress here\n        J = wp.determinant(state.particle_F[p])\n        U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        sig = wp.vec3(0.0)\n        stress = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        wp.svd3(state.particle_F[p], U, sig, V)\n        if model.material == 0 or model.material == 5:\n            stress = kirchoff_stress_FCR(\n                state.particle_F[p], U, V, J, model.mu[p], model.lam[p]\n            )\n        if model.material == 1:\n            stress = kirchoff_stress_StVK(\n                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]\n            )\n        if model.material == 2:\n            stress = kirchoff_stress_drucker_prager(\n                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]\n            )\n        if model.material == 3:\n            # temporarily use stvk, subject to change\n            stress = kirchoff_stress_StVK(\n                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]\n            )\n\n        # stress = (stress + wp.transpose(stress)) / 2.0  # enfore symmetry\n        state.particle_stress[p] = (stress + wp.transpose(stress)) / 2.0\n\n\n@wp.kernel\ndef compute_cov_from_F(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n\n    F = state.particle_F_trial[p]\n\n    init_cov = wp.mat33(0.0)\n    init_cov[0, 0] = state.particle_init_cov[p * 6]\n    init_cov[0, 1] = state.particle_init_cov[p * 6 + 1]\n    init_cov[0, 2] = state.particle_init_cov[p * 6 + 2]\n    init_cov[1, 0] = state.particle_init_cov[p * 6 + 1]\n    init_cov[1, 1] = state.particle_init_cov[p * 6 + 3]\n    init_cov[1, 2] = state.particle_init_cov[p * 6 + 4]\n    init_cov[2, 0] = state.particle_init_cov[p * 6 + 2]\n    init_cov[2, 1] = state.particle_init_cov[p * 6 + 4]\n    init_cov[2, 2] = state.particle_init_cov[p * 6 + 5]\n\n    cov = F * init_cov * wp.transpose(F)\n\n    state.particle_cov[p * 6] = cov[0, 0]\n    state.particle_cov[p * 6 + 1] = cov[0, 1]\n    state.particle_cov[p * 6 + 2] = cov[0, 2]\n    state.particle_cov[p * 6 + 3] = cov[1, 1]\n    state.particle_cov[p * 6 + 4] = cov[1, 2]\n    state.particle_cov[p * 6 + 5] = cov[2, 2]\n\n\n@wp.kernel\ndef compute_R_from_F(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n\n    F = state.particle_F_trial[p]\n\n    # polar svd decomposition\n    U = wp.mat33(0.0)\n    V = wp.mat33(0.0)\n    sig = wp.vec3(0.0)\n    wp.svd3(F, U, sig, V)\n\n    if wp.determinant(U) < 0.0:\n        U[0, 2] = -U[0, 2]\n        U[1, 2] = -U[1, 2]\n        U[2, 2] = -U[2, 2]\n\n    if wp.determinant(V) < 0.0:\n        V[0, 2] = -V[0, 2]\n        V[1, 2] = -V[1, 2]\n        V[2, 2] = -V[2, 2]\n\n    # compute rotation matrix\n    R = U * wp.transpose(V)\n    state.particle_R[p] = wp.transpose(R)\n\n\n@wp.kernel\ndef add_damping_via_grid(state: MPMStateStruct, scale: float):\n    grid_x, grid_y, grid_z = wp.tid()\n    state.grid_v_out[grid_x, grid_y, grid_z] = (\n        state.grid_v_out[grid_x, grid_y, grid_z] * scale\n    )\n\n\n@wp.kernel\ndef apply_additional_params(\n    state: MPMStateStruct,\n    model: MPMModelStruct,\n    params_modifier: MaterialParamsModifier,\n):\n    p = wp.tid()\n    pos = state.particle_x[p]\n    if (\n        pos[0] > params_modifier.point[0] - params_modifier.size[0]\n        and pos[0] < params_modifier.point[0] + params_modifier.size[0]\n        and pos[1] > params_modifier.point[1] - params_modifier.size[1]\n        and pos[1] < params_modifier.point[1] + params_modifier.size[1]\n        and pos[2] > params_modifier.point[2] - params_modifier.size[2]\n        and pos[2] < params_modifier.point[2] + params_modifier.size[2]\n    ):\n        model.E[p] = params_modifier.E\n        model.nu[p] = params_modifier.nu\n        state.particle_density[p] = params_modifier.density\n\n\n@wp.kernel\ndef selection_add_impulse_on_particles(\n    state: MPMStateStruct, impulse_modifier: Impulse_modifier\n):\n    p = wp.tid()\n    offset = state.particle_x[p] - impulse_modifier.point\n    if (\n        wp.abs(offset[0]) < impulse_modifier.size[0]\n        and wp.abs(offset[1]) < impulse_modifier.size[1]\n        and wp.abs(offset[2]) < impulse_modifier.size[2]\n    ):\n        impulse_modifier.mask[p] = 1\n    else:\n        impulse_modifier.mask[p] = 0\n\n\n@wp.kernel\ndef selection_enforce_particle_velocity_translation(\n    state: MPMStateStruct, velocity_modifier: ParticleVelocityModifier\n):\n    p = wp.tid()\n    offset = state.particle_x[p] - velocity_modifier.point\n    if (\n        wp.abs(offset[0]) < velocity_modifier.size[0]\n        and wp.abs(offset[1]) < velocity_modifier.size[1]\n        and wp.abs(offset[2]) < velocity_modifier.size[2]\n    ):\n        velocity_modifier.mask[p] = 1\n    else:\n        velocity_modifier.mask[p] = 0\n\n\n@wp.kernel\ndef selection_enforce_particle_velocity_cylinder(\n    state: MPMStateStruct, velocity_modifier: ParticleVelocityModifier\n):\n    p = wp.tid()\n    offset = state.particle_x[p] - velocity_modifier.point\n\n    vertical_distance = wp.abs(wp.dot(offset, velocity_modifier.normal))\n\n    horizontal_distance = wp.length(\n        offset - wp.dot(offset, velocity_modifier.normal) * velocity_modifier.normal\n    )\n    if (\n        vertical_distance < velocity_modifier.half_height_and_radius[0]\n        and horizontal_distance < velocity_modifier.half_height_and_radius[1]\n    ):\n        velocity_modifier.mask[p] = 1\n    else:\n        velocity_modifier.mask[p] = 0\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/run_gaussian.py",
    "content": "import time\nimport numpy as np\nfrom fire import Fire\nimport os\nimport warp as wp\nfrom mpm_solver_warp import MPM_Simulator_WARP\nfrom engine_utils import *\nimport torch\nfrom tqdm import tqdm\n\n\ndef load_gaussians(input_dir: str = None):\n    name_dict = {\n        \"position\": \"pos.npy\",  # [T, N, 3]\n        \"rotation\": \"rot.npy\",  # [T, N, 4]\n        \"cov\": \"cov.npy\",  # [T, N, 6]\n    }\n\n    assert os.path.exists(input_dir), \"Input directory does not exist\"\n\n    ret_dict = {}\n\n    for key, value in name_dict.items():\n        ret_dict[key] = np.load(os.path.join(input_dir, value))\n\n    pos_max = ret_dict[\"position\"].max()\n    pos_min = ret_dict[\"position\"].min()\n\n    # ret_dict[\"position\"] = (ret_dict[\"position\"]) / (pos_max - pos_min) * 0.5 - pos_min / (pos_max - pos_min) * 0.5 + 0.1\n    # ret_dict[\"cov\"] = ret_dict[\"cov\"] / (pos_max - pos_min) * 0.5\n    scale = (pos_max - pos_min) * 2.0\n    shift = -pos_min\n\n    ret_dict[\"position\"] = (ret_dict[\"position\"] + shift) / scale\n    ret_dict[\"cov\"] = ret_dict[\"cov\"] / scale\n\n    # pos_new = (pos + shift ) / scale\n    # pos_orign = pos_new * scale  - shift\n    return ret_dict, scale, shift\n\n\ndef init_volume(xyz, grid=[-1, 1], num_grid=20):\n    pass\n\n\ndef run_mpm_gaussian(input_dir, output_dir=None, fps=6, device=0):\n    wp.init()\n    wp.config.verify_cuda = True\n\n    device = \"cuda:{}\".format(device)\n\n    gaussian_dict, scale, shift = load_gaussians(input_dir)\n\n    velocity_scaling = 10\n    velocity = (\n        (gaussian_dict[\"position\"][1:] - gaussian_dict[\"position\"][:-1])\n        / fps\n        * velocity_scaling\n    )\n\n    velocity_abs = np.abs(velocity)\n    print(\n        \"velocity mean-max-min\",\n        velocity_abs.mean(),\n        velocity_abs.max(),\n        velocity_abs.min(),\n    )\n\n    init_velocity = velocity[0]\n    init_position = gaussian_dict[\"position\"][0]\n    init_rotation = gaussian_dict[\"rotation\"][0]\n    init_cov = gaussian_dict[\"cov\"][0]\n    tensor_init_pos = torch.from_numpy(init_position).float().to(device)\n    tensor_init_cov = torch.from_numpy(init_cov).float().to(device)\n    tensor_init_velocity = torch.from_numpy(init_velocity).float().to(device)\n\n    # print(tensor_init_pos.max(), tensor_init_pos.min(), tensor_init_pos.shape)\n\n    mpm_solver = MPM_Simulator_WARP(\n        10\n    )  # initialize with whatever number is fine. it will be reintialized\n\n    # TODO, Compute volume later\n    volume_tensor = (\n        torch.ones(\n            init_velocity.shape[0],\n        )\n        * 2.5e-8  # m^3\n    )\n\n    mpm_solver.load_initial_data_from_torch(\n        tensor_init_pos,\n        volume_tensor,\n        tensor_init_cov,\n        tensor_init_velocity,\n        device=device,\n    )\n    # mpm_solver.load_initial_data_from_torch(\n    #     tensor_init_pos, volume_tensor, device=device\n    # )\n\n    position_tensor = mpm_solver.export_particle_x_to_torch()\n    velo = wp.to_torch(mpm_solver.mpm_state.particle_v)\n    cov = wp.to_torch(mpm_solver.mpm_state.particle_init_cov)\n    print(\n        \"pos in box: \",\n        position_tensor.max(),\n        position_tensor.min(),\n    )\n\n    material_params = {\n        \"E\": 0.0002,  # 0.1-200 MPa\n        \"nu\": 0.4,  # > 0.35\n        \"material\": \"jelly\",\n        # \"friction_angle\": 25,\n        \"g\": [0.0, 0.0, 0],\n        \"density\": 1,  # kg / m^3\n    }\n\n    print(\"pre set\")\n    mpm_solver.set_parameters_dict(material_params)\n    print(\"set\")\n    mpm_solver.finalize_mu_lam()  # set mu and lambda from the E and nu input\n    print(\"finalize\")\n    # mpm_solver.add_surface_collider((0.0, 0.0, 0.13), (0.0,0.0,1.0), 'sticky', 0.0)\n\n    if output_dir is None:\n        output_dir = \"./gaussian_sim_results\"\n    os.makedirs(output_dir, exist_ok=True)\n\n    # save_data_at_frame(mpm_solver, output_dir, 0, save_to_ply=True, save_to_h5=False)\n    pos_list = []\n    pos = mpm_solver.export_particle_x_to_torch().clone()\n    pos = (pos * scale) - shift\n    pos_list.append(pos.detach().clone())\n\n    total_time = 20\n    time_step = 0.002\n    total_iters = int(total_time / time_step)\n\n    for k in tqdm(range(1, total_iters)):\n        mpm_solver.p2g2p(k, time_step, device=device)\n\n        if k % 50 == 0:\n            pos = mpm_solver.export_particle_x_to_torch().clone()\n            pos = (pos * scale) - shift\n            pos_list.append(pos.detach().clone())\n            print(k)\n            print(pos.max().item(), pos.min().item(), pos.mean().item())\n        # save_data_at_frame(mpm_solver, output_dir, k, save_to_ply=True, save_to_h5=False)\n\n    save_name = \"\"\n    for key, value in material_params.items():\n        if key == \"g\":\n            continue\n        save_name += \"{}_{}_\".format(key, value)\n\n    save_name += \"_timestep_{}_vs{}_totaltime_{}\".format(\n        time_step, velocity_scaling, total_time\n    )\n\n    render_gaussians(pos_list, save_name)\n\n\ndef code_test(input_dir, device=0):\n    device = \"cuda:{}\".format(device)\n    gaussian_dict, scale, shift = load_gaussians(input_dir)\n    pos = gaussian_dict[\"position\"]\n\n    pos = (pos * scale) - shift\n\n    pos = torch.from_numpy(pos).float().to(device)\n\n    render_gaussians(pos)\n\n\ndef render_gaussians(\n    pos_list,\n    save_name=None,\n    dataset_dir=\"../../data/physics_dreamer/llff_flower_undistorted\",\n):\n    from motionrep.data.datasets.multiview_dataset import MultiviewImageDataset\n    from motionrep.data.datasets.multiview_dataset import (\n        camera_dataset_collate_fn as camera_dataset_collate_fn_img,\n    )\n\n    from motionrep.gaussian_3d.gaussian_renderer.render import render_gaussian\n    from motionrep.gaussian_3d.scene import GaussianModel\n    from typing import NamedTuple\n\n    gaussian_path = os.path.join(dataset_dir, \"point_cloud.ply\")\n    test_dataset = MultiviewImageDataset(\n        dataset_dir,\n        use_white_background=False,\n        resolution=[576, 1024],\n        use_index=list(range(5, 30, 4)),\n    )\n    print(\n        \"len of train dataset\",\n        len(test_dataset),\n        \"len of test dataset\",\n        len(test_dataset),\n    )\n    test_dataloader = torch.utils.data.DataLoader(\n        test_dataset,\n        batch_size=1,\n        shuffle=False,\n        drop_last=True,\n        num_workers=0,\n        collate_fn=camera_dataset_collate_fn_img,\n    )\n\n    class RenderPipe(NamedTuple):\n        convert_SHs_python = False\n        compute_cov3D_python = False\n        debug = False\n\n    class RenderParams(NamedTuple):\n        render_pipe: RenderPipe\n        bg_color: bool\n        gaussians: GaussianModel\n        camera_list: list\n\n    gaussians = GaussianModel(3)\n    camera_list = test_dataset.camera_list\n\n    gaussians.load_ply(gaussian_path)\n    gaussians.detach_grad()\n    print(\n        \"load gaussians from: {}\".format(gaussian_path),\n        \"... num gaussians: \",\n        gaussians._xyz.shape[0],\n    )\n    bg_color = [1, 1, 1] if False else [0, 0, 0]\n    background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n    render_pipe = RenderPipe()\n\n    render_params = RenderParams(\n        render_pipe=render_pipe,\n        bg_color=background,\n        gaussians=gaussians,\n        camera_list=camera_list,\n    )\n\n    data = next(iter(test_dataloader))\n    cam = data[\"cam\"][0]\n\n    ret_img_list = []\n\n    for i in range(len(pos_list) + 1):\n        if i > 0:\n            xyz = pos_list[i - 1]\n            gaussians._xyz = xyz\n\n        img = render_gaussian(\n            cam,\n            gaussians,\n            render_params.render_pipe,\n            background,\n        )[\"render\"]\n\n        ret_img_list.append(img)\n\n    # [T, C, H, W]\n    video_array = torch.stack(ret_img_list, dim=0)\n    video_numpy = video_array.detach().cpu().numpy() * 255\n    video_numpy = np.clip(video_numpy, 0, 255).astype(np.uint8)\n\n    video_numpy = np.transpose(video_numpy, [0, 2, 3, 1])\n    from motionrep.utils.io_utils import save_video_imageio\n\n    if save_name is None:\n        save_path = \"test.mp4\"\n    else:\n        save_path = save_name + \".mp4\"\n    print(\"save video to \", save_path)\n    save_video_imageio(save_path, video_numpy, fps=10)\n\n\nif __name__ == \"__main__\":\n    Fire(run_mpm_gaussian)\n    # Fire(code_test)\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/run_gaussian_static.py",
    "content": "import time\nimport numpy as np\nfrom fire import Fire\nimport os\nimport warp as wp\nfrom mpm_solver_warp import MPM_Simulator_WARP\nfrom engine_utils import *\nimport torch\nfrom tqdm import tqdm\nfrom motionrep.gaussian_3d.scene import GaussianModel\n\n\ndef load_gaussians(input_dir: str = None):\n    if not input_dir.endswith(\"ply\"):\n        gaussian_path = os.path.join(input_dir, \"point_cloud.ply\")\n    else:\n        gaussian_path = input_dir\n\n    gaussians = GaussianModel(3)\n\n    gaussians.load_ply(gaussian_path)\n    gaussians.detach_grad()\n\n    pos = gaussians._xyz.detach().cpu().numpy()\n\n    pos_max = pos.max()\n    pos_min = pos.min()\n\n    scale = (pos_max - pos_min) * 2.5\n    shift = -pos_min + (pos_max - pos_min) * 0.25\n\n    pos = (pos + shift) / scale\n\n    cov = gaussians.get_covariance().detach().cpu().numpy()\n    cov = cov / scale\n\n    velocity = np.zeros_like(pos)\n\n    height = pos[:, 2] - pos[:, 2].min()\n    height_thres = 10\n    velocity_mask = height > np.percentile(height, height_thres)\n\n    static_points = pos[np.logical_not(velocity_mask)]\n\n    static_points_mean = static_points.mean(axis=0)\n    static_points_dist = static_points - static_points_mean\n    max_static_offset = np.abs(static_points_dist).max(axis=0) * 0.8\n\n    # boundary condition, set velocity to 0\n\n    # x_velocity = np.sqrt(np.abs(pos[:, 0]) + 1e-8) * np.sign(pos[:, 0])\n    # x_velocity = np.sqrt(height) * 0.1\n    x_velocity = height**0.2 * 0.1\n    velocity[velocity_mask, 0] = x_velocity[velocity_mask]\n    velocity[velocity_mask, 1] = x_velocity[velocity_mask]\n\n    ret_dict = {\n        \"position\": pos,  # numpy [N, 3]\n        \"cov\": cov,  # numpy [N, 6]\n        \"velocity\": velocity,  # numpy [N, 3]\n        \"satic_center_point\": static_points_mean,  # numpy [3]\n        \"max_static_offset\": max_static_offset,  # numpy [3]\n    }\n\n    return ret_dict, scale, shift\n\n\ndef get_volume(xyzs: np.ndarray, resolution=128) -> np.ndarray:\n    print(\"Compute Volume for each points\")\n    voxel_counts = np.zeros((resolution, resolution, resolution))\n\n    points_xyzindex = ((xyzs + 1) / 2 * (resolution - 1)).astype(np.uint32)\n\n    for x, y, z in points_xyzindex:\n        voxel_counts[x, y, z] += 1\n\n    points_number_in_corresponding_voxel = voxel_counts[\n        points_xyzindex[:, 0], points_xyzindex[:, 1], points_xyzindex[:, 2]\n    ]\n\n    cell_volume = (2.0 / (resolution - 1)) ** 3\n\n    points_volume = cell_volume / points_number_in_corresponding_voxel\n\n    points_volume = points_volume.astype(np.float32)\n\n    print(\n        \"mean volume\",\n        points_volume.mean(),\n        \"max volume\",\n        points_volume.max(),\n        \"min volume\",\n        points_volume.min(),\n    )\n\n    return points_volume\n\n\ndef run_mpm_gaussian(input_dir, output_dir=None, fps=6, device=0):\n    wp.init()\n    wp.config.verify_cuda = True\n\n    device = \"cuda:{}\".format(device)\n\n    gaussian_dict, scale, shift = load_gaussians(input_dir)\n\n    velocity_scaling = 0.5\n\n    init_velocity = velocity_scaling * gaussian_dict[\"velocity\"]\n    init_position = gaussian_dict[\"position\"]\n    init_cov = gaussian_dict[\"cov\"]\n\n    volume_array_path = os.path.join(input_dir, \"volume_array.npy\")\n    if os.path.exists(volume_array_path):\n        volume_tensor = torch.from_numpy(np.load(volume_array_path)).float().to(device)\n    else:\n        volume_array = get_volume(init_position)\n        np.save(volume_array_path, volume_array)\n        volume_tensor = torch.from_numpy(volume_array).float().to(device)\n\n    tensor_init_pos = torch.from_numpy(init_position).float().to(device)\n    tensor_init_cov = torch.from_numpy(init_cov).float().to(device)\n    tensor_init_velocity = torch.from_numpy(init_velocity).float().to(device)\n\n    print(\n        \"init position:\",\n        tensor_init_pos.max(),\n        tensor_init_pos.min(),\n        tensor_init_pos.shape,\n    )\n    velocity_abs = np.abs(init_velocity)\n    print(\n        \"velocity mean-max-min\",\n        velocity_abs.mean(),\n        velocity_abs.max(),\n        velocity_abs.min(),\n    )\n\n    mpm_solver = MPM_Simulator_WARP(\n        10\n    )  # initialize with whatever number is fine. it will be reintialized\n\n    mpm_solver.load_initial_data_from_torch(\n        tensor_init_pos,\n        volume_tensor,\n        tensor_init_cov,\n        tensor_init_velocity,\n        device=device,\n    )\n    # mpm_solver.load_initial_data_from_torch(\n    #     tensor_init_pos, volume_tensor, device=device\n    # )\n\n    # set boundary conditions\n    static_center_point = (\n        torch.from_numpy(gaussian_dict[\"satic_center_point\"]).float().to(device)\n    )\n    max_static_offset = (\n        torch.from_numpy(gaussian_dict[\"max_static_offset\"]).float().to(device)\n    )\n    velocity = torch.zeros_like(static_center_point)\n    mpm_solver.enforce_particle_velocity_translation(\n        static_center_point,\n        max_static_offset,\n        velocity,\n        start_time=0,\n        end_time=1000,\n        device=device,\n    )\n\n    position_tensor = mpm_solver.export_particle_x_to_torch()\n    velo = wp.to_torch(mpm_solver.mpm_state.particle_v)\n    cov = wp.to_torch(mpm_solver.mpm_state.particle_init_cov)\n    print(\n        \"pos in box: \",\n        position_tensor.max(),\n        position_tensor.min(),\n    )\n\n    material_params = {\n        \"E\": 0.2,  # 0.1-200 MPa\n        \"nu\": 0.1,  # > 0.35\n        \"material\": \"jelly\",\n        # \"material\": \"metal\",\n        # \"friction_angle\": 25,\n        \"g\": [0.0, 0.0, 0],\n        \"density\": 0.2,  # kg / m^3\n    }\n\n    print(\"pre set\")\n    mpm_solver.set_parameters_dict(material_params)\n    print(\"set\")\n    mpm_solver.finalize_mu_lam()  # set mu and lambda from the E and nu input\n    print(\"finalize\")\n    # mpm_solver.add_surface_collider((0.0, 0.0, 0.13), (0.0,0.0,1.0), 'sticky', 0.0)\n\n    if output_dir is None:\n        output_dir = \"../../output/gaussian_sim_results\"\n    os.makedirs(output_dir, exist_ok=True)\n\n    # save_data_at_frame(mpm_solver, output_dir, 0, save_to_ply=True, save_to_h5=False)\n    pos_list = []\n    pos = mpm_solver.export_particle_x_to_torch().clone()\n    pos = (pos * scale) - shift\n    pos_list.append(pos.detach().clone())\n\n    total_time = 10\n    time_step = 0.001\n    total_iters = int(total_time / time_step)\n\n    save_dict = {\n        \"pos_init\": mpm_solver.export_particle_x_to_torch()\n        .clone()\n        .detach()\n        .cpu()\n        .numpy(),\n        \"velo_init\": mpm_solver.export_particle_v_to_torch()\n        .clone()\n        .detach()\n        .cpu()\n        .numpy(),\n        \"pos_list\": [],\n    }\n\n    for k in tqdm(range(1, total_iters)):\n        mpm_solver.p2g2p(k, time_step, device=device)\n\n        if k < 20:\n            pos = mpm_solver.export_particle_x_to_torch().clone().detach().cpu().numpy()\n            save_dict[\"pos_list\"].append(pos)\n\n        if k % 100 == 0:\n            pos = mpm_solver.export_particle_x_to_torch().clone()\n            pos = (pos * scale) - shift\n            pos_list.append(pos.detach().clone())\n            print(k)\n            print(pos.max().item(), pos.min().item(), pos.mean().item())\n        # save_data_at_frame(mpm_solver, output_dir, k, save_to_ply=True, save_to_h5=False)\n\n    save_name = \"\"\n    for key, value in material_params.items():\n        if key == \"g\":\n            continue\n        save_name += \"{}_{}_\".format(key, value)\n\n    save_name += \"_timestep_{}_vs{}_totaltime_{}\".format(\n        time_step, velocity_scaling, total_time\n    )\n\n    render_gaussians(pos_list, save_name)\n\n    # save sim data:\n    save_path = os.path.join(output_dir, save_name + \".pkl\")\n    import pickle\n\n    with open(save_path, \"wb\") as f:\n        pickle.dump(save_dict, f)\n\n\ndef code_test(input_dir, device=0):\n    device = \"cuda:{}\".format(device)\n    gaussian_dict, scale, shift = load_gaussians(input_dir)\n    pos = gaussian_dict[\"position\"]\n\n    pos = (pos * scale) - shift\n\n    pos = torch.from_numpy(pos).float().to(device)\n\n    render_gaussians(pos)\n\n\ndef render_gaussians(\n    pos_list,\n    save_name=None,\n    # dataset_dir=\"../../data/physics_dreamer/llff_flower_undistorted\",\n    dataset_dir=\"../../data/physics_dreamer/ficus\",\n):\n    from motionrep.data.datasets.multiview_dataset import MultiviewImageDataset\n    from motionrep.data.datasets.multiview_dataset import (\n        camera_dataset_collate_fn as camera_dataset_collate_fn_img,\n    )\n\n    from motionrep.gaussian_3d.gaussian_renderer.render import render_gaussian\n    from motionrep.gaussian_3d.scene import GaussianModel\n    from typing import NamedTuple\n\n    gaussian_path = os.path.join(dataset_dir, \"point_cloud.ply\")\n    test_dataset = MultiviewImageDataset(\n        dataset_dir,\n        use_white_background=False,\n        resolution=[576, 1024],\n        use_index=list(range(5, 30, 4)),\n        scale_x_angle=1.5,\n    )\n    print(\n        \"len of train dataset\",\n        len(test_dataset),\n        \"len of test dataset\",\n        len(test_dataset),\n    )\n    test_dataloader = torch.utils.data.DataLoader(\n        test_dataset,\n        batch_size=1,\n        shuffle=False,\n        drop_last=True,\n        num_workers=0,\n        collate_fn=camera_dataset_collate_fn_img,\n    )\n\n    class RenderPipe(NamedTuple):\n        convert_SHs_python = False\n        compute_cov3D_python = False\n        debug = False\n\n    class RenderParams(NamedTuple):\n        render_pipe: RenderPipe\n        bg_color: bool\n        gaussians: GaussianModel\n        camera_list: list\n\n    gaussians = GaussianModel(3)\n    camera_list = test_dataset.camera_list\n\n    gaussians.load_ply(gaussian_path)\n    gaussians.detach_grad()\n    print(\n        \"load gaussians from: {}\".format(gaussian_path),\n        \"... num gaussians: \",\n        gaussians._xyz.shape[0],\n    )\n    bg_color = [1, 1, 1] if False else [0, 0, 0]\n    background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n    render_pipe = RenderPipe()\n\n    render_params = RenderParams(\n        render_pipe=render_pipe,\n        bg_color=background,\n        gaussians=gaussians,\n        camera_list=camera_list,\n    )\n\n    data = next(iter(test_dataloader))\n    cam = data[\"cam\"][0]\n\n    ret_img_list = []\n\n    for i in range(len(pos_list) + 1):\n        if i > 0:\n            xyz = pos_list[i - 1]\n            gaussians._xyz = xyz\n\n        img = render_gaussian(\n            cam,\n            gaussians,\n            render_params.render_pipe,\n            background,\n        )[\"render\"]\n\n        ret_img_list.append(img)\n\n    # [T, C, H, W]\n    video_array = torch.stack(ret_img_list, dim=0)\n    video_numpy = video_array.detach().cpu().numpy() * 255\n    video_numpy = np.clip(video_numpy, 0, 255).astype(np.uint8)\n\n    video_numpy = np.transpose(video_numpy, [0, 2, 3, 1])\n    from motionrep.utils.io_utils import save_video_imageio\n\n    if save_name is None:\n        save_path = \"output/test.mp4\"\n    else:\n        save_path = os.path.join(\"output\", save_name + \".mp4\")\n    print(\"save video to \", save_path)\n    save_video_imageio(save_path, video_numpy, fps=10)\n\n\nif __name__ == \"__main__\":\n    Fire(run_mpm_gaussian)\n    # Fire(code_test)\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/run_sand.py",
    "content": "\nimport warp as wp\nfrom mpm_solver_warp import MPM_Simulator_WARP\nfrom engine_utils import *\nimport torch\nwp.init()\nwp.config.verify_cuda = True\n\n\ndvc = \"cuda:0\"\n\nmpm_solver = MPM_Simulator_WARP(10) # initialize with whatever number is fine. it will be reintialized\n\n\n# You can either load sampling data from an external h5 file, containing initial position (n,3) and particle_volume (n,)\nmpm_solver.load_from_sampling(\"sand_column.h5\", n_grid = 150, device=dvc)\n\n# Or load from torch tensor (also position and volume)\n# Here we borrow the data from h5, but you can use your own\n# [N]\nvolume_tensor = torch.ones(mpm_solver.n_particles) * 2.5e-8\n\n# torch.float32, [N, 3]\nposition_tensor = mpm_solver.export_particle_x_to_torch()\nprint(position_tensor.max(), position_tensor.min())\n\nmpm_solver.load_initial_data_from_torch(position_tensor, volume_tensor)\nprint(position_tensor.shape, position_tensor.dtype, volume_tensor.shape, volume_tensor.dtype)\n\n# Note: You must provide 'density=..' to set particle_mass = density * particle_volume\n\nmaterial_params = {\n    'E': 2000,\n    'nu': 0.2,\n    \"material\": \"sand\",\n    'friction_angle': 35,\n    'g': [0.0, 0.0, -4.0],\n    \"density\": 200.0\n}\nmpm_solver.set_parameters_dict(material_params)\n\nmpm_solver.finalize_mu_lam() # set mu and lambda from the E and nu input\n\nmpm_solver.add_surface_collider((0.0, 0.0, 0.13), (0.0,0.0,1.0), 'sticky', 0.0)\n\nfrom IPython import embed\n# embed()\n\ndirectory_to_save = './sim_results'\n\nsave_data_at_frame(mpm_solver, directory_to_save, 0, save_to_ply=True, save_to_h5=False)\n\nfor k in range(1,50):\n    mpm_solver.p2g2p(k, 0.002, device=dvc)\n    save_data_at_frame(mpm_solver, directory_to_save, k, save_to_ply=True, save_to_h5=False)\n\n\n\n# extract the position, make some changes, load it back\nposition = mpm_solver.export_particle_x_to_torch()\n# e.g. we shift the x position\nposition[:,0] = position[:,0] + 0.1\nmpm_solver.import_particle_x_from_torch(position)\n# keep running sim\nfor k in range(50,100):\n\n    mpm_solver.p2g2p(k, 0.002, device=dvc)\n    save_data_at_frame(mpm_solver, directory_to_save, k, save_to_ply=True, save_to_h5=False)\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/sim_grad.py",
    "content": "import warp as wp\nimport numpy as np\nimport torch\nimport os\nfrom mpm_solver_warp_diff import MPM_Simulator_WARPDiff\nfrom run_gaussian_static import load_gaussians, get_volume\nfrom tqdm import tqdm \nfrom fire import Fire \nfrom mpm_utils import *\n\nfrom typing import Optional\n\nimport warp as wp\n\n\nclass MyTape(wp.Tape):\n\n    # returns the adjoint of a kernel parameter\n    def get_adjoint(self, a):\n        if not wp.types.is_array(a) and not isinstance(a, wp.codegen.StructInstance):\n            # if input is a simple type (e.g.: float, vec3, etc) then\n            # no gradient needed (we only return gradients through arrays and structs)\n            print(\"input is a simple type\", type(a))\n            return a\n\n        elif wp.types.is_array(a) and a.grad:\n            # keep track of all gradients used by the tape (for zeroing)\n            # ignore the scalar loss since we don't want to clear its grad\n            self.gradients[a] = a.grad\n            return a.grad\n\n        elif isinstance(a, wp.codegen.StructInstance):\n            adj = a._cls()\n            for name, _ in a._cls.ctype._fields_:\n                if name.startswith(\"_\"):\n                    continue\n                if isinstance(a._cls.vars[name].type, wp.array):\n                    arr = getattr(a, name)\n                    if arr is None:\n                        continue\n                    if arr.grad:\n                        grad = self.gradients[arr] = arr.grad\n                    else:\n                        grad = wp.zeros_like(arr)\n                    setattr(adj, name, grad)\n                else:\n                    setattr(adj, name, getattr(a, name))\n\n            self.gradients[a] = adj\n            return adj\n\n        return None\n\n\n\ndef test(input_dir, output_dir=None, fps=6, device=0):\n    wp.init()\n    wp.config.verify_cuda = True\n\n    device = \"cuda:{}\".format(device)\n\n    gaussian_dict, scale, shift = load_gaussians(input_dir)\n\n    velocity_scaling = 0.5\n    init_velocity = velocity_scaling * gaussian_dict[\"velocity\"]\n    init_position = gaussian_dict[\"position\"]\n    init_cov = gaussian_dict[\"cov\"]\n\n    volume_array_path = os.path.join(input_dir, \"volume_array.npy\")\n    if os.path.exists(volume_array_path):\n        volume_tensor = torch.from_numpy(np.load(volume_array_path)).float().to(device)\n    else:\n        volume_array = get_volume(init_position)\n        np.save(volume_array_path, volume_array)\n        volume_tensor = torch.from_numpy(volume_array).float().to(device)\n\n    tensor_init_pos = torch.from_numpy(init_position).float().to(device)\n    tensor_init_cov = torch.from_numpy(init_cov).float().to(device)\n    tensor_init_velocity = torch.from_numpy(init_velocity).float().to(device)\n\n    mpm_solver = MPM_Simulator_WARPDiff(10)  # initialize with whatever number is fine. it will be reintialized\n\n    tensor_init_pos.requires_grad = True\n    tensor_init_cov.requires_grad = False\n    tensor_init_velocity.requires_grad = True\n    mpm_solver.load_initial_data_from_torch(\n        tensor_init_pos,\n        volume_tensor,\n        tensor_init_cov,\n        tensor_init_velocity,\n        device=device,\n    )\n    mpm_solver.mpm_state.particle_x = wp.from_numpy(init_position, dtype=wp.vec3, requires_grad=True, device=device)\n    mpm_solver.mpm_state.particle_v = wp.from_numpy(init_velocity, dtype=wp.vec3, requires_grad=True, device=device)\n\n    # set boundary conditions\n    static_center_point = (\n        torch.from_numpy(gaussian_dict[\"satic_center_point\"]).float().to(device)\n    )\n    max_static_offset = (\n        torch.from_numpy(gaussian_dict[\"max_static_offset\"]).float().to(device)\n    )\n    velocity = torch.zeros_like(static_center_point)\n    # mpm_solver.enforce_particle_velocity_translation(static_center_point, max_static_offset, velocity, \n    #                                                  start_time=0, end_time=1000, device=device)\n\n    material_params = {\n        \"E\": 2.0,  # 0.1-200 MPa\n        \"nu\": 0.1,  # > 0.35\n        \"material\": \"jelly\",\n        # \"material\": \"metal\",\n        # \"friction_angle\": 25,\n        \"g\": [0.0, 0.0, 0],\n        \"density\": 0.02,  # kg / m^3\n    }\n\n    print(\"pre set\")\n    mpm_solver.set_parameters_dict(material_params)\n    print(\"set\")\n    mpm_solver.finalize_mu_lam()  # set mu and lambda from the E and nu input\n    print(\"finalize\")\n\n    total_time = 0.1\n    time_step = 0.01\n    total_iters = int(total_time / time_step)\n    total_iters = 3\n    loss = torch.zeros(1, device=device)\n    loss = wp.from_torch(loss, requires_grad=True)\n\n    E_tensor = (torch.ones(velocity.shape[0]) * 2.0).contiguous().to(device)\n    nu_tensor = (torch.ones(velocity.shape[0]) * 0.1).contiguous().to(device)\n    E_warp = wp.from_torch(E_tensor, requires_grad=True)\n    nu_warp = wp.from_torch(nu_tensor, requires_grad=True)\n\n    mpm_solver.set_require_grad()\n\n    dt = time_step\n    # from IPython import embed; embed()\n    with tape:\n        # mpm_solver.reset_material(E_warp, nu_warp, device=device)\n        # for k in tqdm(range(1, total_iters)):\n        # mpm_solver.p2g2p(k, time_step, device=device)\n\n        wp.launch(\n            kernel=g2p_test,\n            dim=mpm_solver.n_particles,\n            inputs=[mpm_solver.mpm_state, mpm_solver.mpm_model, dt],\n            device=device,\n        )  # x, v, C, F_trial are updated\n        # wp.launch(position_loss_kernel, dim=mpm_solver.n_particles,  inputs=[mpm_solver.mpm_state, loss], device=device)\n        for i in range(2):\n            # wp.launch(position_loss_kernel, dim=mpm_solver.n_particles,  inputs=[mpm_solver.mpm_state, loss], device=device)\n            wp.launch(position_loss_kernel, dim=mpm_solver.n_particles,  inputs=[mpm_state, loss], device=device)\n            # wp.launch(position_loss_kernel_raw, dim=mpm_solver.n_particles,  inputs=[mpm_state.particle_x, loss], device=device)\n        \n    tape.backward(loss) # 75120.86\n    \n    print(loss)\n    # model_grad = tape.gradients[mpm_solver.mpm_model]\n    # state_grad = tape.gradients[mpm_solver.mpm_state]\n    # v_grad = state_grad.particle_v\n    # x_grad = state_grad.particle_x\n    v_grad = mpm_solver.mpm_state.particle_v.grad\n    x_grad = mpm_solver.mpm_state.particle_x.grad\n    # E_grad = wp.to_torch(tape.gradients[E_warp])\n    print(x_grad)\n    from IPython import embed; embed()\n\n\n@wp.kernel\ndef g2p_test(state: MPMStateStruct, model: MPMModelStruct, dt: float):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n        new_v = wp.vec3(0.0, 0.0, 0.0)\n        new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        \n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    grid_v = state.grid_v_out[ix, iy, iz]\n                    new_v = new_v + grid_v * weight\n                    new_C = new_C + wp.outer(grid_v, dpos) * (\n                        weight * model.inv_dx * 4.0\n                    )\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n                    new_F = new_F + wp.outer(grid_v, dweight)\n\n        state.particle_v[p] = new_v\n        # state.particle_x[p] = state.particle_x[p] + dt * new_v\n        # state.particle_x[p] = state.particle_x[p] + dt * state.particle_v[p]\n        wp.atomic_add(state.particle_x, p, dt * state.particle_v[p])\n        state.particle_C[p] = new_C\n        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n        F_tmp = (I33 + new_F * dt) * state.particle_F[p]\n        state.particle_F_trial[p] = F_tmp\n\n        if model.update_cov_with_F:\n            update_cov(state, p, new_F, dt)\n\n\n@wp.kernel\ndef position_loss_kernel(mpm_state: MPMStateStruct, loss: wp.array(dtype=float)):\n\n    tid = wp.tid()\n\n    pos = mpm_state.particle_x[tid]\n    wp.atomic_add(loss, 0, pos[0] + pos[1] + pos[2])\n    # wp.atomic_add(loss, 0, mpm_state.particle_x[tid][0])\n\n@wp.kernel\ndef position_loss_kernel_raw(particle_x: wp.array(dtype=wp.vec3), loss: wp.array(dtype=float)):\n\n    tid = wp.tid()\n\n    pos = particle_x[tid]\n    wp.atomic_add(loss, 0, pos[0] + pos[1] + pos[2])\n    # wp.atomic_add(loss, 0, mpm_state.particle_x[tid][0])\n\n\nif __name__ == \"__main__\":\n    Fire(test)\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/solver_grad_test.py",
    "content": "import warp as wp\nimport numpy as np\nimport torch\nimport os\nfrom mpm_solver_warp_diff import MPM_Simulator_WARPDiff\nfrom run_gaussian_static import load_gaussians, get_volume\nfrom tqdm import tqdm\nfrom fire import Fire\n\nfrom diff_warp_utils import MPMStateStruct, MPMModelStruct\nfrom warp_rewrite import MyTape\n\nfrom mpm_utils import *\nimport random\n\n\ndef test(input_dir, output_dir=None, fps=6, device=0):\n    seed = 42\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n\n    wp.init()\n    wp.config.verify_cuda = True\n\n    device = \"cuda:{}\".format(device)\n\n    gaussian_dict, scale, shift = load_gaussians(input_dir)\n\n    velocity_scaling = 0.5\n    init_velocity = velocity_scaling * gaussian_dict[\"velocity\"]\n    init_position = gaussian_dict[\"position\"]\n    init_cov = gaussian_dict[\"cov\"]\n\n    volume_array_path = os.path.join(input_dir, \"volume_array.npy\")\n    if os.path.exists(volume_array_path):\n        volume_array = np.load(volume_array_path)\n        volume_tensor = torch.from_numpy(volume_array).float().to(device)\n    else:\n        volume_array = get_volume(init_position)\n        np.save(volume_array_path, volume_array)\n        volume_tensor = torch.from_numpy(volume_array).float().to(device)\n\n    tensor_init_pos = torch.from_numpy(init_position).float().to(device)\n    tensor_init_cov = torch.from_numpy(init_cov).float().to(device)\n    tensor_init_velocity = torch.from_numpy(init_velocity).float().to(device)\n\n    material_params = {\n        \"E\": 2.0,  # 0.1-200 MPa\n        \"nu\": 0.1,  # > 0.35\n        \"material\": \"jelly\",\n        # \"material\": \"metal\",\n        # \"friction_angle\": 25,\n        \"g\": [0.0, 0.0, 0],\n        \"density\": 0.02,  # kg / m^3\n    }\n\n    n_particles = tensor_init_pos.shape[0]\n    mpm_state = MPMStateStruct()\n\n    mpm_state.init(init_position.shape[0], device=device, requires_grad=True)\n    mpm_state.from_torch(\n        tensor_init_pos,\n        volume_tensor,\n        tensor_init_cov,\n        tensor_init_velocity,\n        device=device,\n        requires_grad=True,\n        n_grid=100,\n        grid_lim=1.0,\n    )\n\n    mpm_model = MPMModelStruct()\n    mpm_model.init(n_particles, device=device, requires_grad=True)\n    mpm_model.init_other_params(n_grid=100, grid_lim=1.0, device=device)\n\n    E_tensor = (torch.ones(n_particles) * material_params[\"E\"]).contiguous().to(device)\n    nu_tensor = (\n        (torch.ones(n_particles) * material_params[\"nu\"]).contiguous().to(device)\n    )\n    mpm_model.from_torch(E_tensor, nu_tensor, device=device, requires_grad=True)\n\n    mpm_solver = MPM_Simulator_WARPDiff(\n        n_particles, n_grid=100, grid_lim=1.0, device=device\n    )\n\n    mpm_solver.set_parameters_dict(mpm_model, mpm_state, material_params)\n\n    # set boundary conditions\n    static_center_point = (\n        torch.from_numpy(gaussian_dict[\"satic_center_point\"]).float().to(device)\n    )\n    max_static_offset = (\n        torch.from_numpy(gaussian_dict[\"max_static_offset\"]).float().to(device)\n    )\n    velocity = torch.zeros_like(static_center_point)\n    mpm_solver.enforce_particle_velocity_translation(\n        mpm_state,\n        static_center_point,\n        max_static_offset,\n        velocity,\n        start_time=0,\n        end_time=1000,\n        device=device,\n    )\n\n    mpm_state.set_require_grad(True)\n\n    total_time = 0.1\n    time_step = 0.001\n    total_iters = int(total_time / time_step)\n    total_iters = 3\n    loss = torch.zeros(1, device=device)\n    loss = wp.from_torch(loss, requires_grad=True)\n\n    dt = time_step\n    tape = MyTape()  # wp.Tape()\n\n    with tape:\n        # for k in tqdm(range(1, total_iters)):\n        k = 1\n        # mpm_solver.p2g2p(k, time_step, device=device)\n        for k in range(10):\n            mpm_solver.p2g2p(mpm_model, mpm_state, k, time_step, device=device)\n\n        wp.launch(\n            position_loss_kernel,\n            dim=n_particles,\n            inputs=[mpm_state, loss],\n            device=device,\n        )\n\n    print(loss, \"pre backward\")\n\n    tape.backward(loss)  # 75120.86\n\n    print(loss)\n\n    v_grad = mpm_state.particle_v.grad\n    x_grad = mpm_state.particle_x.grad\n    e_grad = mpm_model.E.grad\n    e_grad_torch = wp.to_torch(e_grad)\n    grid_v_grad = mpm_state.grid_v_out.grad\n    grid_v_in_grad = mpm_state.grid_v_in.grad\n    print(x_grad)\n    from IPython import embed\n\n    embed()\n\n\n@wp.kernel\ndef position_loss_kernel(mpm_state: MPMStateStruct, loss: wp.array(dtype=float)):\n    tid = wp.tid()\n\n    pos = mpm_state.particle_x[tid]\n    wp.atomic_add(loss, 0, pos[0] + pos[1] + pos[2])\n    # wp.atomic_add(loss, 0, mpm_state.particle_x[tid][0])\n\n\nif __name__ == \"__main__\":\n    Fire(test)\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/test_inverse_sim.py",
    "content": "import warp as wp\nimport numpy as np\nimport torch\nimport os\nfrom mpm_solver_warp_diff import MPM_Simulator_WARPDiff\nfrom run_gaussian_static import load_gaussians, get_volume\nfrom tqdm import tqdm\nfrom fire import Fire\n\nfrom diff_warp_utils import MPMStateStruct, MPMModelStruct\nfrom warp_rewrite import MyTape\n\nfrom mpm_utils import *\nimport random\nimport pickle\n\n\ndef test(\n    input_dir,\n    pickle_path=\"output/E_0.2_nu_0.1_material_jelly_density_0.2__timestep_0.001_vs0.5_totaltime_10.pkl\",\n    output_dir=None,\n    fps=6,\n    device=0,\n):\n    seed = 42\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n\n    wp.init()\n    wp.config.verify_cuda = True\n\n    device = \"cuda:{}\".format(device)\n\n    gaussian_dict, scale, shift = load_gaussians(input_dir)\n\n    velocity_scaling = 0.5\n\n    init_velocity = velocity_scaling * gaussian_dict[\"velocity\"]\n    init_position = gaussian_dict[\"position\"]\n    init_cov = gaussian_dict[\"cov\"]\n\n    volume_array_path = os.path.join(input_dir, \"volume_array.npy\")\n    if os.path.exists(volume_array_path):\n        volume_array = np.load(volume_array_path)\n        volume_tensor = torch.from_numpy(volume_array).float().to(device)\n    else:\n        volume_array = get_volume(init_position)\n        np.save(volume_array_path, volume_array)\n        volume_tensor = torch.from_numpy(volume_array).float().to(device)\n\n    tensor_init_pos = torch.from_numpy(init_position).float().to(device)\n    tensor_init_cov = torch.from_numpy(init_cov).float().to(device)\n    tensor_init_velocity = torch.from_numpy(init_velocity).float().to(device)\n\n    material_params = {\n        \"E\": 2.5,  # 0.1-200 MPa\n        \"nu\": 0.2,  # > 0.35\n        \"material\": \"jelly\",\n        # \"material\": \"metal\",\n        # \"friction_angle\": 25,\n        \"g\": [0.0, 0.0, 0],\n        \"density\": 0.2,  # kg / m^3\n    }\n\n    n_particles = tensor_init_pos.shape[0]\n    mpm_state = MPMStateStruct()\n\n    mpm_state.init(init_position.shape[0], device=device, requires_grad=True)\n    mpm_state.from_torch(\n        tensor_init_pos.clone(),\n        volume_tensor,\n        tensor_init_cov,\n        tensor_init_velocity.clone(),\n        device=device,\n        requires_grad=True,\n        n_grid=100,\n        grid_lim=1.0,\n    )\n\n    mpm_model = MPMModelStruct()\n    mpm_model.init(n_particles, device=device, requires_grad=True)\n    mpm_model.init_other_params(n_grid=100, grid_lim=1.0, device=device)\n\n    E_tensor = (torch.ones(n_particles) * material_params[\"E\"]).contiguous().to(device)\n    nu_tensor = (\n        (torch.ones(n_particles) * material_params[\"nu\"]).contiguous().to(device)\n    )\n    mpm_model.from_torch(E_tensor, nu_tensor, device=device, requires_grad=True)\n\n    mpm_solver = MPM_Simulator_WARPDiff(\n        n_particles, n_grid=100, grid_lim=1.0, device=device\n    )\n\n    mpm_solver.set_parameters_dict(mpm_model, mpm_state, material_params)\n\n    # set boundary conditions\n    static_center_point = (\n        torch.from_numpy(gaussian_dict[\"satic_center_point\"]).float().to(device)\n    )\n    max_static_offset = (\n        torch.from_numpy(gaussian_dict[\"max_static_offset\"]).float().to(device)\n    )\n    velocity = torch.zeros_like(static_center_point)\n    mpm_solver.enforce_particle_velocity_translation(\n        mpm_state,\n        static_center_point,\n        max_static_offset,\n        velocity,\n        start_time=0,\n        end_time=1000,\n        device=device,\n    )\n\n    mpm_state.set_require_grad(True)\n\n    total_time = 0.02\n    time_step = 0.001\n    total_iters = int(total_time / time_step)\n    total_iters = 10\n\n    dt = time_step\n    with open(pickle_path, \"rb\") as f:\n        gt_dict = pickle.load(f)\n\n    sim_sub_step = 10\n    gt_pos_numpy_list = gt_dict[\"pos_list\"]\n    pos_gt_1 = gt_pos_numpy_list[sim_sub_step - 1]\n    pos_gt_1_warp = wp.from_numpy(\n        pos_gt_1, dtype=wp.vec3, device=device, requires_grad=True\n    )\n\n    E_cur = material_params[\"E\"]\n    nu_cur = material_params[\"nu\"]\n\n    init_lr = 3e-6\n    total_train_steps = 2000\n    for train_step in range(total_train_steps):\n        learning_rate = (\n            (total_train_steps - train_step + 1) / total_train_steps * init_lr\n        )\n        tape = MyTape()  # wp.Tape()\n        with tape:\n            # for k in tqdm(range(1, total_iters)):\n            k = 0\n            mpm_solver.time = 0.0\n\n            mpm_solver.set_E_nu(mpm_model, E_cur, nu_cur, device=device)\n            for k in range(sim_sub_step):\n                mpm_solver.p2g2p(mpm_model, mpm_state, k, time_step, device=device)\n\n            loss = torch.zeros(1, device=device)\n            loss = wp.from_torch(loss, requires_grad=True)\n            wp.launch(\n                position_loss_kernel,\n                dim=n_particles,\n                inputs=[mpm_state, pos_gt_1_warp, loss],\n                device=device,\n            )\n\n        tape.backward(loss)  # 75120.86\n\n        E_grad = wp.from_torch(torch.zeros(1, device=device), requires_grad=False)\n        nu_grad = wp.from_torch(torch.zeros(1, device=device), requires_grad=False)\n\n        wp.launch(\n            aggregate_grad,\n            dim=n_particles,\n            inputs=[\n                E_grad,\n                mpm_model.E.grad,\n            ],\n            device=device,\n        )\n        wp.launch(\n            aggregate_grad,\n            dim=n_particles,\n            inputs=[nu_grad, mpm_model.nu.grad],\n            device=device,\n        )\n\n        E_cur = E_cur - wp.to_torch(E_grad).item() * learning_rate\n        nu_cur = nu_cur - wp.to_torch(nu_grad).item() * learning_rate\n        # clip:\n        E_cur = max(1e-5, min(E_cur, 200))\n        nu_cur = max(1e-2, min(nu_cur, 0.449))\n\n        tape.zero()\n        print(\n            loss,\n            \"pre backward\",\n            E_cur,\n            nu_cur,\n            E_grad,\n            nu_grad,\n        )\n\n        mpm_state.reset_state(\n            tensor_init_pos.clone(),\n            tensor_init_cov,\n            tensor_init_velocity.clone(),\n            device=device,\n            requires_grad=True,\n        )\n        # might need to set mpm_model.yield_stress\n\n    from IPython import embed  #     embed()\n\n\n@wp.kernel\ndef position_loss_kernel(\n    mpm_state: MPMStateStruct,\n    gt_pos: wp.array(dtype=wp.vec3),\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    pos = mpm_state.particle_x[tid]\n    pos_gt = gt_pos[tid]\n\n    # l1_diff = wp.abs(pos - pos_gt)\n    l2 = wp.length(pos - pos_gt)\n\n    wp.atomic_add(loss, 0, l2)\n\n\n@wp.kernel\ndef step_kernel(x: wp.array(dtype=float), grad: wp.array(dtype=float), alpha: float):\n    tid = wp.tid()\n\n    # gradient descent step\n    x[tid] = x[tid] - grad[tid] * alpha\n\n\n@wp.kernel\ndef aggregate_grad(x: wp.array(dtype=float), grad: wp.array(dtype=float)):\n    tid = wp.tid()\n\n    # gradient descent step\n    wp.atomic_add(x, 0, grad[tid])\n\n\nif __name__ == \"__main__\":\n    Fire(test)\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/test_sim.py",
    "content": "import warp as wp \nimport numpy as np\nimport torch\n\n@wp.struct\nclass MPMStateStruct:\n    ###### essential #####\n    # particle\n    particle_x: wp.array(dtype=wp.vec3)  # current position\n    particle_v: wp.array(dtype=wp.vec3)  # particle velocity\n\n    particle_vol: wp.array(dtype=float)  # current volume\n    particle_F: wp.array(dtype=wp.mat33)  # particle elastic deformation gradient\n    grid_v_out: wp.array(\n        dtype=wp.vec3, ndim=3\n    )  # grid node momentum/velocity, after grid update\n\nclass MPM_Simulator_WARPDiff:\n    def __init__(self, x, v, vol, device):\n\n        self.mpm_state = MPMStateStruct()\n        self.mpm_state.particle_x = wp.array(x, dtype=wp.vec3, requires_grad=True, device=device)\n        self.mpm_state.particle_v = wp.array(v, dtype=wp.vec3, requires_grad=True, device=device)\n        self.mpm_state.particle_vol = wp.array(vol, dtype=float, requires_grad=False, device=device)\n        self.mpm_state.particle_F = wp.array(np.zeros((100, 3, 3), dtype=np.float32), dtype=wp.mat33, requires_grad=True, device=device)\n        self.mpm_state.grid_v_out = wp.array(np.zeros((100, 100, 100, 3), dtype=np.float32), dtype=wp.vec3, requires_grad=True, device=device)\n\n\n@wp.kernel\ndef vec3_add(mpm_state: MPMStateStruct, selection: wp.array(dtype=wp.float32), dt: float):\n\n    tid = wp.tid()\n\n    # new_v = wp.vec3(1.0, 1.0, 1.0)\n    # velocity[tid] = new_v\n    velocity = mpm_state.particle_v\n    if selection[tid] == 0: # no problem. static condition/loop no problem!\n        for i in range(2):\n            for j in range(2):\n                # x[tid] = x[tid] + velocity[tid] * dt\n                # x[tid] = wp.add(x[tid], velocity[tid]) # same as above. wrong gradient\n                wp.atomic_add(mpm_state.particle_x, tid, velocity[tid] * mpm_state.particle_vol[tid])\n\n                # x[tid] += velocity[tid] * dt # error, no support of +=\n    \n@wp.kernel\ndef loss_kernel(mpm_state: MPMStateStruct,  loss: wp.array(dtype=float)):\n\n    tid = wp.tid()\n\n    pos = mpm_state.particle_x[tid]\n\n    wp.atomic_add(loss, 0, pos[0])\n\n\ndef main():\n\n    wp.init()\n    wp.config.verify_cuda = True\n\n    device = 0\n    device = \"cuda:{}\".format(device)\n\n\n    x = np.random.rand(100, 3).astype(np.float32)\n    velocity = np.random.rand(100, 3).astype(np.float32)\n    dt = 0.1\n\n    \n    # mpm_state = MPMStateStruct()\n    # mpm_state.particle_x = wp.array(x, device=device, dtype=wp.vec3,  requires_grad=True)\n    # mpm_state.particle_v = wp.array(velocity, device=device, dtype=wp.vec3, requires_grad=True)\n    # mpm_state.particle_vol = wp.full(shape=100, value=1, device=device, dtype=wp.float32, requires_grad=False)\n    \n    mpm_solver = MPM_Simulator_WARPDiff(x, velocity, np.ones(100, dtype=np.float32), device=device)\n    \n    selection = wp.zeros(100, device=device, dtype=wp.float32)\n\n    loss = torch.zeros(1, device=device)\n    loss = wp.from_torch(loss, requires_grad=True)\n    tape = wp.Tape()\n\n    with tape:\n        for j in range(2):\n            wp.launch(vec3_add, dim=100, inputs=[mpm_solver.mpm_state, selection, dt], device=device)\n        wp.launch(loss_kernel, dim=100, inputs=[mpm_solver.mpm_state, loss], device=device)\n\n    tape.backward(loss) \n\n    v_grad = mpm_solver.mpm_state.particle_v.grad\n    x_grad = mpm_solver.mpm_state.particle_x.grad\n    print(v_grad, loss)\n\n    from IPython import embed; embed()\n\nif __name__ == \"__main__\":  \n    main()"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/warp_rewrite.py",
    "content": "import warp as wp\nimport ctypes\n\n\nfrom warp.torch import (\n    dtype_from_torch,\n    device_from_torch,\n    dtype_is_compatible,\n    from_torch,\n)\n\n\ndef from_torch_safe(t, dtype=None, requires_grad=None, grad=None):\n    \"\"\"Wrap a PyTorch tensor to a Warp array without copying the data.\n\n    Args:\n        t (torch.Tensor): The torch tensor to wrap.\n        dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.\n        requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.\n\n    Returns:\n        warp.array: The wrapped array.\n    \"\"\"\n    if dtype is None:\n        dtype = dtype_from_torch(t.dtype)\n    elif not dtype_is_compatible(t.dtype, dtype):\n        raise RuntimeError(f\"Incompatible data types: {t.dtype} and {dtype}\")\n\n    # get size of underlying data type to compute strides\n    ctype_size = ctypes.sizeof(dtype._type_)\n\n    shape = tuple(t.shape)\n    strides = tuple(s * ctype_size for s in t.stride())\n\n    # if target is a vector or matrix type\n    # then check if trailing dimensions match\n    # the target type and update the shape\n    if hasattr(dtype, \"_shape_\"):\n        dtype_shape = dtype._shape_\n        dtype_dims = len(dtype._shape_)\n        if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:\n            raise RuntimeError(\n                f\"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}\"\n            )\n\n        # ensure the inner strides are contiguous\n        stride = ctype_size\n        for i in range(dtype_dims):\n            if strides[-i - 1] != stride:\n                raise RuntimeError(\n                    f\"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous\"\n                )\n            stride *= dtype_shape[-i - 1]\n\n        shape = tuple(shape[:-dtype_dims]) or (1,)\n        strides = tuple(strides[:-dtype_dims]) or (ctype_size,)\n\n    requires_grad = t.requires_grad if requires_grad is None else requires_grad\n    if grad is not None:\n        if not isinstance(grad, wp.array):\n            import torch\n\n            if isinstance(grad, torch.Tensor):\n                grad = from_torch(grad, dtype=dtype)\n            else:\n                raise ValueError(f\"Invalid gradient type: {type(grad)}\")\n    elif requires_grad:\n        # wrap the tensor gradient, allocate if necessary\n        if t.grad is None:\n            # allocate a zero-filled gradient tensor if it doesn't exist\n            import torch\n\n            t.grad = torch.zeros_like(t, requires_grad=False)\n        grad = from_torch(t.grad, dtype=dtype)\n\n    a = wp.types.array(\n        ptr=t.data_ptr(),\n        dtype=dtype,\n        shape=shape,\n        strides=strides,\n        device=device_from_torch(t.device),\n        copy=False,\n        owner=False,\n        grad=grad,\n        requires_grad=requires_grad,\n    )\n\n    # save a reference to the source tensor, otherwise it will be deallocated\n    a._tensor = t\n    return a\n\n\nclass MyTape(wp.Tape):\n    # returns the adjoint of a kernel parameter\n    def get_adjoint(self, a):\n        if not wp.types.is_array(a) and not isinstance(a, wp.codegen.StructInstance):\n            # if input is a simple type (e.g.: float, vec3, etc) then\n            # no gradient needed (we only return gradients through arrays and structs)\n            return a\n\n        elif wp.types.is_array(a) and a.grad:\n            # keep track of all gradients used by the tape (for zeroing)\n            # ignore the scalar loss since we don't want to clear its grad\n            self.gradients[a] = a.grad\n            return a.grad\n\n        elif isinstance(a, wp.codegen.StructInstance):\n            adj = a._cls()\n            for name, _ in a._cls.ctype._fields_:\n                if name.startswith(\"_\"):\n                    continue\n                if isinstance(a._cls.vars[name].type, wp.array):\n                    arr = getattr(a, name)\n                    if arr is None:\n                        continue\n                    if arr.grad:\n                        grad = self.gradients[arr] = arr.grad\n                    else:\n                        grad = wp.zeros_like(arr)\n                    setattr(adj, name, grad)\n                else:\n                    setattr(adj, name, getattr(a, name))\n\n            self.gradients[a] = adj\n            return adj\n\n        return None\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup/warp_utils.py",
    "content": "import warp as wp\nimport warp.torch\nimport torch\n\n\n@wp.struct\nclass MPMModelStruct:\n    ####### essential #######\n    grid_lim: float\n    n_particles: int\n    n_grid: int\n    dx: float\n    inv_dx: float\n    grid_dim_x: int\n    grid_dim_y: int\n    grid_dim_z: int\n    mu: wp.array(dtype=float)\n    lam: wp.array(dtype=float)\n    E: wp.array(dtype=float)\n    nu: wp.array(dtype=float)\n    material: int\n\n    ######## for plasticity ####\n    yield_stress: wp.array(dtype=float)\n    friction_angle: float\n    alpha: float\n    gravitational_accelaration: wp.vec3\n    hardening: float\n    xi: float\n    plastic_viscosity: float\n    softening: float\n\n    ####### for damping\n    rpic_damping: float\n    grid_v_damping_scale: float\n\n    ####### for PhysGaussian: covariance\n    update_cov_with_F: int\n\n\n@wp.struct\nclass MPMStateStruct:\n    ###### essential #####\n    # particle\n    particle_x: wp.array(dtype=wp.vec3)  # current position\n    particle_v: wp.array(dtype=wp.vec3)  # particle velocity\n    particle_F: wp.array(dtype=wp.mat33)  # particle elastic deformation gradient\n    particle_init_cov: wp.array(dtype=float)  # initial covariance matrix\n    particle_cov: wp.array(dtype=float)  # current covariance matrix\n    particle_F_trial: wp.array(\n        dtype=wp.mat33\n    )  # apply return mapping on this to obtain elastic def grad\n    particle_R: wp.array(dtype=wp.mat33)  # rotation matrix\n    particle_stress: wp.array(dtype=wp.mat33)  # Kirchoff stress, elastic stress\n    particle_C: wp.array(dtype=wp.mat33)\n    particle_vol: wp.array(dtype=float)  # current volume\n    particle_mass: wp.array(dtype=float)  # mass\n    particle_density: wp.array(dtype=float)  # density\n    particle_Jp: wp.array(dtype=float)\n\n    particle_selection: wp.array(dtype=int) # only particle_selection[p] = 0 will be simulated\n\n    # grid\n    grid_m: wp.array(dtype=float, ndim=3)\n    grid_v_in: wp.array(dtype=wp.vec3, ndim=3)  # grid node momentum/velocity\n    grid_v_out: wp.array(\n        dtype=wp.vec3, ndim=3\n    )  # grid node momentum/velocity, after grid update\n\n\n# for various boundary conditions\n@wp.struct\nclass Dirichlet_collider:\n    point: wp.vec3\n    normal: wp.vec3\n    direction: wp.vec3\n\n    start_time: float\n    end_time: float\n\n    friction: float\n    surface_type: int\n\n    velocity: wp.vec3\n\n    threshold: float\n    reset: int\n    index: int\n\n    x_unit: wp.vec3\n    y_unit: wp.vec3\n    radius: float\n    v_scale: float\n    width: float\n    height: float\n    length: float\n    R: float\n\n    size: wp.vec3\n\n    horizontal_axis_1: wp.vec3\n    horizontal_axis_2: wp.vec3\n    half_height_and_radius: wp.vec2\n    \n\n\n@wp.struct\nclass Impulse_modifier:\n    # this needs to be changed for each different BC!\n    point: wp.vec3\n    normal: wp.vec3\n    start_time: float\n    end_time: float\n    force: wp.vec3\n    forceTimesDt: wp.vec3\n    numsteps: int\n\n    point: wp.vec3\n    size: wp.vec3\n    mask: wp.array(dtype=int)\n\n\n@wp.struct\nclass MPMtailoredStruct:\n    # this needs to be changed for each different BC!\n    point: wp.vec3\n    normal: wp.vec3\n    start_time: float\n    end_time: float\n    friction: float\n    surface_type: int\n    velocity: wp.vec3\n    threshold: float\n    reset: int\n\n    point_rotate: wp.vec3\n    normal_rotate: wp.vec3\n    x_unit: wp.vec3\n    y_unit: wp.vec3\n    radius: float\n    v_scale: float\n    width: float\n    point_plane: wp.vec3\n    normal_plane: wp.vec3\n    velocity_plane: wp.vec3\n    threshold_plane: float\n\n@wp.struct\nclass MaterialParamsModifier:\n    point: wp.vec3\n    size: wp.vec3\n    E: float\n    nu: float\n    density: float\n\n@wp.struct\nclass ParticleVelocityModifier:\n    point: wp.vec3\n    normal: wp.vec3\n    half_height_and_radius: wp.vec2\n    rotation_scale: float\n    translation_scale: float\n\n    size: wp.vec3\n\n    horizontal_axis_1: wp.vec3\n    horizontal_axis_2: wp.vec3\n    \n    start_time: float\n\n    end_time: float\n\n    velocity: wp.vec3\n\n    mask: wp.array(dtype=int)\n\n\n\n\n@wp.kernel\ndef set_vec3_to_zero(target_array: wp.array(dtype=wp.vec3)):\n    tid = wp.tid()\n    target_array[tid] = wp.vec3(0.0, 0.0, 0.0)\n\n\n@wp.kernel\ndef set_mat33_to_identity(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n\n\n@wp.kernel\ndef add_identity_to_mat33(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.add(\n        target_array[tid], wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    )\n\n\n@wp.kernel\ndef subtract_identity_to_mat33(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.sub(\n        target_array[tid], wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    )\n\n\n@wp.kernel\ndef add_vec3_to_vec3(\n    first_array: wp.array(dtype=wp.vec3), second_array: wp.array(dtype=wp.vec3)\n):\n    tid = wp.tid()\n    first_array[tid] = wp.add(first_array[tid], second_array[tid])\n\n\n@wp.kernel\ndef set_value_to_float_array(target_array: wp.array(dtype=float), value: float):\n    tid = wp.tid()\n    target_array[tid] = value\n\n\n@wp.kernel\ndef set_warpvalue_to_float_array(target_array: wp.array(dtype=float), value: warp.types.float32):\n    tid = wp.tid()\n    target_array[tid] = value\n\n\n@wp.kernel\ndef get_float_array_product(\n    arrayA: wp.array(dtype=float),\n    arrayB: wp.array(dtype=float),\n    arrayC: wp.array(dtype=float),\n):\n    tid = wp.tid()\n    arrayC[tid] = arrayA[tid] * arrayB[tid]\n\n\ndef torch2warp_quat(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 4\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.quat,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\ndef torch2warp_float(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=warp.types.float32,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\ndef torch2warp_vec3(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 3\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.vec3,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_mat33(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 3\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.mat33,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup_jan10/gaussian_sim_utils.py",
    "content": "import numpy as np\n\ndef get_volume(xyzs: np.ndarray, resolution=128) -> np.ndarray:\n\n    # set a grid in the range of [-1, 1], with resolution\n    voxel_counts = np.zeros((resolution, resolution, resolution))\n\n    points_xyzindex = ((xyzs + 1) / 2 * (resolution - 1)).astype(np.uint32)\n    cell_volume = (2.0 / (resolution - 1)) ** 3\n\n    for x, y, z in points_xyzindex:\n        voxel_counts[x, y, z] += 1\n\n    points_number_in_corresponding_voxel = voxel_counts[\n        points_xyzindex[:, 0], points_xyzindex[:, 1], points_xyzindex[:, 2]\n    ]\n\n    points_volume = cell_volume / points_number_in_corresponding_voxel\n\n    points_volume = points_volume.astype(np.float32)\n\n    return points_volume\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup_jan10/mpm_data_structure.py",
    "content": "import warp as wp\nimport warp.torch\nimport torch\nfrom typing import Optional, Union, Sequence, Any\nfrom torch import Tensor\nimport os \nimport sys\nsys.path.append(os.path.dirname(os.path.realpath(__file__)))\nfrom warp_utils import from_torch_safe\n\n\n@wp.struct\nclass MPMStateStruct(object):\n    ###### essential #####\n    # particle\n    particle_x: wp.array(dtype=wp.vec3)  # current position\n    particle_v: wp.array(dtype=wp.vec3)  # particle velocity\n    particle_F: wp.array(dtype=wp.mat33)  # particle elastic deformation gradient\n    particle_init_cov: wp.array(dtype=float)  # initial covariance matrix\n    particle_cov: wp.array(dtype=float)  # current covariance matrix\n    particle_F_trial: wp.array(\n        dtype=wp.mat33\n    )  # apply return mapping on this to obtain elastic def grad\n    particle_R: wp.array(dtype=wp.mat33)  # rotation matrix\n    particle_stress: wp.array(dtype=wp.mat33)  # Kirchoff stress, elastic stress\n    particle_C: wp.array(dtype=wp.mat33)\n    particle_vol: wp.array(dtype=float)  # current volume\n    particle_mass: wp.array(dtype=float)  # mass\n    particle_density: wp.array(dtype=float)  # density\n    particle_Jp: wp.array(dtype=float)\n\n    particle_selection: wp.array(\n        dtype=int\n    )  # only particle_selection[p] = 0 will be simulated\n\n    # grid\n    grid_m: wp.array(dtype=float, ndim=3)\n    grid_v_in: wp.array(dtype=wp.vec3, ndim=3)  # grid node momentum/velocity\n    grid_v_out: wp.array(\n        dtype=wp.vec3, ndim=3\n    )  # grid node momentum/velocity, after grid update\n\n    def init(\n        self,\n        shape: Union[Sequence[int], int],\n        device: wp.context.Devicelike = None,\n        requires_grad=False,\n    ) -> None:\n        # shape default is int. number of particles\n        self.particle_x = wp.empty(\n            shape, dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.particle_v = wp.zeros(\n            shape, dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.particle_F = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_init_cov = wp.zeros(\n            shape * 6, dtype=float, device=device, requires_grad=False\n        )\n        self.particle_cov = wp.zeros(\n            shape * 6, dtype=float, device=device, requires_grad=False\n        )\n\n        self.particle_F_trial = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_R = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_stress = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_C = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_vol = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=False\n        )\n        self.particle_mass = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=False\n        )\n        self.particle_density = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=False\n        )\n        self.particle_Jp = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_selection = wp.zeros(\n            shape, dtype=int, device=device, requires_grad=requires_grad\n        )\n\n        # grid: will init later\n        self.grid_m = wp.empty(\n            (10, 10, 10), dtype=float, device=device, requires_grad=requires_grad\n        )\n        self.grid_v_in = wp.zeros(\n            (10, 10, 10), dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.grid_v_out = wp.zeros(\n            (10, 10, 10), dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n\n    def init_grid(\n        self, grid_res: int, device: wp.context.Devicelike = None, requires_grad=False\n    ):\n        self.grid_m = wp.zeros(\n            (grid_res, grid_res, grid_res),\n            dtype=float,\n            device=device,\n            requires_grad=False,\n        )\n        self.grid_v_in = wp.zeros(\n            (grid_res, grid_res, grid_res),\n            dtype=wp.vec3,\n            device=device,\n            requires_grad=requires_grad,\n        )\n        self.grid_v_out = wp.zeros(\n            (grid_res, grid_res, grid_res),\n            dtype=wp.vec3,\n            device=device,\n            requires_grad=requires_grad,\n        )\n\n    def from_torch(\n        self,\n        tensor_x: Tensor,\n        tensor_volume: Tensor,\n        tensor_cov: Optional[Tensor] = None,\n        tensor_velocity: Optional[Tensor] = None,\n        n_grid: int = 100,\n        grid_lim=1.0,\n        device=\"cuda:0\",\n        requires_grad=True,\n    ):\n        num_dim, n_particles = tensor_x.shape[1], tensor_x.shape[0]\n        assert tensor_x.shape[0] == tensor_volume.shape[0]\n        # assert tensor_x.shape[0] == tensor_cov.reshape(-1, 6).shape[0]\n        self.init_grid(grid_res=n_grid, device=device, requires_grad=requires_grad)\n\n        if tensor_x is not None:\n            self.particle_x = from_torch_safe(\n                tensor_x.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_volume is not None:\n            print(self.particle_vol.shape, tensor_volume.shape)\n            volume_numpy = tensor_volume.detach().cpu().numpy()\n            self.particle_vol = wp.from_numpy(\n                volume_numpy, dtype=float, device=device, requires_grad=False\n            )\n\n        if tensor_cov is not None:\n            cov_numpy = tensor_cov.reshape(-1).detach().clone().cpu().numpy()\n            self.particle_cov = wp.from_numpy(\n                cov_numpy, dtype=float, device=device, requires_grad=False\n            )\n            self.particle_init_cov = self.particle_cov\n\n        if tensor_velocity is not None:\n            self.particle_v = from_torch_safe(\n                tensor_velocity.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        # initial deformation gradient is set to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F_trial],\n            device=device,\n        )\n        # initial trial deformation gradient is set to identity\n\n        print(\"Particles initialized from torch data.\")\n        print(\"Total particles: \", n_particles)\n\n    def reset_state(\n        self,\n        tensor_x: Tensor,\n        tensor_cov: Optional[Tensor] = None,\n        tensor_velocity: Optional[Tensor] = None,\n        tensor_density: Optional[Tensor] = None,\n        selection_mask: Optional[Tensor] = None,\n        device=\"cuda:0\",\n        requires_grad=True,\n    ):\n        # reset p_c, p_v, p_C, p_F_trial\n        num_dim, n_particles = tensor_x.shape[1], tensor_x.shape[0]\n\n        # assert tensor_x.shape[0] == tensor_cov.reshape(-1, 6).shape[0]\n\n        if tensor_x is not None:\n            self.particle_x = from_torch_safe(\n                tensor_x.contiguous().detach(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_cov is not None:\n            cov_numpy = tensor_cov.reshape(-1).detach().clone().cpu().numpy()\n            self.particle_cov = wp.from_numpy(\n                cov_numpy, dtype=float, device=device, requires_grad=False\n            )\n            self.particle_cov = self.particle_init_cov\n\n        if tensor_velocity is not None:\n            self.particle_v = from_torch_safe(\n                tensor_velocity.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n        \n        if tensor_density is not None and selection_mask is not None:\n            wp_density = from_torch_safe(\n                tensor_density.contiguous().detach().clone(),\n                dtype=wp.float32,\n                requires_grad=False,\n            )\n            # 1 indicate we need to simulate this particle\n            wp_selection_mask = from_torch_safe(\n                selection_mask.contiguous().detach().clone().type(torch.int),\n                dtype=wp.int32,\n                requires_grad=False,)\n\n            wp.launch(\n                kernel=set_float_vec_to_vec_wmask,\n                dim=n_particles,\n                inputs=[self.particle_density, wp_density, wp_selection_mask],\n                device=device,\n            )\n\n        # initial deformation gradient is set to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F_trial],\n            device=device,\n        )\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=set_mat33_to_zero,\n            dim=n_particles,\n            inputs=[self.particle_C],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=set_mat33_to_zero,\n            dim=n_particles,\n            inputs=[self.particle_stress],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_R],\n            device=device,\n        )\n    def set_require_grad(self, requires_grad=True):\n        self.particle_x.requires_grad = requires_grad\n        self.particle_v.requires_grad = requires_grad\n        self.particle_F.requires_grad = requires_grad\n        self.particle_F_trial.requires_grad = requires_grad\n        self.particle_stress.requires_grad = requires_grad\n\n        self.grid_v_out.requires_grad = requires_grad\n        self.grid_v_in.requires_grad = requires_grad\n\n    def reset_density(self, tensor_density: Tensor,\n        selection_mask: Optional[Tensor] = None,\n        device=\"cuda:0\",\n        requires_grad=True,):\n\n        n_particles = tensor_density.shape[0]\n        if tensor_density is not None and selection_mask is not None:\n            wp_density = from_torch_safe(\n                tensor_density.contiguous().detach().clone(),\n                dtype=wp.float32,\n                requires_grad=False,\n            )\n            # 1 indicate we need to simulate this particle\n            wp_selection_mask = from_torch_safe(\n                selection_mask.contiguous().detach().clone().type(torch.int),\n                dtype=wp.int32,\n                requires_grad=False,)\n\n            wp.launch(\n                kernel=set_float_vec_to_vec_wmask,\n                dim=n_particles,\n                inputs=[self.particle_density, wp_density, wp_selection_mask],\n                device=device,\n            )\n\n\n@wp.struct\nclass ParticleStateStruct(object):\n    ###### essential #####\n    # particle\n    particle_x: wp.array(dtype=wp.vec3)  # current position\n    particle_v: wp.array(dtype=wp.vec3)  # particle velocity\n    particle_F: wp.array(dtype=wp.mat33)  # particle elastic deformation gradient\n    particle_init_cov: wp.array(dtype=float)  # initial covariance matrix\n    particle_cov: wp.array(dtype=float)  # current covariance matrix\n    particle_F_trial: wp.array(\n        dtype=wp.mat33\n    )  # apply return mapping on this to obtain elastic def grad\n    particle_C: wp.array(dtype=wp.mat33)\n    particle_vol: wp.array(dtype=float)  # current volume\n\n    particle_selection: wp.array(\n        dtype=int\n    )  # only particle_selection[p] = 0 will be simulated\n\n    def init(\n        self,\n        shape: Union[Sequence[int], int],\n        device: wp.context.Devicelike = None,\n        requires_grad=False,\n    ) -> None:\n        # shape default is int. number of particles\n        self.particle_x = wp.empty(\n            shape, dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.particle_v = wp.zeros(\n            shape, dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.particle_F = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_init_cov = wp.zeros(\n            shape * 6, dtype=float, device=device, requires_grad=requires_grad\n        )\n        self.particle_cov = wp.zeros(\n            shape * 6, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_F_trial = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_stress = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_C = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_vol = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_selection = wp.zeros(\n            shape, dtype=int, device=device, requires_grad=requires_grad\n        )\n\n    def from_torch(\n        self,\n        tensor_x: Tensor,\n        tensor_volume: Tensor,\n        tensor_cov: Optional[Tensor] = None,\n        tensor_velocity: Optional[Tensor] = None,\n        n_grid: int = 100,\n        grid_lim=1.0,\n        device=\"cuda:0\",\n        requires_grad=True,\n    ):\n        num_dim, n_particles = tensor_x.shape[1], tensor_x.shape[0]\n        assert tensor_x.shape[0] == tensor_volume.shape[0]\n        # assert tensor_x.shape[0] == tensor_cov.reshape(-1, 6).shape[0]\n\n        if tensor_x is not None:\n            # print(self.particle_x.shape, tensor_x.shape)\n            # print(tensor_x.grad)\n            if tensor_x.requires_grad:\n                # tensor_x.grad = torch.zeros_like(tensor_x, requires_grad=False)\n                raise RuntimeError(\"tensor_x requires grad\")\n\n            # x_numpy = tensor_x.detach().clone().cpu().numpy()\n            # self.particle_x = wp.from_numpy(x_numpy, dtype=wp.vec3, requires_grad=True, device=device)\n            self.particle_x = from_torch_safe(\n                tensor_x.contiguous().detach(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_volume is not None:\n            print(self.particle_vol.shape, tensor_volume.shape)\n            volume_numpy = tensor_volume.detach().cpu().numpy()\n            # self.particle_vol = wp.from_torch(tensor_volume.contiguous(), dtype=float, device=device, requires_grad=requires_grad)\n            # self.particle_vol = wp.from_torch(tensor_volume.contiguous(), dtype=float, requires_grad=False)\n            self.particle_vol = wp.from_numpy(\n                volume_numpy, dtype=float, device=device, requires_grad=False\n            )\n\n        if tensor_cov is not None:\n            cov_numpy = tensor_cov.reshape(-1).detach().clone().cpu().numpy()\n            self.particle_cov = wp.from_numpy(\n                cov_numpy, dtype=float, device=device, requires_grad=False\n            )\n            self.particle_cov = self.particle_init_cov\n\n        if tensor_velocity is not None:\n            if tensor_velocity.requires_grad:\n                tensor_velocity.grad = torch.zeros_like(\n                    tensor_velocity, requires_grad=False\n                )\n                raise RuntimeError(\"tensor_x requires grad\")\n            self.particle_v = from_torch_safe(\n                tensor_velocity.contiguous().detach(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        # initial deformation gradient is set to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F_trial],\n            device=device,\n        )\n        # initial trial deformation gradient is set to identity\n\n        print(\"Particles initialized from torch data.\")\n        print(\"Total particles: \", n_particles)\n\n    def set_require_grad(self, requires_grad=True):\n        self.particle_x.requires_grad = requires_grad\n        self.particle_v.requires_grad = requires_grad\n        self.particle_F.requires_grad = requires_grad\n        self.particle_F_trial.requires_grad = requires_grad\n        self.particle_stress.requires_grad = requires_grad\n\n\n@wp.struct\nclass MPMModelStruct(object):\n    ####### essential #######\n    grid_lim: float\n    n_particles: int\n    n_grid: int\n    dx: float\n    inv_dx: float\n    grid_dim_x: int\n    grid_dim_y: int\n    grid_dim_z: int\n    mu: wp.array(dtype=float)\n    lam: wp.array(dtype=float)\n    E: wp.array(dtype=float)\n    nu: wp.array(dtype=float)\n    material: int\n\n    ######## for plasticity ####\n    yield_stress: wp.array(dtype=float)\n    friction_angle: float\n    alpha: float\n    gravitational_accelaration: wp.vec3\n    hardening: float\n    xi: float\n    plastic_viscosity: float\n    softening: float\n\n    ####### for damping\n    rpic_damping: float\n    grid_v_damping_scale: float\n\n    ####### for PhysGaussian: covariance\n    update_cov_with_F: int\n\n    def init(\n        self,\n        shape: Union[Sequence[int], int],\n        device: wp.context.Devicelike = None,\n        requires_grad=False,\n    ) -> None:\n        self.E = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )  # young's modulus\n        self.nu = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )  # poisson's ratio\n\n        self.mu = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n        self.lam = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n        self.yield_stress = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n    def finalize_mu_lam(self, n_particles, device=\"cuda:0\"):\n        wp.launch(\n            kernel=compute_mu_lam_from_E_nu_clean,\n            dim=n_particles,\n            inputs=[self.mu, self.lam, self.E, self.nu],\n            device=device,\n        )\n\n    def init_other_params(self, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.grid_lim = grid_lim\n        self.n_grid = n_grid\n        self.grid_dim_x = n_grid\n        self.grid_dim_y = n_grid\n        self.grid_dim_z = n_grid\n        (\n            self.dx,\n            self.inv_dx,\n        ) = self.grid_lim / self.n_grid, float(\n            n_grid / grid_lim\n        )  # [0-1]?\n\n        self.update_cov_with_F = False\n\n        # material is used to switch between different elastoplastic models. 0 is jelly\n        self.material = 0\n\n        self.plastic_viscosity = 0.0\n        self.softening = 0.1\n        self.friction_angle = 25.0\n        sin_phi = wp.sin(self.friction_angle / 180.0 * 3.14159265)\n        self.alpha = wp.sqrt(2.0 / 3.0) * 2.0 * sin_phi / (3.0 - sin_phi)\n\n        self.gravitational_accelaration = wp.vec3(0.0, 0.0, 0.0)\n\n        self.rpic_damping = 0.0  # 0.0 if no damping (apic). -1 if pic\n\n        self.grid_v_damping_scale = 1.1  # globally applied\n\n    def from_torch(\n        self, tensor_E: Tensor, tensor_nu: Tensor, device=\"cuda:0\", requires_grad=False\n    ):\n        self.E = wp.from_torch(tensor_E.contiguous(), requires_grad=requires_grad)\n        self.nu = wp.from_torch(tensor_nu.contiguous(), requires_grad=requires_grad)\n        n_particles = tensor_E.shape[0]\n        self.finalize_mu_lam(n_particles=n_particles, device=device)\n\n    def set_require_grad(self, requires_grad=True):\n        self.E.requires_grad = requires_grad\n        self.nu.requires_grad = requires_grad\n        self.mu.requires_grad = requires_grad\n        self.lam.requires_grad = requires_grad\n\n\n# for various boundary conditions\n@wp.struct\nclass Dirichlet_collider:\n    point: wp.vec3\n    normal: wp.vec3\n    direction: wp.vec3\n\n    start_time: float\n    end_time: float\n\n    friction: float\n    surface_type: int\n\n    velocity: wp.vec3\n\n    threshold: float\n    reset: int\n    index: int\n\n    x_unit: wp.vec3\n    y_unit: wp.vec3\n    radius: float\n    v_scale: float\n    width: float\n    height: float\n    length: float\n    R: float\n\n    size: wp.vec3\n\n    horizontal_axis_1: wp.vec3\n    horizontal_axis_2: wp.vec3\n    half_height_and_radius: wp.vec2\n\n\n@wp.struct\nclass Impulse_modifier:\n    # this needs to be changed for each different BC!\n    point: wp.vec3\n    normal: wp.vec3\n    start_time: float\n    end_time: float\n    force: wp.vec3\n    forceTimesDt: wp.vec3\n    numsteps: int\n\n    point: wp.vec3\n    size: wp.vec3\n    mask: wp.array(dtype=int)\n\n\n@wp.struct\nclass MPMtailoredStruct:\n    # this needs to be changed for each different BC!\n    point: wp.vec3\n    normal: wp.vec3\n    start_time: float\n    end_time: float\n    friction: float\n    surface_type: int\n    velocity: wp.vec3\n    threshold: float\n    reset: int\n\n    point_rotate: wp.vec3\n    normal_rotate: wp.vec3\n    x_unit: wp.vec3\n    y_unit: wp.vec3\n    radius: float\n    v_scale: float\n    width: float\n    point_plane: wp.vec3\n    normal_plane: wp.vec3\n    velocity_plane: wp.vec3\n    threshold_plane: float\n\n\n@wp.struct\nclass MaterialParamsModifier:\n    point: wp.vec3\n    size: wp.vec3\n    E: float\n    nu: float\n    density: float\n\n\n@wp.struct\nclass ParticleVelocityModifier:\n    point: wp.vec3\n    normal: wp.vec3\n    half_height_and_radius: wp.vec2\n    rotation_scale: float\n    translation_scale: float\n\n    size: wp.vec3\n\n    horizontal_axis_1: wp.vec3\n    horizontal_axis_2: wp.vec3\n\n    start_time: float\n\n    end_time: float\n\n    velocity: wp.vec3\n\n    mask: wp.array(dtype=int)\n\n\n@wp.kernel\ndef compute_mu_lam_from_E_nu_clean(\n    mu: wp.array(dtype=float),\n    lam: wp.array(dtype=float),\n    E: wp.array(dtype=float),\n    nu: wp.array(dtype=float),\n):\n    p = wp.tid()\n    mu[p] = E[p] / (2.0 * (1.0 + nu[p]))\n    lam[p] = E[p] * nu[p] / ((1.0 + nu[p]) * (1.0 - 2.0 * nu[p]))\n\n\n@wp.kernel\ndef set_vec3_to_zero(target_array: wp.array(dtype=wp.vec3)):\n    tid = wp.tid()\n    target_array[tid] = wp.vec3(0.0, 0.0, 0.0)\n\n@wp.kernel\ndef set_vec3_to_vec3(source_array: wp.array(dtype=wp.vec3), target_array: wp.array(dtype=wp.vec3)):\n    tid = wp.tid()\n    source_array[tid] = target_array[tid]\n\n@wp.kernel\ndef set_float_vec_to_vec_wmask(source_array: wp.array(dtype=float), target_array: wp.array(dtype=float), selection_mask: wp.array(dtype=int)):\n    tid = wp.tid()\n    if selection_mask[tid] == 1:\n        source_array[tid] = target_array[tid]\n\n@wp.kernel\ndef set_float_vec_to_vec(source_array: wp.array(dtype=float), target_array: wp.array(dtype=float)):\n    tid = wp.tid()\n    source_array[tid] = target_array[tid]\n\n\n\n@wp.kernel\ndef set_mat33_to_identity(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n\n\n@wp.kernel\ndef set_mat33_to_zero(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n\n@wp.kernel\ndef add_identity_to_mat33(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.add(\n        target_array[tid], wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    )\n\n\n@wp.kernel\ndef subtract_identity_to_mat33(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.sub(\n        target_array[tid], wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    )\n\n\n@wp.kernel\ndef add_vec3_to_vec3(\n    first_array: wp.array(dtype=wp.vec3), second_array: wp.array(dtype=wp.vec3)\n):\n    tid = wp.tid()\n    first_array[tid] = wp.add(first_array[tid], second_array[tid])\n\n\n@wp.kernel\ndef set_value_to_float_array(target_array: wp.array(dtype=float), value: float):\n    tid = wp.tid()\n    target_array[tid] = value\n\n\n@wp.kernel\ndef set_warpvalue_to_float_array(\n    target_array: wp.array(dtype=float), value: warp.types.float32\n):\n    tid = wp.tid()\n    target_array[tid] = value\n\n\n@wp.kernel\ndef get_float_array_product(\n    arrayA: wp.array(dtype=float),\n    arrayB: wp.array(dtype=float),\n    arrayC: wp.array(dtype=float),\n):\n    tid = wp.tid()\n    arrayC[tid] = arrayA[tid] * arrayB[tid]\n\n\ndef torch2warp_quat(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 4\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.quat,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_float(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=warp.types.float32,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_vec3(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 3\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.vec3,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_mat33(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 3\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.mat33,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup_jan10/mpm_solver_diff.py",
    "content": "import sys\nimport os\n\nimport warp as wp\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)))\nfrom mpm_data_structure import *\nfrom mpm_utils import *\nfrom typing import Optional, Union, Sequence, Any\n\n\nclass MPMWARPDiff(object):\n    # def __init__(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n    #     self.initialize(n_particles, n_grid, grid_lim, device=device)\n    #     self.time_profile = {}\n\n    def __init__(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.initialize(n_particles, n_grid, grid_lim, device=device)\n        self.time_profile = {}\n\n    def initialize(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.n_particles = n_particles\n\n        self.time = 0.0\n\n        self.grid_postprocess = []\n        self.collider_params = []\n        self.modify_bc = []\n\n        self.tailored_struct_for_bc = MPMtailoredStruct()\n        self.pre_p2g_operations = []\n        self.impulse_params = []\n\n        self.particle_velocity_modifiers = []\n        self.particle_velocity_modifier_params = []\n\n    # must give density. mass will be updated as density * volume\n    def set_parameters(self, device=\"cuda:0\", **kwargs):\n        self.set_parameters_dict(device, kwargs)\n\n    def set_parameters_dict(self, mpm_model, mpm_state, kwargs={}, device=\"cuda:0\"):\n        if \"material\" in kwargs:\n            if kwargs[\"material\"] == \"jelly\":\n                mpm_model.material = 0\n            elif kwargs[\"material\"] == \"metal\":\n                mpm_model.material = 1\n            elif kwargs[\"material\"] == \"sand\":\n                mpm_model.material = 2\n            elif kwargs[\"material\"] == \"foam\":\n                mpm_model.material = 3\n            elif kwargs[\"material\"] == \"snow\":\n                mpm_model.material = 4\n            elif kwargs[\"material\"] == \"plasticine\":\n                mpm_model.material = 5\n            else:\n                raise TypeError(\"Undefined material type\")\n\n        if \"yield_stress\" in kwargs:\n            val = kwargs[\"yield_stress\"]\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_model.yield_stress, val],\n                device=device,\n            )\n        if \"hardening\" in kwargs:\n            mpm_model.hardening = kwargs[\"hardening\"]\n        if \"xi\" in kwargs:\n            mpm_model.xi = kwargs[\"xi\"]\n        if \"friction_angle\" in kwargs:\n            mpm_model.friction_angle = kwargs[\"friction_angle\"]\n            sin_phi = wp.sin(mpm_model.friction_angle / 180.0 * 3.14159265)\n            mpm_model.alpha = wp.sqrt(2.0 / 3.0) * 2.0 * sin_phi / (3.0 - sin_phi)\n\n        if \"g\" in kwargs:\n            mpm_model.gravitational_accelaration = wp.vec3(\n                kwargs[\"g\"][0], kwargs[\"g\"][1], kwargs[\"g\"][2]\n            )\n\n        if \"density\" in kwargs:\n            density_value = kwargs[\"density\"]\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_state.particle_density, density_value],\n                device=device,\n            )\n            wp.launch(\n                kernel=get_float_array_product,\n                dim=self.n_particles,\n                inputs=[\n                    mpm_state.particle_density,\n                    mpm_state.particle_vol,\n                    mpm_state.particle_mass,\n                ],\n                device=device,\n            )\n        if \"rpic_damping\" in kwargs:\n            mpm_model.rpic_damping = kwargs[\"rpic_damping\"]\n        if \"plastic_viscosity\" in kwargs:\n            mpm_model.plastic_viscosity = kwargs[\"plastic_viscosity\"]\n        if \"softening\" in kwargs:\n            mpm_model.softening = kwargs[\"softening\"]\n        if \"grid_v_damping_scale\" in kwargs:\n            mpm_model.grid_v_damping_scale = kwargs[\"grid_v_damping_scale\"]\n\n    def set_E_nu(self, mpm_model, E: float, nu: float, device=\"cuda:0\"):\n\n        if isinstance(E, float):\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_model.E, E],\n                device=device,\n            )\n        else: # E is warp array\n            wp.launch(\n                kernel=set_float_vec_to_vec,\n                dim=self.n_particles,\n                inputs=[mpm_model.E, E],\n                device=device,\n            ) \n\n        if isinstance(nu, float):\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_model.nu, nu],\n                device=device,\n            )\n        else:\n            wp.launch(\n                kernel=set_float_vec_to_vec,\n                dim=self.n_particles,\n                inputs=[mpm_model.nu, nu],\n                device=device,\n            )\n\n    def p2g2p(self, mpm_model, mpm_state, step, dt, device=\"cuda:0\"):\n        grid_size = (\n            mpm_model.grid_dim_x,\n            mpm_model.grid_dim_y,\n            mpm_model.grid_dim_z,\n        )\n\n        # TODO, move this outside of the loop!!\n        wp.launch(\n            kernel=compute_mu_lam_from_E_nu,\n            dim=self.n_particles,\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n        wp.launch(\n            kernel=zero_grid,  # gradient might gone\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        # apply pre-p2g operations on particles\n        # apply impulse force on particles..\n        for k in range(len(self.pre_p2g_operations)):\n            wp.launch(\n                kernel=self.pre_p2g_operations[k],\n                dim=self.n_particles,\n                inputs=[self.time, dt, mpm_state, self.impulse_params[k]],\n                device=device,\n            )\n\n        # apply dirichlet particle v modifier\n        for k in range(len(self.particle_velocity_modifiers)):\n            wp.launch(\n                kernel=self.particle_velocity_modifiers[k],\n                dim=self.n_particles,\n                inputs=[\n                    self.time,\n                    mpm_state,\n                    self.particle_velocity_modifier_params[k],\n                ],\n                device=device,\n            )\n\n        # compute stress = stress(returnMap(F_trial))\n        # F_trail => F                    # TODO: this is overite.. \n        # F, SVD(F), lam, mu => Stress.   # TODO: this is overite.. \n            \n        with wp.ScopedTimer(\n            \"compute_stress_from_F_trial\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=compute_stress_from_F_trial,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # F and stress are updated\n\n        # p2g\n        with wp.ScopedTimer(\n            \"p2g\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=p2g_apic_with_stress,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # apply p2g'\n\n        # grid update\n        with wp.ScopedTimer(\n            \"grid_update\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=grid_normalization_and_gravity,\n                dim=(grid_size),\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )\n\n        if mpm_model.grid_v_damping_scale < 1.0:\n            wp.launch(\n                kernel=add_damping_via_grid,\n                dim=(grid_size),\n                inputs=[mpm_state, mpm_model.grid_v_damping_scale],\n                device=device,\n            )\n\n        # apply BC on grid, collide\n        with wp.ScopedTimer(\n            \"apply_BC_on_grid\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            for k in range(len(self.grid_postprocess)):\n                wp.launch(\n                    kernel=self.grid_postprocess[k],\n                    dim=grid_size,\n                    inputs=[\n                        self.time,\n                        dt,\n                        mpm_state,\n                        mpm_model,\n                        self.collider_params[k],\n                    ],\n                    device=device,\n                )\n                if self.modify_bc[k] is not None:\n                    self.modify_bc[k](self.time, dt, self.collider_params[k])\n\n        # g2p\n        with wp.ScopedTimer(\n            \"g2p\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=g2p,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # x, v, C, F_trial are updated\n\n        #### CFL check ####\n        # particle_v = self.mpm_state.particle_v.numpy()\n        # if np.max(np.abs(particle_v)) > self.mpm_model.dx / dt:\n        #     print(\"max particle v: \", np.max(np.abs(particle_v)))\n        #     print(\"max allowed  v: \", self.mpm_model.dx / dt)\n        #     print(\"does not allow v*dt>dx\")\n        #     input()\n        #### CFL check ####\n        with wp.ScopedTimer(\n            \"clip_particle_x\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=clip_particle_x,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model],\n                device=device,\n            )\n\n        self.time = self.time + dt\n\n    def print_time_profile(self):\n        print(\"MPM Time profile:\")\n        for key, value in self.time_profile.items():\n            print(key, sum(value))\n\n    # a surface specified by a point and the normal vector\n    def add_surface_collider(\n        self,\n        point,\n        normal,\n        surface=\"sticky\",\n        friction=0.0,\n        start_time=0.0,\n        end_time=999.0,\n    ):\n        point = list(point)\n        # Normalize normal\n        normal_scale = 1.0 / wp.sqrt(float(sum(x**2 for x in normal)))\n        normal = list(normal_scale * x for x in normal)\n\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n\n        collider_param.point = wp.vec3(point[0], point[1], point[2])\n        collider_param.normal = wp.vec3(normal[0], normal[1], normal[2])\n\n        if surface == \"sticky\" and friction != 0:\n            raise ValueError(\"friction must be 0 on sticky surfaces.\")\n        if surface == \"sticky\":\n            collider_param.surface_type = 0\n        elif surface == \"slip\":\n            collider_param.surface_type = 1\n        elif surface == \"cut\":\n            collider_param.surface_type = 11\n        else:\n            collider_param.surface_type = 2\n        # frictional\n        collider_param.friction = friction\n\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                offset = wp.vec3(\n                    float(grid_x) * model.dx - param.point[0],\n                    float(grid_y) * model.dx - param.point[1],\n                    float(grid_z) * model.dx - param.point[2],\n                )\n                n = wp.vec3(param.normal[0], param.normal[1], param.normal[2])\n                dotproduct = wp.dot(offset, n)\n\n                if dotproduct < 0.0:\n                    if param.surface_type == 0:\n                        state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                            0.0, 0.0, 0.0\n                        )\n                    elif param.surface_type == 11:\n                        if (\n                            float(grid_z) * model.dx < 0.4\n                            or float(grid_z) * model.dx > 0.53\n                        ):\n                            state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                                0.0, 0.0, 0.0\n                            )\n                        else:\n                            v_in = state.grid_v_out[grid_x, grid_y, grid_z]\n                            state.grid_v_out[grid_x, grid_y, grid_z] = (\n                                wp.vec3(v_in[0], 0.0, v_in[2]) * 0.3\n                            )\n                    else:\n                        v = state.grid_v_out[grid_x, grid_y, grid_z]\n                        normal_component = wp.dot(v, n)\n                        if param.surface_type == 1:\n                            v = (\n                                v - normal_component * n\n                            )  # Project out all normal component\n                        else:\n                            v = (\n                                v - wp.min(normal_component, 0.0) * n\n                            )  # Project out only inward normal component\n                        if normal_component < 0.0 and wp.length(v) > 1e-20:\n                            v = wp.max(\n                                0.0, wp.length(v) + normal_component * param.friction\n                            ) * wp.normalize(\n                                v\n                            )  # apply friction here\n                        state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                            0.0, 0.0, 0.0\n                        )\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(None)\n\n    # a cubiod is a rectangular cube'\n    # centered at `point`\n    # dimension is x: point[0]±size[0]\n    #              y: point[1]±size[1]\n    #              z: point[2]±size[2]\n    # all grid nodes lie within the cubiod will have their speed set to velocity\n    # the cuboid itself is also moving with const speed = velocity\n    # set the speed to zero to fix BC\n    def set_velocity_on_cuboid(\n        self,\n        point,\n        size,\n        velocity,\n        start_time=0.0,\n        end_time=999.0,\n        reset=0,\n    ):\n        point = list(point)\n\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n        collider_param.point = wp.vec3(point[0], point[1], point[2])\n        collider_param.size = size\n        collider_param.velocity = wp.vec3(velocity[0], velocity[1], velocity[2])\n        # collider_param.threshold = threshold\n        collider_param.reset = reset\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                offset = wp.vec3(\n                    float(grid_x) * model.dx - param.point[0],\n                    float(grid_y) * model.dx - param.point[1],\n                    float(grid_z) * model.dx - param.point[2],\n                )\n                if (\n                    wp.abs(offset[0]) < param.size[0]\n                    and wp.abs(offset[1]) < param.size[1]\n                    and wp.abs(offset[2]) < param.size[2]\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = param.velocity\n            elif param.reset == 1:\n                if time < param.end_time + 15.0 * dt:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n\n        def modify(time, dt, param: Dirichlet_collider):\n            if time >= param.start_time and time < param.end_time:\n                param.point = wp.vec3(\n                    param.point[0] + dt * param.velocity[0],\n                    param.point[1] + dt * param.velocity[1],\n                    param.point[2] + dt * param.velocity[2],\n                )  # param.point + dt * param.velocity\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(modify)\n\n    def add_bounding_box(self, start_time=0.0, end_time=999.0):\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            padding = 3\n            if time >= param.start_time and time < param.end_time:\n                if grid_x < padding and state.grid_v_out[grid_x, grid_y, grid_z][0] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n                if (\n                    grid_x >= model.grid_dim_x - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][0] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n\n                if grid_y < padding and state.grid_v_out[grid_x, grid_y, grid_z][1] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n                if (\n                    grid_y >= model.grid_dim_y - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][1] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n\n                if grid_z < padding and state.grid_v_out[grid_x, grid_y, grid_z][2] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        0.0,\n                    )\n                if (\n                    grid_z >= model.grid_dim_z - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][2] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        0.0,\n                    )\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(None)\n\n    # particle_v += force/particle_mass * dt\n    # this is applied from start_dt, ends after num_dt p2g2p's\n    # particle velocity is changed before p2g at each timestep\n    def add_impulse_on_particles(\n        self,\n        mpm_state,\n        force,\n        dt,\n        point=[1, 1, 1],\n        size=[1, 1, 1],\n        num_dt=1,\n        start_time=0.0,\n        device=\"cuda:0\",\n    ):\n        impulse_param = Impulse_modifier()\n        impulse_param.start_time = start_time\n        impulse_param.end_time = start_time + dt * num_dt\n\n        impulse_param.point = wp.vec3(point[0], point[1], point[2])\n        impulse_param.size = wp.vec3(size[0], size[1], size[2])\n        impulse_param.mask = wp.zeros(shape=self.n_particles, dtype=int, device=device)\n\n        impulse_param.force = wp.vec3(\n            force[0],\n            force[1],\n            force[2],\n        )\n\n        wp.launch(\n            kernel=selection_add_impulse_on_particles,\n            dim=self.n_particles,\n            inputs=[mpm_state, impulse_param],\n            device=device,\n        )\n\n        self.impulse_params.append(impulse_param)\n\n        @wp.kernel\n        def apply_force(\n            time: float, dt: float, state: MPMStateStruct, param: Impulse_modifier\n        ):\n            p = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                if param.mask[p] == 1:\n                    impulse = wp.vec3(\n                        param.force[0] / state.particle_mass[p],\n                        param.force[1] / state.particle_mass[p],\n                        param.force[2] / state.particle_mass[p],\n                    )\n                    state.particle_v[p] = state.particle_v[p] + impulse * dt\n\n        self.pre_p2g_operations.append(apply_force)\n\n    def enforce_particle_velocity_translation(\n        self, mpm_state, point, size, velocity, start_time, end_time, device=\"cuda:0\"\n    ):\n        # first select certain particles based on position\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.point = wp.vec3(point[0], point[1], point[2])\n        velocity_modifier_params.size = wp.vec3(size[0], size[1], size[2])\n\n        velocity_modifier_params.velocity = wp.vec3(\n            velocity[0], velocity[1], velocity[2]\n        )\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.zeros(\n            shape=self.n_particles, dtype=int, device=device\n        )\n\n        wp.launch(\n            kernel=selection_enforce_particle_velocity_translation,\n            dim=self.n_particles,\n            inputs=[mpm_state, velocity_modifier_params],\n            device=device,\n        )\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    state.particle_v[p] = velocity_modifier_params.velocity\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    # define a cylinder with center point, half_height, radius, normal\n    # particles within the cylinder are rotating along the normal direction\n    # may also have a translational velocity along the normal direction\n    def enforce_particle_velocity_rotation(\n        self,\n        mpm_state,\n        point,\n        normal,\n        half_height_and_radius,\n        rotation_scale,\n        translation_scale,\n        start_time,\n        end_time,\n        device=\"cuda:0\",\n    ):\n        normal_scale = 1.0 / wp.sqrt(\n            float(normal[0] ** 2 + normal[1] ** 2 + normal[2] ** 2)\n        )\n        normal = list(normal_scale * x for x in normal)\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.point = wp.vec3(point[0], point[1], point[2])\n        velocity_modifier_params.half_height_and_radius = wp.vec2(\n            half_height_and_radius[0], half_height_and_radius[1]\n        )\n        velocity_modifier_params.normal = wp.vec3(normal[0], normal[1], normal[2])\n\n        horizontal_1 = wp.vec3(1.0, 1.0, 1.0)\n        if wp.abs(wp.dot(velocity_modifier_params.normal, horizontal_1)) < 0.01:\n            horizontal_1 = wp.vec3(0.72, 0.37, -0.67)\n        horizontal_1 = (\n            horizontal_1\n            - wp.dot(horizontal_1, velocity_modifier_params.normal)\n            * velocity_modifier_params.normal\n        )\n        horizontal_1 = horizontal_1 * (1.0 / wp.length(horizontal_1))\n        horizontal_2 = wp.cross(horizontal_1, velocity_modifier_params.normal)\n\n        velocity_modifier_params.horizontal_axis_1 = horizontal_1\n        velocity_modifier_params.horizontal_axis_2 = horizontal_2\n\n        velocity_modifier_params.rotation_scale = rotation_scale\n        velocity_modifier_params.translation_scale = translation_scale\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.zeros(\n            shape=self.n_particles, dtype=int, device=device\n        )\n\n        wp.launch(\n            kernel=selection_enforce_particle_velocity_cylinder,\n            dim=self.n_particles,\n            inputs=[mpm_state, velocity_modifier_params],\n            device=device,\n        )\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    offset = state.particle_x[p] - velocity_modifier_params.point\n                    horizontal_distance = wp.length(\n                        offset\n                        - wp.dot(offset, velocity_modifier_params.normal)\n                        * velocity_modifier_params.normal\n                    )\n                    cosine = (\n                        wp.dot(offset, velocity_modifier_params.horizontal_axis_1)\n                        / horizontal_distance\n                    )\n                    theta = wp.acos(cosine)\n                    if wp.dot(offset, velocity_modifier_params.horizontal_axis_2) > 0:\n                        theta = theta\n                    else:\n                        theta = -theta\n                    axis1_scale = (\n                        -horizontal_distance\n                        * wp.sin(theta)\n                        * velocity_modifier_params.rotation_scale\n                    )\n                    axis2_scale = (\n                        horizontal_distance\n                        * wp.cos(theta)\n                        * velocity_modifier_params.rotation_scale\n                    )\n                    axis_vertical_scale = translation_scale\n                    state.particle_v[p] = (\n                        axis1_scale * velocity_modifier_params.horizontal_axis_1\n                        + axis2_scale * velocity_modifier_params.horizontal_axis_2\n                        + axis_vertical_scale * velocity_modifier_params.normal\n                    )\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    # given normal direction, say [0,0,1]\n    # gradually release grid velocities from start position to end position\n    def release_particles_sequentially(\n        self, normal, start_position, end_position, num_layers, start_time, end_time\n    ):\n        num_layers = 50\n        point = [0, 0, 0]\n        size = [0, 0, 0]\n        axis = -1\n        for i in range(3):\n            if normal[i] == 0:\n                point[i] = 1\n                size[i] = 1\n            else:\n                axis = i\n                point[i] = end_position\n\n        half_length_portion = wp.abs(start_position - end_position) / num_layers\n        end_time_portion = end_time / num_layers\n        for i in range(num_layers):\n            size[axis] = half_length_portion * (num_layers - i)\n            self.enforce_particle_velocity_translation(\n                point=point,\n                size=size,\n                velocity=[0, 0, 0],\n                start_time=start_time,\n                end_time=end_time_portion * (i + 1),\n            )\n\n    def enforce_particle_velocity_by_mask(\n        self,\n        mpm_state,\n        selection_mask: torch.Tensor,\n        velocity,\n        start_time,\n        end_time,\n    ):\n        # first select certain particles based on position\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.velocity = wp.vec3(\n            velocity[0], velocity[1], velocity[2]\n        )\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.from_torch(selection_mask)\n\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    state.particle_v[p] = velocity_modifier_params.velocity\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    def restart_and_compute_F_C(self, mpm_model, mpm_state, target_pos, device):\n        grid_size = (\n            mpm_model.grid_dim_x,\n            mpm_model.grid_dim_y,\n            mpm_model.grid_dim_z,\n        )\n\n        wp.launch(\n            kernel=zero_grid,  # gradient might gone\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        wp.launch(\n            set_F_C_p2g,\n            dim=self.n_particles,\n            inputs=[mpm_state, mpm_model, target_pos],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=grid_normalization_and_gravity,\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model, 0],\n            device=device,\n        )\n\n        wp.launch(\n            set_F_C_g2p,\n            dim=self.n_particles,\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=zero_grid,  # gradient might gone\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        # set position to target_pos\n        wp.launch(\n            kernel=set_vec3_to_vec3,\n            dim=self.n_particles,\n            inputs=[mpm_state.particle_x, target_pos],\n            device=device,\n        )"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup_jan10/mpm_utils.py",
    "content": "import warp as wp\nfrom mpm_data_structure import *\nimport numpy as np\nimport math\n\n\n# compute stress from F\n@wp.func\ndef kirchoff_stress_FCR(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, mu: float, lam: float\n):\n    # compute kirchoff stress for FCR model (remember tau = P F^T)\n    R = U * wp.transpose(V)\n    id = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    return 2.0 * mu * (F - R) * wp.transpose(F) + id * lam * J * (J - 1.0)\n\n\n@wp.func\ndef kirchoff_stress_neoHookean(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, sig: wp.vec3, mu: float, lam: float\n):\n    # compute kirchoff stress for FCR model (remember tau = P F^T)\n    b = wp.vec3(sig[0] * sig[0], sig[1] * sig[1], sig[2] * sig[2])\n    b_hat = b - wp.vec3(\n        (b[0] + b[1] + b[2]) / 3.0,\n        (b[0] + b[1] + b[2]) / 3.0,\n        (b[0] + b[1] + b[2]) / 3.0,\n    )\n    tau = mu * J ** (-2.0 / 3.0) * b_hat + lam / 2.0 * (J * J - 1.0) * wp.vec3(\n        1.0, 1.0, 1.0\n    )\n    return (\n        U\n        * wp.mat33(tau[0], 0.0, 0.0, 0.0, tau[1], 0.0, 0.0, 0.0, tau[2])\n        * wp.transpose(V)\n        * wp.transpose(F)\n    )\n\n\n@wp.func\ndef kirchoff_stress_StVK(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, sig: wp.vec3, mu: float, lam: float\n):\n    sig = wp.vec3(\n        wp.max(sig[0], 0.01), wp.max(sig[1], 0.01), wp.max(sig[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    log_sig_sum = wp.log(sig[0]) + wp.log(sig[1]) + wp.log(sig[2])\n    ONE = wp.vec3(1.0, 1.0, 1.0)\n    tau = 2.0 * mu * epsilon + lam * log_sig_sum * ONE\n    return (\n        U\n        * wp.mat33(tau[0], 0.0, 0.0, 0.0, tau[1], 0.0, 0.0, 0.0, tau[2])\n        * wp.transpose(V)\n        * wp.transpose(F)\n    )\n\n\n@wp.func\ndef kirchoff_stress_drucker_prager(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, sig: wp.vec3, mu: float, lam: float\n):\n    log_sig_sum = wp.log(sig[0]) + wp.log(sig[1]) + wp.log(sig[2])\n    center00 = 2.0 * mu * wp.log(sig[0]) * (1.0 / sig[0]) + lam * log_sig_sum * (\n        1.0 / sig[0]\n    )\n    center11 = 2.0 * mu * wp.log(sig[1]) * (1.0 / sig[1]) + lam * log_sig_sum * (\n        1.0 / sig[1]\n    )\n    center22 = 2.0 * mu * wp.log(sig[2]) * (1.0 / sig[2]) + lam * log_sig_sum * (\n        1.0 / sig[2]\n    )\n    center = wp.mat33(center00, 0.0, 0.0, 0.0, center11, 0.0, 0.0, 0.0, center22)\n    return U * center * wp.transpose(V) * wp.transpose(F)\n\n\n@wp.func\ndef von_mises_return_mapping(F_trial: wp.mat33, model: MPMModelStruct, p: int):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig_old = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig_old, V)\n\n    sig = wp.vec3(\n        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    temp = (epsilon[0] + epsilon[1] + epsilon[2]) / 3.0\n\n    tau = 2.0 * model.mu[p] * epsilon + model.lam[p] * (\n        epsilon[0] + epsilon[1] + epsilon[2]\n    ) * wp.vec3(1.0, 1.0, 1.0)\n    sum_tau = tau[0] + tau[1] + tau[2]\n    cond = wp.vec3(\n        tau[0] - sum_tau / 3.0, tau[1] - sum_tau / 3.0, tau[2] - sum_tau / 3.0\n    )\n    if wp.length(cond) > model.yield_stress[p]:\n        epsilon_hat = epsilon - wp.vec3(temp, temp, temp)\n        epsilon_hat_norm = wp.length(epsilon_hat) + 1e-6\n        delta_gamma = epsilon_hat_norm - model.yield_stress[p] / (2.0 * model.mu[p])\n        epsilon = epsilon - (delta_gamma / epsilon_hat_norm) * epsilon_hat\n        sig_elastic = wp.mat33(\n            wp.exp(epsilon[0]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[1]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[2]),\n        )\n        F_elastic = U * sig_elastic * wp.transpose(V)\n        if model.hardening == 1:\n            model.yield_stress[p] = (\n                model.yield_stress[p] + 2.0 * model.mu[p] * model.xi * delta_gamma\n            )\n        return F_elastic\n    else:\n        return F_trial\n\n\n@wp.func\ndef von_mises_return_mapping_with_damage(\n    F_trial: wp.mat33, model: MPMModelStruct, p: int\n):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig_old = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig_old, V)\n\n    sig = wp.vec3(\n        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    temp = (epsilon[0] + epsilon[1] + epsilon[2]) / 3.0\n\n    tau = 2.0 * model.mu[p] * epsilon + model.lam[p] * (\n        epsilon[0] + epsilon[1] + epsilon[2]\n    ) * wp.vec3(1.0, 1.0, 1.0)\n    sum_tau = tau[0] + tau[1] + tau[2]\n    cond = wp.vec3(\n        tau[0] - sum_tau / 3.0, tau[1] - sum_tau / 3.0, tau[2] - sum_tau / 3.0\n    )\n    if wp.length(cond) > model.yield_stress[p]:\n        if model.yield_stress[p] <= 0:\n            return F_trial\n        epsilon_hat = epsilon - wp.vec3(temp, temp, temp)\n        epsilon_hat_norm = wp.length(epsilon_hat) + 1e-6\n        delta_gamma = epsilon_hat_norm - model.yield_stress[p] / (2.0 * model.mu[p])\n        epsilon = epsilon - (delta_gamma / epsilon_hat_norm) * epsilon_hat\n        model.yield_stress[p] = model.yield_stress[p] - model.softening * wp.length(\n            (delta_gamma / epsilon_hat_norm) * epsilon_hat\n        )\n        if model.yield_stress[p] <= 0:\n            model.mu[p] = 0.0\n            model.lam[p] = 0.0\n        sig_elastic = wp.mat33(\n            wp.exp(epsilon[0]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[1]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[2]),\n        )\n        F_elastic = U * sig_elastic * wp.transpose(V)\n        if model.hardening == 1:\n            model.yield_stress[p] = (\n                model.yield_stress[p] + 2.0 * model.mu[p] * model.xi * delta_gamma\n            )\n        return F_elastic\n    else:\n        return F_trial\n\n\n# for toothpaste\n@wp.func\ndef viscoplasticity_return_mapping_with_StVK(\n    F_trial: wp.mat33, model: MPMModelStruct, p: int, dt: float\n):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig_old = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig_old, V)\n\n    sig = wp.vec3(\n        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    b_trial = wp.vec3(sig[0] * sig[0], sig[1] * sig[1], sig[2] * sig[2])\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    trace_epsilon = epsilon[0] + epsilon[1] + epsilon[2]\n    epsilon_hat = epsilon - wp.vec3(\n        trace_epsilon / 3.0, trace_epsilon / 3.0, trace_epsilon / 3.0\n    )\n    s_trial = 2.0 * model.mu[p] * epsilon_hat\n    s_trial_norm = wp.length(s_trial)\n    y = s_trial_norm - wp.sqrt(2.0 / 3.0) * model.yield_stress[p]\n    if y > 0:\n        mu_hat = model.mu[p] * (b_trial[0] + b_trial[1] + b_trial[2]) / 3.0\n        s_new_norm = s_trial_norm - y / (\n            1.0 + model.plastic_viscosity / (2.0 * mu_hat * dt)\n        )\n        s_new = (s_new_norm / s_trial_norm) * s_trial\n        epsilon_new = 1.0 / (2.0 * model.mu[p]) * s_new + wp.vec3(\n            trace_epsilon / 3.0, trace_epsilon / 3.0, trace_epsilon / 3.0\n        )\n        sig_elastic = wp.mat33(\n            wp.exp(epsilon_new[0]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon_new[1]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon_new[2]),\n        )\n        F_elastic = U * sig_elastic * wp.transpose(V)\n        return F_elastic\n    else:\n        return F_trial\n\n\n@wp.func\ndef sand_return_mapping(\n    F_trial: wp.mat33, state: MPMStateStruct, model: MPMModelStruct, p: int\n):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig, V)\n\n    epsilon = wp.vec3(\n        wp.log(wp.max(wp.abs(sig[0]), 1e-14)),\n        wp.log(wp.max(wp.abs(sig[1]), 1e-14)),\n        wp.log(wp.max(wp.abs(sig[2]), 1e-14)),\n    )\n    sigma_out = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    tr = epsilon[0] + epsilon[1] + epsilon[2]  # + state.particle_Jp[p]\n    epsilon_hat = epsilon - wp.vec3(tr / 3.0, tr / 3.0, tr / 3.0)\n    epsilon_hat_norm = wp.length(epsilon_hat)\n    delta_gamma = (\n        epsilon_hat_norm\n        + (3.0 * model.lam[p] + 2.0 * model.mu[p])\n        / (2.0 * model.mu[p])\n        * tr\n        * model.alpha\n    )\n\n    if delta_gamma <= 0:\n        F_elastic = F_trial\n\n    if delta_gamma > 0 and tr > 0:\n        F_elastic = U * wp.transpose(V)\n\n    if delta_gamma > 0 and tr <= 0:\n        H = epsilon - epsilon_hat * (delta_gamma / epsilon_hat_norm)\n        s_new = wp.vec3(wp.exp(H[0]), wp.exp(H[1]), wp.exp(H[2]))\n\n        F_elastic = U * wp.diag(s_new) * wp.transpose(V)\n    return F_elastic\n\n\n@wp.kernel\ndef compute_mu_lam_from_E_nu(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n    model.mu[p] = model.E[p] / (2.0 * (1.0 + model.nu[p]))\n    model.lam[p] = (\n        model.E[p] * model.nu[p] / ((1.0 + model.nu[p]) * (1.0 - 2.0 * model.nu[p]))\n    )\n\n\n@wp.kernel\ndef zero_grid(state: MPMStateStruct, model: MPMModelStruct):\n    grid_x, grid_y, grid_z = wp.tid()\n    state.grid_m[grid_x, grid_y, grid_z] = 0.0\n    state.grid_v_in[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n\n\n@wp.func\ndef compute_dweight(\n    model: MPMModelStruct, w: wp.mat33, dw: wp.mat33, i: int, j: int, k: int\n):\n    dweight = wp.vec3(\n        dw[0, i] * w[1, j] * w[2, k],\n        w[0, i] * dw[1, j] * w[2, k],\n        w[0, i] * w[1, j] * dw[2, k],\n    )\n    return dweight * model.inv_dx\n\n\n@wp.func\ndef update_cov(state: MPMStateStruct, p: int, grad_v: wp.mat33, dt: float):\n    cov_n = wp.mat33(0.0)\n    cov_n[0, 0] = state.particle_cov[p * 6]\n    cov_n[0, 1] = state.particle_cov[p * 6 + 1]\n    cov_n[0, 2] = state.particle_cov[p * 6 + 2]\n    cov_n[1, 0] = state.particle_cov[p * 6 + 1]\n    cov_n[1, 1] = state.particle_cov[p * 6 + 3]\n    cov_n[1, 2] = state.particle_cov[p * 6 + 4]\n    cov_n[2, 0] = state.particle_cov[p * 6 + 2]\n    cov_n[2, 1] = state.particle_cov[p * 6 + 4]\n    cov_n[2, 2] = state.particle_cov[p * 6 + 5]\n\n    cov_np1 = cov_n + dt * (grad_v * cov_n + cov_n * wp.transpose(grad_v))\n\n    state.particle_cov[p * 6] = cov_np1[0, 0]\n    state.particle_cov[p * 6 + 1] = cov_np1[0, 1]\n    state.particle_cov[p * 6 + 2] = cov_np1[0, 2]\n    state.particle_cov[p * 6 + 3] = cov_np1[1, 1]\n    state.particle_cov[p * 6 + 4] = cov_np1[1, 2]\n    state.particle_cov[p * 6 + 5] = cov_np1[2, 2]\n\n\n@wp.kernel\ndef p2g_apic_with_stress(state: MPMStateStruct, model: MPMModelStruct, dt: float):\n    # input given to p2g:   particle_stress\n    #                       particle_x\n    #                       particle_v\n    #                       particle_C\n    # output:               grid_v_in, grid_m\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        stress = state.particle_stress[p]\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    dpos = (\n                        wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    ) * model.dx\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n\n                    C = state.particle_C[p]\n                    # if model.rpic = 0, standard apic\n                    C = (1.0 - model.rpic_damping) * C + model.rpic_damping / 2.0 * (\n                        C - wp.transpose(C)\n                    )\n\n                    # C = (1.0 - model.rpic_damping) * state.particle_C[\n                    #     p\n                    # ] + model.rpic_damping / 2.0 * (\n                    #     state.particle_C[p] - wp.transpose(state.particle_C[p])\n                    # )\n\n                    if model.rpic_damping < -0.001:\n                        # standard pic\n                        C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n                    elastic_force = -state.particle_vol[p] * stress * dweight\n                    v_in_add = (\n                        weight\n                        * state.particle_mass[p]\n                        * (state.particle_v[p] + C * dpos)\n                        + dt * elastic_force\n                    )\n                    wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)\n                    wp.atomic_add(\n                        state.grid_m, ix, iy, iz, weight * state.particle_mass[p]\n                    )\n\n\n# add gravity\n@wp.kernel\ndef grid_normalization_and_gravity(\n    state: MPMStateStruct, model: MPMModelStruct, dt: float\n):\n    grid_x, grid_y, grid_z = wp.tid()\n    if state.grid_m[grid_x, grid_y, grid_z] > 1e-15:\n        v_out = state.grid_v_in[grid_x, grid_y, grid_z] * (\n            1.0 / state.grid_m[grid_x, grid_y, grid_z]\n        )\n        # add gravity\n        v_out = v_out + dt * model.gravitational_accelaration\n        state.grid_v_out[grid_x, grid_y, grid_z] = v_out\n\n\n@wp.kernel\ndef g2p(state: MPMStateStruct, model: MPMModelStruct, dt: float):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n        new_v = wp.vec3(0.0, 0.0, 0.0)\n        new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    grid_v = state.grid_v_out[ix, iy, iz]\n                    new_v = new_v + grid_v * weight\n                    new_C = new_C + wp.outer(grid_v, dpos) * (\n                        weight * model.inv_dx * 4.0\n                    )\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n                    new_F = new_F + wp.outer(grid_v, dweight)\n\n        state.particle_v[p] = new_v\n        # state.particle_x[p] = state.particle_x[p] + dt * new_v\n        # state.particle_x[p] = state.particle_x[p] + dt * state.particle_v[p]\n\n        # wp.atomic_add(state.particle_x, p, dt * state.particle_v[p]) # old one is this.. \n        wp.atomic_add(state.particle_x, p, dt * new_v) # debug\n        # new_x = state.particle_x[p] + dt * state.particle_v[p]\n        # state.particle_x[p] = new_x\n\n        state.particle_C[p] = new_C\n\n        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n        F_tmp = (I33 + new_F * dt) * state.particle_F[p]\n        state.particle_F_trial[p] = F_tmp\n        # debug for jelly\n        # wp.atomic_add(state.particle_F_trial, p, new_F * dt * state.particle_F[p])\n\n        if model.update_cov_with_F:\n            update_cov(state, p, new_F, dt)\n\n\n@wp.kernel\ndef clip_particle_x(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n\n    posx = state.particle_x[p]\n    if state.particle_selection[p] == 0:\n        dx = 1.0 / model.inv_dx\n        a_min = dx * 2.0\n        a_max = model.grid_lim - dx * 2.0\n        new_x = wp.vec3(wp.clamp(posx[0], a_min, a_max),\n                        wp.clamp(posx[1], a_min, a_max), \n                        wp.clamp(posx[2], a_min, a_max))\n\n        delta_x = new_x - posx\n\n        wp.atomic_add(state.particle_x, p, delta_x)\n\n\n\n# compute (Kirchhoff) stress = stress(returnMap(F_trial))\n@wp.kernel\ndef compute_stress_from_F_trial(\n    state: MPMStateStruct, model: MPMModelStruct, dt: float\n):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        # apply return mapping\n        if model.material == 1:  # metal\n            state.particle_F[p] = von_mises_return_mapping(\n                state.particle_F_trial[p], model, p\n            )\n        elif model.material == 2:  # sand\n            state.particle_F[p] = sand_return_mapping(\n                state.particle_F_trial[p], state, model, p\n            )\n        elif model.material == 3:  # visplas, with StVk+VM, no thickening\n            state.particle_F[p] = viscoplasticity_return_mapping_with_StVK(\n                state.particle_F_trial[p], model, p, dt\n            )\n        elif model.material == 5:\n            state.particle_F[p] = von_mises_return_mapping_with_damage(\n                state.particle_F_trial[p], model, p\n            )\n        else:  # elastic, jelly\n            state.particle_F[p] = state.particle_F_trial[p]\n\n        # also compute stress here\n        J = wp.determinant(state.particle_F[p])\n        U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        sig = wp.vec3(0.0)\n        stress = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        wp.svd3(state.particle_F[p], U, sig, V)\n        if model.material == 0 or model.material == 5:\n            stress = kirchoff_stress_FCR(\n                state.particle_F[p], U, V, J, model.mu[p], model.lam[p]\n            )\n        if model.material == 1:\n            stress = kirchoff_stress_StVK(\n                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]\n            )\n        if model.material == 2:\n            stress = kirchoff_stress_drucker_prager(\n                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]\n            )\n        if model.material == 3:\n            # temporarily use stvk, subject to change\n            stress = kirchoff_stress_StVK(\n                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]\n            )\n\n        # stress = (stress + wp.transpose(stress)) / 2.0  # enfore symmetry\n        state.particle_stress[p] = (stress + wp.transpose(stress)) / 2.0\n\n\n@wp.kernel\ndef compute_cov_from_F(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n\n    F = state.particle_F_trial[p]\n\n    init_cov = wp.mat33(0.0)\n    init_cov[0, 0] = state.particle_init_cov[p * 6]\n    init_cov[0, 1] = state.particle_init_cov[p * 6 + 1]\n    init_cov[0, 2] = state.particle_init_cov[p * 6 + 2]\n    init_cov[1, 0] = state.particle_init_cov[p * 6 + 1]\n    init_cov[1, 1] = state.particle_init_cov[p * 6 + 3]\n    init_cov[1, 2] = state.particle_init_cov[p * 6 + 4]\n    init_cov[2, 0] = state.particle_init_cov[p * 6 + 2]\n    init_cov[2, 1] = state.particle_init_cov[p * 6 + 4]\n    init_cov[2, 2] = state.particle_init_cov[p * 6 + 5]\n\n    cov = F * init_cov * wp.transpose(F)\n\n    state.particle_cov[p * 6] = cov[0, 0]\n    state.particle_cov[p * 6 + 1] = cov[0, 1]\n    state.particle_cov[p * 6 + 2] = cov[0, 2]\n    state.particle_cov[p * 6 + 3] = cov[1, 1]\n    state.particle_cov[p * 6 + 4] = cov[1, 2]\n    state.particle_cov[p * 6 + 5] = cov[2, 2]\n\n\n@wp.kernel\ndef compute_R_from_F(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n\n    F = state.particle_F_trial[p]\n\n    # polar svd decomposition\n    U = wp.mat33(0.0)\n    V = wp.mat33(0.0)\n    sig = wp.vec3(0.0)\n    wp.svd3(F, U, sig, V)\n\n    if wp.determinant(U) < 0.0:\n        U[0, 2] = -U[0, 2]\n        U[1, 2] = -U[1, 2]\n        U[2, 2] = -U[2, 2]\n\n    if wp.determinant(V) < 0.0:\n        V[0, 2] = -V[0, 2]\n        V[1, 2] = -V[1, 2]\n        V[2, 2] = -V[2, 2]\n\n    # compute rotation matrix\n    R = U * wp.transpose(V)\n    state.particle_R[p] = wp.transpose(R)\n\n\n@wp.kernel\ndef add_damping_via_grid(state: MPMStateStruct, scale: float):\n    grid_x, grid_y, grid_z = wp.tid()\n    # state.grid_v_out[grid_x, grid_y, grid_z] = (\n    #     state.grid_v_out[grid_x, grid_y, grid_z] * scale\n    # )\n    wp.atomic_sub(state.grid_v_out, grid_x, grid_y, grid_z, (1.0 - scale) * state.grid_v_out[grid_x, grid_y, grid_z])\n\n\n@wp.kernel\ndef apply_additional_params(\n    state: MPMStateStruct,\n    model: MPMModelStruct,\n    params_modifier: MaterialParamsModifier,\n):\n    p = wp.tid()\n    pos = state.particle_x[p]\n    if (\n        pos[0] > params_modifier.point[0] - params_modifier.size[0]\n        and pos[0] < params_modifier.point[0] + params_modifier.size[0]\n        and pos[1] > params_modifier.point[1] - params_modifier.size[1]\n        and pos[1] < params_modifier.point[1] + params_modifier.size[1]\n        and pos[2] > params_modifier.point[2] - params_modifier.size[2]\n        and pos[2] < params_modifier.point[2] + params_modifier.size[2]\n    ):\n        model.E[p] = params_modifier.E\n        model.nu[p] = params_modifier.nu\n        state.particle_density[p] = params_modifier.density\n\n\n@wp.kernel\ndef selection_add_impulse_on_particles(\n    state: MPMStateStruct, impulse_modifier: Impulse_modifier\n):\n    p = wp.tid()\n    offset = state.particle_x[p] - impulse_modifier.point\n    if (\n        wp.abs(offset[0]) < impulse_modifier.size[0]\n        and wp.abs(offset[1]) < impulse_modifier.size[1]\n        and wp.abs(offset[2]) < impulse_modifier.size[2]\n    ):\n        impulse_modifier.mask[p] = 1\n    else:\n        impulse_modifier.mask[p] = 0\n\n\n@wp.kernel\ndef selection_enforce_particle_velocity_translation(\n    state: MPMStateStruct, velocity_modifier: ParticleVelocityModifier\n):\n    p = wp.tid()\n    offset = state.particle_x[p] - velocity_modifier.point\n    if (\n        wp.abs(offset[0]) < velocity_modifier.size[0]\n        and wp.abs(offset[1]) < velocity_modifier.size[1]\n        and wp.abs(offset[2]) < velocity_modifier.size[2]\n    ):\n        velocity_modifier.mask[p] = 1\n    else:\n        velocity_modifier.mask[p] = 0\n\n\n@wp.kernel\ndef selection_enforce_particle_velocity_cylinder(\n    state: MPMStateStruct, velocity_modifier: ParticleVelocityModifier\n):\n    p = wp.tid()\n    offset = state.particle_x[p] - velocity_modifier.point\n\n    vertical_distance = wp.abs(wp.dot(offset, velocity_modifier.normal))\n\n    horizontal_distance = wp.length(\n        offset - wp.dot(offset, velocity_modifier.normal) * velocity_modifier.normal\n    )\n    if (\n        vertical_distance < velocity_modifier.half_height_and_radius[0]\n        and horizontal_distance < velocity_modifier.half_height_and_radius[1]\n    ):\n        velocity_modifier.mask[p] = 1\n    else:\n        velocity_modifier.mask[p] = 0\n\n@wp.kernel\ndef compute_position_l2_loss(\n    mpm_state: MPMStateStruct,\n    gt_pos: wp.array(dtype=wp.vec3),\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    pos = mpm_state.particle_x[tid]\n    pos_gt = gt_pos[tid]\n\n    # l1_diff = wp.abs(pos - pos_gt)\n    l2 = wp.length(pos - pos_gt)\n\n    wp.atomic_add(loss, 0, l2)\n\n@wp.kernel\ndef aggregate_grad(x: wp.array(dtype=float), grad: wp.array(dtype=float)):\n    tid = wp.tid()\n\n    # gradient descent step\n    wp.atomic_add(x, 0, grad[tid])\n\n\n@wp.kernel\ndef set_F_C_p2g(state: MPMStateStruct, model: MPMModelStruct, target_pos: wp.array(dtype=wp.vec3)):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        # p2g for displacement\n        particle_disp = target_pos[p] - state.particle_x[p]\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    v_in_add = weight * state.particle_mass[p] * particle_disp\n                    wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)\n                    wp.atomic_add(\n                        state.grid_m, ix, iy, iz, weight * state.particle_mass[p]\n                    )\n\n\n\n@wp.kernel\ndef set_F_C_g2p(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n        new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n        # g2p for C and F\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    grid_v = state.grid_v_out[ix, iy, iz]\n                    new_C = new_C + wp.outer(grid_v, dpos) * (\n                        weight * model.inv_dx * 4.0\n                    )\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n                    new_F = new_F + wp.outer(grid_v, dweight)\n\n        \n        # C should still be zero..\n        # state.particle_C[p] = new_C        \n        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n        F_tmp = (I33 + new_F)\n        state.particle_F_trial[p] = F_tmp\n\n        if model.update_cov_with_F:\n            update_cov(state, p, new_F, 1.0)\n\n\n@wp.kernel\ndef compute_posloss_with_grad(\n    mpm_state: MPMStateStruct,\n    gt_pos: wp.array(dtype=wp.vec3),\n    grad: wp.array(dtype=wp.vec3),\n    dt: float,\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    pos = mpm_state.particle_x[tid]\n    pos_gt = gt_pos[tid]\n\n    # l1_diff = wp.abs(pos - pos_gt)\n    l2 = wp.length(pos - (pos_gt - grad[tid] * dt))\n\n    wp.atomic_add(loss, 0, l2)"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/backup_jan10/warp_utils.py",
    "content": "import warp as wp\nimport ctypes\nfrom typing import Optional\n\nfrom warp.torch import (\n    dtype_from_torch,\n    device_from_torch,\n    dtype_is_compatible,\n    from_torch,\n)\n\n\ndef from_torch_safe(t, dtype=None, requires_grad=None, grad=None):\n    \"\"\"Wrap a PyTorch tensor to a Warp array without copying the data.\n\n    Args:\n        t (torch.Tensor): The torch tensor to wrap.\n        dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.\n        requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.\n\n    Returns:\n        warp.array: The wrapped array.\n    \"\"\"\n    if dtype is None:\n        dtype = dtype_from_torch(t.dtype)\n    elif not dtype_is_compatible(t.dtype, dtype):\n        raise RuntimeError(f\"Incompatible data types: {t.dtype} and {dtype}\")\n\n    # get size of underlying data type to compute strides\n    ctype_size = ctypes.sizeof(dtype._type_)\n\n    shape = tuple(t.shape)\n    strides = tuple(s * ctype_size for s in t.stride())\n\n    # if target is a vector or matrix type\n    # then check if trailing dimensions match\n    # the target type and update the shape\n    if hasattr(dtype, \"_shape_\"):\n        dtype_shape = dtype._shape_\n        dtype_dims = len(dtype._shape_)\n        if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:\n            raise RuntimeError(\n                f\"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}\"\n            )\n\n        # ensure the inner strides are contiguous\n        stride = ctype_size\n        for i in range(dtype_dims):\n            if strides[-i - 1] != stride:\n                raise RuntimeError(\n                    f\"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous\"\n                )\n            stride *= dtype_shape[-i - 1]\n\n        shape = tuple(shape[:-dtype_dims]) or (1,)\n        strides = tuple(strides[:-dtype_dims]) or (ctype_size,)\n\n    requires_grad = t.requires_grad if requires_grad is None else requires_grad\n    if grad is not None:\n        if not isinstance(grad, wp.array):\n            import torch\n\n            if isinstance(grad, torch.Tensor):\n                grad = from_torch(grad, dtype=dtype)\n            else:\n                raise ValueError(f\"Invalid gradient type: {type(grad)}\")\n    elif requires_grad:\n        # wrap the tensor gradient, allocate if necessary\n        if t.grad is None:\n            # allocate a zero-filled gradient tensor if it doesn't exist\n            import torch\n\n            t.grad = torch.zeros_like(t, requires_grad=False)\n        grad = from_torch(t.grad, dtype=dtype)\n\n    a = wp.types.array(\n        ptr=t.data_ptr(),\n        dtype=dtype,\n        shape=shape,\n        strides=strides,\n        device=device_from_torch(t.device),\n        copy=False,\n        owner=False,\n        grad=grad,\n        requires_grad=requires_grad,\n    )\n\n    # save a reference to the source tensor, otherwise it will be deallocated\n    a._tensor = t\n    return a\n\n\nclass MyTape(wp.Tape):\n    # returns the adjoint of a kernel parameter\n    def get_adjoint(self, a):\n        if not wp.types.is_array(a) and not isinstance(a, wp.codegen.StructInstance):\n            # if input is a simple type (e.g.: float, vec3, etc) then\n            # no gradient needed (we only return gradients through arrays and structs)\n            return a\n\n        elif wp.types.is_array(a) and a.grad:\n            # keep track of all gradients used by the tape (for zeroing)\n            # ignore the scalar loss since we don't want to clear its grad\n            self.gradients[a] = a.grad\n            return a.grad\n\n        elif isinstance(a, wp.codegen.StructInstance):\n            adj = a._cls()\n            for name, _ in a._cls.ctype._fields_:\n                if name.startswith(\"_\"):\n                    continue\n                if isinstance(a._cls.vars[name].type, wp.array):\n                    arr = getattr(a, name)\n                    if arr is None:\n                        continue\n                    if arr.grad:\n                        grad = self.gradients[arr] = arr.grad\n                    else:\n                        grad = wp.zeros_like(arr)\n                    setattr(adj, name, grad)\n                else:\n                    setattr(adj, name, getattr(a, name))\n\n            self.gradients[a] = adj\n            return adj\n\n        return None\n\n\n# from https://github.com/PingchuanMa/NCLaw/blob/main/nclaw/warp/tape.py\nclass CondTape(object):\n    def __init__(self, tape: Optional[MyTape], cond: bool = True) -> None:\n        self.tape = tape\n        self.cond = cond\n\n    def __enter__(self):\n        if self.tape is not None and self.cond:\n            self.tape.__enter__()\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        if self.tape is not None and self.cond:\n            self.tape.__exit__(exc_type, exc_value, traceback)"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/gaussian_sim_utils.py",
    "content": "import numpy as np\n\n\ndef get_volume(xyzs: np.ndarray, resolution=128) -> np.ndarray:\n\n    # set a grid in the range of [-1, 1], with resolution\n    voxel_counts = np.zeros((resolution, resolution, resolution))\n\n    points_xyzindex = ((xyzs + 1) / 2 * (resolution - 1)).astype(np.uint32)\n    cell_volume = (2.0 / (resolution - 1)) ** 3\n\n    for x, y, z in points_xyzindex:\n        voxel_counts[x, y, z] += 1\n\n    points_number_in_corresponding_voxel = voxel_counts[\n        points_xyzindex[:, 0], points_xyzindex[:, 1], points_xyzindex[:, 2]\n    ]\n\n    points_volume = cell_volume / points_number_in_corresponding_voxel\n\n    points_volume = points_volume.astype(np.float32)\n\n    # some statistics\n    num_non_empyt_voxels = np.sum(voxel_counts > 0)\n    max_points_in_voxel = np.max(voxel_counts)\n    min_points_in_voxel = np.min(voxel_counts)\n    avg_points_in_voxel = np.sum(voxel_counts) / num_non_empyt_voxels\n    print(\"Number of non-empty voxels: \", num_non_empyt_voxels)\n    print(\"Max points in voxel: \", max_points_in_voxel)\n    print(\"Min points in voxel: \", min_points_in_voxel)\n    print(\"Avg points in voxel: \", avg_points_in_voxel)\n\n    return points_volume\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/mpm_data_structure.py",
    "content": "import warp as wp\nimport warp.torch\nimport torch\nfrom typing import Optional, Union, Sequence, Any\nfrom torch import Tensor\nimport os\nimport sys\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)))\nfrom warp_utils import from_torch_safe\n\n\n@wp.struct\nclass MPMStateStruct(object):\n    ###### essential #####\n    # particle\n    particle_x: wp.array(dtype=wp.vec3)  # current position\n    particle_v: wp.array(dtype=wp.vec3)  # particle velocity\n    particle_F: wp.array(dtype=wp.mat33)  # particle elastic deformation gradient\n    particle_cov: wp.array(dtype=float)  # current covariance matrix\n    particle_F_trial: wp.array(\n        dtype=wp.mat33\n    )  # apply return mapping on this to obtain elastic def grad\n    particle_stress: wp.array(dtype=wp.mat33)  # Kirchoff stress, elastic stress\n    particle_C: wp.array(dtype=wp.mat33)\n    particle_vol: wp.array(dtype=float)  # current volume\n    particle_mass: wp.array(dtype=float)  # mass\n    particle_density: wp.array(dtype=float)  # density\n\n    particle_selection: wp.array(\n        dtype=int\n    )  # only particle_selection[p] = 0 will be simulated\n\n    # grid\n    grid_m: wp.array(dtype=float, ndim=3)\n    grid_v_in: wp.array(dtype=wp.vec3, ndim=3)  # grid node momentum/velocity\n    grid_v_out: wp.array(\n        dtype=wp.vec3, ndim=3\n    )  # grid node momentum/velocity, after grid update\n\n    def init(\n        self,\n        shape: Union[Sequence[int], int],\n        device: wp.context.Devicelike = None,\n        requires_grad=False,\n    ) -> None:\n        # shape default is int. number of particles\n        self.particle_x = wp.zeros(\n            shape, dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.particle_v = wp.zeros(\n            shape, dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.particle_F = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_cov = wp.zeros(\n            shape * 6, dtype=float, device=device, requires_grad=False\n        )\n\n        self.particle_F_trial = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_stress = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n        self.particle_C = wp.zeros(\n            shape, dtype=wp.mat33, device=device, requires_grad=requires_grad\n        )\n\n        self.particle_vol = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=False\n        )\n        self.particle_mass = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=False\n        )\n        self.particle_density = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=False\n        )\n\n        self.particle_selection = wp.zeros(\n            shape, dtype=int, device=device, requires_grad=False\n        )\n\n        # grid: will init later\n        self.grid_m = wp.zeros(\n            (10, 10, 10), dtype=float, device=device, requires_grad=requires_grad\n        )\n        self.grid_v_in = wp.zeros(\n            (10, 10, 10), dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n        self.grid_v_out = wp.zeros(\n            (10, 10, 10), dtype=wp.vec3, device=device, requires_grad=requires_grad\n        )\n\n    def init_grid(\n        self, grid_res: int, device: wp.context.Devicelike = None, requires_grad=False\n    ):\n        self.grid_m = wp.zeros(\n            (grid_res, grid_res, grid_res),\n            dtype=float,\n            device=device,\n            requires_grad=False,\n        )\n        self.grid_v_in = wp.zeros(\n            (grid_res, grid_res, grid_res),\n            dtype=wp.vec3,\n            device=device,\n            requires_grad=requires_grad,\n        )\n        self.grid_v_out = wp.zeros(\n            (grid_res, grid_res, grid_res),\n            dtype=wp.vec3,\n            device=device,\n            requires_grad=requires_grad,\n        )\n\n    def from_torch(\n        self,\n        tensor_x: Tensor,\n        tensor_volume: Tensor,\n        tensor_cov: Optional[Tensor] = None,\n        tensor_velocity: Optional[Tensor] = None,\n        n_grid: int = 100,\n        grid_lim=1.0,\n        device=\"cuda:0\",\n        requires_grad=True,\n    ):\n        num_dim, n_particles = tensor_x.shape[1], tensor_x.shape[0]\n        assert tensor_x.shape[0] == tensor_volume.shape[0]\n        # assert tensor_x.shape[0] == tensor_cov.reshape(-1, 6).shape[0]\n        self.init_grid(grid_res=n_grid, device=device, requires_grad=requires_grad)\n\n        if tensor_x is not None:\n            self.particle_x = from_torch_safe(\n                tensor_x.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_volume is not None:\n            print(self.particle_vol.shape, tensor_volume.shape)\n            volume_numpy = tensor_volume.detach().cpu().numpy()\n            self.particle_vol = wp.from_numpy(\n                volume_numpy, dtype=float, device=device, requires_grad=False\n            )\n\n        if tensor_cov is not None:\n            cov_numpy = tensor_cov.reshape(-1).detach().clone().cpu().numpy()\n            self.particle_cov = wp.from_numpy(\n                cov_numpy, dtype=float, device=device, requires_grad=False\n            )\n\n        if tensor_velocity is not None:\n            self.particle_v = from_torch_safe(\n                tensor_velocity.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        # initial deformation gradient is set to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F_trial],\n            device=device,\n        )\n        # initial trial deformation gradient is set to identity\n\n        print(\"Particles initialized from torch data.\")\n        print(\"Total particles: \", n_particles)\n\n    def reset_state(\n        self,\n        tensor_x: Tensor,\n        tensor_cov: Optional[Tensor] = None,\n        tensor_velocity: Optional[Tensor] = None,\n        tensor_density: Optional[Tensor] = None,\n        selection_mask: Optional[Tensor] = None,\n        device=\"cuda:0\",\n        requires_grad=True,\n    ):\n        # reset p_c, p_v, p_C, p_F_trial\n        num_dim, n_particles = tensor_x.shape[1], tensor_x.shape[0]\n\n        # assert tensor_x.shape[0] == tensor_cov.reshape(-1, 6).shape[0]\n\n        if tensor_x is not None:\n            self.particle_x = from_torch_safe(\n                tensor_x.contiguous().detach(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_cov is not None:\n            cov_numpy = tensor_cov.reshape(-1).detach().clone().cpu().numpy()\n            self.particle_cov = wp.from_numpy(\n                cov_numpy, dtype=float, device=device, requires_grad=False\n            )\n\n        if tensor_velocity is not None:\n            self.particle_v = from_torch_safe(\n                tensor_velocity.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_density is not None and selection_mask is not None:\n            wp_density = from_torch_safe(\n                tensor_density.contiguous().detach().clone(),\n                dtype=wp.float32,\n                requires_grad=False,\n            )\n            # 1 indicate we need to simulate this particle\n            wp_selection_mask = from_torch_safe(\n                selection_mask.contiguous().detach().clone().type(torch.int),\n                dtype=wp.int32,\n                requires_grad=False,\n            )\n\n            wp.launch(\n                kernel=set_float_vec_to_vec_wmask,\n                dim=n_particles,\n                inputs=[self.particle_density, wp_density, wp_selection_mask],\n                device=device,\n            )\n\n        # initial deformation gradient is set to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F_trial],\n            device=device,\n        )\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[self.particle_F],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=set_mat33_to_zero,\n            dim=n_particles,\n            inputs=[self.particle_C],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=set_mat33_to_zero,\n            dim=n_particles,\n            inputs=[self.particle_stress],\n            device=device,\n        )\n\n    def continue_from_torch(\n        self,\n        tensor_x: Tensor,\n        tensor_velocity: Optional[Tensor] = None,\n        tensor_F: Optional[Tensor] = None,\n        tensor_C: Optional[Tensor] = None,\n        device=\"cuda:0\",\n        requires_grad=True,\n    ):\n        if tensor_x is not None:\n            self.particle_x = from_torch_safe(\n                tensor_x.contiguous().detach(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_velocity is not None:\n            self.particle_v = from_torch_safe(\n                tensor_velocity.contiguous().detach().clone(),\n                dtype=wp.vec3,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_F is not None:\n            self.particle_F_trial = from_torch_safe(\n                tensor_F.contiguous().detach().clone(),\n                dtype=wp.mat33,\n                requires_grad=requires_grad,\n            )\n\n        if tensor_C is not None:\n            self.particle_C = from_torch_safe(\n                tensor_C.contiguous().detach().clone(),\n                dtype=wp.mat33,\n                requires_grad=requires_grad,\n            )\n\n    def set_require_grad(self, requires_grad=True):\n        self.particle_x.requires_grad = requires_grad\n        self.particle_v.requires_grad = requires_grad\n        self.particle_F.requires_grad = requires_grad\n        self.particle_F_trial.requires_grad = requires_grad\n        self.particle_stress.requires_grad = requires_grad\n        self.particle_C.requires_grad = requires_grad\n\n        self.grid_v_out.requires_grad = requires_grad\n        self.grid_v_in.requires_grad = requires_grad\n\n    def reset_density(\n        self,\n        tensor_density: Tensor,\n        selection_mask: Optional[Tensor] = None,\n        device=\"cuda:0\",\n        requires_grad=True,\n        update_mass=False,\n    ):\n        n_particles = tensor_density.shape[0]\n        if tensor_density is not None:\n            wp_density = from_torch_safe(\n                tensor_density.contiguous().detach().clone(),\n                dtype=wp.float32,\n                requires_grad=False,\n            )\n        \n        if selection_mask is not None:\n            # 1 indicate we need to simulate this particle\n            wp_selection_mask = from_torch_safe(\n                selection_mask.contiguous().detach().clone().type(torch.int),\n                dtype=wp.int32,\n                requires_grad=False,\n            )\n\n            wp.launch(\n                kernel=set_float_vec_to_vec_wmask,\n                dim=n_particles,\n                inputs=[self.particle_density, wp_density, wp_selection_mask],\n                device=device,\n            )\n        else:\n            wp.launch(\n                kernel=set_float_vec_to_vec,\n                dim=n_particles,\n                inputs=[self.particle_density, wp_density],\n                device=device,\n            )\n\n        if update_mass:\n            num_particles = self.particle_x.shape[0]\n            wp.launch(\n                kernel=get_float_array_product,\n                dim=num_particles,\n                inputs=[\n                    self.particle_density,\n                    self.particle_vol,\n                    self.particle_mass,\n                ],\n                device=device,\n            )\n\n    def partial_clone(self, device=\"cuda:0\", requires_grad=True):\n        new_state = MPMStateStruct()\n        n_particles = self.particle_x.shape[0]\n        new_state.init(n_particles, device=device, requires_grad=requires_grad)\n\n        # clone section:\n        # new_state.particle_vol = wp.clone(self.particle_vol, requires_grad=False)\n        # new_state.particle_density = wp.clone(self.particle_density, requires_grad=False)\n        # new_state.particle_mass = wp.clone(self.particle_mass, requires_grad=False)\n\n        # new_state.particle_selection = wp.clone(self.particle_selection, requires_grad=False)\n\n        wp.copy(new_state.particle_vol, self.particle_vol)\n        wp.copy(new_state.particle_density, self.particle_density)\n        wp.copy(new_state.particle_mass, self.particle_mass)\n        wp.copy(new_state.particle_selection, self.particle_selection)\n\n        # init grid to zero with grid res.\n        new_state.init_grid(\n            grid_res=self.grid_v_in.shape[0], device=device, requires_grad=requires_grad\n        )\n\n        # init some matrix to identity\n        wp.launch(\n            kernel=set_mat33_to_identity,\n            dim=n_particles,\n            inputs=[new_state.particle_F_trial],\n            device=device,\n        )\n\n        new_state.set_require_grad(requires_grad=requires_grad)\n        return new_state\n\n\n@wp.struct\nclass MPMModelStruct(object):\n    ####### essential #######\n    grid_lim: float\n    n_particles: int\n    n_grid: int\n    dx: float\n    inv_dx: float\n    grid_dim_x: int\n    grid_dim_y: int\n    grid_dim_z: int\n    mu: wp.array(dtype=float)\n    lam: wp.array(dtype=float)\n    E: wp.array(dtype=float)\n    nu: wp.array(dtype=float)\n    material: int\n\n    ######## for plasticity ####\n    yield_stress: wp.array(dtype=float)\n    friction_angle: float\n    alpha: float\n    gravitational_accelaration: wp.vec3\n    hardening: float\n    xi: float\n    plastic_viscosity: float\n    softening: float\n\n    ####### for damping\n    rpic_damping: float\n    grid_v_damping_scale: float\n\n    ####### for PhysGaussian: covariance\n    update_cov_with_F: int\n\n    def init(\n        self,\n        shape: Union[Sequence[int], int],\n        device: wp.context.Devicelike = None,\n        requires_grad=False,\n    ) -> None:\n        self.E = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )  # young's modulus\n        self.nu = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )  # poisson's ratio\n\n        self.mu = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n        self.lam = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n        self.yield_stress = wp.zeros(\n            shape, dtype=float, device=device, requires_grad=requires_grad\n        )\n\n    def finalize_mu_lam(self, n_particles, device=\"cuda:0\"):\n        wp.launch(\n            kernel=compute_mu_lam_from_E_nu_clean,\n            dim=n_particles,\n            inputs=[self.mu, self.lam, self.E, self.nu],\n            device=device,\n        )\n\n    def init_other_params(self, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.grid_lim = grid_lim\n        self.n_grid = n_grid\n        self.grid_dim_x = n_grid\n        self.grid_dim_y = n_grid\n        self.grid_dim_z = n_grid\n        (\n            self.dx,\n            self.inv_dx,\n        ) = self.grid_lim / self.n_grid, float(\n            n_grid / grid_lim\n        )  # [0-1]?\n\n        self.update_cov_with_F = False\n\n        # material is used to switch between different elastoplastic models. 0 is jelly\n        self.material = 0\n\n        self.plastic_viscosity = 0.0\n        self.softening = 0.1\n        self.friction_angle = 25.0\n        sin_phi = wp.sin(self.friction_angle / 180.0 * 3.14159265)\n        self.alpha = wp.sqrt(2.0 / 3.0) * 2.0 * sin_phi / (3.0 - sin_phi)\n\n        self.gravitational_accelaration = wp.vec3(0.0, 0.0, 0.0)\n\n        self.rpic_damping = 0.0  # 0.0 if no damping (apic). -1 if pic\n\n        self.grid_v_damping_scale = 1.1  # globally applied\n\n    def from_torch(\n        self, tensor_E: Tensor, tensor_nu: Tensor, device=\"cuda:0\", requires_grad=False\n    ):\n        self.E = wp.from_torch(tensor_E.contiguous(), requires_grad=requires_grad)\n        self.nu = wp.from_torch(tensor_nu.contiguous(), requires_grad=requires_grad)\n        n_particles = tensor_E.shape[0]\n        self.finalize_mu_lam(n_particles=n_particles, device=device)\n\n    def set_require_grad(self, requires_grad=True):\n        self.E.requires_grad = requires_grad\n        self.nu.requires_grad = requires_grad\n        self.mu.requires_grad = requires_grad\n        self.lam.requires_grad = requires_grad\n\n\n# for various boundary conditions\n@wp.struct\nclass Dirichlet_collider:\n    point: wp.vec3\n    normal: wp.vec3\n    direction: wp.vec3\n\n    start_time: float\n    end_time: float\n\n    friction: float\n    surface_type: int\n\n    velocity: wp.vec3\n\n    threshold: float\n    reset: int\n    index: int\n\n    x_unit: wp.vec3\n    y_unit: wp.vec3\n    radius: float\n    v_scale: float\n    width: float\n    height: float\n    length: float\n    R: float\n\n    size: wp.vec3\n\n    horizontal_axis_1: wp.vec3\n    horizontal_axis_2: wp.vec3\n    half_height_and_radius: wp.vec2\n\n\n@wp.struct\nclass GridCollider:\n    point: wp.vec3\n    normal: wp.vec3\n    direction: wp.vec3\n\n    start_time: float\n    end_time: float\n    mask: wp.array(dtype=int, ndim=3)\n\n\n@wp.struct\nclass Impulse_modifier:\n    # this needs to be changed for each different BC!\n    point: wp.vec3\n    normal: wp.vec3\n    start_time: float\n    end_time: float\n    force: wp.vec3\n    forceTimesDt: wp.vec3\n    numsteps: int\n\n    point: wp.vec3\n    size: wp.vec3\n    mask: wp.array(dtype=int)\n\n\n@wp.struct\nclass MPMtailoredStruct:\n    # this needs to be changed for each different BC!\n    point: wp.vec3\n    normal: wp.vec3\n    start_time: float\n    end_time: float\n    friction: float\n    surface_type: int\n    velocity: wp.vec3\n    threshold: float\n    reset: int\n\n    point_rotate: wp.vec3\n    normal_rotate: wp.vec3\n    x_unit: wp.vec3\n    y_unit: wp.vec3\n    radius: float\n    v_scale: float\n    width: float\n    point_plane: wp.vec3\n    normal_plane: wp.vec3\n    velocity_plane: wp.vec3\n    threshold_plane: float\n\n\n@wp.struct\nclass MaterialParamsModifier:\n    point: wp.vec3\n    size: wp.vec3\n    E: float\n    nu: float\n    density: float\n\n\n@wp.struct\nclass ParticleVelocityModifier:\n    point: wp.vec3\n    normal: wp.vec3\n    half_height_and_radius: wp.vec2\n    rotation_scale: float\n    translation_scale: float\n\n    size: wp.vec3\n\n    horizontal_axis_1: wp.vec3\n    horizontal_axis_2: wp.vec3\n\n    start_time: float\n\n    end_time: float\n\n    velocity: wp.vec3\n\n    mask: wp.array(dtype=int)\n\n\n@wp.kernel\ndef compute_mu_lam_from_E_nu_clean(\n    mu: wp.array(dtype=float),\n    lam: wp.array(dtype=float),\n    E: wp.array(dtype=float),\n    nu: wp.array(dtype=float),\n):\n    p = wp.tid()\n    mu[p] = E[p] / (2.0 * (1.0 + nu[p]))\n    lam[p] = E[p] * nu[p] / ((1.0 + nu[p]) * (1.0 - 2.0 * nu[p]))\n\n\n@wp.kernel\ndef set_vec3_to_zero(target_array: wp.array(dtype=wp.vec3)):\n    tid = wp.tid()\n    target_array[tid] = wp.vec3(0.0, 0.0, 0.0)\n\n\n@wp.kernel\ndef set_vec3_to_vec3(\n    source_array: wp.array(dtype=wp.vec3), target_array: wp.array(dtype=wp.vec3)\n):\n    tid = wp.tid()\n    source_array[tid] = target_array[tid]\n\n\n@wp.kernel\ndef set_float_vec_to_vec_wmask(\n    source_array: wp.array(dtype=float),\n    target_array: wp.array(dtype=float),\n    selection_mask: wp.array(dtype=int),\n):\n    tid = wp.tid()\n    if selection_mask[tid] == 1:\n        source_array[tid] = target_array[tid]\n\n\n@wp.kernel\ndef set_float_vec_to_vec(\n    source_array: wp.array(dtype=float), target_array: wp.array(dtype=float)\n):\n    tid = wp.tid()\n    source_array[tid] = target_array[tid]\n\n\n@wp.kernel\ndef set_mat33_to_identity(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n\n\n@wp.kernel\ndef set_mat33_to_zero(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n\n@wp.kernel\ndef add_identity_to_mat33(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.add(\n        target_array[tid], wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    )\n\n\n@wp.kernel\ndef subtract_identity_to_mat33(target_array: wp.array(dtype=wp.mat33)):\n    tid = wp.tid()\n    target_array[tid] = wp.sub(\n        target_array[tid], wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    )\n\n\n@wp.kernel\ndef add_vec3_to_vec3(\n    first_array: wp.array(dtype=wp.vec3), second_array: wp.array(dtype=wp.vec3)\n):\n    tid = wp.tid()\n    first_array[tid] = wp.add(first_array[tid], second_array[tid])\n\n\n@wp.kernel\ndef set_value_to_float_array(target_array: wp.array(dtype=float), value: float):\n    tid = wp.tid()\n    target_array[tid] = value\n\n\n@wp.kernel\ndef set_warpvalue_to_float_array(\n    target_array: wp.array(dtype=float), value: warp.types.float32\n):\n    tid = wp.tid()\n    target_array[tid] = value\n\n\n@wp.kernel\ndef get_float_array_product(\n    arrayA: wp.array(dtype=float),\n    arrayB: wp.array(dtype=float),\n    arrayC: wp.array(dtype=float),\n):\n    tid = wp.tid()\n    arrayC[tid] = arrayA[tid] * arrayB[tid]\n\n\ndef torch2warp_quat(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 4\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.quat,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_float(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=warp.types.float32,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_vec3(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 3\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.vec3,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n\n\ndef torch2warp_mat33(t, copy=False, dtype=warp.types.float32, dvc=\"cuda:0\"):\n    assert t.is_contiguous()\n    if t.dtype != torch.float32 and t.dtype != torch.int32:\n        raise RuntimeError(\n            \"Error aliasing Torch tensor to Warp array. Torch tensor must be float32 or int32 type\"\n        )\n    assert t.shape[1] == 3\n    a = warp.types.array(\n        ptr=t.data_ptr(),\n        dtype=wp.mat33,\n        shape=t.shape[0],\n        copy=False,\n        owner=False,\n        requires_grad=t.requires_grad,\n        # device=t.device.type)\n        device=dvc,\n    )\n    a.tensor = t\n    return a\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/mpm_solver_diff.py",
    "content": "import sys\nimport os\n\nimport warp as wp\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)))\nfrom mpm_data_structure import *\nfrom mpm_utils import *\nfrom typing import Optional, Union, Sequence, Any, Tuple\nfrom jaxtyping import Float, Int, Shaped\n\n\nclass MPMWARPDiff(object):\n    # def __init__(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n    #     self.initialize(n_particles, n_grid, grid_lim, device=device)\n    #     self.time_profile = {}\n\n    def __init__(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.initialize(n_particles, n_grid, grid_lim, device=device)\n        self.time_profile = {}\n\n    def initialize(self, n_particles, n_grid=100, grid_lim=1.0, device=\"cuda:0\"):\n        self.n_particles = n_particles\n\n        self.time = 0.0\n\n        self.grid_postprocess = []\n        self.collider_params = []\n        self.modify_bc = []\n\n        self.tailored_struct_for_bc = MPMtailoredStruct()\n        self.pre_p2g_operations = []\n        self.impulse_params = []\n\n        self.particle_velocity_modifiers = []\n        self.particle_velocity_modifier_params = []\n\n    # must give density. mass will be updated as density * volume\n    def set_parameters(self, device=\"cuda:0\", **kwargs):\n        self.set_parameters_dict(device, kwargs)\n\n    def set_parameters_dict(self, mpm_model, mpm_state, kwargs={}, device=\"cuda:0\"):\n        if \"material\" in kwargs:\n            if kwargs[\"material\"] == \"jelly\":\n                mpm_model.material = 0\n            elif kwargs[\"material\"] == \"metal\":\n                mpm_model.material = 1\n            elif kwargs[\"material\"] == \"sand\":\n                mpm_model.material = 2\n            elif kwargs[\"material\"] == \"foam\":\n                mpm_model.material = 3\n            elif kwargs[\"material\"] == \"snow\":\n                mpm_model.material = 4\n            elif kwargs[\"material\"] == \"plasticine\":\n                mpm_model.material = 5\n            elif kwargs[\"material\"] == \"neo-hookean\":\n                mpm_model.material = 6\n            else:\n                raise TypeError(\"Undefined material type\")\n\n        if \"yield_stress\" in kwargs:\n            val = kwargs[\"yield_stress\"]\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_model.yield_stress, val],\n                device=device,\n            )\n        if \"hardening\" in kwargs:\n            mpm_model.hardening = kwargs[\"hardening\"]\n        if \"xi\" in kwargs:\n            mpm_model.xi = kwargs[\"xi\"]\n        if \"friction_angle\" in kwargs:\n            mpm_model.friction_angle = kwargs[\"friction_angle\"]\n            sin_phi = wp.sin(mpm_model.friction_angle / 180.0 * 3.14159265)\n            mpm_model.alpha = wp.sqrt(2.0 / 3.0) * 2.0 * sin_phi / (3.0 - sin_phi)\n\n        if \"g\" in kwargs:\n            mpm_model.gravitational_accelaration = wp.vec3(\n                kwargs[\"g\"][0], kwargs[\"g\"][1], kwargs[\"g\"][2]\n            )\n\n        if \"density\" in kwargs:\n            density_value = kwargs[\"density\"]\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_state.particle_density, density_value],\n                device=device,\n            )\n            wp.launch(\n                kernel=get_float_array_product,\n                dim=self.n_particles,\n                inputs=[\n                    mpm_state.particle_density,\n                    mpm_state.particle_vol,\n                    mpm_state.particle_mass,\n                ],\n                device=device,\n            )\n        if \"rpic_damping\" in kwargs:\n            mpm_model.rpic_damping = kwargs[\"rpic_damping\"]\n        if \"plastic_viscosity\" in kwargs:\n            mpm_model.plastic_viscosity = kwargs[\"plastic_viscosity\"]\n        if \"softening\" in kwargs:\n            mpm_model.softening = kwargs[\"softening\"]\n        if \"grid_v_damping_scale\" in kwargs:\n            mpm_model.grid_v_damping_scale = kwargs[\"grid_v_damping_scale\"]\n\n    def set_E_nu(self, mpm_model, E: float, nu: float, device=\"cuda:0\"):\n        if isinstance(E, float):\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_model.E, E],\n                device=device,\n            )\n        else:  # E is warp array\n            wp.launch(\n                kernel=set_float_vec_to_vec,\n                dim=self.n_particles,\n                inputs=[mpm_model.E, E],\n                device=device,\n            )\n\n        if isinstance(nu, float):\n            wp.launch(\n                kernel=set_value_to_float_array,\n                dim=self.n_particles,\n                inputs=[mpm_model.nu, nu],\n                device=device,\n            )\n        else:\n            wp.launch(\n                kernel=set_float_vec_to_vec,\n                dim=self.n_particles,\n                inputs=[mpm_model.nu, nu],\n                device=device,\n            )\n\n    def set_E_nu_from_torch(\n        self,\n        mpm_model,\n        E: Float[Tensor, \"n\"] | Float[Tensor, \"1\"],\n        nu: Float[Tensor, \"n\"] | Float[Tensor, \"1\"],\n        device=\"cuda:0\",\n    ):\n        if E.ndim == 0:\n            E_inp = E.item()  # float\n        else:\n            E_inp = from_torch_safe(E, dtype=wp.float32, requires_grad=True)\n\n        if nu.ndim == 0:\n            nu_inp = nu.item()  # float\n        else:\n            nu_inp = from_torch_safe(nu, dtype=wp.float32, requires_grad=True)\n\n        self.set_E_nu(mpm_model, E_inp, nu_inp, device=device)\n\n    def prepare_mu_lam(self, mpm_model, mpm_state, device=\"cuda:0\"):\n        # compute mu and lam from E and nu\n        wp.launch(\n            kernel=compute_mu_lam_from_E_nu,\n            dim=self.n_particles,\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n    def p2g2p_differentiable(\n        self, mpm_model, mpm_state, next_state, dt, device=\"cuda:0\"\n    ):\n        \"\"\"\n        Some boundary conditions, might not give gradient,\n        see kernels in\n            self.pre_p2g_operations,    Usually None.\n            self.particle_velocity_modifiers.   Mostly used to freeze points\n            self.grid_postprocess,      Should apply BC here\n        \"\"\"\n        grid_size = (\n            mpm_model.grid_dim_x,\n            mpm_model.grid_dim_y,\n            mpm_model.grid_dim_z,\n        )\n        wp.launch(\n            kernel=zero_grid,  # gradient might gone\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        # apply pre-p2g operations on particles\n        # apply impulse force on particles..\n        for k in range(len(self.pre_p2g_operations)):\n            wp.launch(\n                kernel=self.pre_p2g_operations[k],\n                dim=self.n_particles,\n                inputs=[self.time, dt, mpm_state, self.impulse_params[k]],\n                device=device,\n            )\n\n        # apply dirichlet particle v modifier\n        for k in range(len(self.particle_velocity_modifiers)):\n            wp.launch(\n                kernel=self.particle_velocity_modifiers[k],\n                dim=self.n_particles,\n                inputs=[\n                    self.time,\n                    mpm_state,\n                    self.particle_velocity_modifier_params[k],\n                ],\n                device=device,\n            )\n\n        # compute stress = stress(returnMap(F_trial))\n        # F_trail => F                    # TODO: this is overite..\n        # F, SVD(F), lam, mu => Stress.   # TODO: this is overite..\n\n        with wp.ScopedTimer(\n            \"compute_stress_from_F_trial\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=compute_stress_from_F_trial,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # F and stress are updated\n\n        # p2g\n        with wp.ScopedTimer(\n            \"p2g\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=p2g_apic_with_stress,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # apply p2g'\n\n        # grid update\n        with wp.ScopedTimer(\n            \"grid_update\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=grid_normalization_and_gravity,\n                dim=(grid_size),\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )\n\n        if mpm_model.grid_v_damping_scale < 1.0:\n            wp.launch(\n                kernel=add_damping_via_grid,\n                dim=(grid_size),\n                inputs=[mpm_state, mpm_model.grid_v_damping_scale],\n                device=device,\n            )\n\n        # apply BC on grid, collide\n        with wp.ScopedTimer(\n            \"apply_BC_on_grid\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            for k in range(len(self.grid_postprocess)):\n                wp.launch(\n                    kernel=self.grid_postprocess[k],\n                    dim=grid_size,\n                    inputs=[\n                        self.time,\n                        dt,\n                        mpm_state,\n                        mpm_model,\n                        self.collider_params[k],\n                    ],\n                    device=device,\n                )\n                if self.modify_bc[k] is not None:\n                    self.modify_bc[k](self.time, dt, self.collider_params[k])\n\n        # g2p\n        with wp.ScopedTimer(\n            \"g2p\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=g2p_differentiable,\n                dim=self.n_particles,\n                inputs=[mpm_state, next_state, mpm_model, dt],\n                device=device,\n            )  # x, v, C, F_trial are updated\n\n        self.time = self.time + dt\n\n    def p2g2p(self, mpm_model, mpm_state, step, dt, device=\"cuda:0\"):\n        grid_size = (\n            mpm_model.grid_dim_x,\n            mpm_model.grid_dim_y,\n            mpm_model.grid_dim_z,\n        )\n\n        wp.launch(\n            kernel=zero_grid,  # gradient might gone\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        # apply pre-p2g operations on particles\n        # apply impulse force on particles..\n        for k in range(len(self.pre_p2g_operations)):\n            wp.launch(\n                kernel=self.pre_p2g_operations[k],\n                dim=self.n_particles,\n                inputs=[self.time, dt, mpm_state, self.impulse_params[k]],\n                device=device,\n            )\n\n        # apply dirichlet particle v modifier\n        for k in range(len(self.particle_velocity_modifiers)):\n            wp.launch(\n                kernel=self.particle_velocity_modifiers[k],\n                dim=self.n_particles,\n                inputs=[\n                    self.time,\n                    mpm_state,\n                    self.particle_velocity_modifier_params[k],\n                ],\n                device=device,\n            )\n\n        # compute stress = stress(returnMap(F_trial))\n        # F_trail => F                    # TODO: this is overite..\n        # F, SVD(F), lam, mu => Stress.   # TODO: this is overite..\n\n        with wp.ScopedTimer(\n            \"compute_stress_from_F_trial\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=compute_stress_from_F_trial,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # F and stress are updated\n\n        # p2g\n        with wp.ScopedTimer(\n            \"p2g\",\n            synchronize=True,\n            print=False,\n            dict=self.time_profile,\n        ):\n            wp.launch(\n                kernel=p2g_apic_with_stress,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # apply p2g'\n\n        # grid update\n        with wp.ScopedTimer(\n            \"grid_update\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=grid_normalization_and_gravity,\n                dim=(grid_size),\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )\n\n        if mpm_model.grid_v_damping_scale < 1.0:\n            wp.launch(\n                kernel=add_damping_via_grid,\n                dim=(grid_size),\n                inputs=[mpm_state, mpm_model.grid_v_damping_scale],\n                device=device,\n            )\n\n        # apply BC on grid, collide\n        with wp.ScopedTimer(\n            \"apply_BC_on_grid\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            for k in range(len(self.grid_postprocess)):\n                wp.launch(\n                    kernel=self.grid_postprocess[k],\n                    dim=grid_size,\n                    inputs=[\n                        self.time,\n                        dt,\n                        mpm_state,\n                        mpm_model,\n                        self.collider_params[k],\n                    ],\n                    device=device,\n                )\n                if self.modify_bc[k] is not None:\n                    self.modify_bc[k](self.time, dt, self.collider_params[k])\n\n        # g2p\n        with wp.ScopedTimer(\n            \"g2p\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=g2p,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model, dt],\n                device=device,\n            )  # x, v, C, F_trial are updated\n\n        #### CFL check ####\n        # particle_v = self.mpm_state.particle_v.numpy()\n        # if np.max(np.abs(particle_v)) > self.mpm_model.dx / dt:\n        #     print(\"max particle v: \", np.max(np.abs(particle_v)))\n        #     print(\"max allowed  v: \", self.mpm_model.dx / dt)\n        #     print(\"does not allow v*dt>dx\")\n        #     input()\n        #### CFL check ####\n        with wp.ScopedTimer(\n            \"clip_particle_x\", synchronize=True, print=False, dict=self.time_profile\n        ):\n            wp.launch(\n                kernel=clip_particle_x,\n                dim=self.n_particles,\n                inputs=[mpm_state, mpm_model],\n                device=device,\n            )\n\n        self.time = self.time + dt\n\n    def print_time_profile(self):\n        print(\"MPM Time profile:\")\n        for key, value in self.time_profile.items():\n            print(key, sum(value))\n\n    # a surface specified by a point and the normal vector\n    def add_surface_collider(\n        self,\n        point,\n        normal,\n        surface=\"sticky\",\n        friction=0.0,\n        start_time=0.0,\n        end_time=999.0,\n    ):\n        point = list(point)\n        # Normalize normal\n        normal_scale = 1.0 / wp.sqrt(float(sum(x**2 for x in normal)))\n        normal = list(normal_scale * x for x in normal)\n\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n\n        collider_param.point = wp.vec3(point[0], point[1], point[2])\n        collider_param.normal = wp.vec3(normal[0], normal[1], normal[2])\n\n        if surface == \"sticky\" and friction != 0:\n            raise ValueError(\"friction must be 0 on sticky surfaces.\")\n        if surface == \"sticky\":\n            collider_param.surface_type = 0\n        elif surface == \"slip\":\n            collider_param.surface_type = 1\n        elif surface == \"cut\":\n            collider_param.surface_type = 11\n        else:\n            collider_param.surface_type = 2\n        # frictional\n        collider_param.friction = friction\n\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                offset = wp.vec3(\n                    float(grid_x) * model.dx - param.point[0],\n                    float(grid_y) * model.dx - param.point[1],\n                    float(grid_z) * model.dx - param.point[2],\n                )\n                n = wp.vec3(param.normal[0], param.normal[1], param.normal[2])\n                dotproduct = wp.dot(offset, n)\n\n                if dotproduct < 0.0:\n                    if param.surface_type == 0:\n                        state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                            0.0, 0.0, 0.0\n                        )\n                    elif param.surface_type == 11:\n                        if (\n                            float(grid_z) * model.dx < 0.4\n                            or float(grid_z) * model.dx > 0.53\n                        ):\n                            state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                                0.0, 0.0, 0.0\n                            )\n                        else:\n                            v_in = state.grid_v_out[grid_x, grid_y, grid_z]\n                            state.grid_v_out[grid_x, grid_y, grid_z] = (\n                                wp.vec3(v_in[0], 0.0, v_in[2]) * 0.3\n                            )\n                    else:\n                        v = state.grid_v_out[grid_x, grid_y, grid_z]\n                        normal_component = wp.dot(v, n)\n                        if param.surface_type == 1:\n                            v = (\n                                v - normal_component * n\n                            )  # Project out all normal component\n                        else:\n                            v = (\n                                v - wp.min(normal_component, 0.0) * n\n                            )  # Project out only inward normal component\n                        if normal_component < 0.0 and wp.length(v) > 1e-20:\n                            v = wp.max(\n                                0.0, wp.length(v) + normal_component * param.friction\n                            ) * wp.normalize(\n                                v\n                            )  # apply friction here\n                        state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                            0.0, 0.0, 0.0\n                        )\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(None)\n\n    # a cubiod is a rectangular cube'\n    # centered at `point`\n    # dimension is x: point[0]±size[0]\n    #              y: point[1]±size[1]\n    #              z: point[2]±size[2]\n    # all grid nodes lie within the cubiod will have their speed set to velocity\n    # the cuboid itself is also moving with const speed = velocity\n    # set the speed to zero to fix BC\n    def set_velocity_on_cuboid(\n        self,\n        point,\n        size,\n        velocity,\n        start_time=0.0,\n        end_time=999.0,\n        reset=0,\n    ):\n        point = list(point)\n\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n        collider_param.point = wp.vec3(point[0], point[1], point[2])\n        collider_param.size = size\n        collider_param.velocity = wp.vec3(velocity[0], velocity[1], velocity[2])\n        # collider_param.threshold = threshold\n        collider_param.reset = reset\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                offset = wp.vec3(\n                    float(grid_x) * model.dx - param.point[0],\n                    float(grid_y) * model.dx - param.point[1],\n                    float(grid_z) * model.dx - param.point[2],\n                )\n                if (\n                    wp.abs(offset[0]) < param.size[0]\n                    and wp.abs(offset[1]) < param.size[1]\n                    and wp.abs(offset[2]) < param.size[2]\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = param.velocity\n            elif param.reset == 1:\n                if time < param.end_time + 15.0 * dt:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n\n        def modify(time, dt, param: Dirichlet_collider):\n            if time >= param.start_time and time < param.end_time:\n                param.point = wp.vec3(\n                    param.point[0] + dt * param.velocity[0],\n                    param.point[1] + dt * param.velocity[1],\n                    param.point[2] + dt * param.velocity[2],\n                )  # param.point + dt * param.velocity\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(modify)\n\n    def add_bounding_box(self, start_time=0.0, end_time=999.0):\n        collider_param = Dirichlet_collider()\n        collider_param.start_time = start_time\n        collider_param.end_time = end_time\n\n        self.collider_params.append(collider_param)\n\n        @wp.kernel\n        def collide(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            param: Dirichlet_collider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n            padding = 3\n            if time >= param.start_time and time < param.end_time:\n                if grid_x < padding and state.grid_v_out[grid_x, grid_y, grid_z][0] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n                if (\n                    grid_x >= model.grid_dim_x - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][0] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n\n                if grid_y < padding and state.grid_v_out[grid_x, grid_y, grid_z][1] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n                if (\n                    grid_y >= model.grid_dim_y - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][1] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        0.0,\n                        state.grid_v_out[grid_x, grid_y, grid_z][2],\n                    )\n\n                if grid_z < padding and state.grid_v_out[grid_x, grid_y, grid_z][2] < 0:\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        0.0,\n                    )\n                if (\n                    grid_z >= model.grid_dim_z - padding\n                    and state.grid_v_out[grid_x, grid_y, grid_z][2] > 0\n                ):\n                    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(\n                        state.grid_v_out[grid_x, grid_y, grid_z][0],\n                        state.grid_v_out[grid_x, grid_y, grid_z][1],\n                        0.0,\n                    )\n\n        self.grid_postprocess.append(collide)\n        self.modify_bc.append(None)\n\n    # particle_v += force/particle_mass * dt\n    # this is applied from start_dt, ends after num_dt p2g2p's\n    # particle velocity is changed before p2g at each timestep\n    def add_impulse_on_particles(\n        self,\n        mpm_state,\n        force,\n        dt,\n        point=[1, 1, 1],\n        size=[1, 1, 1],\n        num_dt=1,\n        start_time=0.0,\n        device=\"cuda:0\",\n    ):\n        impulse_param = Impulse_modifier()\n        impulse_param.start_time = start_time\n        impulse_param.end_time = start_time + dt * num_dt\n\n        impulse_param.point = wp.vec3(point[0], point[1], point[2])\n        impulse_param.size = wp.vec3(size[0], size[1], size[2])\n        impulse_param.mask = wp.zeros(shape=self.n_particles, dtype=int, device=device)\n\n        impulse_param.force = wp.vec3(\n            force[0],\n            force[1],\n            force[2],\n        )\n\n        wp.launch(\n            kernel=selection_add_impulse_on_particles,\n            dim=self.n_particles,\n            inputs=[mpm_state, impulse_param],\n            device=device,\n        )\n\n        self.impulse_params.append(impulse_param)\n\n        @wp.kernel\n        def apply_force(\n            time: float, dt: float, state: MPMStateStruct, param: Impulse_modifier\n        ):\n            p = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                if param.mask[p] == 1:\n                    impulse = wp.vec3(\n                        param.force[0] / state.particle_mass[p],\n                        param.force[1] / state.particle_mass[p],\n                        param.force[2] / state.particle_mass[p],\n                    )\n                    state.particle_v[p] = state.particle_v[p] + impulse * dt\n\n        self.pre_p2g_operations.append(apply_force)\n\n    def enforce_particle_velocity_translation(\n        self, mpm_state, point, size, velocity, start_time, end_time, device=\"cuda:0\"\n    ):\n        # first select certain particles based on position\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.point = wp.vec3(point[0], point[1], point[2])\n        velocity_modifier_params.size = wp.vec3(size[0], size[1], size[2])\n\n        velocity_modifier_params.velocity = wp.vec3(\n            velocity[0], velocity[1], velocity[2]\n        )\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.zeros(\n            shape=self.n_particles, dtype=int, device=device\n        )\n\n        wp.launch(\n            kernel=selection_enforce_particle_velocity_translation,\n            dim=self.n_particles,\n            inputs=[mpm_state, velocity_modifier_params],\n            device=device,\n        )\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    state.particle_v[p] = velocity_modifier_params.velocity\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    # define a cylinder with center point, half_height, radius, normal\n    # particles within the cylinder are rotating along the normal direction\n    # may also have a translational velocity along the normal direction\n    def enforce_particle_velocity_rotation(\n        self,\n        mpm_state,\n        point,\n        normal,\n        half_height_and_radius,\n        rotation_scale,\n        translation_scale,\n        start_time,\n        end_time,\n        device=\"cuda:0\",\n    ):\n        normal_scale = 1.0 / wp.sqrt(\n            float(normal[0] ** 2 + normal[1] ** 2 + normal[2] ** 2)\n        )\n        normal = list(normal_scale * x for x in normal)\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.point = wp.vec3(point[0], point[1], point[2])\n        velocity_modifier_params.half_height_and_radius = wp.vec2(\n            half_height_and_radius[0], half_height_and_radius[1]\n        )\n        velocity_modifier_params.normal = wp.vec3(normal[0], normal[1], normal[2])\n\n        horizontal_1 = wp.vec3(1.0, 1.0, 1.0)\n        if wp.abs(wp.dot(velocity_modifier_params.normal, horizontal_1)) < 0.01:\n            horizontal_1 = wp.vec3(0.72, 0.37, -0.67)\n        horizontal_1 = (\n            horizontal_1\n            - wp.dot(horizontal_1, velocity_modifier_params.normal)\n            * velocity_modifier_params.normal\n        )\n        horizontal_1 = horizontal_1 * (1.0 / wp.length(horizontal_1))\n        horizontal_2 = wp.cross(horizontal_1, velocity_modifier_params.normal)\n\n        velocity_modifier_params.horizontal_axis_1 = horizontal_1\n        velocity_modifier_params.horizontal_axis_2 = horizontal_2\n\n        velocity_modifier_params.rotation_scale = rotation_scale\n        velocity_modifier_params.translation_scale = translation_scale\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.zeros(\n            shape=self.n_particles, dtype=int, device=device\n        )\n\n        wp.launch(\n            kernel=selection_enforce_particle_velocity_cylinder,\n            dim=self.n_particles,\n            inputs=[mpm_state, velocity_modifier_params],\n            device=device,\n        )\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    offset = state.particle_x[p] - velocity_modifier_params.point\n                    horizontal_distance = wp.length(\n                        offset\n                        - wp.dot(offset, velocity_modifier_params.normal)\n                        * velocity_modifier_params.normal\n                    )\n                    cosine = (\n                        wp.dot(offset, velocity_modifier_params.horizontal_axis_1)\n                        / horizontal_distance\n                    )\n                    theta = wp.acos(cosine)\n                    if wp.dot(offset, velocity_modifier_params.horizontal_axis_2) > 0:\n                        theta = theta\n                    else:\n                        theta = -theta\n                    axis1_scale = (\n                        -horizontal_distance\n                        * wp.sin(theta)\n                        * velocity_modifier_params.rotation_scale\n                    )\n                    axis2_scale = (\n                        horizontal_distance\n                        * wp.cos(theta)\n                        * velocity_modifier_params.rotation_scale\n                    )\n                    axis_vertical_scale = translation_scale\n                    state.particle_v[p] = (\n                        axis1_scale * velocity_modifier_params.horizontal_axis_1\n                        + axis2_scale * velocity_modifier_params.horizontal_axis_2\n                        + axis_vertical_scale * velocity_modifier_params.normal\n                    )\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    # given normal direction, say [0,0,1]\n    # gradually release grid velocities from start position to end position\n    def release_particles_sequentially(\n        self, normal, start_position, end_position, num_layers, start_time, end_time\n    ):\n        num_layers = 50\n        point = [0, 0, 0]\n        size = [0, 0, 0]\n        axis = -1\n        for i in range(3):\n            if normal[i] == 0:\n                point[i] = 1\n                size[i] = 1\n            else:\n                axis = i\n                point[i] = end_position\n\n        half_length_portion = wp.abs(start_position - end_position) / num_layers\n        end_time_portion = end_time / num_layers\n        for i in range(num_layers):\n            size[axis] = half_length_portion * (num_layers - i)\n            self.enforce_particle_velocity_translation(\n                point=point,\n                size=size,\n                velocity=[0, 0, 0],\n                start_time=start_time,\n                end_time=end_time_portion * (i + 1),\n            )\n\n    def enforce_particle_velocity_by_mask(\n        self,\n        mpm_state,\n        selection_mask: torch.Tensor,\n        velocity,\n        start_time,\n        end_time,\n    ):\n        # first select certain particles based on position\n\n        velocity_modifier_params = ParticleVelocityModifier()\n\n        velocity_modifier_params.velocity = wp.vec3(\n            velocity[0],\n            velocity[1],\n            velocity[2],\n        )\n\n        velocity_modifier_params.start_time = start_time\n        velocity_modifier_params.end_time = end_time\n\n        velocity_modifier_params.mask = wp.from_torch(selection_mask)\n\n        self.particle_velocity_modifier_params.append(velocity_modifier_params)\n\n        @wp.kernel\n        def modify_particle_v_before_p2g(\n            time: float,\n            state: MPMStateStruct,\n            velocity_modifier_params: ParticleVelocityModifier,\n        ):\n            p = wp.tid()\n            if (\n                time >= velocity_modifier_params.start_time\n                and time < velocity_modifier_params.end_time\n            ):\n                if velocity_modifier_params.mask[p] == 1:\n                    state.particle_v[p] = velocity_modifier_params.velocity\n\n        self.particle_velocity_modifiers.append(modify_particle_v_before_p2g)\n\n    def restart_and_compute_F_C(self, mpm_model, mpm_state, target_pos, device):\n        grid_size = (\n            mpm_model.grid_dim_x,\n            mpm_model.grid_dim_y,\n            mpm_model.grid_dim_z,\n        )\n\n        wp.launch(\n            kernel=zero_grid,  # gradient might gone\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        wp.launch(\n            set_F_C_p2g,\n            dim=self.n_particles,\n            inputs=[mpm_state, mpm_model, target_pos],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=grid_normalization_and_gravity,\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model, 0],\n            device=device,\n        )\n\n        wp.launch(\n            set_F_C_g2p,\n            dim=self.n_particles,\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        wp.launch(\n            kernel=zero_grid,  # gradient might gone\n            dim=(grid_size),\n            inputs=[mpm_state, mpm_model],\n            device=device,\n        )\n\n        # set position to target_pos\n        wp.launch(\n            kernel=set_vec3_to_vec3,\n            dim=self.n_particles,\n            inputs=[mpm_state.particle_x, target_pos],\n            device=device,\n        )\n\n    def enforce_grid_velocity_by_mask(\n        self,\n        selection_mask: torch.Tensor,  # should be int\n    ):\n\n        grid_modifier_params = GridCollider()\n\n        grid_modifier_params.mask = wp.from_torch(selection_mask)\n\n        self.collider_params.append(grid_modifier_params)\n\n        @wp.kernel\n        def modify_grid_v_before_g2p(\n            time: float,\n            dt: float,\n            state: MPMStateStruct,\n            model: MPMModelStruct,\n            grid_modifier_params: GridCollider,\n        ):\n            grid_x, grid_y, grid_z = wp.tid()\n\n            if grid_modifier_params.mask[grid_x, grid_y, grid_z] >= 1:\n                state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n\n        self.grid_postprocess.append(modify_grid_v_before_g2p)\n        self.modify_bc.append(None)\n\n    # particle_v += force/particle_mass * dt\n    # this is applied from start_dt, ends after num_dt p2g2p's\n    # particle velocity is changed before p2g at each timestep\n    def add_impulse_on_particles_with_mask(\n        self,\n        mpm_state,\n        force,\n        dt,\n        particle_mask,  # 1 for selected particles, 0 for others\n        point=[1, 1, 1],\n        size=[1, 1, 1],\n        end_time=1,\n        start_time=0.0,\n        device=\"cuda:0\",\n    ):\n        assert (\n            len(particle_mask) == self.n_particles\n        ), \"mask should have n_particles elements\"\n        impulse_param = Impulse_modifier()\n        impulse_param.start_time = start_time\n        impulse_param.end_time = end_time\n        impulse_param.mask = wp.from_torch(particle_mask)\n\n        impulse_param.point = wp.vec3(point[0], point[1], point[2])\n        impulse_param.size = wp.vec3(size[0], size[1], size[2])\n\n        impulse_param.force = wp.vec3(\n            force[0],\n            force[1],\n            force[2],\n        )\n\n        wp.launch(\n            kernel=selection_add_impulse_on_particles,\n            dim=self.n_particles,\n            inputs=[mpm_state, impulse_param],\n            device=device,\n        )\n\n        self.impulse_params.append(impulse_param)\n\n        @wp.kernel\n        def apply_force(\n            time: float, dt: float, state: MPMStateStruct, param: Impulse_modifier\n        ):\n            p = wp.tid()\n            if time >= param.start_time and time < param.end_time:\n                if param.mask[p] >= 1:\n                    # impulse = wp.vec3(\n                    #     param.force[0] / state.particle_mass[p],\n                    #     param.force[1] / state.particle_mass[p],\n                    #     param.force[2] / state.particle_mass[p],\n                    # )\n                    impulse = wp.vec3(\n                        param.force[0],\n                        param.force[1],\n                        param.force[2],\n                    )\n                    state.particle_v[p] = state.particle_v[p] + impulse * dt\n\n        self.pre_p2g_operations.append(apply_force)\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/mpm_utils.py",
    "content": "import warp as wp\nfrom mpm_data_structure import *\nimport numpy as np\nimport math\n\n\n# compute stress from F\n@wp.func\ndef kirchoff_stress_FCR(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, mu: float, lam: float\n):\n    # compute kirchoff stress for FCR model (remember tau = P F^T)\n    R = U * wp.transpose(V)\n    id = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    return 2.0 * mu * (F - R) * wp.transpose(F) + id * lam * J * (J - 1.0)\n\n\n@wp.func\ndef kirchoff_stress_neoHookean(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, J: float, sig: wp.vec3, mu: float, lam: float\n):\n    \"\"\"\n    B = F * wp.transpose(F)\n    dev(B) = B - (1/3) * tr(B) * I\n\n    For a compressible Rivlin neo-Hookean materia, the cauchy stress is given by:\n    mu * J^(-2/3) * dev(B) + lam * J (J - 1) * I\n    see: https://en.wikipedia.org/wiki/Neo-Hookean_solid\n    \"\"\"\n\n    # compute kirchoff stress for FCR model (remember tau = P F^T)\n    b = wp.vec3(sig[0] * sig[0], sig[1] * sig[1], sig[2] * sig[2])\n    b_hat = b - wp.vec3(\n        (b[0] + b[1] + b[2]) / 3.0,\n        (b[0] + b[1] + b[2]) / 3.0,\n        (b[0] + b[1] + b[2]) / 3.0,\n    )\n    tau = mu * J ** (-2.0 / 3.0) * b_hat + lam / 2.0 * (J * J - 1.0) * wp.vec3(\n        1.0, 1.0, 1.0\n    )\n\n    return (\n        U\n        * wp.mat33(tau[0], 0.0, 0.0, 0.0, tau[1], 0.0, 0.0, 0.0, tau[2])\n        * wp.transpose(V)\n        * wp.transpose(F)\n    )\n\n\n@wp.func\ndef kirchoff_stress_StVK(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, sig: wp.vec3, mu: float, lam: float\n):\n    sig = wp.vec3(\n        wp.max(sig[0], 0.01), wp.max(sig[1], 0.01), wp.max(sig[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    log_sig_sum = wp.log(sig[0]) + wp.log(sig[1]) + wp.log(sig[2])\n    ONE = wp.vec3(1.0, 1.0, 1.0)\n    tau = 2.0 * mu * epsilon + lam * log_sig_sum * ONE\n    return (\n        U\n        * wp.mat33(tau[0], 0.0, 0.0, 0.0, tau[1], 0.0, 0.0, 0.0, tau[2])\n        * wp.transpose(V)\n        * wp.transpose(F)\n    )\n\n\n@wp.func\ndef kirchoff_stress_drucker_prager(\n    F: wp.mat33, U: wp.mat33, V: wp.mat33, sig: wp.vec3, mu: float, lam: float\n):\n    log_sig_sum = wp.log(sig[0]) + wp.log(sig[1]) + wp.log(sig[2])\n    center00 = 2.0 * mu * wp.log(sig[0]) * (1.0 / sig[0]) + lam * log_sig_sum * (\n        1.0 / sig[0]\n    )\n    center11 = 2.0 * mu * wp.log(sig[1]) * (1.0 / sig[1]) + lam * log_sig_sum * (\n        1.0 / sig[1]\n    )\n    center22 = 2.0 * mu * wp.log(sig[2]) * (1.0 / sig[2]) + lam * log_sig_sum * (\n        1.0 / sig[2]\n    )\n    center = wp.mat33(center00, 0.0, 0.0, 0.0, center11, 0.0, 0.0, 0.0, center22)\n    return U * center * wp.transpose(V) * wp.transpose(F)\n\n\n@wp.func\ndef von_mises_return_mapping(F_trial: wp.mat33, model: MPMModelStruct, p: int):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig_old = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig_old, V)\n\n    sig = wp.vec3(\n        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    temp = (epsilon[0] + epsilon[1] + epsilon[2]) / 3.0\n\n    tau = 2.0 * model.mu[p] * epsilon + model.lam[p] * (\n        epsilon[0] + epsilon[1] + epsilon[2]\n    ) * wp.vec3(1.0, 1.0, 1.0)\n    sum_tau = tau[0] + tau[1] + tau[2]\n    cond = wp.vec3(\n        tau[0] - sum_tau / 3.0, tau[1] - sum_tau / 3.0, tau[2] - sum_tau / 3.0\n    )\n    if wp.length(cond) > model.yield_stress[p]:\n        epsilon_hat = epsilon - wp.vec3(temp, temp, temp)\n        epsilon_hat_norm = wp.length(epsilon_hat) + 1e-6\n        delta_gamma = epsilon_hat_norm - model.yield_stress[p] / (2.0 * model.mu[p])\n        epsilon = epsilon - (delta_gamma / epsilon_hat_norm) * epsilon_hat\n        sig_elastic = wp.mat33(\n            wp.exp(epsilon[0]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[1]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[2]),\n        )\n        F_elastic = U * sig_elastic * wp.transpose(V)\n        if model.hardening == 1:\n            model.yield_stress[p] = (\n                model.yield_stress[p] + 2.0 * model.mu[p] * model.xi * delta_gamma\n            )\n        return F_elastic\n    else:\n        return F_trial\n\n\n@wp.func\ndef von_mises_return_mapping_with_damage(\n    F_trial: wp.mat33, model: MPMModelStruct, p: int\n):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig_old = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig_old, V)\n\n    sig = wp.vec3(\n        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    temp = (epsilon[0] + epsilon[1] + epsilon[2]) / 3.0\n\n    tau = 2.0 * model.mu[p] * epsilon + model.lam[p] * (\n        epsilon[0] + epsilon[1] + epsilon[2]\n    ) * wp.vec3(1.0, 1.0, 1.0)\n    sum_tau = tau[0] + tau[1] + tau[2]\n    cond = wp.vec3(\n        tau[0] - sum_tau / 3.0, tau[1] - sum_tau / 3.0, tau[2] - sum_tau / 3.0\n    )\n    if wp.length(cond) > model.yield_stress[p]:\n        if model.yield_stress[p] <= 0:\n            return F_trial\n        epsilon_hat = epsilon - wp.vec3(temp, temp, temp)\n        epsilon_hat_norm = wp.length(epsilon_hat) + 1e-6\n        delta_gamma = epsilon_hat_norm - model.yield_stress[p] / (2.0 * model.mu[p])\n        epsilon = epsilon - (delta_gamma / epsilon_hat_norm) * epsilon_hat\n        model.yield_stress[p] = model.yield_stress[p] - model.softening * wp.length(\n            (delta_gamma / epsilon_hat_norm) * epsilon_hat\n        )\n        if model.yield_stress[p] <= 0:\n            model.mu[p] = 0.0\n            model.lam[p] = 0.0\n        sig_elastic = wp.mat33(\n            wp.exp(epsilon[0]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[1]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon[2]),\n        )\n        F_elastic = U * sig_elastic * wp.transpose(V)\n        if model.hardening == 1:\n            model.yield_stress[p] = (\n                model.yield_stress[p] + 2.0 * model.mu[p] * model.xi * delta_gamma\n            )\n        return F_elastic\n    else:\n        return F_trial\n\n\n# for toothpaste\n@wp.func\ndef viscoplasticity_return_mapping_with_StVK(\n    F_trial: wp.mat33, model: MPMModelStruct, p: int, dt: float\n):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig_old = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig_old, V)\n\n    sig = wp.vec3(\n        wp.max(sig_old[0], 0.01), wp.max(sig_old[1], 0.01), wp.max(sig_old[2], 0.01)\n    )  # add this to prevent NaN in extrem cases\n    b_trial = wp.vec3(sig[0] * sig[0], sig[1] * sig[1], sig[2] * sig[2])\n    epsilon = wp.vec3(wp.log(sig[0]), wp.log(sig[1]), wp.log(sig[2]))\n    trace_epsilon = epsilon[0] + epsilon[1] + epsilon[2]\n    epsilon_hat = epsilon - wp.vec3(\n        trace_epsilon / 3.0, trace_epsilon / 3.0, trace_epsilon / 3.0\n    )\n    s_trial = 2.0 * model.mu[p] * epsilon_hat\n    s_trial_norm = wp.length(s_trial)\n    y = s_trial_norm - wp.sqrt(2.0 / 3.0) * model.yield_stress[p]\n    if y > 0:\n        mu_hat = model.mu[p] * (b_trial[0] + b_trial[1] + b_trial[2]) / 3.0\n        s_new_norm = s_trial_norm - y / (\n            1.0 + model.plastic_viscosity / (2.0 * mu_hat * dt)\n        )\n        s_new = (s_new_norm / s_trial_norm) * s_trial\n        epsilon_new = 1.0 / (2.0 * model.mu[p]) * s_new + wp.vec3(\n            trace_epsilon / 3.0, trace_epsilon / 3.0, trace_epsilon / 3.0\n        )\n        sig_elastic = wp.mat33(\n            wp.exp(epsilon_new[0]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon_new[1]),\n            0.0,\n            0.0,\n            0.0,\n            wp.exp(epsilon_new[2]),\n        )\n        F_elastic = U * sig_elastic * wp.transpose(V)\n        return F_elastic\n    else:\n        return F_trial\n\n\n@wp.func\ndef sand_return_mapping(\n    F_trial: wp.mat33, state: MPMStateStruct, model: MPMModelStruct, p: int\n):\n    U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n    sig = wp.vec3(0.0)\n    wp.svd3(F_trial, U, sig, V)\n\n    epsilon = wp.vec3(\n        wp.log(wp.max(wp.abs(sig[0]), 1e-14)),\n        wp.log(wp.max(wp.abs(sig[1]), 1e-14)),\n        wp.log(wp.max(wp.abs(sig[2]), 1e-14)),\n    )\n    sigma_out = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n    tr = epsilon[0] + epsilon[1] + epsilon[2]  # + state.particle_Jp[p]\n    epsilon_hat = epsilon - wp.vec3(tr / 3.0, tr / 3.0, tr / 3.0)\n    epsilon_hat_norm = wp.length(epsilon_hat)\n    delta_gamma = (\n        epsilon_hat_norm\n        + (3.0 * model.lam[p] + 2.0 * model.mu[p])\n        / (2.0 * model.mu[p])\n        * tr\n        * model.alpha\n    )\n\n    if delta_gamma <= 0:\n        F_elastic = F_trial\n\n    if delta_gamma > 0 and tr > 0:\n        F_elastic = U * wp.transpose(V)\n\n    if delta_gamma > 0 and tr <= 0:\n        H = epsilon - epsilon_hat * (delta_gamma / epsilon_hat_norm)\n        s_new = wp.vec3(wp.exp(H[0]), wp.exp(H[1]), wp.exp(H[2]))\n\n        F_elastic = U * wp.diag(s_new) * wp.transpose(V)\n    return F_elastic\n\n\n@wp.kernel\ndef compute_mu_lam_from_E_nu(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n    model.mu[p] = model.E[p] / (2.0 * (1.0 + model.nu[p]))\n    model.lam[p] = (\n        model.E[p] * model.nu[p] / ((1.0 + model.nu[p]) * (1.0 - 2.0 * model.nu[p]))\n    )\n\n\n@wp.kernel\ndef zero_grid(state: MPMStateStruct, model: MPMModelStruct):\n    grid_x, grid_y, grid_z = wp.tid()\n    state.grid_m[grid_x, grid_y, grid_z] = 0.0\n    state.grid_v_in[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n    state.grid_v_out[grid_x, grid_y, grid_z] = wp.vec3(0.0, 0.0, 0.0)\n\n\n@wp.func\ndef compute_dweight(\n    model: MPMModelStruct, w: wp.mat33, dw: wp.mat33, i: int, j: int, k: int\n):\n    dweight = wp.vec3(\n        dw[0, i] * w[1, j] * w[2, k],\n        w[0, i] * dw[1, j] * w[2, k],\n        w[0, i] * w[1, j] * dw[2, k],\n    )\n    return dweight * model.inv_dx\n\n\n@wp.func\ndef update_cov(state: MPMStateStruct, p: int, grad_v: wp.mat33, dt: float):\n    cov_n = wp.mat33(0.0)\n    cov_n[0, 0] = state.particle_cov[p * 6]\n    cov_n[0, 1] = state.particle_cov[p * 6 + 1]\n    cov_n[0, 2] = state.particle_cov[p * 6 + 2]\n    cov_n[1, 0] = state.particle_cov[p * 6 + 1]\n    cov_n[1, 1] = state.particle_cov[p * 6 + 3]\n    cov_n[1, 2] = state.particle_cov[p * 6 + 4]\n    cov_n[2, 0] = state.particle_cov[p * 6 + 2]\n    cov_n[2, 1] = state.particle_cov[p * 6 + 4]\n    cov_n[2, 2] = state.particle_cov[p * 6 + 5]\n\n    cov_np1 = cov_n + dt * (grad_v * cov_n + cov_n * wp.transpose(grad_v))\n\n    state.particle_cov[p * 6] = cov_np1[0, 0]\n    state.particle_cov[p * 6 + 1] = cov_np1[0, 1]\n    state.particle_cov[p * 6 + 2] = cov_np1[0, 2]\n    state.particle_cov[p * 6 + 3] = cov_np1[1, 1]\n    state.particle_cov[p * 6 + 4] = cov_np1[1, 2]\n    state.particle_cov[p * 6 + 5] = cov_np1[2, 2]\n\n\n@wp.func\ndef update_cov_differentiable(\n    state: MPMStateStruct,\n    next_state: MPMStateStruct,\n    p: int,\n    grad_v: wp.mat33,\n    dt: float,\n):\n    cov_n = wp.mat33(0.0)\n    cov_n[0, 0] = state.particle_cov[p * 6]\n    cov_n[0, 1] = state.particle_cov[p * 6 + 1]\n    cov_n[0, 2] = state.particle_cov[p * 6 + 2]\n    cov_n[1, 0] = state.particle_cov[p * 6 + 1]\n    cov_n[1, 1] = state.particle_cov[p * 6 + 3]\n    cov_n[1, 2] = state.particle_cov[p * 6 + 4]\n    cov_n[2, 0] = state.particle_cov[p * 6 + 2]\n    cov_n[2, 1] = state.particle_cov[p * 6 + 4]\n    cov_n[2, 2] = state.particle_cov[p * 6 + 5]\n\n    cov_np1 = cov_n + dt * (grad_v * cov_n + cov_n * wp.transpose(grad_v))\n\n    next_state.particle_cov[p * 6] = cov_np1[0, 0]\n    next_state.particle_cov[p * 6 + 1] = cov_np1[0, 1]\n    next_state.particle_cov[p * 6 + 2] = cov_np1[0, 2]\n    next_state.particle_cov[p * 6 + 3] = cov_np1[1, 1]\n    next_state.particle_cov[p * 6 + 4] = cov_np1[1, 2]\n    next_state.particle_cov[p * 6 + 5] = cov_np1[2, 2]\n\n\n@wp.kernel\ndef p2g_apic_with_stress(state: MPMStateStruct, model: MPMModelStruct, dt: float):\n    # input given to p2g:   particle_stress\n    #                       particle_x\n    #                       particle_v\n    #                       particle_C\n    # output:               grid_v_in, grid_m\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        stress = state.particle_stress[p]\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    dpos = (\n                        wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    ) * model.dx\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n\n                    C = state.particle_C[p]\n                    # if model.rpic = 0, standard apic\n                    C = (1.0 - model.rpic_damping) * C + model.rpic_damping / 2.0 * (\n                        C - wp.transpose(C)\n                    )\n\n                    # C = (1.0 - model.rpic_damping) * state.particle_C[\n                    #     p\n                    # ] + model.rpic_damping / 2.0 * (\n                    #     state.particle_C[p] - wp.transpose(state.particle_C[p])\n                    # )\n\n                    if model.rpic_damping < -0.001:\n                        # standard pic\n                        C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n                    elastic_force = -state.particle_vol[p] * stress * dweight\n                    v_in_add = (\n                        weight\n                        * state.particle_mass[p]\n                        * (state.particle_v[p] + C * dpos)\n                        + dt * elastic_force\n                    )\n                    wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)\n                    wp.atomic_add(\n                        state.grid_m, ix, iy, iz, weight * state.particle_mass[p]\n                    )\n\n\n# add gravity\n@wp.kernel\ndef grid_normalization_and_gravity(\n    state: MPMStateStruct, model: MPMModelStruct, dt: float\n):\n    grid_x, grid_y, grid_z = wp.tid()\n    if state.grid_m[grid_x, grid_y, grid_z] > 1e-15:\n        v_out = state.grid_v_in[grid_x, grid_y, grid_z] * (\n            1.0 / state.grid_m[grid_x, grid_y, grid_z]\n        )\n        # add gravity\n        v_out = v_out + dt * model.gravitational_accelaration\n        state.grid_v_out[grid_x, grid_y, grid_z] = v_out\n\n\n@wp.kernel\ndef g2p(state: MPMStateStruct, model: MPMModelStruct, dt: float):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n        new_v = wp.vec3(0.0, 0.0, 0.0)\n        new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    grid_v = state.grid_v_out[ix, iy, iz]\n                    new_v = new_v + grid_v * weight\n                    new_C = new_C + wp.outer(grid_v, dpos) * (\n                        weight * model.inv_dx * 4.0\n                    )\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n                    new_F = new_F + wp.outer(grid_v, dweight)\n\n        state.particle_v[p] = new_v\n        # state.particle_x[p] = state.particle_x[p] + dt * new_v\n        # state.particle_x[p] = state.particle_x[p] + dt * state.particle_v[p]\n\n        # wp.atomic_add(state.particle_x, p, dt * state.particle_v[p]) # old one is this..\n        wp.atomic_add(state.particle_x, p, dt * new_v)  # debug\n        # new_x = state.particle_x[p] + dt * state.particle_v[p]\n        # state.particle_x[p] = new_x\n\n        state.particle_C[p] = new_C\n\n        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n        F_tmp = (I33 + new_F * dt) * state.particle_F[p]\n        state.particle_F_trial[p] = F_tmp\n        # debug for jelly\n        # wp.atomic_add(state.particle_F_trial, p, new_F * dt * state.particle_F[p])\n\n        if model.update_cov_with_F:\n            update_cov(state, p, new_F, dt)\n\n\n@wp.kernel\ndef g2p_differentiable(\n    state: MPMStateStruct, next_state: MPMStateStruct, model: MPMModelStruct, dt: float\n):\n    \"\"\"\n    Compute:\n        next_state.particle_v, next_state.particle_x, next_state.particle_C, next_state.particle_F_trial\n    \"\"\"\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n        new_v = wp.vec3(0.0, 0.0, 0.0)\n        # new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        new_C = wp.mat33(new_v, new_v, new_v)\n        \n        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    grid_v = state.grid_v_out[ix, iy, iz]\n                    new_v = (\n                        new_v + grid_v * weight\n                    )  # TODO, check gradient from static loop\n                    new_C = new_C + wp.outer(grid_v, dpos) * (\n                        weight * model.inv_dx * 4.0\n                    )\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n                    new_F = new_F + wp.outer(grid_v, dweight)\n\n        next_state.particle_v[p] = new_v\n\n        # add clip here:\n        new_x = state.particle_x[p] + dt * new_v\n        dx = 1.0 / model.inv_dx\n        a_min = dx * 2.0\n        a_max = model.grid_lim - dx * 2.0\n\n        new_x_clamped = wp.vec3(\n            wp.clamp(new_x[0], a_min, a_max),\n            wp.clamp(new_x[1], a_min, a_max),\n            wp.clamp(new_x[2], a_min, a_max),\n        )\n        next_state.particle_x[p] = new_x_clamped\n\n        # next_state.particle_x[p] = new_x\n\n        next_state.particle_C[p] = new_C\n\n        I33_1 = wp.vec3(1.0, 0.0, 0.0)\n        I33_2 = wp.vec3(0.0, 1.0, 0.0)\n        I33_3 = wp.vec3(0.0, 0.0, 1.0)\n        I33 = wp.mat33(I33_1, I33_2, I33_3)\n        F_tmp = (I33 + new_F * dt) * state.particle_F[p]\n        next_state.particle_F_trial[p] = F_tmp\n\n        if 0:\n            update_cov_differentiable(state, next_state, p, new_F, dt)\n\n\n@wp.kernel\ndef clip_particle_x(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n\n    posx = state.particle_x[p]\n    if state.particle_selection[p] == 0:\n        dx = 1.0 / model.inv_dx\n        a_min = dx * 2.0\n        a_max = model.grid_lim - dx * 2.0\n        new_x = wp.vec3(\n            wp.clamp(posx[0], a_min, a_max),\n            wp.clamp(posx[1], a_min, a_max),\n            wp.clamp(posx[2], a_min, a_max),\n        )\n\n        state.particle_x[\n            p\n        ] = new_x  # Warn: this gives wrong gradient, don't use this for backward\n\n\n# compute (Kirchhoff) stress = stress(returnMap(F_trial))\n@wp.kernel\ndef compute_stress_from_F_trial(\n    state: MPMStateStruct, model: MPMModelStruct, dt: float\n):\n    \"\"\"\n    state.particle_F_trial => state.particle_F   # return mapping\n    state.particle_F => state.particle_stress    # stress-strain\n\n    TODO: check the gradient of SVD!  is wp.svd3 differentiable? I guess so\n    \"\"\"\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        # apply return mapping\n        if model.material == 1:  # metal\n            state.particle_F[p] = von_mises_return_mapping(\n                state.particle_F_trial[p], model, p\n            )\n        elif model.material == 2:  # sand\n            state.particle_F[p] = sand_return_mapping(\n                state.particle_F_trial[p], state, model, p\n            )\n        elif model.material == 3:  # visplas, with StVk+VM, no thickening\n            state.particle_F[p] = viscoplasticity_return_mapping_with_StVK(\n                state.particle_F_trial[p], model, p, dt\n            )\n        elif model.material == 5:\n            state.particle_F[p] = von_mises_return_mapping_with_damage(\n                state.particle_F_trial[p], model, p\n            )\n        else:  # elastic, jelly, or neo-hookean\n            state.particle_F[p] = state.particle_F_trial[p]\n\n        # also compute stress here\n        J = wp.determinant(state.particle_F[p])\n        U = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        V = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        sig = wp.vec3(0.0)\n        stress = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        wp.svd3(state.particle_F[p], U, sig, V)\n        if model.material == 0 or model.material == 5:\n            stress = kirchoff_stress_FCR(\n                state.particle_F[p], U, V, J, model.mu[p], model.lam[p]\n            )\n        if model.material == 1:\n            stress = kirchoff_stress_StVK(\n                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]\n            )\n        if model.material == 2:\n            stress = kirchoff_stress_drucker_prager(\n                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]\n            )\n        if model.material == 3:\n            # temporarily use stvk, subject to change\n            stress = kirchoff_stress_StVK(\n                state.particle_F[p], U, V, sig, model.mu[p], model.lam[p]\n            )\n\n        if model.material == 6:\n            stress = kirchoff_stress_neoHookean(\n                state.particle_F[p], U, V, J, sig, model.mu[p], model.lam[p]\n            )\n        # stress = (stress + wp.transpose(stress)) / 2.0  # enfore symmetry\n        state.particle_stress[p] = (stress + wp.transpose(stress)) / 2.0\n\n\n# @wp.kernel\n# def compute_cov_from_F(state: MPMStateStruct, model: MPMModelStruct):\n#     p = wp.tid()\n\n#     F = state.particle_F_trial[p]\n\n#     init_cov = wp.mat33(0.0)\n#     init_cov[0, 0] = state.particle_init_cov[p * 6]\n#     init_cov[0, 1] = state.particle_init_cov[p * 6 + 1]\n#     init_cov[0, 2] = state.particle_init_cov[p * 6 + 2]\n#     init_cov[1, 0] = state.particle_init_cov[p * 6 + 1]\n#     init_cov[1, 1] = state.particle_init_cov[p * 6 + 3]\n#     init_cov[1, 2] = state.particle_init_cov[p * 6 + 4]\n#     init_cov[2, 0] = state.particle_init_cov[p * 6 + 2]\n#     init_cov[2, 1] = state.particle_init_cov[p * 6 + 4]\n#     init_cov[2, 2] = state.particle_init_cov[p * 6 + 5]\n\n#     cov = F * init_cov * wp.transpose(F)\n\n#     state.particle_cov[p * 6] = cov[0, 0]\n#     state.particle_cov[p * 6 + 1] = cov[0, 1]\n#     state.particle_cov[p * 6 + 2] = cov[0, 2]\n#     state.particle_cov[p * 6 + 3] = cov[1, 1]\n#     state.particle_cov[p * 6 + 4] = cov[1, 2]\n#     state.particle_cov[p * 6 + 5] = cov[2, 2]\n\n\n# @wp.kernel\n# def compute_R_from_F(state: MPMStateStruct, model: MPMModelStruct):\n#     p = wp.tid()\n\n#     F = state.particle_F_trial[p]\n\n#     # polar svd decomposition\n#     U = wp.mat33(0.0)\n#     V = wp.mat33(0.0)\n#     sig = wp.vec3(0.0)\n#     wp.svd3(F, U, sig, V)\n\n#     if wp.determinant(U) < 0.0:\n#         U[0, 2] = -U[0, 2]\n#         U[1, 2] = -U[1, 2]\n#         U[2, 2] = -U[2, 2]\n\n#     if wp.determinant(V) < 0.0:\n#         V[0, 2] = -V[0, 2]\n#         V[1, 2] = -V[1, 2]\n#         V[2, 2] = -V[2, 2]\n\n#     # compute rotation matrix\n#     R = U * wp.transpose(V)\n#     state.particle_R[p] = wp.transpose(R) # particle R is removed\n\n\n@wp.kernel\ndef add_damping_via_grid(state: MPMStateStruct, scale: float):\n    grid_x, grid_y, grid_z = wp.tid()\n    # state.grid_v_out[grid_x, grid_y, grid_z] = (\n    #     state.grid_v_out[grid_x, grid_y, grid_z] * scale\n    # )\n    wp.atomic_sub(\n        state.grid_v_out,\n        grid_x,\n        grid_y,\n        grid_z,\n        (1.0 - scale) * state.grid_v_out[grid_x, grid_y, grid_z],\n    )\n\n\n@wp.kernel\ndef apply_additional_params(\n    state: MPMStateStruct,\n    model: MPMModelStruct,\n    params_modifier: MaterialParamsModifier,\n):\n    p = wp.tid()\n    pos = state.particle_x[p]\n    if (\n        pos[0] > params_modifier.point[0] - params_modifier.size[0]\n        and pos[0] < params_modifier.point[0] + params_modifier.size[0]\n        and pos[1] > params_modifier.point[1] - params_modifier.size[1]\n        and pos[1] < params_modifier.point[1] + params_modifier.size[1]\n        and pos[2] > params_modifier.point[2] - params_modifier.size[2]\n        and pos[2] < params_modifier.point[2] + params_modifier.size[2]\n    ):\n        model.E[p] = params_modifier.E\n        model.nu[p] = params_modifier.nu\n        state.particle_density[p] = params_modifier.density\n\n\n@wp.kernel\ndef selection_add_impulse_on_particles(\n    state: MPMStateStruct, impulse_modifier: Impulse_modifier\n):\n    p = wp.tid()\n    offset = state.particle_x[p] - impulse_modifier.point\n    if (\n        wp.abs(offset[0]) < impulse_modifier.size[0]\n        and wp.abs(offset[1]) < impulse_modifier.size[1]\n        and wp.abs(offset[2]) < impulse_modifier.size[2]\n    ):\n        impulse_modifier.mask[p] = 1\n    else:\n        impulse_modifier.mask[p] = 0\n\n\n@wp.kernel\ndef selection_enforce_particle_velocity_translation(\n    state: MPMStateStruct, velocity_modifier: ParticleVelocityModifier\n):\n    p = wp.tid()\n    offset = state.particle_x[p] - velocity_modifier.point\n    if (\n        wp.abs(offset[0]) < velocity_modifier.size[0]\n        and wp.abs(offset[1]) < velocity_modifier.size[1]\n        and wp.abs(offset[2]) < velocity_modifier.size[2]\n    ):\n        velocity_modifier.mask[p] = 1\n    else:\n        velocity_modifier.mask[p] = 0\n\n\n@wp.kernel\ndef selection_enforce_particle_velocity_cylinder(\n    state: MPMStateStruct, velocity_modifier: ParticleVelocityModifier\n):\n    p = wp.tid()\n    offset = state.particle_x[p] - velocity_modifier.point\n\n    vertical_distance = wp.abs(wp.dot(offset, velocity_modifier.normal))\n\n    horizontal_distance = wp.length(\n        offset - wp.dot(offset, velocity_modifier.normal) * velocity_modifier.normal\n    )\n    if (\n        vertical_distance < velocity_modifier.half_height_and_radius[0]\n        and horizontal_distance < velocity_modifier.half_height_and_radius[1]\n    ):\n        velocity_modifier.mask[p] = 1\n    else:\n        velocity_modifier.mask[p] = 0\n\n\n@wp.kernel\ndef compute_position_l2_loss(\n    mpm_state: MPMStateStruct,\n    gt_pos: wp.array(dtype=wp.vec3),\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    pos = mpm_state.particle_x[tid]\n    pos_gt = gt_pos[tid]\n\n    # l1_diff = wp.abs(pos - pos_gt)\n    l2 = wp.length(pos - pos_gt)\n\n    wp.atomic_add(loss, 0, l2)\n\n\n@wp.kernel\ndef aggregate_grad(x: wp.array(dtype=float), grad: wp.array(dtype=float)):\n    tid = wp.tid()\n\n    # gradient descent step\n    wp.atomic_add(x, 0, grad[tid])\n\n\n@wp.kernel\ndef set_F_C_p2g(\n    state: MPMStateStruct, model: MPMModelStruct, target_pos: wp.array(dtype=wp.vec3)\n):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        # p2g for displacement\n        particle_disp = target_pos[p] - state.particle_x[p]\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    v_in_add = weight * state.particle_mass[p] * particle_disp\n                    wp.atomic_add(state.grid_v_in, ix, iy, iz, v_in_add)\n                    wp.atomic_add(\n                        state.grid_m, ix, iy, iz, weight * state.particle_mass[p]\n                    )\n\n\n@wp.kernel\ndef set_F_C_g2p(state: MPMStateStruct, model: MPMModelStruct):\n    p = wp.tid()\n    if state.particle_selection[p] == 0:\n        grid_pos = state.particle_x[p] * model.inv_dx\n        base_pos_x = wp.int(grid_pos[0] - 0.5)\n        base_pos_y = wp.int(grid_pos[1] - 0.5)\n        base_pos_z = wp.int(grid_pos[2] - 0.5)\n        fx = grid_pos - wp.vec3(\n            wp.float(base_pos_x), wp.float(base_pos_y), wp.float(base_pos_z)\n        )\n        wa = wp.vec3(1.5) - fx\n        wb = fx - wp.vec3(1.0)\n        wc = fx - wp.vec3(0.5)\n        w = wp.mat33(\n            wp.cw_mul(wa, wa) * 0.5,\n            wp.vec3(0.0, 0.0, 0.0) - wp.cw_mul(wb, wb) + wp.vec3(0.75),\n            wp.cw_mul(wc, wc) * 0.5,\n        )\n        dw = wp.mat33(fx - wp.vec3(1.5), -2.0 * (fx - wp.vec3(1.0)), fx - wp.vec3(0.5))\n        new_C = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n        new_F = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)\n\n        # g2p for C and F\n        for i in range(0, 3):\n            for j in range(0, 3):\n                for k in range(0, 3):\n                    ix = base_pos_x + i\n                    iy = base_pos_y + j\n                    iz = base_pos_z + k\n                    dpos = wp.vec3(wp.float(i), wp.float(j), wp.float(k)) - fx\n                    weight = w[0, i] * w[1, j] * w[2, k]  # tricubic interpolation\n                    grid_v = state.grid_v_out[ix, iy, iz]\n                    new_C = new_C + wp.outer(grid_v, dpos) * (\n                        weight * model.inv_dx * 4.0\n                    )\n                    dweight = compute_dweight(model, w, dw, i, j, k)\n                    new_F = new_F + wp.outer(grid_v, dweight)\n\n        # C should still be zero..\n        # state.particle_C[p] = new_C\n        I33 = wp.mat33(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)\n        F_tmp = I33 + new_F\n        state.particle_F_trial[p] = F_tmp\n\n        if model.update_cov_with_F:\n            update_cov(state, p, new_F, 1.0)\n\n\n@wp.kernel\ndef compute_posloss_with_grad(\n    mpm_state: MPMStateStruct,\n    gt_pos: wp.array(dtype=wp.vec3),\n    grad: wp.array(dtype=wp.vec3),\n    dt: float,\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    pos = mpm_state.particle_x[tid]\n    pos_gt = gt_pos[tid]\n\n    # l1_diff = wp.abs(pos - pos_gt)\n    # l2 = wp.length(pos - (pos_gt - grad[tid] * dt))\n    diff = pos - (pos_gt - grad[tid] * dt)\n    l2 = wp.dot(diff, diff)\n    wp.atomic_add(loss, 0, l2)\n\n\n@wp.kernel\ndef compute_veloloss_with_grad(\n    mpm_state: MPMStateStruct,\n    gt_pos: wp.array(dtype=wp.vec3),\n    grad: wp.array(dtype=wp.vec3),\n    dt: float,\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    pos = mpm_state.particle_v[tid]\n    pos_gt = gt_pos[tid]\n\n    # l1_diff = wp.abs(pos - pos_gt)\n    # l2 = wp.length(pos - (pos_gt - grad[tid] * dt))\n\n    diff = pos - (pos_gt - grad[tid] * dt)\n    l2 = wp.dot(diff, diff)\n    wp.atomic_add(loss, 0, l2)\n\n\n@wp.kernel\ndef compute_Floss_with_grad(\n    mpm_state: MPMStateStruct,\n    gt_mat: wp.array(dtype=wp.mat33),\n    grad: wp.array(dtype=wp.mat33),\n    dt: float,\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    mat_ = mpm_state.particle_F_trial[tid]\n    mat_gt = gt_mat[tid]\n\n    mat_gt = mat_gt - grad[tid] * dt\n    # l1_diff = wp.abs(pos - pos_gt)\n    mat_diff = mat_ - mat_gt\n\n    l2 = wp.ddot(mat_diff, mat_diff)\n    # l2 = wp.sqrt(\n    #     mat_diff[0, 0] ** 2.0\n    #     + mat_diff[0, 1] ** 2.0\n    #     + mat_diff[0, 2] ** 2.0\n    #     + mat_diff[1, 0] ** 2.0\n    #     + mat_diff[1, 1] ** 2.0\n    #     + mat_diff[1, 2] ** 2.0\n    #     + mat_diff[2, 0] ** 2.0\n    #     + mat_diff[2, 1] ** 2.0\n    #     + mat_diff[2, 2] ** 2.0\n    # )\n\n    wp.atomic_add(loss, 0, l2)\n\n\n@wp.kernel\ndef compute_Closs_with_grad(\n    mpm_state: MPMStateStruct,\n    gt_mat: wp.array(dtype=wp.mat33),\n    grad: wp.array(dtype=wp.mat33),\n    dt: float,\n    loss: wp.array(dtype=float),\n):\n    tid = wp.tid()\n\n    mat_ = mpm_state.particle_C[tid]\n    mat_gt = gt_mat[tid]\n\n    mat_gt = mat_gt - grad[tid] * dt\n    # l1_diff = wp.abs(pos - pos_gt)\n\n    mat_diff = mat_ - mat_gt\n    l2 = wp.ddot(mat_diff, mat_diff)\n\n    wp.atomic_add(loss, 0, l2)\n"
  },
  {
    "path": "projects/uncleaned_train/thirdparty_code/warp_mpm/warp_utils.py",
    "content": "import warp as wp\nimport ctypes\nfrom typing import Optional\n\nfrom warp.torch import (\n    dtype_from_torch,\n    device_from_torch,\n    dtype_is_compatible,\n    from_torch,\n)\n\n\ndef from_torch_safe(t, dtype=None, requires_grad=None, grad=None):\n    \"\"\"Wrap a PyTorch tensor to a Warp array without copying the data.\n\n    Args:\n        t (torch.Tensor): The torch tensor to wrap.\n        dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.\n        requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.\n\n    Returns:\n        warp.array: The wrapped array.\n    \"\"\"\n    if dtype is None:\n        dtype = dtype_from_torch(t.dtype)\n    elif not dtype_is_compatible(t.dtype, dtype):\n        raise RuntimeError(f\"Incompatible data types: {t.dtype} and {dtype}\")\n\n    # get size of underlying data type to compute strides\n    ctype_size = ctypes.sizeof(dtype._type_)\n\n    shape = tuple(t.shape)\n    strides = tuple(s * ctype_size for s in t.stride())\n\n    # if target is a vector or matrix type\n    # then check if trailing dimensions match\n    # the target type and update the shape\n    if hasattr(dtype, \"_shape_\"):\n        dtype_shape = dtype._shape_\n        dtype_dims = len(dtype._shape_)\n        if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:\n            raise RuntimeError(\n                f\"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}\"\n            )\n\n        # ensure the inner strides are contiguous\n        stride = ctype_size\n        for i in range(dtype_dims):\n            if strides[-i - 1] != stride:\n                raise RuntimeError(\n                    f\"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous\"\n                )\n            stride *= dtype_shape[-i - 1]\n\n        shape = tuple(shape[:-dtype_dims]) or (1,)\n        strides = tuple(strides[:-dtype_dims]) or (ctype_size,)\n\n    requires_grad = t.requires_grad if requires_grad is None else requires_grad\n    if grad is not None:\n        if not isinstance(grad, wp.array):\n            import torch\n\n            if isinstance(grad, torch.Tensor):\n                grad = from_torch(grad, dtype=dtype)\n            else:\n                raise ValueError(f\"Invalid gradient type: {type(grad)}\")\n    elif requires_grad:\n        # wrap the tensor gradient, allocate if necessary\n        if t.grad is None:\n            # allocate a zero-filled gradient tensor if it doesn't exist\n            import torch\n\n            t.grad = torch.zeros_like(t, requires_grad=False)\n        grad = from_torch(t.grad, dtype=dtype)\n\n    a = wp.types.array(\n        ptr=t.data_ptr(),\n        dtype=dtype,\n        shape=shape,\n        strides=strides,\n        device=device_from_torch(t.device),\n        copy=False,\n        owner=False,\n        grad=grad,\n        requires_grad=requires_grad,\n    )\n\n    # save a reference to the source tensor, otherwise it will be deallocated\n    a._tensor = t\n    return a\n\n\nclass MyTape(wp.Tape):\n    # returns the adjoint of a kernel parameter\n    def get_adjoint(self, a):\n        if not wp.types.is_array(a) and not isinstance(a, wp.codegen.StructInstance):\n            # if input is a simple type (e.g.: float, vec3, etc) then\n            # no gradient needed (we only return gradients through arrays and structs)\n            return a\n\n        elif wp.types.is_array(a) and a.grad:\n            # keep track of all gradients used by the tape (for zeroing)\n            # ignore the scalar loss since we don't want to clear its grad\n            self.gradients[a] = a.grad\n            return a.grad\n\n        elif isinstance(a, wp.codegen.StructInstance):\n            adj = a._cls()\n            for name, _ in a._cls.ctype._fields_:\n                if name.startswith(\"_\"):\n                    continue\n                if isinstance(a._cls.vars[name].type, wp.array):\n                    arr = getattr(a, name)\n                    if arr is None:\n                        continue\n                    if arr.grad:\n                        grad = self.gradients[arr] = arr.grad\n                    else:\n                        grad = wp.zeros_like(arr)\n                    setattr(adj, name, grad)\n                else:\n                    setattr(adj, name, getattr(a, name))\n\n            self.gradients[a] = adj\n            return adj\n\n        return None\n\n\n# from https://github.com/PingchuanMa/NCLaw/blob/main/nclaw/warp/tape.py\nclass CondTape(object):\n    def __init__(self, tape: Optional[MyTape], cond: bool = True) -> None:\n        self.tape = tape\n        self.cond = cond\n\n    def __enter__(self):\n        if self.tape is not None and self.cond:\n            self.tape.__enter__()\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        if self.tape is not None and self.cond:\n            self.tape.__exit__(exc_type, exc_value, traceback)"
  },
  {
    "path": "requirements.txt",
    "content": "accelerate==0.25.0\ndecord==0.6.0\neinops==0.7.0\nfire==0.5.0\nimageio==2.34.0\nipython==8.12.3\nipython==8.18.1\njaxtyping==0.2.28\nkmeans_gpu==0.0.5\nmatplotlib==3.7.2\nmediapy==1.2.0\nnumpy==1.24.2\nomegaconf==2.1.1\nopen3d==0.18.0\nopencv_python==4.6.0.66\nopencv_python_headless==4.9.0.80\nPillow==9.5.0\nPillow==10.3.0\nplyfile==1.0.3\npoint_cloud_utils==0.30.2\npyfqmr==0.2.0\npygltflib==1.16.2\nPyMCubes==0.1.4\npymeshlab==2023.12\nsafetensors==0.3.3\nscikit_learn==1.3.2\nscipy==1.13.0\nsimple_knn==0.0.0\ntorch==2.2.2+cu121\ntorchvision==0.17.2\ntqdm==4.65.0\ntrimesh==4.0.8\nwarp_lang==0.10.1\nxatlas==0.0.9\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\nsetup(\n    name=\"physdreamer\",\n    version=\"0.0.1\",\n    packages=find_packages(),\n)\n"
  }
]