[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.yaml",
    "content": "name: Bug Report\ndescription: File a bug report\ntitle: \"<title>\"\nlabels: [\"bug\"]\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        Before filing a bug report, [search for an existing issue](https://github.com/ashawkey/stable-dreamfusion/issues).\n        \n        Also, ensure you are running the latest version.\n  - type: textarea\n    id: description\n    attributes:\n      label: Description\n      description: Provide a clear and concise description of what the bug is.\n      placeholder: Description\n    validations:\n      required: true\n  - type: textarea\n    id: steps\n    attributes:\n      label: Steps to Reproduce\n      description: List the steps needed to reproduce the issue.\n      placeholder: |\n        1. Go to '...'\n        2. Click on '...'\n    validations:\n      required: true\n  - type: textarea\n    id: expected-behavior\n    attributes:\n      label: Expected Behavior\n      description: Describe what you expected to happen.\n      placeholder: |\n        The 'action' would do 'some amazing thing'.\n    validations:\n      required: true\n  - type: textarea\n    id: environment\n    attributes:\n      label: Environment\n      description: Describe your environment.\n      placeholder: |\n        Ubuntu 22.04, PyTorch 1.13, CUDA 11.6\n    validations:\n      required: true\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: enhancement\nassignees: ''\n\n---\n\n**Is your feature request related to a problem? Please describe.**\nA clear and concise description of what the problem is. Ex. I'm always frustrated when [...]\n\n**Describe the solution you'd like**\nA clear and concise description of what you want to happen.\n\n**Describe alternatives you've considered**\nA clear and concise description of any alternative solutions or features you've considered.\n\n**Additional context**\nAdd any other context or screenshots about the feature request here.\n"
  },
  {
    "path": ".gitignore",
    "content": "__pycache__/\nbuild/\n*.egg-info/\n*.so\nvenv_*/\n\ntmp*\n# data/\nldm/data/\ndata2\nscripts2\ntrial*/\n.vs/\n\nTOKEN\n*.ckpt\n\ndensegridencoder\ntets/256_tets.npz\n\n.vscode/launch.json\n\ndata2\ndata/car*\ndata/chair*\ndata/warrior*\ndata/wd*\ndata/space*\ndata/corgi*\ndata/turtle*\n\n# Only keep the original image, not the automatically-generated depth, normals, rgba\ndata/baby_phoenix_on_ice_*\ndata/bollywood_actress_*\ndata/beach_house_1_*\ndata/beach_house_2_*\ndata/mona_lisa_*\ndata/futuristic_car_*\ndata/church_ruins_*\n\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "activation.py",
    "content": "import torch\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd \n\nclass _trunc_exp(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float)\n    def forward(ctx, x):\n        ctx.save_for_backward(x)\n        return torch.exp(x)\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, g):\n        x = ctx.saved_tensors[0]\n        return g * torch.exp(x.clamp(max=15))\n\ntrunc_exp = _trunc_exp.apply\n\ndef biased_softplus(x, bias=0):\n    return torch.nn.functional.softplus(x - bias)"
  },
  {
    "path": "assets/advanced.md",
    "content": "\n# Code organization & Advanced tips\n\nThis is a simple description of the most important implementation details.\nIf you are interested in improving this repo, this might be a starting point.\nAny contribution would be greatly appreciated!\n\n* The SDS loss is located at `./guidance/sd_utils.py > StableDiffusion > train_step`:\n```python\n## 1. we need to interpolate the NeRF rendering to 512x512, to feed it to SD's VAE.\npred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)\n## 2. image (512x512) --- VAE --> latents (64x64), this is SD's difference from Imagen.\nlatents = self.encode_imgs(pred_rgb_512)\n... # timestep sampling, noise adding and UNet noise predicting\n## 3. the SDS loss\nw = (1 - self.alphas[t])\ngrad = w * (noise_pred - noise)\n# since UNet part is ignored and cannot simply audodiff, we have two ways to set the grad:\n# 3.1. call backward and set the grad now (need to retain graph since we will call a second backward for the other losses later)\nlatents.backward(gradient=grad, retain_graph=True)\nreturn 0 # dummy loss\n\n# 3.2. use a custom function to set a hook in backward, so we only call backward once (credits to @elliottzheng)\nclass SpecifyGradient(torch.autograd.Function):\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, input_tensor, gt_grad):\n        ctx.save_for_backward(gt_grad)\n        # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.\n        return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_scale):\n        gt_grad, = ctx.saved_tensors\n        gt_grad = gt_grad * grad_scale\n        return gt_grad, None\n\nloss = SpecifyGradient.apply(latents, grad)\nreturn loss # functional loss\n\n# 3.3. reparameterization (credits to @Xallt)\n# d(loss)/d(latents) = grad, since grad is already detached, it's this simple.\nloss = (grad * latents).sum()\nreturn loss\n\n# 3.4. reparameterization (credits to threestudio)\n# this is the same as 3.3, but the loss value only reflects the magnitude of grad, which is more informative.\ntargets = (latents - grad).detach()\nloss = 0.5 * F.mse_loss(latents, targets, reduction='sum')\nreturn loss\n```\n* Other regularizations are in `./nerf/utils.py > Trainer > train_step`.\n    * The generation seems quite sensitive to regularizations on weights_sum (alphas for each ray). The original opacity loss tends to make NeRF disappear (zero density everywhere), so we use an entropy loss to replace it for now (encourages alpha to be either 0 or 1).\n* NeRF Rendering core function: `./nerf/renderer.py > NeRFRenderer > run & run_cuda`.\n* Shading & normal evaluation: `./nerf/network*.py > NeRFNetwork > forward`.\n    * light direction: current implementation use a plane light source, instead of a point light source.\n* View-dependent prompting: `./nerf/provider.py > get_view_direction`.\n    * use `--angle_overhead, --angle_front` to set the border.\n* Network backbone (`./nerf/network*.py`) can be chosen by the `--backbone` option.\n* Spatial density bias (density blob): `./nerf/network*.py > NeRFNetwork > density_blob`.\n\n\n# Debugging\n\n`debugpy-run` is a convenient way to remotely debug this project. Simply replace a command like this one:\n\n```bash\npython main.py --text \"a hamburger\" --workspace trial -O --vram_O\n```\n\n... with:\n\n```bash\ndebugpy-run main.py -- --text \"a hamburger\" --workspace trial -O --vram_O\n```\n\nFor more details: https://github.com/bulletmark/debugpy-run \n\n# Axes and directions of polar, azimuth, etc. in NeRF and Zero123\n\n<img width=\"1119\" alt=\"NeRF_Zero123\" src=\"https://github.com/ashawkey/stable-dreamfusion/assets/22424247/a0f432ff-2d08-45a4-a390-bda64f5cbc94\">\n\nThis code refers to theta for polar, phi for azimuth.\n\n"
  },
  {
    "path": "assets/update_logs.md",
    "content": "### 2023.4.19\n* Fix depth supervision, migrate depth estimation model to omnidata.\n* Add normal supervision (also by omnidata).\n\nhttps://user-images.githubusercontent.com/25863658/232403294-b77409bf-ddc7-4bb8-af32-ee0cc123825a.mp4\n\n### 2023.4.7\nImprovement on mesh quality & DMTet finetuning support.\n\nhttps://user-images.githubusercontent.com/25863658/230535363-298c960e-bf9c-4906-8b96-cd60edcb24dd.mp4\n\n### 2023.3.30\n* adopt ideas from [Fantasia3D](https://fantasia3d.github.io/) to concatenate normal and mask as the latent code in a warm up stage, which shows faster convergence of shape.\n\nhttps://user-images.githubusercontent.com/25863658/230535373-6ee28f16-bb21-4ec4-bc86-d46597361a04.mp4\n\n### 2023.1.30\n* Use an MLP to predict the surface normals as in Magic3D to avoid finite difference / second order gradient, generation quality is greatly improved.\n* More efficient two-pass raymarching in training inspired by nerfacc.\n\nhttps://user-images.githubusercontent.com/25863658/215996308-9fd959f5-b5c7-4a8e-a241-0fe63ec86a4a.mp4\n\n### 2022.12.3\n* Support Stable-diffusion 2.0 base.\n\n### 2022.11.15\n* Add the vanilla backbone that is pure-pytorch.\n\n### 2022.10.9\n* The shading (partially) starts to work, at least it won't make scene empty. For some prompts, it shows better results (less severe Janus problem). The textureless rendering mode is still disabled.\n* Enable shading by default (--latent_iter_ratio 1000).\n\n### 2022.10.5\n* Basic reproduction finished.\n* Non --cuda_ray, --tcnn are not working, need to fix.\n* Shading is not working, disabled in utils.py for now. Surface normals are bad.\n* Use an entropy loss to regularize weights_sum (alpha), the original L2 reg always leads to degenerated geometry...\n\nhttps://user-images.githubusercontent.com/25863658/194241493-f3e68f78-aefe-479e-a4a8-001424a61b37.mp4\n"
  },
  {
    "path": "config/anya.csv",
    "content": "zero123_weight, radius, polar, azimuth, image\n1, 3, 90, 0, data/anya_front_rgba.png\n1, 3, 90, 180, data/anya_back_rgba.png"
  },
  {
    "path": "config/car.csv",
    "content": "zero123_weight, radius, polar, azimuth, image\n4, 3.2, 90, 0, data/car_left_rgba.png\n1, 3, 90, 90, data/car_front_rgba.png\n4, 3.2, 90, 180, data/car_right_rgba.png\n1, 3, 90, -90, data/car_back_rgba.png"
  },
  {
    "path": "config/corgi.csv",
    "content": "zero123_weight, radius, polar, azimuth, image\n1, 3.2, 90, 0, data/corgi_puppy_sitting_looking_up_rgba.png"
  },
  {
    "path": "docker/Dockerfile",
    "content": "FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04\n\n# Remove any third-party apt sources to avoid issues with expiring keys.\nRUN rm -f /etc/apt/sources.list.d/*.list\n\nRUN apt-get update\n\nRUN DEBIAN_FRONTEND=noninteractive TZ=Europe/MADRID apt-get install -y tzdata\n\n# Install some basic utilities\nRUN apt-get install -y \\\n    curl \\\n    ca-certificates \\\n    sudo \\\n    git \\\n    bzip2 \\\n    libx11-6 \\\n    python3 \\\n    python3-pip \\\n    libglfw3-dev \\\n    libgles2-mesa-dev \\\n    libglib2.0-0 \\\n && rm -rf /var/lib/apt/lists/*\n\n\n# Create a working directory\nRUN mkdir /app\nWORKDIR /app\n\nRUN cd /app\nRUN git clone https://github.com/ashawkey/stable-dreamfusion.git\n\n\nRUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116\n\nWORKDIR /app/stable-dreamfusion\n\nRUN pip3 install -r requirements.txt\nRUN pip3 install git+https://github.com/NVlabs/nvdiffrast/\n\n# Needs nvidia runtime, if you have \"No CUDA runtime is found\" error: https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime, first answer\nRUN pip3 install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch\n\nRUN pip3 install git+https://github.com/openai/CLIP.git\nRUN bash scripts/install_ext.sh\n\n\n\n\n\n# Set the default command to python3\n#CMD [\"python3\"]\n\n"
  },
  {
    "path": "docker/README.md",
    "content": "### Docker installation\n\n## Build image\nTo build the docker image on your own machine, which may take 15-30 mins:\n```\ndocker build -t stable-dreamfusion:latest .\n```\n\nIf you have the error **No CUDA runtime is found** when building the wheels for tiny-cuda-nn you need to setup the nvidia-runtime for docker.\n```\nsudo apt-get install nvidia-container-runtime\n```\nThen edit `/etc/docker/daemon.json` and add the default-runtime:\n```\n{\n    \"runtimes\": {\n        \"nvidia\": {\n            \"path\": \"nvidia-container-runtime\",\n            \"runtimeArgs\": []\n        }\n    },\n    \"default-runtime\": \"nvidia\"\n}\n```\nAnd restart docker:\n```\nsudo systemctl restart docker\n```\nNow you can build tiny-cuda-nn inside docker.\n\n## Download image\nTo download the image (~6GB) instead:\n```\ndocker pull supercabb/stable-dreamfusion:3080_0.0.1\ndocker tag supercabb/stable-dreamfusion:3080_0.0.1 stable-dreamfusion\n```\n\n## Use image\n\nYou can launch an interactive shell inside the container:\n\n```\ndocker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash\n```\nFrom this shell, all the code in the repo should work.\n\nTo run any single command `<command...>` inside the docker container:\n```\ndocker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c \"<command...>\"\n```\nTo train:\n```\nexport TOKEN=\"#HUGGING FACE ACCESS TOKEN#\"\ndocker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c \"echo ${TOKEN} > TOKEN \\\n&& python3 main.py --text \\\"a hamburger\\\" --workspace trial -O\"\n\n```\nRun test without gui:\n```\nexport PATH_TO_WORKSPACE=\"#PATH_TO_WORKSPACE#\"\ndocker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \\\n-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c \"python3 \\\nmain.py --workspace trial -O --test\"\n```\nRun test with gui:\n```\nexport PATH_TO_WORKSPACE=\"#PATH_TO_WORKSPACE#\"\nxhost +\ndocker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \\\n-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c \"python3 \\\nmain.py --workspace trial -O --test --gui\"\nxhost -\n```\n\n\n\n\n\n\n\n"
  },
  {
    "path": "dpt.py",
    "content": "import math\nimport types\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport timm\n\nclass BaseModel(torch.nn.Module):\n    def load(self, path):\n        \"\"\"Load model from file.\n        Args:\n            path (str): file path\n        \"\"\"\n        parameters = torch.load(path, map_location=torch.device('cpu'))\n\n        if \"optimizer\" in parameters:\n            parameters = parameters[\"model\"]\n\n        self.load_state_dict(parameters)\n\n\ndef unflatten_with_named_tensor(input, dim, sizes):\n    \"\"\"Workaround for unflattening with named tensor.\"\"\"\n    # tracer acts up with unflatten. See https://github.com/pytorch/pytorch/issues/49538\n    new_shape = list(input.shape)[:dim] + list(sizes) + list(input.shape)[dim+1:]\n    return input.view(*new_shape)\n    \nclass Slice(nn.Module):\n    def __init__(self, start_index=1):\n        super(Slice, self).__init__()\n        self.start_index = start_index\n\n    def forward(self, x):\n        return x[:, self.start_index :]\n\n\nclass AddReadout(nn.Module):\n    def __init__(self, start_index=1):\n        super(AddReadout, self).__init__()\n        self.start_index = start_index\n\n    def forward(self, x):\n        if self.start_index == 2:\n            readout = (x[:, 0] + x[:, 1]) / 2\n        else:\n            readout = x[:, 0]\n        return x[:, self.start_index :] + readout.unsqueeze(1)\n\n\nclass ProjectReadout(nn.Module):\n    def __init__(self, in_features, start_index=1):\n        super(ProjectReadout, self).__init__()\n        self.start_index = start_index\n\n        self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())\n\n    def forward(self, x):\n        readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])\n        features = torch.cat((x[:, self.start_index :], readout), -1)\n\n        return self.project(features)\n\n\nclass Transpose(nn.Module):\n    def __init__(self, dim0, dim1):\n        super(Transpose, self).__init__()\n        self.dim0 = dim0\n        self.dim1 = dim1\n\n    def forward(self, x):\n        x = x.transpose(self.dim0, self.dim1)\n        return x\n\n\ndef forward_vit(pretrained, x):\n    b, c, h, w = x.shape\n\n    glob = pretrained.model.forward_flex(x)\n\n    layer_1 = pretrained.activations[\"1\"]\n    layer_2 = pretrained.activations[\"2\"]\n    layer_3 = pretrained.activations[\"3\"]\n    layer_4 = pretrained.activations[\"4\"]\n\n    layer_1 = pretrained.act_postprocess1[0:2](layer_1)\n    layer_2 = pretrained.act_postprocess2[0:2](layer_2)\n    layer_3 = pretrained.act_postprocess3[0:2](layer_3)\n    layer_4 = pretrained.act_postprocess4[0:2](layer_4)\n\n\n    unflattened_dim = 2\n    unflattened_size = (\n        int(torch.div(h, pretrained.model.patch_size[1], rounding_mode='floor')),\n        int(torch.div(w, pretrained.model.patch_size[0], rounding_mode='floor')),\n    )\n    unflatten = nn.Sequential(nn.Unflatten(unflattened_dim, unflattened_size))\n    \n\n    if layer_1.ndim == 3:\n        layer_1 = unflatten(layer_1)\n    if layer_2.ndim == 3:\n        layer_2 = unflatten(layer_2)\n    if layer_3.ndim == 3:\n        layer_3 = unflatten_with_named_tensor(layer_3, unflattened_dim, unflattened_size)\n    if layer_4.ndim == 3:\n        layer_4 = unflatten_with_named_tensor(layer_4, unflattened_dim, unflattened_size)\n\n    layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)\n    layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)\n    layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)\n    layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)\n\n    return layer_1, layer_2, layer_3, layer_4\n\n\ndef _resize_pos_embed(self, posemb, gs_h, gs_w):\n    posemb_tok, posemb_grid = (\n        posemb[:, : self.start_index],\n        posemb[0, self.start_index :],\n    )\n\n    gs_old = int(math.sqrt(posemb_grid.shape[0]))\n\n    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)\n    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode=\"bilinear\")\n    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)\n\n    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)\n\n    return posemb\n\n\ndef forward_flex(self, x):\n    b, c, h, w = x.shape\n\n    pos_embed = self._resize_pos_embed(\n        self.pos_embed, torch.div(h, self.patch_size[1], rounding_mode='floor'), torch.div(w, self.patch_size[0], rounding_mode='floor')\n    )\n\n    B = x.shape[0]\n\n    if hasattr(self.patch_embed, \"backbone\"):\n        x = self.patch_embed.backbone(x)\n        if isinstance(x, (list, tuple)):\n            x = x[-1]  # last feature if backbone outputs list/tuple of features\n\n    x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)\n\n    if getattr(self, \"dist_token\", None) is not None:\n        cls_tokens = self.cls_token.expand(\n            B, -1, -1\n        )  # stole cls_tokens impl from Phil Wang, thanks\n        dist_token = self.dist_token.expand(B, -1, -1)\n        x = torch.cat((cls_tokens, dist_token, x), dim=1)\n    else:\n        cls_tokens = self.cls_token.expand(\n            B, -1, -1\n        )  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n\n    x = x + pos_embed\n    x = self.pos_drop(x)\n\n    for blk in self.blocks:\n        x = blk(x)\n\n    x = self.norm(x)\n\n    return x\n\n\nactivations = {}\n\n\ndef get_activation(name):\n    def hook(model, input, output):\n        activations[name] = output\n\n    return hook\n\n\ndef get_readout_oper(vit_features, features, use_readout, start_index=1):\n    if use_readout == \"ignore\":\n        readout_oper = [Slice(start_index)] * len(features)\n    elif use_readout == \"add\":\n        readout_oper = [AddReadout(start_index)] * len(features)\n    elif use_readout == \"project\":\n        readout_oper = [\n            ProjectReadout(vit_features, start_index) for out_feat in features\n        ]\n    else:\n        assert (\n            False\n        ), \"wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'\"\n\n    return readout_oper\n\n\ndef _make_vit_b16_backbone(\n    model,\n    features=[96, 192, 384, 768],\n    size=[384, 384],\n    hooks=[2, 5, 8, 11],\n    vit_features=768,\n    use_readout=\"ignore\",\n    start_index=1,\n):\n    pretrained = nn.Module()\n\n    pretrained.model = model\n    pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation(\"1\"))\n    pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation(\"2\"))\n    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation(\"3\"))\n    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation(\"4\"))\n\n    pretrained.activations = activations\n\n    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)\n\n    # 32, 48, 136, 384\n    pretrained.act_postprocess1 = nn.Sequential(\n        readout_oper[0],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[0],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n        nn.ConvTranspose2d(\n            in_channels=features[0],\n            out_channels=features[0],\n            kernel_size=4,\n            stride=4,\n            padding=0,\n            bias=True,\n            dilation=1,\n            groups=1,\n        ),\n    )\n\n    pretrained.act_postprocess2 = nn.Sequential(\n        readout_oper[1],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[1],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n        nn.ConvTranspose2d(\n            in_channels=features[1],\n            out_channels=features[1],\n            kernel_size=2,\n            stride=2,\n            padding=0,\n            bias=True,\n            dilation=1,\n            groups=1,\n        ),\n    )\n\n    pretrained.act_postprocess3 = nn.Sequential(\n        readout_oper[2],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[2],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n    )\n\n    pretrained.act_postprocess4 = nn.Sequential(\n        readout_oper[3],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[3],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n        nn.Conv2d(\n            in_channels=features[3],\n            out_channels=features[3],\n            kernel_size=3,\n            stride=2,\n            padding=1,\n        ),\n    )\n\n    pretrained.model.start_index = start_index\n    pretrained.model.patch_size = [16, 16]\n\n    # We inject this function into the VisionTransformer instances so that\n    # we can use it with interpolated position embeddings without modifying the library source.\n    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)\n    pretrained.model._resize_pos_embed = types.MethodType(\n        _resize_pos_embed, pretrained.model\n    )\n\n    return pretrained\n\n\ndef _make_pretrained_vitl16_384(pretrained, use_readout=\"ignore\", hooks=None):\n    model = timm.create_model(\"vit_large_patch16_384\", pretrained=pretrained)\n\n    hooks = [5, 11, 17, 23] if hooks == None else hooks\n    return _make_vit_b16_backbone(\n        model,\n        features=[256, 512, 1024, 1024],\n        hooks=hooks,\n        vit_features=1024,\n        use_readout=use_readout,\n    )\n\n\ndef _make_pretrained_vitb16_384(pretrained, use_readout=\"ignore\", hooks=None):\n    model = timm.create_model(\"vit_base_patch16_384\", pretrained=pretrained)\n\n    hooks = [2, 5, 8, 11] if hooks == None else hooks\n    return _make_vit_b16_backbone(\n        model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout\n    )\n\n\ndef _make_pretrained_deitb16_384(pretrained, use_readout=\"ignore\", hooks=None):\n    model = timm.create_model(\"vit_deit_base_patch16_384\", pretrained=pretrained)\n\n    hooks = [2, 5, 8, 11] if hooks == None else hooks\n    return _make_vit_b16_backbone(\n        model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout\n    )\n\n\ndef _make_pretrained_deitb16_distil_384(pretrained, use_readout=\"ignore\", hooks=None):\n    model = timm.create_model(\n        \"vit_deit_base_distilled_patch16_384\", pretrained=pretrained\n    )\n\n    hooks = [2, 5, 8, 11] if hooks == None else hooks\n    return _make_vit_b16_backbone(\n        model,\n        features=[96, 192, 384, 768],\n        hooks=hooks,\n        use_readout=use_readout,\n        start_index=2,\n    )\n\n\ndef _make_vit_b_rn50_backbone(\n    model,\n    features=[256, 512, 768, 768],\n    size=[384, 384],\n    hooks=[0, 1, 8, 11],\n    vit_features=768,\n    use_vit_only=False,\n    use_readout=\"ignore\",\n    start_index=1,\n):\n    pretrained = nn.Module()\n\n    pretrained.model = model\n\n    if use_vit_only == True:\n        pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation(\"1\"))\n        pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation(\"2\"))\n    else:\n        pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(\n            get_activation(\"1\")\n        )\n        pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(\n            get_activation(\"2\")\n        )\n\n    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation(\"3\"))\n    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation(\"4\"))\n\n    pretrained.activations = activations\n\n    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)\n\n    if use_vit_only == True:\n        pretrained.act_postprocess1 = nn.Sequential(\n            readout_oper[0],\n            Transpose(1, 2),\n            nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n            nn.Conv2d(\n                in_channels=vit_features,\n                out_channels=features[0],\n                kernel_size=1,\n                stride=1,\n                padding=0,\n            ),\n            nn.ConvTranspose2d(\n                in_channels=features[0],\n                out_channels=features[0],\n                kernel_size=4,\n                stride=4,\n                padding=0,\n                bias=True,\n                dilation=1,\n                groups=1,\n            ),\n        )\n\n        pretrained.act_postprocess2 = nn.Sequential(\n            readout_oper[1],\n            Transpose(1, 2),\n            nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n            nn.Conv2d(\n                in_channels=vit_features,\n                out_channels=features[1],\n                kernel_size=1,\n                stride=1,\n                padding=0,\n            ),\n            nn.ConvTranspose2d(\n                in_channels=features[1],\n                out_channels=features[1],\n                kernel_size=2,\n                stride=2,\n                padding=0,\n                bias=True,\n                dilation=1,\n                groups=1,\n            ),\n        )\n    else:\n        pretrained.act_postprocess1 = nn.Sequential(\n            nn.Identity(), nn.Identity(), nn.Identity()\n        )\n        pretrained.act_postprocess2 = nn.Sequential(\n            nn.Identity(), nn.Identity(), nn.Identity()\n        )\n\n    pretrained.act_postprocess3 = nn.Sequential(\n        readout_oper[2],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[2],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n    )\n\n    pretrained.act_postprocess4 = nn.Sequential(\n        readout_oper[3],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[3],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n        nn.Conv2d(\n            in_channels=features[3],\n            out_channels=features[3],\n            kernel_size=3,\n            stride=2,\n            padding=1,\n        ),\n    )\n\n    pretrained.model.start_index = start_index\n    pretrained.model.patch_size = [16, 16]\n\n    # We inject this function into the VisionTransformer instances so that\n    # we can use it with interpolated position embeddings without modifying the library source.\n    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)\n\n    # We inject this function into the VisionTransformer instances so that\n    # we can use it with interpolated position embeddings without modifying the library source.\n    pretrained.model._resize_pos_embed = types.MethodType(\n        _resize_pos_embed, pretrained.model\n    )\n\n    return pretrained\n\n\ndef _make_pretrained_vitb_rn50_384(\n    pretrained, use_readout=\"ignore\", hooks=None, use_vit_only=False\n):\n    model = timm.create_model(\"vit_base_resnet50_384\", pretrained=pretrained)\n\n    hooks = [0, 1, 8, 11] if hooks == None else hooks\n    return _make_vit_b_rn50_backbone(\n        model,\n        features=[256, 512, 768, 768],\n        size=[384, 384],\n        hooks=hooks,\n        use_vit_only=use_vit_only,\n        use_readout=use_readout,\n    )\n\ndef _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout=\"ignore\",):\n    if backbone == \"vitl16_384\":\n        pretrained = _make_pretrained_vitl16_384(\n            use_pretrained, hooks=hooks, use_readout=use_readout\n        )\n        scratch = _make_scratch(\n            [256, 512, 1024, 1024], features, groups=groups, expand=expand\n        )  # ViT-L/16 - 85.0% Top1 (backbone)\n    elif backbone == \"vitb_rn50_384\":\n        pretrained = _make_pretrained_vitb_rn50_384(\n            use_pretrained,\n            hooks=hooks,\n            use_vit_only=use_vit_only,\n            use_readout=use_readout,\n        )\n        scratch = _make_scratch(\n            [256, 512, 768, 768], features, groups=groups, expand=expand\n        )  # ViT-H/16 - 85.0% Top1 (backbone)\n    elif backbone == \"vitb16_384\":\n        pretrained = _make_pretrained_vitb16_384(\n            use_pretrained, hooks=hooks, use_readout=use_readout\n        )\n        scratch = _make_scratch(\n            [96, 192, 384, 768], features, groups=groups, expand=expand\n        )  # ViT-B/16 - 84.6% Top1 (backbone)\n    elif backbone == \"resnext101_wsl\":\n        pretrained = _make_pretrained_resnext101_wsl(use_pretrained)\n        scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand)     # efficientnet_lite3  \n    elif backbone == \"efficientnet_lite3\":\n        pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)\n        scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand)  # efficientnet_lite3     \n    else:\n        print(f\"Backbone '{backbone}' not implemented\")\n        assert False\n        \n    return pretrained, scratch\n\n\ndef _make_scratch(in_shape, out_shape, groups=1, expand=False):\n    scratch = nn.Module()\n\n    out_shape1 = out_shape\n    out_shape2 = out_shape\n    out_shape3 = out_shape\n    out_shape4 = out_shape\n    if expand==True:\n        out_shape1 = out_shape\n        out_shape2 = out_shape*2\n        out_shape3 = out_shape*4\n        out_shape4 = out_shape*8\n\n    scratch.layer1_rn = nn.Conv2d(\n        in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n    scratch.layer2_rn = nn.Conv2d(\n        in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n    scratch.layer3_rn = nn.Conv2d(\n        in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n    scratch.layer4_rn = nn.Conv2d(\n        in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n\n    return scratch\n\n\ndef _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):\n    efficientnet = torch.hub.load(\n        \"rwightman/gen-efficientnet-pytorch\",\n        \"tf_efficientnet_lite3\",\n        pretrained=use_pretrained,\n        exportable=exportable\n    )\n    return _make_efficientnet_backbone(efficientnet)\n\n\ndef _make_efficientnet_backbone(effnet):\n    pretrained = nn.Module()\n\n    pretrained.layer1 = nn.Sequential(\n        effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]\n    )\n    pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])\n    pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])\n    pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])\n\n    return pretrained\n    \n\ndef _make_resnet_backbone(resnet):\n    pretrained = nn.Module()\n    pretrained.layer1 = nn.Sequential(\n        resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1\n    )\n\n    pretrained.layer2 = resnet.layer2\n    pretrained.layer3 = resnet.layer3\n    pretrained.layer4 = resnet.layer4\n\n    return pretrained\n\n\ndef _make_pretrained_resnext101_wsl(use_pretrained):\n    resnet = torch.hub.load(\"facebookresearch/WSL-Images\", \"resnext101_32x8d_wsl\")\n    return _make_resnet_backbone(resnet)\n\n\n\nclass Interpolate(nn.Module):\n    \"\"\"Interpolation module.\n    \"\"\"\n\n    def __init__(self, scale_factor, mode, align_corners=False):\n        \"\"\"Init.\n        Args:\n            scale_factor (float): scaling\n            mode (str): interpolation mode\n        \"\"\"\n        super(Interpolate, self).__init__()\n\n        self.interp = nn.functional.interpolate\n        self.scale_factor = scale_factor\n        self.mode = mode\n        self.align_corners = align_corners\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n        Args:\n            x (tensor): input\n        Returns:\n            tensor: interpolated data\n        \"\"\"\n\n        x = self.interp(\n            x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners\n        )\n\n        return x\n\n\nclass ResidualConvUnit(nn.Module):\n    \"\"\"Residual convolution module.\n    \"\"\"\n\n    def __init__(self, features):\n        \"\"\"Init.\n        Args:\n            features (int): number of features\n        \"\"\"\n        super().__init__()\n\n        self.conv1 = nn.Conv2d(\n            features, features, kernel_size=3, stride=1, padding=1, bias=True\n        )\n\n        self.conv2 = nn.Conv2d(\n            features, features, kernel_size=3, stride=1, padding=1, bias=True\n        )\n\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n        Args:\n            x (tensor): input\n        Returns:\n            tensor: output\n        \"\"\"\n        out = self.relu(x)\n        out = self.conv1(out)\n        out = self.relu(out)\n        out = self.conv2(out)\n\n        return out + x\n\n\nclass FeatureFusionBlock(nn.Module):\n    \"\"\"Feature fusion block.\n    \"\"\"\n\n    def __init__(self, features):\n        \"\"\"Init.\n        Args:\n            features (int): number of features\n        \"\"\"\n        super(FeatureFusionBlock, self).__init__()\n\n        self.resConfUnit1 = ResidualConvUnit(features)\n        self.resConfUnit2 = ResidualConvUnit(features)\n\n    def forward(self, *xs):\n        \"\"\"Forward pass.\n        Returns:\n            tensor: output\n        \"\"\"\n        output = xs[0]\n\n        if len(xs) == 2:\n            output += self.resConfUnit1(xs[1])\n\n        output = self.resConfUnit2(output)\n\n        output = nn.functional.interpolate(\n            output, scale_factor=2, mode=\"bilinear\", align_corners=True\n        )\n\n        return output\n\n\n\n\nclass ResidualConvUnit_custom(nn.Module):\n    \"\"\"Residual convolution module.\n    \"\"\"\n\n    def __init__(self, features, activation, bn):\n        \"\"\"Init.\n        Args:\n            features (int): number of features\n        \"\"\"\n        super().__init__()\n\n        self.bn = bn\n\n        self.groups=1\n\n        self.conv1 = nn.Conv2d(\n            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups\n        )\n        \n        self.conv2 = nn.Conv2d(\n            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups\n        )\n\n        if self.bn==True:\n            self.bn1 = nn.BatchNorm2d(features)\n            self.bn2 = nn.BatchNorm2d(features)\n\n        self.activation = activation\n\n        self.skip_add = nn.quantized.FloatFunctional()\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n        Args:\n            x (tensor): input\n        Returns:\n            tensor: output\n        \"\"\"\n        \n        out = self.activation(x)\n        out = self.conv1(out)\n        if self.bn==True:\n            out = self.bn1(out)\n       \n        out = self.activation(out)\n        out = self.conv2(out)\n        if self.bn==True:\n            out = self.bn2(out)\n\n        if self.groups > 1:\n            out = self.conv_merge(out)\n\n        return self.skip_add.add(out, x)\n\n        # return out + x\n\n\nclass FeatureFusionBlock_custom(nn.Module):\n    \"\"\"Feature fusion block.\n    \"\"\"\n\n    def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):\n        \"\"\"Init.\n        Args:\n            features (int): number of features\n        \"\"\"\n        super(FeatureFusionBlock_custom, self).__init__()\n\n        self.deconv = deconv\n        self.align_corners = align_corners\n\n        self.groups=1\n\n        self.expand = expand\n        out_features = features\n        if self.expand==True:\n            out_features = features//2\n        \n        self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)\n\n        self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)\n        self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)\n        \n        self.skip_add = nn.quantized.FloatFunctional()\n\n    def forward(self, *xs):\n        \"\"\"Forward pass.\n        Returns:\n            tensor: output\n        \"\"\"\n        output = xs[0]\n\n        if len(xs) == 2:\n            res = self.resConfUnit1(xs[1])\n            output = self.skip_add.add(output, res)\n            # output += res\n\n        output = self.resConfUnit2(output)\n\n        output = nn.functional.interpolate(\n            output, scale_factor=2, mode=\"bilinear\", align_corners=self.align_corners\n        )\n\n        output = self.out_conv(output)\n\n        return output        \n\n\n\ndef _make_fusion_block(features, use_bn):\n    return FeatureFusionBlock_custom(\n        features,\n        nn.ReLU(False),\n        deconv=False,\n        bn=use_bn,\n        expand=False,\n        align_corners=True,\n    )\n\n\nclass DPT(BaseModel):\n    def __init__(\n        self,\n        head,\n        features=256,\n        backbone=\"vitb_rn50_384\",\n        readout=\"project\",\n        channels_last=False,\n        use_bn=False,\n    ):\n\n        super(DPT, self).__init__()\n\n        self.channels_last = channels_last\n\n        hooks = {\n            \"vitb_rn50_384\": [0, 1, 8, 11],\n            \"vitb16_384\": [2, 5, 8, 11],\n            \"vitl16_384\": [5, 11, 17, 23],\n        }\n\n        # Instantiate backbone and reassemble blocks\n        self.pretrained, self.scratch = _make_encoder(\n            backbone,\n            features,\n            True, # Set to true of you want to train from scratch, uses ImageNet weights\n            groups=1,\n            expand=False,\n            exportable=False,\n            hooks=hooks[backbone],\n            use_readout=readout,\n        )\n\n        self.scratch.refinenet1 = _make_fusion_block(features, use_bn)\n        self.scratch.refinenet2 = _make_fusion_block(features, use_bn)\n        self.scratch.refinenet3 = _make_fusion_block(features, use_bn)\n        self.scratch.refinenet4 = _make_fusion_block(features, use_bn)\n\n        self.scratch.output_conv = head\n\n\n    def forward(self, x):\n        if self.channels_last == True:\n            x.contiguous(memory_format=torch.channels_last)\n\n        layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)\n\n        layer_1_rn = self.scratch.layer1_rn(layer_1)\n        layer_2_rn = self.scratch.layer2_rn(layer_2)\n        layer_3_rn = self.scratch.layer3_rn(layer_3)\n        layer_4_rn = self.scratch.layer4_rn(layer_4)\n\n        path_4 = self.scratch.refinenet4(layer_4_rn)\n        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)\n        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)\n        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)\n\n        out = self.scratch.output_conv(path_1)\n\n        return out\n\nclass DPTDepthModel(DPT):\n    def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs):\n        features = kwargs[\"features\"] if \"features\" in kwargs else 256\n\n        head = nn.Sequential(\n            nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),\n            Interpolate(scale_factor=2, mode=\"bilinear\", align_corners=True),\n            nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),\n            nn.ReLU(True),\n            nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0),\n            nn.ReLU(True) if non_negative else nn.Identity(),\n            nn.Identity(),\n        )\n\n        super().__init__(head, **kwargs)\n\n        if path is not None:\n           self.load(path)\n\n    def forward(self, x):\n        return super().forward(x).squeeze(dim=1)"
  },
  {
    "path": "encoding.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass FreqEncoder_torch(nn.Module):\n    def __init__(self, input_dim, max_freq_log2, N_freqs,\n                 log_sampling=True, include_input=True,\n                 periodic_fns=(torch.sin, torch.cos)):\n    \n        super().__init__()\n\n        self.input_dim = input_dim\n        self.include_input = include_input\n        self.periodic_fns = periodic_fns\n        self.N_freqs = N_freqs\n\n        self.output_dim = 0\n        if self.include_input:\n            self.output_dim += self.input_dim\n\n        self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)\n\n        if log_sampling:\n            self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs)\n        else:\n            self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs)\n\n        self.freq_bands = self.freq_bands.numpy().tolist()\n\n    def forward(self, input, max_level=None, **kwargs):\n\n        if max_level is None:\n            max_level = self.N_freqs\n        else:\n            max_level = int(max_level * self.N_freqs)\n\n        out = []\n        if self.include_input:\n            out.append(input)\n\n        for i in range(max_level):\n            freq = self.freq_bands[i]\n            for p_fn in self.periodic_fns:\n                out.append(p_fn(input * freq))\n\n        # append 0\n        if self.N_freqs - max_level > 0:\n            out.append(torch.zeros(*input.shape[:-1], (self.N_freqs - max_level) * 2 * input.shape[-1], device=input.device, dtype=input.dtype))\n        \n        out = torch.cat(out, dim=-1)\n\n        return out\n\ndef get_encoder(encoding, input_dim=3, \n                multires=6, \n                degree=4,\n                num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, interpolation='linear',\n                **kwargs):\n\n    if encoding == 'None':\n        return lambda x, **kwargs: x, input_dim\n    \n    elif encoding == 'frequency_torch':\n        encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)\n\n    elif encoding == 'frequency': # CUDA implementation, faster than torch.\n        from freqencoder import FreqEncoder\n        encoder = FreqEncoder(input_dim=input_dim, degree=multires)\n\n    elif encoding == 'sphere_harmonics':\n        from shencoder import SHEncoder\n        encoder = SHEncoder(input_dim=input_dim, degree=degree)\n\n    elif encoding == 'hashgrid':\n        from gridencoder import GridEncoder\n        encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation)\n    \n    elif encoding == 'tiledgrid':\n        from gridencoder import GridEncoder\n        encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation)\n    \n    elif encoding == 'hashgrid_taichi':\n        from taichi_modules.hash_encoder import HashEncoderTaichi\n        encoder = HashEncoderTaichi(batch_size=4096) #TODO: hard encoded batch size\n\n    else:\n        raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')\n\n    return encoder, encoder.output_dim"
  },
  {
    "path": "evaluation/Prompt.py",
    "content": "import textwrap\nfrom transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification\nfrom transformers import pipeline\nimport argparse\nimport sys\nimport warnings\nwarnings.filterwarnings(\"ignore\", category=UserWarning)\n\n\n#python Prompt.py --text \"a dog is in front of a rabbit\" --model vlt5\n\n\nif __name__ == '__main__':\n\n    # Mimic the calling part of the main, using\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--text', default=\"\", type=str, help=\"text prompt\")\n    #parser.add_argument('--workspace', default=\"trial\", type=str, help=\"workspace\")\n    parser.add_argument('--model', default='vlt5', type=str, help=\"model choices - vlt5, bert, XLNet\")\n\n    opt = parser.parse_args()\n\n    if opt.model == \"vlt5\":\n        tokenizer = AutoTokenizer.from_pretrained(\"Voicelab/vlt5-base-keywords\")\n        model = AutoModelForSeq2SeqLM.from_pretrained(\"Voicelab/vlt5-base-keywords\")\n\n        task_prefix = \"Keywords: \"\n        inputs = [\n        opt.text\n        ]\n\n        for sample in inputs:\n            input_sequences = [task_prefix + sample]\n            input_ids = tokenizer(\n                input_sequences, return_tensors=\"pt\", truncation=True\n            ).input_ids\n            output = model.generate(input_ids, no_repeat_ngram_size=3, num_beams=4)\n            output_text = tokenizer.decode(output[0], skip_special_tokens=True)\n            #print(sample, \"\\n --->\", output_text)\n\n    elif opt.model == \"bert\":\n        tokenizer = AutoTokenizer.from_pretrained(\"yanekyuk/bert-uncased-keyword-extractor\")\n        model = AutoModelForTokenClassification.from_pretrained(\"yanekyuk/bert-uncased-keyword-extractor\")\n\n        text = opt.text\n        input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors=\"pt\")\n\n        # Classify tokens\n        outputs = model(input_ids)\n        predictions = outputs.logits.detach().numpy()[0]\n        labels = predictions.argmax(axis=1)\n        labels = labels[1:-1]\n\n        print(labels)\n        tokens = tokenizer.convert_ids_to_tokens(input_ids[0])\n        tokens = tokens[1:-1]\n        output_tokens = [tokens[i] for i in range(len(tokens)) if labels[i] != 0]\n        output_text = tokenizer.convert_tokens_to_string(output_tokens)\n\n        #print(output_text)\n\n\n    elif opt.model == \"XLNet\":\n        tokenizer = AutoTokenizer.from_pretrained(\"jasminejwebb/KeywordIdentifier\")\n        model = AutoModelForTokenClassification.from_pretrained(\"jasminejwebb/KeywordIdentifier\")\n\n        text = opt.text\n        input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors=\"pt\")\n\n        # Classify tokens\n        outputs = model(input_ids)\n        predictions = outputs.logits.detach().numpy()[0]\n        labels = predictions.argmax(axis=1)\n        labels = labels[1:-1]\n\n        print(labels)\n        tokens = tokenizer.convert_ids_to_tokens(input_ids[0])\n        tokens = tokens[1:-1]\n        output_tokens = [tokens[i] for i in range(len(tokens)) if labels[i] != 0]\n        output_text = tokenizer.convert_tokens_to_string(output_tokens)\n\n        #print(output_text)\n\nwrapped_text = textwrap.fill(output_text, width=50)\n\n\nprint('+' + '-'*52 + '+')\nfor line in wrapped_text.split('\\n'):\n    print('| {} |'.format(line.ljust(50)))\nprint('+' + '-'*52 + '+')\n#print(result)\n"
  },
  {
    "path": "evaluation/mesh_to_video.py",
    "content": "import os\nimport numpy as np\nimport trimesh\nimport argparse\nfrom pathlib import Path\nfrom tqdm import tqdm\nimport pyvista as pv\n\ndef render_video(anim_mesh):\n    center = anim_mesh.center_mass\n    plotter = pv.Plotter(off_screen=True)\n    plotter.add_mesh(anim_mesh)\n\n    radius = 10\n    n_frames = 360  \n    angle_step = 2 * np.pi / n_frames  \n    for i in tqdm(range(n_frames)):\n        camera_pos = [center[0] + radius * np.cos(i*angle_step),center[1] + radius *np.sin(i*angle_step),center[2]]\n        plotter.camera_position = (camera_pos, center, (0, 0, 1))\n        plotter.show(screenshot=f'frame_{i}.png', auto_close=False)\n    plotter.close()\n    os.system('ffmpeg -r 30 -f image2 -s 1920x1080 -i \"result/frame_%d.png\" -vcodec libx264 -crf 25  -pix_fmt yuv420p result/output.mp4')\n\n\n\ndef generate_mesh(obj1,obj2,transform_vector):\n\n    # Read 2 objects\n    filename1 = obj1 # Central Object\n    filename2 = obj2 # Surrounding Object\n    mesh1 = trimesh.load_mesh(filename1)\n    mesh2 = trimesh.load_mesh(filename2)\n\n    extents1 = mesh1.extents\n    extents2 = mesh1.extents\n    \n    radius1 = sum(extents1) / 3.0\n    radius2 = sum(extents2) / 3.0\n\n    center1 = mesh1.center_mass\n    center2 = mesh2.center_mass\n\n    # Move\n    T1 = -center1\n    new =[]\n    for i in transform_vector:\n        try:\n            new.append(float(i))*radius1\n        except:\n            pass\n    transform_vector = new\n    print(T1, transform_vector, radius1)\n    T2 = -center2 + transform_vector\n\n    # Transform\n    mesh1.apply_translation(T1)\n    mesh2.apply_translation(T2)\n\n    # merge mesh\n    merged_mesh = trimesh.util.concatenate((mesh1, mesh2))\n\n    # save mesh\n    merged_mesh.export('merged_mesh.obj')\n    print(\"----> merge mesh done\")\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Generate rotating mesh animation.')\n    parser.add_argument('--center_obj', type=str, help='Input OBJ1 file.')\n    parser.add_argument('--surround_obj', type=str, help='Input OBJ2 file.')\n    parser.add_argument('--transform_vector', help='Transform_vector.')\n    parser.add_argument('--output_file', type=str, default=\"result/Demo.mp4\", help='Output MP4 file.')\n    parser.add_argument('--num_frames', type=int, default=100, help='Number of frames to render.')\n    args = parser.parse_args()\n    \n    #mesh = obj.Obj(\"wr.obj\")\n    generate_mesh(args.center_obj,args.surround_obj,args.transform_vector)\n\n    input_file = Path(\"merged_mesh.obj\")\n    output_file = Path(args.output_file)\n\n    out_dir = output_file.parent.joinpath('frames')\n    out_dir.mkdir(parents=True, exist_ok=True)\n\n    anim_mesh = trimesh.load_mesh(str(input_file))\n\n    render_video(anim_mesh)\n\n"
  },
  {
    "path": "evaluation/r_precision.py",
    "content": "from sentence_transformers import SentenceTransformer, util\nfrom PIL import Image\nimport argparse\nimport sys\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--text', default=\"\", type=str, help=\"text prompt\")\n    parser.add_argument('--workspace', default=\"trial\", type=str, help=\"text prompt\")\n    parser.add_argument('--latest', default='ep0001', type=str, help=\"which epoch result you want to use for image path\")\n    parser.add_argument('--mode', default='rgb', type=str, help=\"mode of result, color(rgb) or textureless()\")\n    parser.add_argument('--clip', default=\"clip-ViT-B-32\", type=str, help=\"CLIP model to encode the img and prompt\")\n\n    opt = parser.parse_args()\n\n    #Load CLIP model\n    model = SentenceTransformer(f'{opt.clip}')\n\n    #Encode an image:\n    img_emb = model.encode(Image.open(f'../results/{opt.workspace}/validation/df_{opt.latest}_0005_{opt.mode}.png'))\n\n    #Encode text descriptions\n    text_emb = model.encode([f'{opt.text}'])\n\n    #Compute cosine similarities\n    cos_scores = util.cos_sim(img_emb, text_emb)\n    print(\"The final CLIP R-Precision is:\", cos_scores[0][0].cpu().numpy())\n\n"
  },
  {
    "path": "evaluation/readme.md",
    "content": "### Improvement:\n\n- Usage\n\n  - r_precision.py <br>\n  For prompt seperation <br>\n  --text is for the prompt following the author of stable dream fusion <br>\n  --workspace is the workspace folder which will be created for every prompt fed into stable dreamfusion <br>\n  --latest is which ckpt is used. Stable dream fusion record every epoch data. Normally is ep0100 unless the training is not finished or we further extend the training <br>\n  --mode has choices of rgb and depth which is correspondent to color and texture result as original paper Figure 5: Qualitative comparison with baselines. <br>\n  --clip has choices of clip-ViT-B-32, CLIP B/16, CLIP L/14, same as original paper <br>\n\n      ```bash\n      python Prompt.py --text \"matte painting of a castle made of cheesecake surrounded by a moat made of ice cream\" --workspace ../castle --latest ep0100 --mode rgb --clip clip-ViT-B-32\n      ```\n\n  - Prompt.py (model name case sensitive) <br>\n  For prompt seperation <br> <br>\n  --text is for the prompt following the author of stable dream fusion <br>\n  --model is for choose the pretrain models <br>\n\n      ```bash\n      python Prompt.py --text \"a dog is in front of a rabbit\" --model vlt5\n      python Prompt.py --text \"a dog is in front of a rabbit\" --model bert\n      python Prompt.py --text \"a dog is in front of a rabbit\" --model XLNet\n      ```\n\n\n  - mesh_to_video.py <br>\n  --center_obj IS THE CENTER OBJECT <br>\n  --surround_obj IS THE SURROUNDING OBJECT SUBJECT TO CHANGE <br>\n  --transform_vector THE X Y Z 3d vector for transform <br>\n\n      ```bash\n      python mesh_to_video.py --center_obj 'mesh_whiterabbit/mesh.obj' --surround_obj 'mesh_snake/mesh.obj' --transform_vector [1,0,0]\n      ```\n"
  },
  {
    "path": "freqencoder/__init__.py",
    "content": "from .freq import FreqEncoder"
  },
  {
    "path": "freqencoder/backend.py",
    "content": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    '-O3', '-std=c++14',\n    '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',\n    '-use_fast_math'\n]\n\nif os.name == \"posix\":\n    c_flags = ['-O3', '-std=c++14']\nelif os.name == \"nt\":\n    c_flags = ['/O2', '/std:c++17']\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n        for program_files in [r\"C:\\\\Program Files (x86)\", r\"C:\\\\Program Files\"]:\n            for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n                paths = sorted(glob.glob(r\"%s\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\" % (program_files, edition)), reverse=True)\n                if paths:\n                    return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\"Could not locate a supported Microsoft Visual C++ installation\")\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n_backend = load(name='_freqencoder',\n                extra_cflags=c_flags,\n                extra_cuda_cflags=nvcc_flags,\n                sources=[os.path.join(_src_path, 'src', f) for f in [\n                    'freqencoder.cu',\n                    'bindings.cpp',\n                ]],\n                )\n\n__all__ = ['_backend']"
  },
  {
    "path": "freqencoder/freq.py",
    "content": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function import once_differentiable\nfrom torch.cuda.amp import custom_bwd, custom_fwd \n\ntry:\n    import _freqencoder as _backend\nexcept ImportError:\n    from .backend import _backend\n\n\nclass _freq_encoder(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision\n    def forward(ctx, inputs, degree, output_dim):\n        # inputs: [B, input_dim], float \n        # RETURN: [B, F], float\n\n        if not inputs.is_cuda: inputs = inputs.cuda()\n        inputs = inputs.contiguous()\n\n        B, input_dim = inputs.shape # batch size, coord dim\n        \n        outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)\n\n        _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)\n\n        ctx.save_for_backward(inputs, outputs)\n        ctx.dims = [B, input_dim, degree, output_dim]\n\n        return outputs\n    \n    @staticmethod\n    #@once_differentiable\n    @custom_bwd\n    def backward(ctx, grad):\n        # grad: [B, C * C]\n\n        grad = grad.contiguous()\n        inputs, outputs = ctx.saved_tensors\n        B, input_dim, degree, output_dim = ctx.dims\n\n        grad_inputs = torch.zeros_like(inputs)\n        _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)\n\n        return grad_inputs, None, None\n    \n\nfreq_encode = _freq_encoder.apply\n\n\nclass FreqEncoder(nn.Module):\n    def __init__(self, input_dim=3, degree=4):\n        super().__init__()\n\n        self.input_dim = input_dim\n        self.degree = degree\n        self.output_dim = input_dim + input_dim * 2 * degree\n        \n    def __repr__(self):\n        return f\"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}\"\n    \n    def forward(self, inputs, **kwargs):\n        # inputs: [..., input_dim]\n        # return: [..., ]\n\n        prefix_shape = list(inputs.shape[:-1])\n        inputs = inputs.reshape(-1, self.input_dim)\n\n        outputs = freq_encode(inputs, self.degree, self.output_dim)\n\n        outputs = outputs.reshape(prefix_shape + [self.output_dim])\n\n        return outputs"
  },
  {
    "path": "freqencoder/setup.py",
    "content": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    '-O3', '-std=c++14',\n    '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',\n    '-use_fast_math'\n]\n\nif os.name == \"posix\":\n    c_flags = ['-O3', '-std=c++14']\nelif os.name == \"nt\":\n    c_flags = ['/O2', '/std:c++17']\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n        for program_files in [r\"C:\\\\Program Files (x86)\", r\"C:\\\\Program Files\"]:\n            for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n                paths = sorted(glob.glob(r\"%s\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\" % (program_files, edition)), reverse=True)\n                if paths:\n                    return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\"Could not locate a supported Microsoft Visual C++ installation\")\n        os.environ[\"PATH\"] += \";\" + cl_path\n\nsetup(\n    name='freqencoder', # package name, import this to use python API\n    ext_modules=[\n        CUDAExtension(\n            name='_freqencoder', # extension name, import this to use CUDA API\n            sources=[os.path.join(_src_path, 'src', f) for f in [\n                'freqencoder.cu',\n                'bindings.cpp',\n            ]],\n            extra_compile_args={\n                'cxx': c_flags,\n                'nvcc': nvcc_flags,\n            }\n        ),\n    ],\n    cmdclass={\n        'build_ext': BuildExtension,\n    }\n)"
  },
  {
    "path": "freqencoder/src/bindings.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"freqencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"freq_encode_forward\", &freq_encode_forward, \"freq encode forward (CUDA)\");\n    m.def(\"freq_encode_backward\", &freq_encode_backward, \"freq encode backward (CUDA)\");\n}"
  },
  {
    "path": "freqencoder/src/freqencoder.cu",
    "content": "#include <stdint.h>\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/torch.h>\n\n#include <algorithm>\n#include <stdexcept>\n\n#include <cstdio>\n\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\")\n#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x \" must be an int tensor\")\n#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x \" must be a floating tensor\")\n\ninline constexpr __device__ float PI() { return 3.141592653589793f; }\n\ntemplate <typename T>\n__host__ __device__ T div_round_up(T val, T divisor) {\n    return (val + divisor - 1) / divisor;\n}\n\n// inputs: [B, D]\n// outputs: [B, C], C = D + D * deg * 2\n__global__ void kernel_freq(\n    const float * __restrict__ inputs, \n    uint32_t B, uint32_t D, uint32_t deg, uint32_t C,\n    float * outputs\n) {\n    // parallel on per-element\n    const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;\n    if (t >= B * C) return;\n\n    // get index\n    const uint32_t b = t / C;\n    const uint32_t c = t - b * C; // t % C;\n\n    // locate\n    inputs += b * D;\n    outputs += t;\n\n    // write self\n    if (c < D) {\n        outputs[0] = inputs[c];\n    // write freq\n    } else {\n        const uint32_t col = c / D - 1;\n        const uint32_t d = c % D;\n        const uint32_t freq = col / 2;\n        const float phase_shift = (col % 2) * (PI() / 2);\n        outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);\n    }\n}\n\n// grad: [B, C], C = D + D * deg * 2\n// outputs: [B, C]\n// grad_inputs: [B, D]\n__global__ void kernel_freq_backward(\n    const float * __restrict__ grad,\n    const float * __restrict__ outputs,\n    uint32_t B, uint32_t D, uint32_t deg, uint32_t C,\n    float * grad_inputs\n) {\n    // parallel on per-element\n    const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;\n    if (t >= B * D) return;\n\n    const uint32_t b = t / D;\n    const uint32_t d = t - b * D; // t % D;\n\n    // locate\n    grad += b * C;\n    outputs += b * C;\n    grad_inputs += t;\n\n    // register \n    float result = grad[d];\n    grad += D;\n    outputs += D;\n\n    for (uint32_t f = 0; f < deg; f++) {\n        result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);\n        grad += 2 * D;\n        outputs += 2 * D;\n    }\n\n    // write\n    grad_inputs[0] = result;\n}\n\n\nvoid freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(outputs);\n    \n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(outputs);\n\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(outputs);\n\n    static constexpr uint32_t N_THREADS = 128;\n\n    kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());\n}\n\n\nvoid freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {\n    CHECK_CUDA(grad);\n    CHECK_CUDA(outputs);\n    CHECK_CUDA(grad_inputs);\n    \n    CHECK_CONTIGUOUS(grad);\n    CHECK_CONTIGUOUS(outputs);\n    CHECK_CONTIGUOUS(grad_inputs);\n\n    CHECK_IS_FLOATING(grad);\n    CHECK_IS_FLOATING(outputs);\n    CHECK_IS_FLOATING(grad_inputs);\n\n    static constexpr uint32_t N_THREADS = 128;\n\n    kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());\n}"
  },
  {
    "path": "freqencoder/src/freqencoder.h",
    "content": "# pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)\nvoid freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);\n\n// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)\nvoid freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);"
  },
  {
    "path": "gridencoder/__init__.py",
    "content": "from .grid import GridEncoder"
  },
  {
    "path": "gridencoder/backend.py",
    "content": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    '-O3', '-std=c++14',\n    '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',\n]\n\nif os.name == \"posix\":\n    c_flags = ['-O3', '-std=c++14']\nelif os.name == \"nt\":\n    c_flags = ['/O2', '/std:c++17']\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n        for program_files in [r\"C:\\\\Program Files (x86)\", r\"C:\\\\Program Files\"]:\n            for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n                paths = sorted(glob.glob(r\"%s\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\" % (program_files, edition)), reverse=True)\n                if paths:\n                    return paths[0]\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\"Could not locate a supported Microsoft Visual C++ installation\")\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n_backend = load(name='_grid_encoder',\n                extra_cflags=c_flags,\n                extra_cuda_cflags=nvcc_flags,\n                sources=[os.path.join(_src_path, 'src', f) for f in [\n                    'gridencoder.cu',\n                    'bindings.cpp',\n                ]],\n                )\n\n__all__ = ['_backend']"
  },
  {
    "path": "gridencoder/grid.py",
    "content": "import math\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function import once_differentiable\nfrom torch.cuda.amp import custom_bwd, custom_fwd \n\ntry:\n    import _gridencoder as _backend\nexcept ImportError:\n    from .backend import _backend\n\n_gridtype_to_id = {\n    'hash': 0,\n    'tiled': 1,\n}\n\n_interp_to_id = {\n    'linear': 0,\n    'smoothstep': 1,\n}\n\nclass _grid_encode(Function):\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0, max_level=None):\n        # inputs: [B, D], float in [0, 1]\n        # embeddings: [sO, C], float\n        # offsets: [L + 1], int\n        # RETURN: [B, F], float\n\n        inputs = inputs.contiguous()\n\n        B, D = inputs.shape # batch size, coord dim\n        L = offsets.shape[0] - 1 # level\n        C = embeddings.shape[1] # embedding dim for each level\n        S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f\n        H = base_resolution # base resolution\n\n        max_level = L if max_level is None else max(min(int(math.ceil(max_level * L)), L), 1)\n\n        # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)\n        # if C % 2 != 0, force float, since half for atomicAdd is very slow.\n        if torch.is_autocast_enabled() and C % 2 == 0:\n            embeddings = embeddings.to(torch.half)\n\n        # L first, optimize cache for cuda kernel, but needs an extra permute later\n        outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)\n\n        # zero init if we only calculate partial levels\n        if max_level < L: outputs.zero_()\n\n        if calc_grad_inputs:\n            dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)\n            if max_level < L: dy_dx.zero_()\n        else:\n            dy_dx = None\n\n        _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interpolation)\n\n        # permute back to [B, L * C]\n        outputs = outputs.permute(1, 0, 2).reshape(B, L * C)\n\n        ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)\n        ctx.dims = [B, D, C, L, S, H, gridtype, interpolation, max_level]\n        ctx.align_corners = align_corners\n\n        return outputs\n    \n    @staticmethod\n    #@once_differentiable\n    @custom_bwd\n    def backward(ctx, grad):\n\n        inputs, embeddings, offsets, dy_dx = ctx.saved_tensors\n        B, D, C, L, S, H, gridtype, interpolation, max_level = ctx.dims\n        align_corners = ctx.align_corners\n\n        # grad: [B, L * C] --> [L, B, C]\n        grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()\n\n        grad_embeddings = torch.zeros_like(embeddings)\n\n        if dy_dx is not None:\n            grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)\n        else:\n            grad_inputs = None\n\n        _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation)\n\n        if dy_dx is not None:\n            grad_inputs = grad_inputs.to(inputs.dtype)\n\n        return grad_inputs, grad_embeddings, None, None, None, None, None, None, None, None\n        \n\n\ngrid_encode = _grid_encode.apply\n\n\nclass GridEncoder(nn.Module):\n    def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'):\n        super().__init__()\n\n        # the finest resolution desired at the last level, if provided, overridee per_level_scale\n        if desired_resolution is not None:\n            per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))\n\n        self.input_dim = input_dim # coord dims, 2 or 3\n        self.num_levels = num_levels # num levels, each level multiply resolution by 2\n        self.level_dim = level_dim # encode channels per level\n        self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.\n        self.log2_hashmap_size = log2_hashmap_size\n        self.base_resolution = base_resolution\n        self.output_dim = num_levels * level_dim\n        self.gridtype = gridtype\n        self.gridtype_id = _gridtype_to_id[gridtype] # \"tiled\" or \"hash\"\n        self.interpolation = interpolation\n        self.interp_id = _interp_to_id[interpolation] # \"linear\" or \"smoothstep\"\n        self.align_corners = align_corners\n\n        # allocate parameters\n        offsets = []\n        offset = 0\n        self.max_params = 2 ** log2_hashmap_size\n        for i in range(num_levels):\n            resolution = int(np.ceil(base_resolution * per_level_scale ** i))\n            params_in_level = min(self.max_params, (resolution) ** input_dim) # limit max number\n            params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible\n            offsets.append(offset)\n            offset += params_in_level\n        offsets.append(offset)\n        offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))\n        self.register_buffer('offsets', offsets)\n        \n        self.n_params = offsets[-1] * level_dim\n\n        # parameters\n        self.embeddings = nn.Parameter(torch.empty(offset, level_dim))\n\n        self.reset_parameters()\n    \n    def reset_parameters(self):\n        std = 1e-4\n        self.embeddings.data.uniform_(-std, std)\n\n    def __repr__(self):\n        return f\"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}\"\n    \n    def forward(self, inputs, bound=1, max_level=None):\n        # inputs: [..., input_dim], normalized real world positions in [-bound, bound]\n        # max_level: only calculate first max_level levels (None will use all levels)\n        # return: [..., num_levels * level_dim]\n\n        inputs = (inputs + bound) / (2 * bound) # map to [0, 1]\n        \n        #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())\n\n        prefix_shape = list(inputs.shape[:-1])\n        inputs = inputs.view(-1, self.input_dim)\n\n        outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id, max_level)\n        outputs = outputs.view(prefix_shape + [self.output_dim])\n\n        #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())\n\n        return outputs\n\n    # always run in float precision!\n    @torch.cuda.amp.autocast(enabled=False)\n    def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000):\n        # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss.\n        \n        D = self.input_dim\n        C = self.embeddings.shape[1] # embedding dim for each level\n        L = self.offsets.shape[0] - 1 # level\n        S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f\n        H = self.base_resolution # base resolution\n\n        if inputs is None:\n            # randomized in [0, 1]\n            inputs = torch.rand(B, self.input_dim, device=self.embeddings.device)\n        else:\n            inputs = (inputs + bound) / (2 * bound) # map to [0, 1]\n            inputs = inputs.view(-1, self.input_dim)\n            B = inputs.shape[0]\n\n        if self.embeddings.grad is None:\n            raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')\n\n        _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners)\n    \n    @torch.cuda.amp.autocast(enabled=False)\n    def grad_weight_decay(self, weight=0.1):\n        # level-wise meaned weight decay (ref: zip-nerf)\n        \n        B = self.embeddings.shape[0] # size of embedding\n        C = self.embeddings.shape[1] # embedding dim for each level\n        L = self.offsets.shape[0] - 1 # level\n        \n        if self.embeddings.grad is None:\n            raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')\n\n        _backend.grad_weight_decay(self.embeddings, self.embeddings.grad, self.offsets, weight, B, C, L)"
  },
  {
    "path": "gridencoder/setup.py",
    "content": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    '-O3', '-std=c++14',\n    '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',\n]\n\nif os.name == \"posix\":\n    c_flags = ['-O3', '-std=c++14']\nelif os.name == \"nt\":\n    c_flags = ['/O2', '/std:c++17']\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n        for program_files in [r\"C:\\\\Program Files (x86)\", r\"C:\\\\Program Files\"]:\n            for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n                paths = sorted(glob.glob(r\"%s\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\" % (program_files, edition)), reverse=True)\n                if paths:\n                    return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\"Could not locate a supported Microsoft Visual C++ installation\")\n        os.environ[\"PATH\"] += \";\" + cl_path\n\nsetup(\n    name='gridencoder', # package name, import this to use python API\n    ext_modules=[\n        CUDAExtension(\n            name='_gridencoder', # extension name, import this to use CUDA API\n            sources=[os.path.join(_src_path, 'src', f) for f in [\n                'gridencoder.cu',\n                'bindings.cpp',\n            ]],\n            extra_compile_args={\n                'cxx': c_flags,\n                'nvcc': nvcc_flags,\n            }\n        ),\n    ],\n    cmdclass={\n        'build_ext': BuildExtension,\n    }\n)"
  },
  {
    "path": "gridencoder/src/bindings.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"gridencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"grid_encode_forward\", &grid_encode_forward, \"grid_encode_forward (CUDA)\");\n    m.def(\"grid_encode_backward\", &grid_encode_backward, \"grid_encode_backward (CUDA)\");\n    m.def(\"grad_total_variation\", &grad_total_variation, \"grad_total_variation (CUDA)\");\n    m.def(\"grad_weight_decay\", &grad_weight_decay, \"grad_weight_decay (CUDA)\");\n}"
  },
  {
    "path": "gridencoder/src/gridencoder.cu",
    "content": "#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/torch.h>\n\n#include <algorithm>\n#include <stdexcept>\n\n#include <stdint.h>\n#include <cstdio>\n\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\")\n#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x \" must be an int tensor\")\n#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x \" must be a floating tensor\")\n\n\n// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here!\n __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) {\n  // requires CUDA >= 10 and ARCH >= 70\n  // this is very slow compared to float or __half2, never use it.\n  //return atomicAdd(reinterpret_cast<__half*>(address), val);\n}\n\n\ntemplate <typename T>\n__host__ __device__ inline T div_round_up(T val, T divisor) {\n    return (val + divisor - 1) / divisor;\n}\n\ntemplate <typename T>\n__device__ inline T smoothstep(T val) {\n\treturn val*val*(3.0f - 2.0f * val);\n}\n\ntemplate <typename T>\n__device__ inline T smoothstep_derivative(T val) {\n\treturn 6*val*(1.0f - val);\n}\n\n\ntemplate <uint32_t D>\n__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {\n    \n    // coherent type of hashing\n    constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u };\n\n    uint32_t result = 0;\n    #pragma unroll\n    for (uint32_t i = 0; i < D; ++i) {\n        result ^= pos_grid[i] * primes[i];\n    }\n\n    return result;\n}\n\n\ntemplate <uint32_t D, uint32_t C>\n__device__ uint32_t get_grid_index(const uint32_t gridtype, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {\n    uint32_t stride = 1;\n    uint32_t index = 0;\n\n    #pragma unroll\n    for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {\n        index += pos_grid[d] * stride;\n        stride *= resolution;\n    }\n\n    // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.\n    // gridtype: 0 == hash, 1 == tiled\n    if (gridtype == 0 && stride > hashmap_size) {\n        index = fast_hash<D>(pos_grid);\n    }\n\n    return (index % hashmap_size) * C + ch;\n}\n\n\ntemplate <typename scalar_t, uint32_t D, uint32_t C>\n__global__ void kernel_grid(\n    const float * __restrict__ inputs, \n    const scalar_t * __restrict__ grid, \n    const int * __restrict__ offsets, \n    scalar_t * __restrict__ outputs, \n    const uint32_t B, const uint32_t L, const float S, const uint32_t H,\n    scalar_t * __restrict__ dy_dx,\n    const uint32_t gridtype,\n    const bool align_corners,\n    const uint32_t interp\n) {\n    const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;\n    \n    if (b >= B) return;\n\n    const uint32_t level = blockIdx.y;\n    \n    // locate\n    grid += (uint32_t)offsets[level] * C;\n    inputs += b * D;\n    outputs += level * B * C + b * C;\n\n    // check input range (should be in [0, 1])\n    bool flag_oob = false;\n    #pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        if (inputs[d] < 0 || inputs[d] > 1) {\n            flag_oob = true;\n        }\n    }\n    // if input out of bound, just set output to 0\n    if (flag_oob) {\n        #pragma unroll\n        for (uint32_t ch = 0; ch < C; ch++) {\n            outputs[ch] = 0; \n        }\n        if (dy_dx) {\n            dy_dx += b * D * L * C + level * D * C; // B L D C\n            #pragma unroll\n            for (uint32_t d = 0; d < D; d++) {\n                #pragma unroll\n                for (uint32_t ch = 0; ch < C; ch++) {\n                    dy_dx[d * C + ch] = 0; \n                }       \n            }\n        }\n        return;\n    }\n\n    const uint32_t hashmap_size = offsets[level + 1] - offsets[level];\n    const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);\n    \n    // calculate coordinate (always use float for precision!)\n    float pos[D];\n    float pos_deriv[D];\n    uint32_t pos_grid[D];\n\n    #pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        \n        // align_corners\n        if (align_corners) {\n            pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]\n            pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]\n        } else {\n            pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]\n            pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]\n        }\n        pos[d] -= (float)pos_grid[d];\n\n        // smoothstep instead of linear\n        if (interp == 1) {\n            pos_deriv[d] = smoothstep_derivative(pos[d]);\n            pos[d] = smoothstep(pos[d]);\n        } else {\n            pos_deriv[d] = 1.0f;\n        }\n    }\n\n    // verification of alignment\n    // if (level == L - 1 && b < 4) {\n    //     printf(\"[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\\n\", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);\n    // }\n\n    // interpolate\n    scalar_t results[C] = {0}; // temp results in register\n\n    #pragma unroll\n    for (uint32_t idx = 0; idx < (1 << D); idx++) {\n        float w = 1;\n        uint32_t pos_grid_local[D];\n\n        #pragma unroll\n        for (uint32_t d = 0; d < D; d++) {\n            if ((idx & (1 << d)) == 0) {\n                w *= 1 - pos[d];\n                pos_grid_local[d] = pos_grid[d];\n            } else {\n                w *= pos[d];\n                pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);\n            }\n        }\n\n        uint32_t index = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid_local);\n\n        // writing to register (fast)\n        #pragma unroll\n        for (uint32_t ch = 0; ch < C; ch++) {\n            results[ch] += w * grid[index + ch];\n        }\n\n        //printf(\"[b=%d, l=%d] int %d, idx %d, w %f, val %f\\n\", b, level, idx, index, w, grid[index]);\n    }    \n\n    // writing to global memory (slow)\n    #pragma unroll\n    for (uint32_t ch = 0; ch < C; ch++) {\n        outputs[ch] = results[ch]; \n    }\n\n    // prepare dy_dx\n    // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9\n    if (dy_dx) {\n\n        dy_dx += b * D * L * C + level * D * C; // B L D C\n\n        #pragma unroll\n        for (uint32_t gd = 0; gd < D; gd++) {\n\n            scalar_t results_grad[C] = {0};\n\n            #pragma unroll\n            for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {\n                float w = (float)(align_corners ? resolution - 1 : resolution);\n                uint32_t pos_grid_local[D];\n\n                #pragma unroll\n                for (uint32_t nd = 0; nd < D - 1; nd++) {\n                    const uint32_t d = (nd >= gd) ? (nd + 1) : nd;\n\n                    if ((idx & (1 << nd)) == 0) {\n                        w *= 1 - pos[d];\n                        pos_grid_local[d] = pos_grid[d];\n                    } else {\n                        w *= pos[d];\n                        pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);\n                    }\n                }\n\n                pos_grid_local[gd] = pos_grid[gd];\n                uint32_t index_left = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid_local);\n                pos_grid_local[gd] = min(pos_grid[gd] + 1, resolution - 1);\n                uint32_t index_right = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid_local);\n\n                #pragma unroll\n                for (uint32_t ch = 0; ch < C; ch++) {\n                    results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd];\n                }\n            }\n\n            #pragma unroll\n            for (uint32_t ch = 0; ch < C; ch++) {\n                dy_dx[gd * C + ch] = results_grad[ch];\n            }\n        }\n    }\n}\n\n\ntemplate <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>\n__global__ void kernel_grid_backward(\n    const scalar_t * __restrict__ grad,\n    const float * __restrict__ inputs, \n    const scalar_t * __restrict__ grid, \n    const int * __restrict__ offsets, \n    scalar_t * __restrict__ grad_grid, \n    const uint32_t B, const uint32_t L, const float S, const uint32_t H,\n    const uint32_t gridtype,\n    const bool align_corners,\n    const uint32_t interp\n) {\n    const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;\n    if (b >= B) return;\n\n    const uint32_t level = blockIdx.y;\n    const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;\n\n    // locate\n    grad_grid += offsets[level] * C;\n    inputs += b * D;\n    grad += level * B * C + b * C + ch; // L, B, C\n\n    const uint32_t hashmap_size = offsets[level + 1] - offsets[level];\n    const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);\n\n    // check input range (should be in [0, 1])\n    #pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        if (inputs[d] < 0 || inputs[d] > 1) {\n            return; // grad is init as 0, so we simply return.\n        }\n    }\n\n    // calculate coordinate\n    float pos[D];\n    uint32_t pos_grid[D];\n\n    #pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        // align_corners\n        if (align_corners) {\n            pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]\n            pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]\n        } else {\n            pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]\n            pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]\n        }\n        pos[d] -= (float)pos_grid[d];\n        // smoothstep instead of linear\n        if (interp == 1) {\n            pos[d] = smoothstep(pos[d]);\n        }\n    }\n\n    scalar_t grad_cur[N_C] = {0}; // fetch to register\n    #pragma unroll\n    for (uint32_t c = 0; c < N_C; c++) {\n        grad_cur[c] = grad[c];\n    }\n\n    // interpolate\n    #pragma unroll\n    for (uint32_t idx = 0; idx < (1 << D); idx++) {\n        float w = 1;\n        uint32_t pos_grid_local[D];\n\n        #pragma unroll\n        for (uint32_t d = 0; d < D; d++) {\n            if ((idx & (1 << d)) == 0) {\n                w *= 1 - pos[d];\n                pos_grid_local[d] = pos_grid[d];\n            } else {\n                w *= pos[d];\n                pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);\n            }\n        }\n\n        uint32_t index = get_grid_index<D, C>(gridtype, ch, hashmap_size, resolution, pos_grid_local);\n\n        // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0\n        // TODO: use float which is better than __half, if N_C % 2 != 0\n        if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {\n            #pragma unroll\n            for (uint32_t c = 0; c < N_C; c += 2) {\n                // process two __half at once (by interpreting as a __half2)\n                __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};\n                atomicAdd((__half2*)&grad_grid[index + c], v);\n            }\n        // float, or __half when N_C % 2 != 0 (which means C == 1)\n        } else {\n            #pragma unroll\n            for (uint32_t c = 0; c < N_C; c++) {\n                atomicAdd(&grad_grid[index + c], w * grad_cur[c]);\n            }\n        }\n    }    \n}\n\n\ntemplate <typename scalar_t, uint32_t D, uint32_t C>\n__global__ void kernel_input_backward(\n    const scalar_t * __restrict__ grad,\n    const scalar_t * __restrict__ dy_dx,  \n    scalar_t * __restrict__ grad_inputs, \n    uint32_t B, uint32_t L\n) {\n    const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;\n    if (t >= B * D) return;\n\n    const uint32_t b = t / D;\n    const uint32_t d = t - b * D;\n\n    dy_dx += b * L * D * C;\n\n    scalar_t result = 0;\n    \n    # pragma unroll\n    for (int l = 0; l < L; l++) {\n        # pragma unroll\n        for (int ch = 0; ch < C; ch++) {\n            result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];\n        }\n    }\n\n    grad_inputs[t] = result;\n}\n\n\ntemplate <typename scalar_t, uint32_t D>\nvoid kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {\n    static constexpr uint32_t N_THREAD = 512;\n    const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), max_level, 1 };\n    switch (C) {\n        case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;\n        case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;\n        case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;\n        case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;\n        case 16: kernel_grid<scalar_t, D, 16><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;\n        case 32: kernel_grid<scalar_t, D, 32><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;\n        default: throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, 8, 16 or 32.\"};\n    }\n}\n\n// inputs: [B, D], float, in [0, 1]\n// embeddings: [sO, C], float\n// offsets: [L + 1], uint32_t\n// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)\n// H: base resolution\n// dy_dx: [B, L * D * C]\ntemplate <typename scalar_t>\nvoid grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {\n    switch (D) {\n        case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;\n        case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;\n        case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;\n        case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;\n        default: throw std::runtime_error{\"GridEncoding: D must be 2, 3, 4 or 5.\"};\n    }   \n}\n\ntemplate <typename scalar_t, uint32_t D>\nvoid kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {\n    static constexpr uint32_t N_THREAD = 256;\n    const uint32_t N_C = std::min(2u, C); // n_features_per_thread\n    const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), max_level, 1 };\n    switch (C) {\n        case 1: \n            kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); \n            if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);\n            break;\n        case 2: \n            kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);\n            if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);\n            break;\n        case 4: \n            kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);\n            if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);\n            break;\n        case 8: \n            kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);\n            if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);\n            break;\n        case 16: \n            kernel_grid_backward<scalar_t, D, 16, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);\n            if (dy_dx) kernel_input_backward<scalar_t, D, 16><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);\n            break;\n        case 32: \n            kernel_grid_backward<scalar_t, D, 32, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);\n            if (dy_dx) kernel_input_backward<scalar_t, D, 32><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);\n            break;\n        default: throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, 8, 16 or 32.\"};\n    }\n}\n\n\n// grad: [L, B, C], float\n// inputs: [B, D], float, in [0, 1]\n// embeddings: [sO, C], float\n// offsets: [L + 1], uint32_t\n// grad_embeddings: [sO, C]\n// H: base resolution\ntemplate <typename scalar_t>\nvoid grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {\n    switch (D) {\n        case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;\n        case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;\n        case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;\n        case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;\n        default: throw std::runtime_error{\"GridEncoding: D must be 2, 3, 4 or 5.\"};\n    }\n}\n\n\n\nvoid grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(embeddings);\n    CHECK_CUDA(offsets);\n    CHECK_CUDA(outputs);\n    // CHECK_CUDA(dy_dx);\n    \n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(embeddings);\n    CHECK_CONTIGUOUS(offsets);\n    CHECK_CONTIGUOUS(outputs);\n    // CHECK_CONTIGUOUS(dy_dx);\n\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(embeddings);\n    CHECK_IS_INT(offsets);\n    CHECK_IS_FLOATING(outputs);\n    // CHECK_IS_FLOATING(dy_dx);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    embeddings.scalar_type(), \"grid_encode_forward\", ([&] {\n        grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp);\n    }));\n}\n\nvoid grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {\n    CHECK_CUDA(grad);\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(embeddings);\n    CHECK_CUDA(offsets);\n    CHECK_CUDA(grad_embeddings);\n    // CHECK_CUDA(dy_dx);\n    // CHECK_CUDA(grad_inputs);\n    \n    CHECK_CONTIGUOUS(grad);\n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(embeddings);\n    CHECK_CONTIGUOUS(offsets);\n    CHECK_CONTIGUOUS(grad_embeddings);\n    // CHECK_CONTIGUOUS(dy_dx);\n    // CHECK_CONTIGUOUS(grad_inputs);\n\n    CHECK_IS_FLOATING(grad);\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(embeddings);\n    CHECK_IS_INT(offsets);\n    CHECK_IS_FLOATING(grad_embeddings);\n    // CHECK_IS_FLOATING(dy_dx);\n    // CHECK_IS_FLOATING(grad_inputs);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    grad.scalar_type(), \"grid_encode_backward\", ([&] {\n        grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp);\n    }));\n    \n}\n\n\ntemplate <typename scalar_t, uint32_t D, uint32_t C>\n__global__ void kernel_grad_tv(\n    const scalar_t * __restrict__ inputs,\n    const scalar_t * __restrict__ grid, \n    scalar_t * __restrict__ grad, \n    const int * __restrict__ offsets, \n    const float weight,\n    const uint32_t B, const uint32_t L, const float S, const uint32_t H,\n    const uint32_t gridtype,\n    const bool align_corners\n) {\n    const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;\n    \n    if (b >= B) return;\n\n    const uint32_t level = blockIdx.y;\n    \n    // locate\n    inputs += b * D;\n    grid += (uint32_t)offsets[level] * C;\n    grad += (uint32_t)offsets[level] * C;\n\n    // check input range (should be in [0, 1])\n    bool flag_oob = false;\n    #pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        if (inputs[d] < 0 || inputs[d] > 1) {\n            flag_oob = true;\n        }\n    }\n\n    // if input out of bound, do nothing\n    if (flag_oob) return;\n\n    const uint32_t hashmap_size = offsets[level + 1] - offsets[level];\n    const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);\n    \n    // calculate coordinate\n    float pos[D];\n    uint32_t pos_grid[D]; // [0, resolution]\n\n    #pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        // align_corners\n        if (align_corners) {\n            pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]\n            pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]\n        } else {\n            pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]\n            pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]\n        }\n    }\n\n    //printf(\"[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\\n\", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);\n\n    // total variation on pos_grid\n    scalar_t results[C] = {0}; // temp results in register\n    scalar_t idelta[C] = {0};\n\n    uint32_t index = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid);\n\n    scalar_t w = weight / (2 * D);\n\n    #pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n\n        uint32_t cur_d = pos_grid[d];\n        scalar_t grad_val;\n\n        // right side\n        if (cur_d < resolution) {\n            pos_grid[d] = cur_d + 1;\n            uint32_t index_right = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid);\n\n            #pragma unroll\n            for (uint32_t ch = 0; ch < C; ch++) {\n                grad_val = (grid[index + ch] - grid[index_right + ch]);\n                results[ch] += grad_val;\n                idelta[ch] += grad_val * grad_val;\n            }\n        }\n\n        // left side\n        if (cur_d > 0) {\n            pos_grid[d] = cur_d - 1;\n            uint32_t index_left = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid);\n\n            #pragma unroll\n            for (uint32_t ch = 0; ch < C; ch++) {\n                grad_val = (grid[index + ch] - grid[index_left + ch]);\n                results[ch] += grad_val;\n                idelta[ch] += grad_val * grad_val;\n            }\n        }\n\n        // reset\n        pos_grid[d] = cur_d;\n    }\n\n    // writing to global memory (slow)\n    #pragma unroll\n    for (uint32_t ch = 0; ch < C; ch++) {\n        // index may collide, so use atomic!\n        atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f));\n    }\n\n}\n\n\ntemplate <typename scalar_t, uint32_t D>\nvoid kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {\n    static constexpr uint32_t N_THREAD = 512;\n    const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };\n    switch (C) {\n        case 1: kernel_grad_tv<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;\n        case 2: kernel_grad_tv<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;\n        case 4: kernel_grad_tv<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;\n        case 8: kernel_grad_tv<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;\n        case 16: kernel_grad_tv<scalar_t, D, 16><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;\n        case 32: kernel_grad_tv<scalar_t, D, 32><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;\n        default: throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, 8, 16 or 32.\"};\n    }\n}\n\n\ntemplate <typename scalar_t>\nvoid grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {\n    switch (D) {\n        case 2: kernel_grad_tv_wrapper<scalar_t, 2>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;\n        case 3: kernel_grad_tv_wrapper<scalar_t, 3>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;\n        case 4: kernel_grad_tv_wrapper<scalar_t, 4>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;\n        case 5: kernel_grad_tv_wrapper<scalar_t, 5>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;\n        default: throw std::runtime_error{\"GridEncoding: D must be 2, 3, 4, or 5.\"};\n    }   \n}\n\n\nvoid grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    embeddings.scalar_type(), \"grad_total_variation\", ([&] {\n        grad_total_variation_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), embeddings.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), offsets.data_ptr<int>(), weight, B, D, C, L, S, H, gridtype, align_corners);\n    }));\n}\n\ntemplate <typename scalar_t>\n__global__ void kernel_grad_wd(\n    const scalar_t * __restrict__ grid, \n    scalar_t * __restrict__ grad, \n    const int * __restrict__ offsets, \n    const float weight,\n    const uint32_t B, const uint32_t L, const uint32_t C\n) {\n    const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;\n    \n    if (b >= B * C) return;\n\n    // locate\n    grid += b;\n    grad += b;\n\n    // decide in which level is this thread... \n    uint32_t level = 0;\n    const uint32_t n = b / C;\n    // binary search b in offsets\n    uint32_t l = 0, r = L;\n    while (l < r) {\n        uint32_t m = (l + r) / 2;\n        if (offsets[m] <= n) {\n            level = m;\n            l = m + 1;\n        } else {\n            r = m;\n        }\n    }\n\n    const uint32_t hashmap_size = offsets[level + 1] - offsets[level];\n    grad[0] += 2 * weight * grid[0] / hashmap_size;\n}\n\nvoid grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L) {\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    embeddings.scalar_type(), \"grad_weight_decay\", ([&] {\n        static constexpr uint32_t N_THREAD = 1024;\n        const dim3 blocks_hashgrid = { div_round_up(B * C, N_THREAD), 1, 1 };\n        kernel_grad_wd<scalar_t><<<blocks_hashgrid, N_THREAD>>>(embeddings.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), offsets.data_ptr<int>(), weight, B, L, C);\n    }));\n}"
  },
  {
    "path": "gridencoder/src/gridencoder.h",
    "content": "#ifndef _HASH_ENCODE_H\n#define _HASH_ENCODE_H\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// inputs: [B, D], float, in [0, 1]\n// embeddings: [sO, C], float\n// offsets: [L + 1], uint32_t\n// outputs: [B, L * C], float\n// H: base resolution\nvoid grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp);\nvoid grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp);\n\nvoid grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners);\nvoid grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L);\n\n#endif"
  },
  {
    "path": "guidance/clip_utils.py",
    "content": "import torch\nimport torch.nn as nn\n\nimport torchvision.transforms as T\nimport torchvision.transforms.functional as TF\n\nimport clip\n\nclass CLIP(nn.Module):\n    def __init__(self, device, **kwargs):\n        super().__init__()\n\n        self.device = device\n        self.clip_model, self.clip_preprocess = clip.load(\"ViT-B/16\", device=self.device, jit=False)\n\n        self.aug = T.Compose([\n            T.Resize((224, 224)),\n            T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n        ])\n    \n    def get_text_embeds(self, prompt, **kwargs):\n\n        text = clip.tokenize(prompt).to(self.device)\n        text_z = self.clip_model.encode_text(text)\n        text_z = text_z / text_z.norm(dim=-1, keepdim=True)\n\n        return text_z\n\n    def get_img_embeds(self, image, **kwargs):\n\n        image_z = self.clip_model.encode_image(self.aug(image))\n        image_z = image_z / image_z.norm(dim=-1, keepdim=True)\n\n        return image_z\n\n    \n    def train_step(self, clip_z, pred_rgb, grad_scale=10, **kwargs):\n        \"\"\"\n            Args:\n                grad_scale: scalar or 1-tensor of size [B], i.e. 1 grad_scale per batch item. \n        \"\"\"\n        # TODO: resize the image from NeRF-rendered resolution (e.g. 128x128) to what CLIP expects (512x512), to prevent Pytorch warning about `antialias=None`.\n        image_z = self.clip_model.encode_image(self.aug(pred_rgb))\n        image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features\n\n        loss = 0\n        if 'image' in clip_z:\n            loss -= ((image_z * clip_z['image']).sum(-1) * grad_scale).mean()\n        \n        if 'text' in clip_z:\n            loss -= ((image_z * clip_z['text']).sum(-1) * grad_scale).mean()\n\n        return loss\n\n"
  },
  {
    "path": "guidance/if_utils.py",
    "content": "from transformers import logging\nfrom diffusers import IFPipeline, DDPMScheduler\n\n# suppress partial model loading warning\nlogging.set_verbosity_error()\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torch.cuda.amp import custom_bwd, custom_fwd\nfrom .perpneg_utils import weighted_perpendicular_aggregator\n\n\ndef seed_everything(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    #torch.backends.cudnn.deterministic = True\n    #torch.backends.cudnn.benchmark = True\n\n\nclass IF(nn.Module):\n    def __init__(self, device, vram_O, t_range=[0.02, 0.98]):\n        super().__init__()\n\n        self.device = device\n\n        print(f'[INFO] loading DeepFloyd IF-I-XL...')\n\n        model_key = \"DeepFloyd/IF-I-XL-v1.0\"\n\n        is_torch2 = torch.__version__[0] == '2'\n\n        # Create model\n        pipe = IFPipeline.from_pretrained(model_key, variant=\"fp16\", torch_dtype=torch.float16)\n        if not is_torch2:\n            pipe.enable_xformers_memory_efficient_attention()\n\n        if vram_O:\n            pipe.unet.to(memory_format=torch.channels_last)\n            pipe.enable_attention_slicing(1)\n            pipe.enable_model_cpu_offload()\n        else:\n            pipe.to(device)\n\n        self.unet = pipe.unet\n        self.tokenizer = pipe.tokenizer\n        self.text_encoder = pipe.text_encoder\n        self.unet = pipe.unet\n        self.scheduler = pipe.scheduler\n\n        self.pipe = pipe\n\n        self.num_train_timesteps = self.scheduler.config.num_train_timesteps\n        self.min_step = int(self.num_train_timesteps * t_range[0])\n        self.max_step = int(self.num_train_timesteps * t_range[1])\n        self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience\n\n        print(f'[INFO] loaded DeepFloyd IF-I-XL!')\n\n    @torch.no_grad()\n    def get_text_embeds(self, prompt):\n        # prompt: [str]\n\n        # TODO: should I add the preprocessing at https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py#LL486C10-L486C28\n        prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)\n        inputs = self.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt')\n        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]\n\n        return embeddings\n\n\n    def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1):\n\n        # [0, 1] to [-1, 1] and make sure shape is [64, 64]\n        images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1\n\n        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level\n        t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device)\n\n        # predict the noise residual with unet, NO grad!\n        with torch.no_grad():\n            # add noise\n            noise = torch.randn_like(images)\n            images_noisy = self.scheduler.add_noise(images, noise, t)\n\n            # pred noise\n            model_input = torch.cat([images_noisy] * 2)\n            model_input = self.scheduler.scale_model_input(model_input, t)\n            tt = torch.cat([t] * 2)\n            noise_pred = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample\n            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)\n            noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # TODO: how to use the variance here?\n            # noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)\n\n        # w(t), sigma_t^2\n        w = (1 - self.alphas[t])\n        grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)\n        grad = torch.nan_to_num(grad)\n\n        targets = (images - grad).detach()\n        loss = 0.5 * F.mse_loss(images.float(), targets, reduction='sum') / images.shape[0]\n\n        return loss\n\n    def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, grad_scale=1):\n\n        B = pred_rgb.shape[0]\n        K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts        \n\n        # [0, 1] to [-1, 1] and make sure shape is [64, 64]\n        images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1\n\n        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level\n        t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device)\n\n        # predict the noise residual with unet, NO grad!\n        with torch.no_grad():\n            # add noise\n            noise = torch.randn_like(images)\n            images_noisy = self.scheduler.add_noise(images, noise, t)\n\n            # pred noise\n            model_input = torch.cat([images_noisy] * (1 + K))\n            model_input = self.scheduler.scale_model_input(model_input, t)\n            tt = torch.cat([t] * (1 + K))\n            unet_output = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample\n            noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:]\n            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)\n            noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)\n            # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n            delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1)\n            noise_pred = noise_pred_uncond + guidance_scale * weighted_perpendicular_aggregator(delta_noise_preds, weights, B)\n\n\n\n        # w(t), sigma_t^2\n        w = (1 - self.alphas[t])\n        grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)\n        grad = torch.nan_to_num(grad)\n\n        targets = (images - grad).detach()\n        loss = 0.5 * F.mse_loss(images.float(), targets, reduction='sum') / images.shape[0]\n\n        return loss\n\n    @torch.no_grad()\n    def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps=50, guidance_scale=7.5):\n\n        images = torch.randn((1, 3, height, width), device=text_embeddings.device, dtype=text_embeddings.dtype)\n        images = images * self.scheduler.init_noise_sigma\n\n        self.scheduler.set_timesteps(num_inference_steps)\n\n        for i, t in enumerate(self.scheduler.timesteps):\n            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.\n            model_input = torch.cat([images] * 2)\n            model_input = self.scheduler.scale_model_input(model_input, t)\n\n            # predict the noise residual\n            noise_pred = self.unet(model_input, t, encoder_hidden_states=text_embeddings).sample\n\n            # perform guidance\n            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)\n            noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n            noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            images = self.scheduler.step(noise_pred, t, images).prev_sample\n\n        images = (images + 1) / 2\n\n        return images\n\n\n    def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):\n\n        if isinstance(prompts, str):\n            prompts = [prompts]\n\n        if isinstance(negative_prompts, str):\n            negative_prompts = [negative_prompts]\n\n        # Prompts -> text embeds\n        pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]\n        neg_embeds = self.get_text_embeds(negative_prompts)\n        text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]\n\n        # Text embeds -> img\n        imgs = self.produce_imgs(text_embeds, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]\n\n        # Img to Numpy\n        imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()\n        imgs = (imgs * 255).round().astype('uint8')\n\n        return imgs\n\n\nif __name__ == '__main__':\n\n    import argparse\n    import matplotlib.pyplot as plt\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('prompt', type=str)\n    parser.add_argument('--negative', default='', type=str)\n    parser.add_argument('--vram_O', action='store_true', help=\"optimization for low VRAM usage\")\n    parser.add_argument('-H', type=int, default=64)\n    parser.add_argument('-W', type=int, default=64)\n    parser.add_argument('--seed', type=int, default=0)\n    parser.add_argument('--steps', type=int, default=50)\n    opt = parser.parse_args()\n\n    seed_everything(opt.seed)\n\n    device = torch.device('cuda')\n\n    sd = IF(device, opt.vram_O)\n\n    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)\n\n    # visualize image\n    plt.imshow(imgs[0])\n    plt.show()\n\n\n\n\n"
  },
  {
    "path": "guidance/perpneg_utils.py",
    "content": "import torch\n\n# Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm\ndef get_perpendicular_component(x, y):\n    assert x.shape == y.shape\n    return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y\n\n\ndef batch_get_perpendicular_component(x, y):\n    assert x.shape == y.shape\n    result = []\n    for i in range(x.shape[0]):\n        result.append(get_perpendicular_component(x[i], y[i]))\n    return torch.stack(result)\n\n\ndef weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_size):\n    \"\"\" \n    Notes: \n     - weights: an array with the weights for combining the noise predictions\n     - delta_noise_preds: [B x K, 4, 64, 64], K = max_prompts_per_dir\n    \"\"\"\n    delta_noise_preds = delta_noise_preds.split(batch_size, dim=0) # K x [B, 4, 64, 64]\n    weights = weights.split(batch_size, dim=0) # K x [B]\n    # print(f\"{weights[0].shape = } {weights = }\")\n\n    assert torch.all(weights[0] == 1.0)\n\n    main_positive = delta_noise_preds[0] # [B, 4, 64, 64]\n\n    accumulated_output = torch.zeros_like(main_positive)\n    for i, complementary_noise_pred in enumerate(delta_noise_preds[1:], start=1):\n        # print(f\"\\n{i = }, {weights[i] = }, {weights[i].shape = }\\n\")\n\n        idx_non_zero = torch.abs(weights[i]) > 1e-4\n        \n        # print(f\"{idx_non_zero.shape = }, {idx_non_zero = }\")\n        # print(f\"{weights[i][idx_non_zero].shape = }, {weights[i][idx_non_zero] = }\")\n        # print(f\"{complementary_noise_pred.shape = }, {complementary_noise_pred[idx_non_zero].shape = }\")\n        # print(f\"{main_positive.shape = }, {main_positive[idx_non_zero].shape = }\")\n        if sum(idx_non_zero) == 0:\n            continue\n        accumulated_output[idx_non_zero] += weights[i][idx_non_zero].reshape(-1, 1, 1, 1) * batch_get_perpendicular_component(complementary_noise_pred[idx_non_zero], main_positive[idx_non_zero])\n    \n    assert accumulated_output.shape == main_positive.shape, f\"{accumulated_output.shape = }, {main_positive.shape = }\"\n\n\n    return accumulated_output + main_positive"
  },
  {
    "path": "guidance/sd_utils.py",
    "content": "from transformers import CLIPTextModel, CLIPTokenizer, logging\nfrom diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom os.path import isfile\nfrom pathlib import Path\n\n# suppress partial model loading warning\nlogging.set_verbosity_error()\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision.utils import save_image\n\nfrom torch.cuda.amp import custom_bwd, custom_fwd\nfrom .perpneg_utils import weighted_perpendicular_aggregator\n\n\ndef seed_everything(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    #torch.backends.cudnn.deterministic = True\n    #torch.backends.cudnn.benchmark = True\n\nclass StableDiffusion(nn.Module):\n    def __init__(self, device, fp16, vram_O, sd_version='2.1', hf_key=None, t_range=[0.02, 0.98]):\n        super().__init__()\n\n        self.device = device\n        self.sd_version = sd_version\n\n        print(f'[INFO] loading stable diffusion...')\n\n        if hf_key is not None:\n            print(f'[INFO] using hugging face custom model key: {hf_key}')\n            model_key = hf_key\n        elif self.sd_version == '2.1':\n            model_key = \"stabilityai/stable-diffusion-2-1-base\"\n        elif self.sd_version == '2.0':\n            model_key = \"stabilityai/stable-diffusion-2-base\"\n        elif self.sd_version == '1.5':\n            model_key = \"runwayml/stable-diffusion-v1-5\"\n        else:\n            raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')\n\n        self.precision_t = torch.float16 if fp16 else torch.float32\n\n        # Create model\n        pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.precision_t)\n\n        if vram_O:\n            pipe.enable_sequential_cpu_offload()\n            pipe.enable_vae_slicing()\n            pipe.unet.to(memory_format=torch.channels_last)\n            pipe.enable_attention_slicing(1)\n            # pipe.enable_model_cpu_offload()\n        else:\n            pipe.to(device)\n\n        self.vae = pipe.vae\n        self.tokenizer = pipe.tokenizer\n        self.text_encoder = pipe.text_encoder\n        self.unet = pipe.unet\n\n        self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder=\"scheduler\", torch_dtype=self.precision_t)\n\n        del pipe\n\n        self.num_train_timesteps = self.scheduler.config.num_train_timesteps\n        self.min_step = int(self.num_train_timesteps * t_range[0])\n        self.max_step = int(self.num_train_timesteps * t_range[1])\n        self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience\n\n        print(f'[INFO] loaded stable diffusion!')\n\n    @torch.no_grad()\n    def get_text_embeds(self, prompt):\n        # prompt: [str]\n\n        inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')\n        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]\n\n        return embeddings\n\n\n    def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1,\n                   save_guidance_path:Path=None):\n\n        if as_latent:\n            latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1\n        else:\n            # interp to 512x512 to be fed into vae.\n            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)\n            # encode image into latents with vae, requires grad!\n            latents = self.encode_imgs(pred_rgb_512)\n\n        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level\n        t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)\n\n        # predict the noise residual with unet, NO grad!\n        with torch.no_grad():\n            # add noise\n            noise = torch.randn_like(latents)\n            latents_noisy = self.scheduler.add_noise(latents, noise, t)\n            # pred noise\n            latent_model_input = torch.cat([latents_noisy] * 2)\n            tt = torch.cat([t] * 2)\n            noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample\n\n            # perform guidance (high scale from paper!)\n            noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)\n\n        # import kiui\n        # latents_tmp = torch.randn((1, 4, 64, 64), device=self.device)\n        # latents_tmp = latents_tmp.detach()\n        # kiui.lo(latents_tmp)\n        # self.scheduler.set_timesteps(30)\n        # for i, t in enumerate(self.scheduler.timesteps):\n        #     latent_model_input = torch.cat([latents_tmp] * 3)\n        #     noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']\n        #     noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)\n        #     noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond)\n        #     latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample']\n        # imgs = self.decode_latents(latents_tmp)\n        # kiui.vis.plot_image(imgs)\n\n        # w(t), sigma_t^2\n        w = (1 - self.alphas[t])\n        grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)\n        grad = torch.nan_to_num(grad)\n\n        if save_guidance_path:\n            with torch.no_grad():\n                if as_latent:\n                    pred_rgb_512 = self.decode_latents(latents)\n\n                # visualize predicted denoised image\n                # The following block of code is equivalent to `predict_start_from_noise`...\n                # see zero123_utils.py's version for a simpler implementation.\n                alphas = self.scheduler.alphas.to(latents)\n                total_timesteps = self.max_step - self.min_step + 1\n                index = total_timesteps - t.to(latents.device) - 1 \n                b = len(noise_pred)\n                a_t = alphas[index].reshape(b,1,1,1).to(self.device)\n                sqrt_one_minus_alphas = torch.sqrt(1 - alphas)\n                sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)                \n                pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0\n                result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))\n\n                # visualize noisier image\n                result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))\n\n                # TODO: also denoise all-the-way\n\n                # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]\n                viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)\n                save_image(viz_images, save_guidance_path)\n\n        targets = (latents - grad).detach()\n        loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]\n\n        return loss\n    \n\n    def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1,\n                   save_guidance_path:Path=None):\n\n        B = pred_rgb.shape[0]\n        K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts       \n\n        if as_latent:\n            latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1\n        else:\n            # interp to 512x512 to be fed into vae.\n            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)\n            # encode image into latents with vae, requires grad!\n            latents = self.encode_imgs(pred_rgb_512)\n\n        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level\n        t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)\n\n        # predict the noise residual with unet, NO grad!\n        with torch.no_grad():\n            # add noise\n            noise = torch.randn_like(latents)\n            latents_noisy = self.scheduler.add_noise(latents, noise, t)\n            # pred noise\n            latent_model_input = torch.cat([latents_noisy] * (1 + K))\n            tt = torch.cat([t] * (1 + K))\n            unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample\n\n            # perform guidance (high scale from paper!)\n            noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:]\n            delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1)\n            noise_pred = noise_pred_uncond + guidance_scale * weighted_perpendicular_aggregator(delta_noise_preds, weights, B)            \n\n        # import kiui\n        # latents_tmp = torch.randn((1, 4, 64, 64), device=self.device)\n        # latents_tmp = latents_tmp.detach()\n        # kiui.lo(latents_tmp)\n        # self.scheduler.set_timesteps(30)\n        # for i, t in enumerate(self.scheduler.timesteps):\n        #     latent_model_input = torch.cat([latents_tmp] * 3)\n        #     noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']\n        #     noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)\n        #     noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond)\n        #     latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample']\n        # imgs = self.decode_latents(latents_tmp)\n        # kiui.vis.plot_image(imgs)\n\n        # w(t), sigma_t^2\n        w = (1 - self.alphas[t])\n        grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)\n        grad = torch.nan_to_num(grad)\n\n        if save_guidance_path:\n            with torch.no_grad():\n                if as_latent:\n                    pred_rgb_512 = self.decode_latents(latents)\n\n                # visualize predicted denoised image\n                # The following block of code is equivalent to `predict_start_from_noise`...\n                # see zero123_utils.py's version for a simpler implementation.\n                alphas = self.scheduler.alphas.to(latents)\n                total_timesteps = self.max_step - self.min_step + 1\n                index = total_timesteps - t.to(latents.device) - 1 \n                b = len(noise_pred)\n                a_t = alphas[index].reshape(b,1,1,1).to(self.device)\n                sqrt_one_minus_alphas = torch.sqrt(1 - alphas)\n                sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)                \n                pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0\n                result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))\n\n                # visualize noisier image\n                result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))\n\n\n\n                # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]\n                viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)\n                save_image(viz_images, save_guidance_path)\n\n        targets = (latents - grad).detach()\n        loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]\n\n        return loss\n\n\n    @torch.no_grad()\n    def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):\n\n        if latents is None:\n            latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)\n\n        self.scheduler.set_timesteps(num_inference_steps)\n\n        for i, t in enumerate(self.scheduler.timesteps):\n            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.\n            latent_model_input = torch.cat([latents] * 2)\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']\n\n            # perform guidance\n            noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']\n\n        return latents\n\n    def decode_latents(self, latents):\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n\n        imgs = self.vae.decode(latents).sample\n        imgs = (imgs / 2 + 0.5).clamp(0, 1)\n\n        return imgs\n\n    def encode_imgs(self, imgs):\n        # imgs: [B, 3, H, W]\n\n        imgs = 2 * imgs - 1\n\n        posterior = self.vae.encode(imgs).latent_dist\n        latents = posterior.sample() * self.vae.config.scaling_factor\n\n        return latents\n\n    def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):\n\n        if isinstance(prompts, str):\n            prompts = [prompts]\n\n        if isinstance(negative_prompts, str):\n            negative_prompts = [negative_prompts]\n\n        # Prompts -> text embeds\n        pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]\n        neg_embeds = self.get_text_embeds(negative_prompts)\n        text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]\n\n        # Text embeds -> img latents\n        latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]\n\n        # Img latents -> imgs\n        imgs = self.decode_latents(latents) # [1, 3, 512, 512]\n\n        # Img to Numpy\n        imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()\n        imgs = (imgs * 255).round().astype('uint8')\n\n        return imgs\n\n\nif __name__ == '__main__':\n\n    import argparse\n    import matplotlib.pyplot as plt\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('prompt', type=str)\n    parser.add_argument('--negative', default='', type=str)\n    parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help=\"stable diffusion version\")\n    parser.add_argument('--hf_key', type=str, default=None, help=\"hugging face Stable diffusion model key\")\n    parser.add_argument('--fp16', action='store_true', help=\"use float16 for training\")\n    parser.add_argument('--vram_O', action='store_true', help=\"optimization for low VRAM usage\")\n    parser.add_argument('-H', type=int, default=512)\n    parser.add_argument('-W', type=int, default=512)\n    parser.add_argument('--seed', type=int, default=0)\n    parser.add_argument('--steps', type=int, default=50)\n    opt = parser.parse_args()\n\n    seed_everything(opt.seed)\n\n    device = torch.device('cuda')\n\n    sd = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key)\n\n    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)\n\n    # visualize image\n    plt.imshow(imgs[0])\n    plt.show()\n\n\n\n\n"
  },
  {
    "path": "guidance/zero123_utils.py",
    "content": "import math\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom pathlib import Path\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.cuda.amp import custom_bwd, custom_fwd\nfrom torchvision.utils import save_image\n\nfrom diffusers import DDIMScheduler\n\nimport sys\nfrom os import path\nsys.path.append(path.dirname(path.dirname(path.abspath(__file__))))\n\nfrom ldm.util import instantiate_from_config\n\n\n# load model\ndef load_model_from_config(config, ckpt, device, vram_O=False, verbose=False):\n\n    pl_sd = torch.load(ckpt, map_location='cpu')\n\n    if 'global_step' in pl_sd and verbose:\n        print(f'[INFO] Global Step: {pl_sd[\"global_step\"]}')\n\n    sd = pl_sd['state_dict']\n\n    model = instantiate_from_config(config.model)\n    m, u = model.load_state_dict(sd, strict=False)\n\n    if len(m) > 0 and verbose:\n        print('[INFO] missing keys: \\n', m)\n    if len(u) > 0 and verbose:\n        print('[INFO] unexpected keys: \\n', u)\n\n    # manually load ema and delete it to save GPU memory\n    if model.use_ema:\n        if verbose:\n            print('[INFO] loading EMA...')\n        model.model_ema.copy_to(model.model)\n        del model.model_ema\n\n    if vram_O:\n        # we don't need decoder\n        del model.first_stage_model.decoder\n\n    torch.cuda.empty_cache()\n\n    model.eval().to(device)\n\n    return model\n\nclass Zero123(nn.Module):\n    def __init__(self, device, fp16,\n                 config='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml',\n                 ckpt='./pretrained/zero123/zero123-xl.ckpt', vram_O=False, t_range=[0.02, 0.98], opt=None):\n        super().__init__()\n\n        self.device = device\n        self.fp16 = fp16\n        self.vram_O = vram_O\n        self.t_range = t_range\n        self.opt = opt\n\n        self.config = OmegaConf.load(config)\n        # TODO: seems it cannot load into fp16...\n        self.model = load_model_from_config(self.config, ckpt, device=self.device, vram_O=vram_O)\n\n        # timesteps: use diffuser for convenience... hope it's alright.\n        self.num_train_timesteps = self.config.model.params.timesteps\n\n        self.scheduler = DDIMScheduler(\n            self.num_train_timesteps,\n            self.config.model.params.linear_start,\n            self.config.model.params.linear_end,\n            beta_schedule='scaled_linear',\n            clip_sample=False,\n            set_alpha_to_one=False,\n            steps_offset=1,\n        )\n\n        self.min_step = int(self.num_train_timesteps * t_range[0])\n        self.max_step = int(self.num_train_timesteps * t_range[1])\n        self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience\n\n    @torch.no_grad()\n    def get_img_embeds(self, x):\n        # x: image tensor [B, 3, 256, 256] in [0, 1]\n        x = x * 2 - 1\n        c = [self.model.get_learned_conditioning(xx.unsqueeze(0)) for xx in x] #.tile(n_samples, 1, 1)\n        v = [self.model.encode_first_stage(xx.unsqueeze(0)).mode() for xx in x]\n        return c, v\n\n    def angle_between(self, sph_v1, sph_v2):\n        def sph2cart(sv):\n            r, theta, phi = sv[0], sv[1], sv[2]\n            return torch.tensor([r * torch.sin(theta) * torch.cos(phi), r * torch.sin(theta) * torch.sin(phi), r * torch.cos(theta)])\n        def unit_vector(v):\n            return v / torch.linalg.norm(v)\n        def angle_between_2_sph(sv1, sv2):\n            v1, v2 = sph2cart(sv1), sph2cart(sv2)\n            v1_u, v2_u = unit_vector(v1), unit_vector(v2)\n            return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0))\n        angles = torch.empty(len(sph_v1), len(sph_v2))\n        for i, sv1 in enumerate(sph_v1):\n            for j, sv2 in enumerate(sph_v2):\n                angles[i][j] = angle_between_2_sph(sv1, sv2)\n        return angles\n\n    def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scale=3, as_latent=False, grad_scale=1, save_guidance_path:Path=None):\n        # pred_rgb: tensor [1, 3, H, W] in [0, 1]\n\n        # adjust SDS scale based on how far the novel view is from the known view\n        ref_radii = embeddings['ref_radii']\n        ref_polars = embeddings['ref_polars']\n        ref_azimuths = embeddings['ref_azimuths']\n        v1 = torch.stack([radius + ref_radii[0], torch.deg2rad(polar + ref_polars[0]), torch.deg2rad(azimuth + ref_azimuths[0])], dim=-1)   # polar,azimuth,radius are all actually delta wrt default\n        v2 = torch.stack([torch.tensor(ref_radii), torch.deg2rad(torch.tensor(ref_polars)), torch.deg2rad(torch.tensor(ref_azimuths))], dim=-1)\n        angles = torch.rad2deg(self.angle_between(v1, v2)).to(self.device)\n        if self.opt.zero123_grad_scale == 'angle':\n            grad_scale = (angles.min(dim=1)[0] / (180/len(ref_azimuths))) * grad_scale  # rethink 180/len(ref_azimuths) # claforte: try inverting grad_scale or just fixing it to 1.0\n        elif self.opt.zero123_grad_scale == 'None':\n            grad_scale = 1.0 # claforte: I think this might converge faster...?\n        else:\n            assert False, f'Unrecognized `zero123_grad_scale`: {self.opt.zero123_grad_scale}'\n        \n        if as_latent:\n            latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1\n        else:\n            pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)\n            latents = self.encode_imgs(pred_rgb_256)\n\n        t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)\n\n        # Set weights acc to closeness in angle\n        if len(ref_azimuths) > 1:\n            inv_angles = 1/angles\n            inv_angles[inv_angles > 100] = 100\n            inv_angles /= inv_angles.max(dim=-1, keepdim=True)[0]\n            inv_angles[inv_angles < 0.1] = 0\n        else:\n            inv_angles = torch.tensor([1.]).to(self.device)\n\n        # Multiply closeness-weight by user-given weights\n        zero123_ws = torch.tensor(embeddings['zero123_ws'])[None, :].to(self.device) * inv_angles\n        zero123_ws /= zero123_ws.max(dim=-1, keepdim=True)[0]\n        zero123_ws[zero123_ws < 0.1] = 0\n\n        with torch.no_grad():\n            noise = torch.randn_like(latents)\n            latents_noisy = self.scheduler.add_noise(latents, noise, t)\n\n            x_in = torch.cat([latents_noisy] * 2)\n            t_in = torch.cat([t] * 2)\n\n            noise_preds = []\n            # Loop through each ref image\n            for (zero123_w, c_crossattn, c_concat, ref_polar, ref_azimuth, ref_radius) in zip(zero123_ws.T,\n                                                                                              embeddings['c_crossattn'], embeddings['c_concat'],\n                                                                                              ref_polars, ref_azimuths, ref_radii):\n                # polar,azimuth,radius are all actually delta wrt default\n                p = polar + ref_polars[0] - ref_polar\n                a = azimuth + ref_azimuths[0] - ref_azimuth\n                a[a > 180] -= 360 # range in [-180, 180]\n                r = radius + ref_radii[0] - ref_radius\n                # T = torch.tensor([math.radians(p), math.sin(math.radians(-a)), math.cos(math.radians(a)), r])\n                # T = T[None, None, :].to(self.device)\n                T = torch.stack([torch.deg2rad(p), torch.sin(torch.deg2rad(-a)), torch.cos(torch.deg2rad(a)), r], dim=-1)[:, None, :]\n                cond = {}\n                clip_emb = self.model.cc_projection(torch.cat([c_crossattn.repeat(len(T), 1, 1), T], dim=-1))\n                cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]\n                cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).repeat(len(T), 1, 1, 1).to(self.device), c_concat.repeat(len(T), 1, 1, 1)], dim=0)]\n                noise_pred = self.model.apply_model(x_in, t_in, cond)\n                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n                noise_preds.append(zero123_w[:, None, None, None] * noise_pred)\n\n        noise_pred = torch.stack(noise_preds).sum(dim=0) / zero123_ws.sum(dim=-1)[:, None, None, None]\n\n        w = (1 - self.alphas[t])\n        grad = (grad_scale * w)[:, None, None, None] * (noise_pred - noise)\n        grad = torch.nan_to_num(grad)\n\n        # import kiui\n        # if not as_latent:\n        #     kiui.vis.plot_image(pred_rgb_256)\n        # kiui.vis.plot_matrix(latents)\n        # kiui.vis.plot_matrix(grad)\n\n        # import kiui\n        # latents = torch.randn((1, 4, 32, 32), device=self.device)\n        # kiui.lo(latents)\n        # self.scheduler.set_timesteps(30)\n        # with torch.no_grad():\n        #     for i, t in enumerate(self.scheduler.timesteps):\n        #         x_in = torch.cat([latents] * 2)\n        #         t_in = torch.cat([t.view(1)] * 2).to(self.device)\n\n        #         noise_pred = self.model.apply_model(x_in, t_in, cond)\n        #         noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)\n        #         noise_pred = noise_pred_uncond + 3 * (noise_pred_cond - noise_pred_uncond)\n\n        #         latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']\n        # imgs = self.decode_latents(latents)\n        # print(polar, azimuth, radius)\n        # kiui.vis.plot_image(pred_rgb_256, imgs)\n\n        if save_guidance_path:\n            with torch.no_grad():\n                if as_latent:\n                    pred_rgb_256 = self.decode_latents(latents) # claforte: test!\n\n                # visualize predicted denoised image\n                result_hopefully_less_noisy_image = self.decode_latents(self.model.predict_start_from_noise(latents_noisy, t, noise_pred))\n\n                # visualize noisier image\n                result_noisier_image = self.decode_latents(latents_noisy)\n\n                # TODO: also denoise all-the-way\n\n                # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]\n                viz_images = torch.cat([pred_rgb_256, result_noisier_image, result_hopefully_less_noisy_image],dim=-1)\n                save_image(viz_images, save_guidance_path)\n\n        targets = (latents - grad).detach()\n        loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]\n\n        return loss\n\n    # verification\n    @torch.no_grad()\n    def __call__(self,\n            image, # image tensor [1, 3, H, W] in [0, 1]\n            polar=0, azimuth=0, radius=0, # new view params\n            scale=3, ddim_steps=50, ddim_eta=1, h=256, w=256, # diffusion params\n            c_crossattn=None, c_concat=None, post_process=True,\n        ):\n\n        if c_crossattn is None:\n            embeddings = self.get_img_embeds(image)\n\n        T = torch.tensor([math.radians(polar), math.sin(math.radians(azimuth)), math.cos(math.radians(azimuth)), radius])\n        T = T[None, None, :].to(self.device)\n\n        cond = {}\n        clip_emb = self.model.cc_projection(torch.cat([embeddings['c_crossattn'] if c_crossattn is None else c_crossattn, T], dim=-1))\n        cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]\n        cond['c_concat'] = [torch.cat([torch.zeros_like(embeddings['c_concat']).to(self.device), embeddings['c_concat']], dim=0)] if c_concat is None else [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)]\n\n        # produce latents loop\n        latents = torch.randn((1, 4, h // 8, w // 8), device=self.device)\n        self.scheduler.set_timesteps(ddim_steps)\n\n        for i, t in enumerate(self.scheduler.timesteps):\n            x_in = torch.cat([latents] * 2)\n            t_in = torch.cat([t.view(1)] * 2).to(self.device)\n\n            noise_pred = self.model.apply_model(x_in, t_in, cond)\n            noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + scale * (noise_pred_cond - noise_pred_uncond)\n\n            latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)['prev_sample']\n\n        imgs = self.decode_latents(latents)\n        imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs\n\n        return imgs\n\n    def decode_latents(self, latents):\n        # zs: [B, 4, 32, 32] Latent space image\n        # with self.model.ema_scope():\n        imgs = self.model.decode_first_stage(latents)\n        imgs = (imgs / 2 + 0.5).clamp(0, 1)\n\n        return imgs # [B, 3, 256, 256] RGB space image\n\n    def encode_imgs(self, imgs):\n        # imgs: [B, 3, 256, 256] RGB space image\n        # with self.model.ema_scope():\n        imgs = imgs * 2 - 1\n        latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0)\n        return latents # [B, 4, 32, 32] Latent space image\n\n\nif __name__ == '__main__':\n    import cv2\n    import argparse\n    import numpy as np\n    import matplotlib.pyplot as plt\n\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('input', type=str)\n    parser.add_argument('--fp16', action='store_true', help=\"use float16 for training\") # no use now, can only run in fp32\n\n    parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]')\n    parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]')\n    parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]')\n\n    opt = parser.parse_args()\n\n    device = torch.device('cuda')\n\n    print(f'[INFO] loading image from {opt.input} ...')\n    image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED)\n    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n    image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)\n    image = image.astype(np.float32) / 255.0\n    image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)\n\n    print(f'[INFO] loading model ...')\n    zero123 = Zero123(device, opt.fp16, opt=opt)\n\n    print(f'[INFO] running model ...')\n    outputs = zero123(image, polar=opt.polar, azimuth=opt.azimuth, radius=opt.radius)\n    plt.imshow(outputs[0])\n    plt.show()"
  },
  {
    "path": "ldm/extras.py",
    "content": "from pathlib import Path\nfrom omegaconf import OmegaConf\nimport torch\nfrom ldm.util import instantiate_from_config\nimport logging\nfrom contextlib import contextmanager\n\nfrom contextlib import contextmanager\nimport logging\n\n@contextmanager\ndef all_logging_disabled(highest_level=logging.CRITICAL):\n    \"\"\"\n    A context manager that will prevent any logging messages\n    triggered during the body from being processed.\n\n    :param highest_level: the maximum logging level in use.\n      This would only need to be changed if a custom level greater than CRITICAL\n      is defined.\n\n    https://gist.github.com/simon-weber/7853144\n    \"\"\"\n    # two kind-of hacks here:\n    #    * can't get the highest logging level in effect => delegate to the user\n    #    * can't get the current module-level override => use an undocumented\n    #       (but non-private!) interface\n\n    previous_level = logging.root.manager.disable\n\n    logging.disable(highest_level)\n\n    try:\n        yield\n    finally:\n        logging.disable(previous_level)\n\ndef load_training_dir(train_dir, device, epoch=\"last\"):\n    \"\"\"Load a checkpoint and config from training directory\"\"\"\n    train_dir = Path(train_dir)\n    ckpt = list(train_dir.rglob(f\"*{epoch}.ckpt\"))\n    assert len(ckpt) == 1, f\"found {len(ckpt)} matching ckpt files\"\n    config = list(train_dir.rglob(f\"*-project.yaml\"))\n    assert len(ckpt) > 0, f\"didn't find any config in {train_dir}\"\n    if len(config) > 1:\n        print(f\"found {len(config)} matching config files\")\n        config = sorted(config)[-1]\n        print(f\"selecting {config}\")\n    else:\n        config = config[0]\n\n\n    config = OmegaConf.load(config)\n    return load_model_from_config(config, ckpt[0], device)\n\ndef load_model_from_config(config, ckpt, device=\"cpu\", verbose=False):\n    \"\"\"Loads a model from config and a ckpt\n    if config is a path will use omegaconf to load\n    \"\"\"\n    if isinstance(config, (str, Path)):\n        config = OmegaConf.load(config)\n\n    with all_logging_disabled():\n        print(f\"Loading model from {ckpt}\")\n        pl_sd = torch.load(ckpt, map_location=\"cpu\")\n        global_step = pl_sd[\"global_step\"]\n        sd = pl_sd[\"state_dict\"]\n        model = instantiate_from_config(config.model)\n        m, u = model.load_state_dict(sd, strict=False)\n        if len(m) > 0 and verbose:\n            print(\"missing keys:\")\n            print(m)\n        if len(u) > 0 and verbose:\n            print(\"unexpected keys:\")\n        model.to(device)\n        model.eval()\n        model.cond_stage_model.device = device\n        return model"
  },
  {
    "path": "ldm/guidance.py",
    "content": "from typing import List, Tuple\nfrom scipy import interpolate\nimport numpy as np\nimport torch\nimport matplotlib.pyplot as plt\nfrom IPython.display import clear_output\nimport abc\n\n\nclass GuideModel(torch.nn.Module, abc.ABC):\n    def __init__(self) -> None:\n        super().__init__()\n\n    @abc.abstractmethod\n    def preprocess(self, x_img):\n        pass\n\n    @abc.abstractmethod\n    def compute_loss(self, inp):\n        pass\n\n\nclass Guider(torch.nn.Module):\n    def __init__(self, sampler, guide_model, scale=1.0, verbose=False):\n        \"\"\"Apply classifier guidance\n\n        Specify a guidance scale as either a scalar\n        Or a schedule as a list of tuples t = 0->1 and scale, e.g.\n        [(0, 10), (0.5, 20), (1, 50)]\n        \"\"\"\n        super().__init__()\n        self.sampler = sampler\n        self.index = 0\n        self.show = verbose\n        self.guide_model = guide_model\n        self.history = []\n\n        if isinstance(scale, (Tuple, List)):\n            times = np.array([x[0] for x in scale])\n            values = np.array([x[1] for x in scale])\n            self.scale_schedule = {\"times\": times, \"values\": values}\n        else:\n            self.scale_schedule = float(scale)\n\n        self.ddim_timesteps = sampler.ddim_timesteps\n        self.ddpm_num_timesteps = sampler.ddpm_num_timesteps\n\n\n    def get_scales(self):\n        if isinstance(self.scale_schedule, float):\n            return len(self.ddim_timesteps)*[self.scale_schedule]\n\n        interpolater = interpolate.interp1d(self.scale_schedule[\"times\"], self.scale_schedule[\"values\"])\n        fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps\n        return interpolater(fractional_steps)\n\n    def modify_score(self, model, e_t, x, t, c):\n\n        # TODO look up index by t\n        scale = self.get_scales()[self.index]\n\n        if (scale == 0):\n            return e_t\n\n        sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)\n        with torch.enable_grad():\n            x_in = x.detach().requires_grad_(True)\n            pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)\n            x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)\n\n            inp = self.guide_model.preprocess(x_img)\n            loss = self.guide_model.compute_loss(inp)\n            grads = torch.autograd.grad(loss.sum(), x_in)[0]\n            correction = grads * scale\n\n            if self.show:\n                clear_output(wait=True)\n                print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())\n                self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])\n                plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)\n                plt.axis('off')\n                plt.show()\n                plt.imshow(correction[0][0].detach().cpu())\n                plt.axis('off')\n                plt.show()\n\n\n        e_t_mod = e_t - sqrt_1ma*correction\n        if self.show:\n            fig, axs = plt.subplots(1, 3)\n            axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)\n            axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)\n            axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)\n            plt.show()\n        self.index += 1\n        return e_t_mod"
  },
  {
    "path": "ldm/lr_scheduler.py",
    "content": "import numpy as np\n\n\nclass LambdaWarmUpCosineScheduler:\n    \"\"\"\n    note: use with a base_lr of 1.0\n    \"\"\"\n    def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):\n        self.lr_warm_up_steps = warm_up_steps\n        self.lr_start = lr_start\n        self.lr_min = lr_min\n        self.lr_max = lr_max\n        self.lr_max_decay_steps = max_decay_steps\n        self.last_lr = 0.\n        self.verbosity_interval = verbosity_interval\n\n    def schedule(self, n, **kwargs):\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0: print(f\"current step: {n}, recent lr-multiplier: {self.last_lr}\")\n        if n < self.lr_warm_up_steps:\n            lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start\n            self.last_lr = lr\n            return lr\n        else:\n            t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)\n            t = min(t, 1.0)\n            lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (\n                    1 + np.cos(t * np.pi))\n            self.last_lr = lr\n            return lr\n\n    def __call__(self, n, **kwargs):\n        return self.schedule(n,**kwargs)\n\n\nclass LambdaWarmUpCosineScheduler2:\n    \"\"\"\n    supports repeated iterations, configurable via lists\n    note: use with a base_lr of 1.0.\n    \"\"\"\n    def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):\n        assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)\n        self.lr_warm_up_steps = warm_up_steps\n        self.f_start = f_start\n        self.f_min = f_min\n        self.f_max = f_max\n        self.cycle_lengths = cycle_lengths\n        self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))\n        self.last_f = 0.\n        self.verbosity_interval = verbosity_interval\n\n    def find_in_interval(self, n):\n        interval = 0\n        for cl in self.cum_cycles[1:]:\n            if n <= cl:\n                return interval\n            interval += 1\n\n    def schedule(self, n, **kwargs):\n        cycle = self.find_in_interval(n)\n        n = n - self.cum_cycles[cycle]\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0: print(f\"current step: {n}, recent lr-multiplier: {self.last_f}, \"\n                                                       f\"current cycle {cycle}\")\n        if n < self.lr_warm_up_steps[cycle]:\n            f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]\n            self.last_f = f\n            return f\n        else:\n            t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])\n            t = min(t, 1.0)\n            f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (\n                    1 + np.cos(t * np.pi))\n            self.last_f = f\n            return f\n\n    def __call__(self, n, **kwargs):\n        return self.schedule(n, **kwargs)\n\n\nclass LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):\n\n    def schedule(self, n, **kwargs):\n        cycle = self.find_in_interval(n)\n        n = n - self.cum_cycles[cycle]\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0: print(f\"current step: {n}, recent lr-multiplier: {self.last_f}, \"\n                                                       f\"current cycle {cycle}\")\n\n        if n < self.lr_warm_up_steps[cycle]:\n            f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]\n            self.last_f = f\n            return f\n        else:\n            f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])\n            self.last_f = f\n            return f\n\n"
  },
  {
    "path": "ldm/models/autoencoder.py",
    "content": "import torch\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom contextlib import contextmanager\n\nfrom taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer\n\nfrom ldm.modules.diffusionmodules.model import Encoder, Decoder\nfrom ldm.modules.distributions.distributions import DiagonalGaussianDistribution\n\nfrom ldm.util import instantiate_from_config\n\n\nclass VQModel(pl.LightningModule):\n    def __init__(self,\n                 ddconfig,\n                 lossconfig,\n                 n_embed,\n                 embed_dim,\n                 ckpt_path=None,\n                 ignore_keys=[],\n                 image_key=\"image\",\n                 colorize_nlabels=None,\n                 monitor=None,\n                 batch_resize_range=None,\n                 scheduler_config=None,\n                 lr_g_factor=1.0,\n                 remap=None,\n                 sane_index_shape=False, # tell vector quantizer to return indices as bhw\n                 use_ema=False\n                 ):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.n_embed = n_embed\n        self.image_key = image_key\n        self.encoder = Encoder(**ddconfig)\n        self.decoder = Decoder(**ddconfig)\n        self.loss = instantiate_from_config(lossconfig)\n        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,\n                                        remap=remap,\n                                        sane_index_shape=sane_index_shape)\n        self.quant_conv = torch.nn.Conv2d(ddconfig[\"z_channels\"], embed_dim, 1)\n        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig[\"z_channels\"], 1)\n        if colorize_nlabels is not None:\n            assert type(colorize_nlabels)==int\n            self.register_buffer(\"colorize\", torch.randn(3, colorize_nlabels, 1, 1))\n        if monitor is not None:\n            self.monitor = monitor\n        self.batch_resize_range = batch_resize_range\n        if self.batch_resize_range is not None:\n            print(f\"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.\")\n\n        self.use_ema = use_ema\n        if self.use_ema:\n            self.model_ema = LitEma(self)\n            print(f\"Keeping EMAs of {len(list(self.model_ema.buffers()))}.\")\n\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)\n        self.scheduler_config = scheduler_config\n        self.lr_g_factor = lr_g_factor\n\n    @contextmanager\n    def ema_scope(self, context=None):\n        if self.use_ema:\n            self.model_ema.store(self.parameters())\n            self.model_ema.copy_to(self)\n            if context is not None:\n                print(f\"{context}: Switched to EMA weights\")\n        try:\n            yield None\n        finally:\n            if self.use_ema:\n                self.model_ema.restore(self.parameters())\n                if context is not None:\n                    print(f\"{context}: Restored training weights\")\n\n    def init_from_ckpt(self, path, ignore_keys=list()):\n        sd = torch.load(path, map_location=\"cpu\")[\"state_dict\"]\n        keys = list(sd.keys())\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del sd[k]\n        missing, unexpected = self.load_state_dict(sd, strict=False)\n        print(f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\")\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n            print(f\"Unexpected Keys: {unexpected}\")\n\n    def on_train_batch_end(self, *args, **kwargs):\n        if self.use_ema:\n            self.model_ema(self)\n\n    def encode(self, x):\n        h = self.encoder(x)\n        h = self.quant_conv(h)\n        quant, emb_loss, info = self.quantize(h)\n        return quant, emb_loss, info\n\n    def encode_to_prequant(self, x):\n        h = self.encoder(x)\n        h = self.quant_conv(h)\n        return h\n\n    def decode(self, quant):\n        quant = self.post_quant_conv(quant)\n        dec = self.decoder(quant)\n        return dec\n\n    def decode_code(self, code_b):\n        quant_b = self.quantize.embed_code(code_b)\n        dec = self.decode(quant_b)\n        return dec\n\n    def forward(self, input, return_pred_indices=False):\n        quant, diff, (_,_,ind) = self.encode(input)\n        dec = self.decode(quant)\n        if return_pred_indices:\n            return dec, diff, ind\n        return dec, diff\n\n    def get_input(self, batch, k):\n        x = batch[k]\n        if len(x.shape) == 3:\n            x = x[..., None]\n        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()\n        if self.batch_resize_range is not None:\n            lower_size = self.batch_resize_range[0]\n            upper_size = self.batch_resize_range[1]\n            if self.global_step <= 4:\n                # do the first few batches with max size to avoid later oom\n                new_resize = upper_size\n            else:\n                new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))\n            if new_resize != x.shape[2]:\n                x = F.interpolate(x, size=new_resize, mode=\"bicubic\")\n            x = x.detach()\n        return x\n\n    def training_step(self, batch, batch_idx, optimizer_idx):\n        # https://github.com/pytorch/pytorch/issues/37142\n        # try not to fool the heuristics\n        x = self.get_input(batch, self.image_key)\n        xrec, qloss, ind = self(x, return_pred_indices=True)\n\n        if optimizer_idx == 0:\n            # autoencode\n            aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,\n                                            last_layer=self.get_last_layer(), split=\"train\",\n                                            predicted_indices=ind)\n\n            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)\n            return aeloss\n\n        if optimizer_idx == 1:\n            # discriminator\n            discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,\n                                            last_layer=self.get_last_layer(), split=\"train\")\n            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)\n            return discloss\n\n    def validation_step(self, batch, batch_idx):\n        log_dict = self._validation_step(batch, batch_idx)\n        with self.ema_scope():\n            log_dict_ema = self._validation_step(batch, batch_idx, suffix=\"_ema\")\n        return log_dict\n\n    def _validation_step(self, batch, batch_idx, suffix=\"\"):\n        x = self.get_input(batch, self.image_key)\n        xrec, qloss, ind = self(x, return_pred_indices=True)\n        aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,\n                                        self.global_step,\n                                        last_layer=self.get_last_layer(),\n                                        split=\"val\"+suffix,\n                                        predicted_indices=ind\n                                        )\n\n        discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,\n                                            self.global_step,\n                                            last_layer=self.get_last_layer(),\n                                            split=\"val\"+suffix,\n                                            predicted_indices=ind\n                                            )\n        rec_loss = log_dict_ae[f\"val{suffix}/rec_loss\"]\n        self.log(f\"val{suffix}/rec_loss\", rec_loss,\n                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)\n        self.log(f\"val{suffix}/aeloss\", aeloss,\n                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)\n        if version.parse(pl.__version__) >= version.parse('1.4.0'):\n            del log_dict_ae[f\"val{suffix}/rec_loss\"]\n        self.log_dict(log_dict_ae)\n        self.log_dict(log_dict_disc)\n        return self.log_dict\n\n    def configure_optimizers(self):\n        lr_d = self.learning_rate\n        lr_g = self.lr_g_factor*self.learning_rate\n        print(\"lr_d\", lr_d)\n        print(\"lr_g\", lr_g)\n        opt_ae = torch.optim.Adam(list(self.encoder.parameters())+\n                                  list(self.decoder.parameters())+\n                                  list(self.quantize.parameters())+\n                                  list(self.quant_conv.parameters())+\n                                  list(self.post_quant_conv.parameters()),\n                                  lr=lr_g, betas=(0.5, 0.9))\n        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),\n                                    lr=lr_d, betas=(0.5, 0.9))\n\n        if self.scheduler_config is not None:\n            scheduler = instantiate_from_config(self.scheduler_config)\n\n            print(\"Setting up LambdaLR scheduler...\")\n            scheduler = [\n                {\n                    'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),\n                    'interval': 'step',\n                    'frequency': 1\n                },\n                {\n                    'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),\n                    'interval': 'step',\n                    'frequency': 1\n                },\n            ]\n            return [opt_ae, opt_disc], scheduler\n        return [opt_ae, opt_disc], []\n\n    def get_last_layer(self):\n        return self.decoder.conv_out.weight\n\n    def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):\n        log = dict()\n        x = self.get_input(batch, self.image_key)\n        x = x.to(self.device)\n        if only_inputs:\n            log[\"inputs\"] = x\n            return log\n        xrec, _ = self(x)\n        if x.shape[1] > 3:\n            # colorize with random projection\n            assert xrec.shape[1] > 3\n            x = self.to_rgb(x)\n            xrec = self.to_rgb(xrec)\n        log[\"inputs\"] = x\n        log[\"reconstructions\"] = xrec\n        if plot_ema:\n            with self.ema_scope():\n                xrec_ema, _ = self(x)\n                if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)\n                log[\"reconstructions_ema\"] = xrec_ema\n        return log\n\n    def to_rgb(self, x):\n        assert self.image_key == \"segmentation\"\n        if not hasattr(self, \"colorize\"):\n            self.register_buffer(\"colorize\", torch.randn(3, x.shape[1], 1, 1).to(x))\n        x = F.conv2d(x, weight=self.colorize)\n        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.\n        return x\n\n\nclass VQModelInterface(VQModel):\n    def __init__(self, embed_dim, *args, **kwargs):\n        super().__init__(embed_dim=embed_dim, *args, **kwargs)\n        self.embed_dim = embed_dim\n\n    def encode(self, x):\n        h = self.encoder(x)\n        h = self.quant_conv(h)\n        return h\n\n    def decode(self, h, force_not_quantize=False):\n        # also go through quantization layer\n        if not force_not_quantize:\n            quant, emb_loss, info = self.quantize(h)\n        else:\n            quant = h\n        quant = self.post_quant_conv(quant)\n        dec = self.decoder(quant)\n        return dec\n\n\nclass AutoencoderKL(pl.LightningModule):\n    def __init__(self,\n                 ddconfig,\n                 lossconfig,\n                 embed_dim,\n                 ckpt_path=None,\n                 ignore_keys=[],\n                 image_key=\"image\",\n                 colorize_nlabels=None,\n                 monitor=None,\n                 ):\n        super().__init__()\n        self.image_key = image_key\n        self.encoder = Encoder(**ddconfig)\n        self.decoder = Decoder(**ddconfig)\n        self.loss = instantiate_from_config(lossconfig)\n        assert ddconfig[\"double_z\"]\n        self.quant_conv = torch.nn.Conv2d(2*ddconfig[\"z_channels\"], 2*embed_dim, 1)\n        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig[\"z_channels\"], 1)\n        self.embed_dim = embed_dim\n        if colorize_nlabels is not None:\n            assert type(colorize_nlabels)==int\n            self.register_buffer(\"colorize\", torch.randn(3, colorize_nlabels, 1, 1))\n        if monitor is not None:\n            self.monitor = monitor\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)\n\n    def init_from_ckpt(self, path, ignore_keys=list()):\n        sd = torch.load(path, map_location=\"cpu\")[\"state_dict\"]\n        keys = list(sd.keys())\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del sd[k]\n        self.load_state_dict(sd, strict=False)\n        print(f\"Restored from {path}\")\n\n    def encode(self, x):\n        h = self.encoder(x)\n        moments = self.quant_conv(h)\n        posterior = DiagonalGaussianDistribution(moments)\n        return posterior\n\n    def decode(self, z):\n        z = self.post_quant_conv(z)\n        dec = self.decoder(z)\n        return dec\n\n    def forward(self, input, sample_posterior=True):\n        posterior = self.encode(input)\n        if sample_posterior:\n            z = posterior.sample()\n        else:\n            z = posterior.mode()\n        dec = self.decode(z)\n        return dec, posterior\n\n    def get_input(self, batch, k):\n        x = batch[k]\n        if len(x.shape) == 3:\n            x = x[..., None]\n        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()\n        return x\n\n    def training_step(self, batch, batch_idx, optimizer_idx):\n        inputs = self.get_input(batch, self.image_key)\n        reconstructions, posterior = self(inputs)\n\n        if optimizer_idx == 0:\n            # train encoder+decoder+logvar\n            aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,\n                                            last_layer=self.get_last_layer(), split=\"train\")\n            self.log(\"aeloss\", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)\n            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)\n            return aeloss\n\n        if optimizer_idx == 1:\n            # train the discriminator\n            discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,\n                                                last_layer=self.get_last_layer(), split=\"train\")\n\n            self.log(\"discloss\", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)\n            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)\n            return discloss\n\n    def validation_step(self, batch, batch_idx):\n        inputs = self.get_input(batch, self.image_key)\n        reconstructions, posterior = self(inputs)\n        aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,\n                                        last_layer=self.get_last_layer(), split=\"val\")\n\n        discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,\n                                            last_layer=self.get_last_layer(), split=\"val\")\n\n        self.log(\"val/rec_loss\", log_dict_ae[\"val/rec_loss\"])\n        self.log_dict(log_dict_ae)\n        self.log_dict(log_dict_disc)\n        return self.log_dict\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n        opt_ae = torch.optim.Adam(list(self.encoder.parameters())+\n                                  list(self.decoder.parameters())+\n                                  list(self.quant_conv.parameters())+\n                                  list(self.post_quant_conv.parameters()),\n                                  lr=lr, betas=(0.5, 0.9))\n        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),\n                                    lr=lr, betas=(0.5, 0.9))\n        return [opt_ae, opt_disc], []\n\n    def get_last_layer(self):\n        return self.decoder.conv_out.weight\n\n    @torch.no_grad()\n    def log_images(self, batch, only_inputs=False, **kwargs):\n        log = dict()\n        x = self.get_input(batch, self.image_key)\n        x = x.to(self.device)\n        if not only_inputs:\n            xrec, posterior = self(x)\n            if x.shape[1] > 3:\n                # colorize with random projection\n                assert xrec.shape[1] > 3\n                x = self.to_rgb(x)\n                xrec = self.to_rgb(xrec)\n            log[\"samples\"] = self.decode(torch.randn_like(posterior.sample()))\n            log[\"reconstructions\"] = xrec\n        log[\"inputs\"] = x\n        return log\n\n    def to_rgb(self, x):\n        assert self.image_key == \"segmentation\"\n        if not hasattr(self, \"colorize\"):\n            self.register_buffer(\"colorize\", torch.randn(3, x.shape[1], 1, 1).to(x))\n        x = F.conv2d(x, weight=self.colorize)\n        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.\n        return x\n\n\nclass IdentityFirstStage(torch.nn.Module):\n    def __init__(self, *args, vq_interface=False, **kwargs):\n        self.vq_interface = vq_interface  # TODO: Should be true by default but check to not break older stuff\n        super().__init__()\n\n    def encode(self, x, *args, **kwargs):\n        return x\n\n    def decode(self, x, *args, **kwargs):\n        return x\n\n    def quantize(self, x, *args, **kwargs):\n        if self.vq_interface:\n            return x, None, [None, None, None]\n        return x\n\n    def forward(self, x, *args, **kwargs):\n        return x\n"
  },
  {
    "path": "ldm/models/diffusion/__init__.py",
    "content": ""
  },
  {
    "path": "ldm/models/diffusion/classifier.py",
    "content": "import os\nimport torch\nimport pytorch_lightning as pl\nfrom omegaconf import OmegaConf\nfrom torch.nn import functional as F\nfrom torch.optim import AdamW\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom copy import deepcopy\nfrom einops import rearrange\nfrom glob import glob\nfrom natsort import natsorted\n\nfrom ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel\nfrom ldm.util import log_txt_as_img, default, ismap, instantiate_from_config\n\n__models__ = {\n    'class_label': EncoderUNetModel,\n    'segmentation': UNetModel\n}\n\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\nclass NoisyLatentImageClassifier(pl.LightningModule):\n\n    def __init__(self,\n                 diffusion_path,\n                 num_classes,\n                 ckpt_path=None,\n                 pool='attention',\n                 label_key=None,\n                 diffusion_ckpt_path=None,\n                 scheduler_config=None,\n                 weight_decay=1.e-2,\n                 log_steps=10,\n                 monitor='val/loss',\n                 *args,\n                 **kwargs):\n        super().__init__(*args, **kwargs)\n        self.num_classes = num_classes\n        # get latest config of diffusion model\n        diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]\n        self.diffusion_config = OmegaConf.load(diffusion_config).model\n        self.diffusion_config.params.ckpt_path = diffusion_ckpt_path\n        self.load_diffusion()\n\n        self.monitor = monitor\n        self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1\n        self.log_time_interval = self.diffusion_model.num_timesteps // log_steps\n        self.log_steps = log_steps\n\n        self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \\\n            else self.diffusion_model.cond_stage_key\n\n        assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'\n\n        if self.label_key not in __models__:\n            raise NotImplementedError()\n\n        self.load_classifier(ckpt_path, pool)\n\n        self.scheduler_config = scheduler_config\n        self.use_scheduler = self.scheduler_config is not None\n        self.weight_decay = weight_decay\n\n    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):\n        sd = torch.load(path, map_location=\"cpu\")\n        if \"state_dict\" in list(sd.keys()):\n            sd = sd[\"state_dict\"]\n        keys = list(sd.keys())\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del sd[k]\n        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(\n            sd, strict=False)\n        print(f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\")\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n        if len(unexpected) > 0:\n            print(f\"Unexpected Keys: {unexpected}\")\n\n    def load_diffusion(self):\n        model = instantiate_from_config(self.diffusion_config)\n        self.diffusion_model = model.eval()\n        self.diffusion_model.train = disabled_train\n        for param in self.diffusion_model.parameters():\n            param.requires_grad = False\n\n    def load_classifier(self, ckpt_path, pool):\n        model_config = deepcopy(self.diffusion_config.params.unet_config.params)\n        model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels\n        model_config.out_channels = self.num_classes\n        if self.label_key == 'class_label':\n            model_config.pool = pool\n\n        self.model = __models__[self.label_key](**model_config)\n        if ckpt_path is not None:\n            print('#####################################################################')\n            print(f'load from ckpt \"{ckpt_path}\"')\n            print('#####################################################################')\n            self.init_from_ckpt(ckpt_path)\n\n    @torch.no_grad()\n    def get_x_noisy(self, x, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x))\n        continuous_sqrt_alpha_cumprod = None\n        if self.diffusion_model.use_continuous_noise:\n            continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)\n            # todo: make sure t+1 is correct here\n\n        return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,\n                                             continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)\n\n    def forward(self, x_noisy, t, *args, **kwargs):\n        return self.model(x_noisy, t)\n\n    @torch.no_grad()\n    def get_input(self, batch, k):\n        x = batch[k]\n        if len(x.shape) == 3:\n            x = x[..., None]\n        x = rearrange(x, 'b h w c -> b c h w')\n        x = x.to(memory_format=torch.contiguous_format).float()\n        return x\n\n    @torch.no_grad()\n    def get_conditioning(self, batch, k=None):\n        if k is None:\n            k = self.label_key\n        assert k is not None, 'Needs to provide label key'\n\n        targets = batch[k].to(self.device)\n\n        if self.label_key == 'segmentation':\n            targets = rearrange(targets, 'b h w c -> b c h w')\n            for down in range(self.numd):\n                h, w = targets.shape[-2:]\n                targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')\n\n            # targets = rearrange(targets,'b c h w -> b h w c')\n\n        return targets\n\n    def compute_top_k(self, logits, labels, k, reduction=\"mean\"):\n        _, top_ks = torch.topk(logits, k, dim=1)\n        if reduction == \"mean\":\n            return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()\n        elif reduction == \"none\":\n            return (top_ks == labels[:, None]).float().sum(dim=-1)\n\n    def on_train_epoch_start(self):\n        # save some memory\n        self.diffusion_model.model.to('cpu')\n\n    @torch.no_grad()\n    def write_logs(self, loss, logits, targets):\n        log_prefix = 'train' if self.training else 'val'\n        log = {}\n        log[f\"{log_prefix}/loss\"] = loss.mean()\n        log[f\"{log_prefix}/acc@1\"] = self.compute_top_k(\n            logits, targets, k=1, reduction=\"mean\"\n        )\n        log[f\"{log_prefix}/acc@5\"] = self.compute_top_k(\n            logits, targets, k=5, reduction=\"mean\"\n        )\n\n        self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)\n        self.log('loss', log[f\"{log_prefix}/loss\"], prog_bar=True, logger=False)\n        self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)\n        lr = self.optimizers().param_groups[0]['lr']\n        self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)\n\n    def shared_step(self, batch, t=None):\n        x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)\n        targets = self.get_conditioning(batch)\n        if targets.dim() == 4:\n            targets = targets.argmax(dim=1)\n        if t is None:\n            t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()\n        else:\n            t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()\n        x_noisy = self.get_x_noisy(x, t)\n        logits = self(x_noisy, t)\n\n        loss = F.cross_entropy(logits, targets, reduction='none')\n\n        self.write_logs(loss.detach(), logits.detach(), targets.detach())\n\n        loss = loss.mean()\n        return loss, logits, x_noisy, targets\n\n    def training_step(self, batch, batch_idx):\n        loss, *_ = self.shared_step(batch)\n        return loss\n\n    def reset_noise_accs(self):\n        self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in\n                          range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}\n\n    def on_validation_start(self):\n        self.reset_noise_accs()\n\n    @torch.no_grad()\n    def validation_step(self, batch, batch_idx):\n        loss, *_ = self.shared_step(batch)\n\n        for t in self.noisy_acc:\n            _, logits, _, targets = self.shared_step(batch, t)\n            self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))\n            self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))\n\n        return loss\n\n    def configure_optimizers(self):\n        optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n\n        if self.use_scheduler:\n            scheduler = instantiate_from_config(self.scheduler_config)\n\n            print(\"Setting up LambdaLR scheduler...\")\n            scheduler = [\n                {\n                    'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),\n                    'interval': 'step',\n                    'frequency': 1\n                }]\n            return [optimizer], scheduler\n\n        return optimizer\n\n    @torch.no_grad()\n    def log_images(self, batch, N=8, *args, **kwargs):\n        log = dict()\n        x = self.get_input(batch, self.diffusion_model.first_stage_key)\n        log['inputs'] = x\n\n        y = self.get_conditioning(batch)\n\n        if self.label_key == 'class_label':\n            y = log_txt_as_img((x.shape[2], x.shape[3]), batch[\"human_label\"])\n            log['labels'] = y\n\n        if ismap(y):\n            log['labels'] = self.diffusion_model.to_rgb(y)\n\n            for step in range(self.log_steps):\n                current_time = step * self.log_time_interval\n\n                _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)\n\n                log[f'inputs@t{current_time}'] = x_noisy\n\n                pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)\n                pred = rearrange(pred, 'b h w c -> b c h w')\n\n                log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)\n\n        for key in log:\n            log[key] = log[key][:N]\n\n        return log\n"
  },
  {
    "path": "ldm/models/diffusion/ddim.py",
    "content": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\nfrom einops import rearrange\n\nfrom ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor\nfrom ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding\n\n\nclass DDIMSampler(object):\n    def __init__(self, model, schedule=\"linear\", **kwargs):\n        super().__init__()\n        self.model = model\n        self.ddpm_num_timesteps = model.num_timesteps\n        self.schedule = schedule\n\n    def to(self, device):\n        \"\"\"Same as to in torch module\n        Don't really underestand why this isn't a module in the first place\"\"\"\n        for k, v in self.__dict__.items():\n            if isinstance(v, torch.Tensor):\n                new_v = getattr(self, k).to(device)\n                setattr(self, k, new_v)\n\n\n    def register_buffer(self, name, attr):\n        if type(attr) == torch.Tensor:\n            if attr.device != torch.device(\"cuda\"):\n                attr = attr.to(torch.device(\"cuda\"))\n        setattr(self, name, attr)\n\n    def make_schedule(self, ddim_num_steps, ddim_discretize=\"uniform\", ddim_eta=0., verbose=True):\n        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,\n                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)\n        alphas_cumprod = self.model.alphas_cumprod\n        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'\n        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)\n\n        self.register_buffer('betas', to_torch(self.model.betas))\n        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))\n        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))\n        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))\n        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))\n        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))\n\n        # ddim sampling parameters\n        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),\n                                                                                   ddim_timesteps=self.ddim_timesteps,\n                                                                                   eta=ddim_eta,verbose=verbose)\n        self.register_buffer('ddim_sigmas', ddim_sigmas)\n        self.register_buffer('ddim_alphas', ddim_alphas)\n        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))\n        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (\n                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))\n        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)\n\n    @torch.no_grad()\n    def sample(self,\n               S,\n               batch_size,\n               shape,\n               conditioning=None,\n               callback=None,\n               normals_sequence=None,\n               img_callback=None,\n               quantize_x0=False,\n               eta=0.,\n               mask=None,\n               x0=None,\n               temperature=1.,\n               noise_dropout=0.,\n               score_corrector=None,\n               corrector_kwargs=None,\n               verbose=True,\n               x_T=None,\n               log_every_t=100,\n               unconditional_guidance_scale=1.,\n               unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...\n               dynamic_threshold=None,\n               **kwargs\n               ):\n        if conditioning is not None:\n            if isinstance(conditioning, dict):\n                ctmp = conditioning[list(conditioning.keys())[0]]\n                while isinstance(ctmp, list): ctmp = ctmp[0]\n                cbs = ctmp.shape[0]\n                if cbs != batch_size:\n                    print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n\n            else:\n                if conditioning.shape[0] != batch_size:\n                    print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n\n        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n        # sampling\n        C, H, W = shape\n        size = (batch_size, C, H, W)\n        # print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n\n        samples, intermediates = self.ddim_sampling(conditioning, size,\n                                                    callback=callback,\n                                                    img_callback=img_callback,\n                                                    quantize_denoised=quantize_x0,\n                                                    mask=mask, x0=x0,\n                                                    ddim_use_original_steps=False,\n                                                    noise_dropout=noise_dropout,\n                                                    temperature=temperature,\n                                                    score_corrector=score_corrector,\n                                                    corrector_kwargs=corrector_kwargs,\n                                                    x_T=x_T,\n                                                    log_every_t=log_every_t,\n                                                    unconditional_guidance_scale=unconditional_guidance_scale,\n                                                    unconditional_conditioning=unconditional_conditioning,\n                                                    dynamic_threshold=dynamic_threshold,\n                                                    )\n        return samples, intermediates\n\n    @torch.no_grad()\n    def ddim_sampling(self, cond, shape,\n                      x_T=None, ddim_use_original_steps=False,\n                      callback=None, timesteps=None, quantize_denoised=False,\n                      mask=None, x0=None, img_callback=None, log_every_t=100,\n                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,\n                      unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,\n                      t_start=-1):\n        device = self.model.betas.device\n        b = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=device)\n        else:\n            img = x_T\n\n        if timesteps is None:\n            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps\n        elif timesteps is not None and not ddim_use_original_steps:\n            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1\n            timesteps = self.ddim_timesteps[:subset_end]\n\n        timesteps = timesteps[:t_start]\n\n        intermediates = {'x_inter': [img], 'pred_x0': [img]}\n        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)\n        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n        # print(f\"Running DDIM Sampling with {total_steps} timesteps\")\n\n        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)\n\n        for i, step in enumerate(iterator):\n            index = total_steps - i - 1\n            ts = torch.full((b,), step, device=device, dtype=torch.long)\n\n            if mask is not None:\n                assert x0 is not None\n                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?\n                img = img_orig * mask + (1. - mask) * img\n\n            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,\n                                      quantize_denoised=quantize_denoised, temperature=temperature,\n                                      noise_dropout=noise_dropout, score_corrector=score_corrector,\n                                      corrector_kwargs=corrector_kwargs,\n                                      unconditional_guidance_scale=unconditional_guidance_scale,\n                                      unconditional_conditioning=unconditional_conditioning,\n                                      dynamic_threshold=dynamic_threshold)\n            img, pred_x0 = outs\n            if callback:\n                img = callback(i, img, pred_x0)\n            if img_callback: \n                img_callback(pred_x0, i)\n\n            if index % log_every_t == 0 or index == total_steps - 1:\n                intermediates['x_inter'].append(img)\n                intermediates['pred_x0'].append(pred_x0)\n\n        return img, intermediates\n\n    @torch.no_grad()\n    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,\n                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,\n                      unconditional_guidance_scale=1., unconditional_conditioning=None,\n                      dynamic_threshold=None):\n        b, *_, device = *x.shape, x.device\n\n        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:\n            e_t = self.model.apply_model(x, t, c)\n        else:\n            x_in = torch.cat([x] * 2)\n            t_in = torch.cat([t] * 2)\n            if isinstance(c, dict):\n                assert isinstance(unconditional_conditioning, dict)\n                c_in = dict()\n                for k in c:\n                    if isinstance(c[k], list):\n                        c_in[k] = [torch.cat([\n                            unconditional_conditioning[k][i],\n                            c[k][i]]) for i in range(len(c[k]))]\n                    else:\n                        c_in[k] = torch.cat([\n                                unconditional_conditioning[k],\n                                c[k]])\n            else:\n                c_in = torch.cat([unconditional_conditioning, c])\n            e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)\n            e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)\n\n        if score_corrector is not None:\n            assert self.model.parameterization == \"eps\"\n            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)\n\n        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas\n        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev\n        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas\n        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas\n        # select parameters corresponding to the currently considered timestep\n        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)\n\n        # current prediction for x_0\n        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n\n        print(t, sqrt_one_minus_at, a_t)\n\n        if quantize_denoised:\n            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n\n        if dynamic_threshold is not None:\n            pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)\n\n        # direction pointing to x_t\n        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t\n        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n        if noise_dropout > 0.:\n            noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n        return x_prev, pred_x0\n\n    @torch.no_grad()\n    def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,\n               unconditional_guidance_scale=1.0, unconditional_conditioning=None):\n        num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]\n\n        assert t_enc <= num_reference_steps\n        num_steps = t_enc\n\n        if use_original_steps:\n            alphas_next = self.alphas_cumprod[:num_steps]\n            alphas = self.alphas_cumprod_prev[:num_steps]\n        else:\n            alphas_next = self.ddim_alphas[:num_steps]\n            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])\n\n        x_next = x0\n        intermediates = []\n        inter_steps = []\n        for i in tqdm(range(num_steps), desc='Encoding Image'):\n            t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)\n            if unconditional_guidance_scale == 1.:\n                noise_pred = self.model.apply_model(x_next, t, c)\n            else:\n                assert unconditional_conditioning is not None\n                e_t_uncond, noise_pred = torch.chunk(\n                    self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),\n                                           torch.cat((unconditional_conditioning, c))), 2)\n                noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)\n\n            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next\n            weighted_noise_pred = alphas_next[i].sqrt() * (\n                    (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred\n            x_next = xt_weighted + weighted_noise_pred\n            if return_intermediates and i % (\n                    num_steps // return_intermediates) == 0 and i < num_steps - 1:\n                intermediates.append(x_next)\n                inter_steps.append(i)\n            elif return_intermediates and i >= num_steps - 2:\n                intermediates.append(x_next)\n                inter_steps.append(i)\n\n        out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}\n        if return_intermediates:\n            out.update({'intermediates': intermediates})\n        return x_next, out\n\n    @torch.no_grad()\n    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):\n        # fast, but does not allow for exact reconstruction\n        # t serves as an index to gather the correct alphas\n        if use_original_steps:\n            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod\n            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod\n        else:\n            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)\n            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas\n\n        if noise is None:\n            noise = torch.randn_like(x0)\n        return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +\n                extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)\n\n    @torch.no_grad()\n    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,\n               use_original_steps=False):\n\n        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps\n        timesteps = timesteps[:t_start]\n\n        time_range = np.flip(timesteps)\n        total_steps = timesteps.shape[0]\n        # print(f\"Running DDIM Sampling with {total_steps} timesteps\")\n\n        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)\n        x_dec = x_latent\n        for i, step in enumerate(iterator):\n            index = total_steps - i - 1\n            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)\n            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,\n                                          unconditional_guidance_scale=unconditional_guidance_scale,\n                                          unconditional_conditioning=unconditional_conditioning)\n        return x_dec"
  },
  {
    "path": "ldm/models/diffusion/ddpm.py",
    "content": "\"\"\"\nwild mixture of\nhttps://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py\nhttps://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py\nhttps://github.com/CompVis/taming-transformers\n-- merci\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom einops import rearrange, repeat\nfrom contextlib import contextmanager, nullcontext\nfrom functools import partial\nimport itertools\nfrom tqdm import tqdm\nfrom torchvision.utils import make_grid\nfrom pytorch_lightning.utilities.rank_zero import rank_zero_only \nfrom omegaconf import ListConfig\n\nfrom ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config\nfrom ldm.modules.ema import LitEma\nfrom ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution\nfrom ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL\nfrom ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like\nfrom ldm.models.diffusion.ddim import DDIMSampler\nfrom ldm.modules.attention import CrossAttention\n\n\n__conditioning_keys__ = {'concat': 'c_concat',\n                         'crossattn': 'c_crossattn',\n                         'adm': 'y'}\n\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\ndef uniform_on_device(r1, r2, shape, device):\n    return (r1 - r2) * torch.rand(*shape, device=device) + r2\n\n\nclass DDPM(pl.LightningModule):\n    # classic DDPM with Gaussian diffusion, in image space\n    def __init__(self,\n                 unet_config,\n                 timesteps=1000,\n                 beta_schedule=\"linear\",\n                 loss_type=\"l2\",\n                 ckpt_path=None,\n                 ignore_keys=[],\n                 load_only_unet=False,\n                 monitor=\"val/loss\",\n                 use_ema=True,\n                 first_stage_key=\"image\",\n                 image_size=256,\n                 channels=3,\n                 log_every_t=100,\n                 clip_denoised=True,\n                 linear_start=1e-4,\n                 linear_end=2e-2,\n                 cosine_s=8e-3,\n                 given_betas=None,\n                 original_elbo_weight=0.,\n                 v_posterior=0.,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta\n                 l_simple_weight=1.,\n                 conditioning_key=None,\n                 parameterization=\"eps\",  # all assuming fixed variance schedules\n                 scheduler_config=None,\n                 use_positional_encodings=False,\n                 learn_logvar=False,\n                 logvar_init=0.,\n                 make_it_fit=False,\n                 ucg_training=None,\n                 ):\n        super().__init__()\n        assert parameterization in [\"eps\", \"x0\"], 'currently only supporting \"eps\" and \"x0\"'\n        self.parameterization = parameterization\n        print(f\"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode\")\n        self.cond_stage_model = None\n        self.clip_denoised = clip_denoised\n        self.log_every_t = log_every_t\n        self.first_stage_key = first_stage_key\n        self.image_size = image_size  # try conv?\n        self.channels = channels\n        self.use_positional_encodings = use_positional_encodings\n        self.model = DiffusionWrapper(unet_config, conditioning_key)\n        count_params(self.model, verbose=True)\n        self.use_ema = use_ema\n        if self.use_ema:\n            self.model_ema = LitEma(self.model)\n            print(f\"Keeping EMAs of {len(list(self.model_ema.buffers()))}.\")\n\n        self.use_scheduler = scheduler_config is not None\n        if self.use_scheduler:\n            self.scheduler_config = scheduler_config\n\n        self.v_posterior = v_posterior\n        self.original_elbo_weight = original_elbo_weight\n        self.l_simple_weight = l_simple_weight\n\n        if monitor is not None:\n            self.monitor = monitor\n        self.make_it_fit = make_it_fit\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)\n\n        self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,\n                               linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)\n\n        self.loss_type = loss_type\n\n        self.learn_logvar = learn_logvar\n        self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))\n        if self.learn_logvar:\n            self.logvar = nn.Parameter(self.logvar, requires_grad=True)\n\n        self.ucg_training = ucg_training or dict()\n        if self.ucg_training:\n            self.ucg_prng = np.random.RandomState()\n\n    def register_schedule(self, given_betas=None, beta_schedule=\"linear\", timesteps=1000,\n                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):\n        if exists(given_betas):\n            betas = given_betas\n        else:\n            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,\n                                       cosine_s=cosine_s)\n        alphas = 1. - betas\n        alphas_cumprod = np.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])\n\n        timesteps, = betas.shape\n        self.num_timesteps = int(timesteps)\n        self.linear_start = linear_start\n        self.linear_end = linear_end\n        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'\n\n        to_torch = partial(torch.tensor, dtype=torch.float32)\n\n        self.register_buffer('betas', to_torch(betas))\n        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))\n        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))\n        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))\n        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))\n        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (\n                    1. - alphas_cumprod) + self.v_posterior * betas\n        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n        self.register_buffer('posterior_variance', to_torch(posterior_variance))\n        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))\n        self.register_buffer('posterior_mean_coef1', to_torch(\n            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))\n        self.register_buffer('posterior_mean_coef2', to_torch(\n            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))\n\n        if self.parameterization == \"eps\":\n            lvlb_weights = self.betas ** 2 / (\n                        2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))\n        elif self.parameterization == \"x0\":\n            lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))\n        else:\n            raise NotImplementedError(\"mu not supported\")\n        # TODO how to choose this term\n        lvlb_weights[0] = lvlb_weights[1]\n        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)\n        assert not torch.isnan(self.lvlb_weights).all()\n\n    @contextmanager\n    def ema_scope(self, context=None):\n        if self.use_ema:\n            self.model_ema.store(self.model.parameters())\n            self.model_ema.copy_to(self.model)\n            if context is not None:\n                print(f\"{context}: Switched to EMA weights\")\n        try:\n            yield None\n        finally:\n            if self.use_ema:\n                self.model_ema.restore(self.model.parameters())\n                if context is not None:\n                    print(f\"{context}: Restored training weights\")\n\n    @torch.no_grad()\n    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):\n        sd = torch.load(path, map_location=\"cpu\")\n        if \"state_dict\" in list(sd.keys()):\n            sd = sd[\"state_dict\"]\n        keys = list(sd.keys())\n\n        if self.make_it_fit:\n            n_params = len([name for name, _ in\n                            itertools.chain(self.named_parameters(),\n                                            self.named_buffers())])\n            for name, param in tqdm(\n                    itertools.chain(self.named_parameters(),\n                                    self.named_buffers()),\n                    desc=\"Fitting old weights to new weights\",\n                    total=n_params\n            ):\n                if not name in sd:\n                    continue\n                old_shape = sd[name].shape\n                new_shape = param.shape\n                assert len(old_shape)==len(new_shape)\n                if len(new_shape) > 2:\n                    # we only modify first two axes\n                    assert new_shape[2:] == old_shape[2:]\n                # assumes first axis corresponds to output dim\n                if not new_shape == old_shape:\n                    new_param = param.clone()\n                    old_param = sd[name]\n                    if len(new_shape) == 1:\n                        for i in range(new_param.shape[0]):\n                            new_param[i] = old_param[i % old_shape[0]]\n                    elif len(new_shape) >= 2:\n                        for i in range(new_param.shape[0]):\n                            for j in range(new_param.shape[1]):\n                                new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]\n\n                        n_used_old = torch.ones(old_shape[1])\n                        for j in range(new_param.shape[1]):\n                            n_used_old[j % old_shape[1]] += 1\n                        n_used_new = torch.zeros(new_shape[1])\n                        for j in range(new_param.shape[1]):\n                            n_used_new[j] = n_used_old[j % old_shape[1]]\n\n                        n_used_new = n_used_new[None, :]\n                        while len(n_used_new.shape) < len(new_shape):\n                            n_used_new = n_used_new.unsqueeze(-1)\n                        new_param /= n_used_new\n\n                    sd[name] = new_param\n\n        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(\n            sd, strict=False)\n        print(f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\")\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n        if len(unexpected) > 0:\n            print(f\"Unexpected Keys: {unexpected}\")\n\n    def q_mean_variance(self, x_start, t):\n        \"\"\"\n        Get the distribution q(x_t | x_0).\n        :param x_start: the [N x C x ...] tensor of noiseless inputs.\n        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.\n        :return: A tuple (mean, variance, log_variance), all of x_start's shape.\n        \"\"\"\n        mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)\n        variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)\n        log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)\n        return mean, variance, log_variance\n\n    def predict_start_from_noise(self, x_t, t, noise):\n        return (\n                extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -\n                extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise\n        )\n\n    def q_posterior(self, x_start, x_t, t):\n        posterior_mean = (\n                extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +\n                extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n    def p_mean_variance(self, x, t, clip_denoised: bool):\n        model_out = self.model(x, t)\n        if self.parameterization == \"eps\":\n            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)\n        elif self.parameterization == \"x0\":\n            x_recon = model_out\n        if clip_denoised:\n            x_recon.clamp_(-1., 1.)\n\n        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)\n        return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.no_grad()\n    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):\n        b, *_, device = *x.shape, x.device\n        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)\n        noise = noise_like(x.shape, device, repeat_noise)\n        # no noise when t == 0\n        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n\n    @torch.no_grad()\n    def p_sample_loop(self, shape, return_intermediates=False):\n        device = self.betas.device\n        b = shape[0]\n        img = torch.randn(shape, device=device)\n        intermediates = [img]\n        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):\n            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),\n                                clip_denoised=self.clip_denoised)\n            if i % self.log_every_t == 0 or i == self.num_timesteps - 1:\n                intermediates.append(img)\n        if return_intermediates:\n            return img, intermediates\n        return img\n\n    @torch.no_grad()\n    def sample(self, batch_size=16, return_intermediates=False):\n        image_size = self.image_size\n        channels = self.channels\n        return self.p_sample_loop((batch_size, channels, image_size, image_size),\n                                  return_intermediates=return_intermediates)\n\n    def q_sample(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +\n                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)\n\n    def get_loss(self, pred, target, mean=True):\n        if self.loss_type == 'l1':\n            loss = (target - pred).abs()\n            if mean:\n                loss = loss.mean()\n        elif self.loss_type == 'l2':\n            if mean:\n                loss = torch.nn.functional.mse_loss(target, pred)\n            else:\n                loss = torch.nn.functional.mse_loss(target, pred, reduction='none')\n        else:\n            raise NotImplementedError(\"unknown loss type '{loss_type}'\")\n\n        return loss\n\n    def p_losses(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n        model_out = self.model(x_noisy, t)\n\n        loss_dict = {}\n        if self.parameterization == \"eps\":\n            target = noise\n        elif self.parameterization == \"x0\":\n            target = x_start\n        else:\n            raise NotImplementedError(f\"Paramterization {self.parameterization} not yet supported\")\n\n        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])\n\n        log_prefix = 'train' if self.training else 'val'\n\n        loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})\n        loss_simple = loss.mean() * self.l_simple_weight\n\n        loss_vlb = (self.lvlb_weights[t] * loss).mean()\n        loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})\n\n        loss = loss_simple + self.original_elbo_weight * loss_vlb\n\n        loss_dict.update({f'{log_prefix}/loss': loss})\n\n        return loss, loss_dict\n\n    def forward(self, x, *args, **kwargs):\n        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size\n        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'\n        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()\n        return self.p_losses(x, t, *args, **kwargs)\n\n    def get_input(self, batch, k):\n        x = batch[k]\n        if len(x.shape) == 3:\n            x = x[..., None]\n        x = rearrange(x, 'b h w c -> b c h w')\n        x = x.to(memory_format=torch.contiguous_format).float()\n        return x\n\n    def shared_step(self, batch):\n        x = self.get_input(batch, self.first_stage_key)\n        loss, loss_dict = self(x)\n        return loss, loss_dict\n\n    def training_step(self, batch, batch_idx):\n        for k in self.ucg_training:\n            p = self.ucg_training[k][\"p\"]\n            val = self.ucg_training[k][\"val\"]\n            if val is None:\n                val = \"\"\n            for i in range(len(batch[k])):\n                if self.ucg_prng.choice(2, p=[1-p, p]):\n                    batch[k][i] = val\n\n        loss, loss_dict = self.shared_step(batch)\n\n        self.log_dict(loss_dict, prog_bar=True,\n                      logger=True, on_step=True, on_epoch=True)\n\n        self.log(\"global_step\", self.global_step,\n                 prog_bar=True, logger=True, on_step=True, on_epoch=False)\n\n        if self.use_scheduler:\n            lr = self.optimizers().param_groups[0]['lr']\n            self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)\n\n        return loss\n\n    @torch.no_grad()\n    def validation_step(self, batch, batch_idx):\n        _, loss_dict_no_ema = self.shared_step(batch)\n        with self.ema_scope():\n            _, loss_dict_ema = self.shared_step(batch)\n            loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}\n        self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)\n        self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)\n\n    def on_train_batch_end(self, *args, **kwargs):\n        if self.use_ema:\n            self.model_ema(self.model)\n\n    def _get_rows_from_list(self, samples):\n        n_imgs_per_row = len(samples)\n        denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')\n        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')\n        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)\n        return denoise_grid\n\n    @torch.no_grad()\n    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):\n        log = dict()\n        x = self.get_input(batch, self.first_stage_key)\n        N = min(x.shape[0], N)\n        n_row = min(x.shape[0], n_row)\n        x = x.to(self.device)[:N]\n        log[\"inputs\"] = x\n\n        # get diffusion row\n        diffusion_row = list()\n        x_start = x[:n_row]\n\n        for t in range(self.num_timesteps):\n            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:\n                t = repeat(torch.tensor([t]), '1 -> b', b=n_row)\n                t = t.to(self.device).long()\n                noise = torch.randn_like(x_start)\n                x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n                diffusion_row.append(x_noisy)\n\n        log[\"diffusion_row\"] = self._get_rows_from_list(diffusion_row)\n\n        if sample:\n            # get denoise row\n            with self.ema_scope(\"Plotting\"):\n                samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)\n\n            log[\"samples\"] = samples\n            log[\"denoise_row\"] = self._get_rows_from_list(denoise_row)\n\n        if return_keys:\n            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:\n                return log\n            else:\n                return {key: log[key] for key in return_keys}\n        return log\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n        params = list(self.model.parameters())\n        if self.learn_logvar:\n            params = params + [self.logvar]\n        opt = torch.optim.AdamW(params, lr=lr)\n        return opt\n\n\nclass LatentDiffusion(DDPM):\n    \"\"\"main class\"\"\"\n    def __init__(self,\n                 first_stage_config,\n                 cond_stage_config,\n                 num_timesteps_cond=None,\n                 cond_stage_key=\"image\",\n                 cond_stage_trainable=False,\n                 concat_mode=True,\n                 cond_stage_forward=None,\n                 conditioning_key=None,\n                 scale_factor=1.0,\n                 scale_by_std=False,\n                 unet_trainable=True,\n                 *args, **kwargs):\n        self.num_timesteps_cond = default(num_timesteps_cond, 1)\n        self.scale_by_std = scale_by_std\n        assert self.num_timesteps_cond <= kwargs['timesteps']\n        # for backwards compatibility after implementation of DiffusionWrapper\n        if conditioning_key is None:\n            conditioning_key = 'concat' if concat_mode else 'crossattn'\n        if cond_stage_config == '__is_unconditional__':\n            conditioning_key = None\n        ckpt_path = kwargs.pop(\"ckpt_path\", None)\n        ignore_keys = kwargs.pop(\"ignore_keys\", [])\n        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)\n        self.concat_mode = concat_mode\n        self.cond_stage_trainable = cond_stage_trainable\n        self.unet_trainable = unet_trainable\n        self.cond_stage_key = cond_stage_key\n        try:\n            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1\n        except:\n            self.num_downs = 0\n        if not scale_by_std:\n            self.scale_factor = scale_factor\n        else:\n            self.register_buffer('scale_factor', torch.tensor(scale_factor))\n        self.instantiate_first_stage(first_stage_config)\n        self.instantiate_cond_stage(cond_stage_config)\n        self.cond_stage_forward = cond_stage_forward\n\n        # construct linear projection layer for concatenating image CLIP embedding and RT\n        self.cc_projection = nn.Linear(772, 768)\n        nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768])\n        nn.init.zeros_(list(self.cc_projection.parameters())[1])\n        self.cc_projection.requires_grad_(True)\n        \n        self.clip_denoised = False\n        self.bbox_tokenizer = None\n\n        self.restarted_from_ckpt = False\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys)\n            self.restarted_from_ckpt = True\n\n    def make_cond_schedule(self, ):\n        self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)\n        ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()\n        self.cond_ids[:self.num_timesteps_cond] = ids\n\n    @rank_zero_only\n    @torch.no_grad()\n    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):\n        # only for very first batch\n        if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:\n            assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'\n            # set rescale weight to 1./std of encodings\n            print(\"### USING STD-RESCALING ###\")\n            x = super().get_input(batch, self.first_stage_key)\n            x = x.to(self.device)\n            encoder_posterior = self.encode_first_stage(x)\n            z = self.get_first_stage_encoding(encoder_posterior).detach()\n            del self.scale_factor\n            self.register_buffer('scale_factor', 1. / z.flatten().std())\n            print(f\"setting self.scale_factor to {self.scale_factor}\")\n            print(\"### USING STD-RESCALING ###\")\n\n    def register_schedule(self,\n                          given_betas=None, beta_schedule=\"linear\", timesteps=1000,\n                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):\n        super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)\n\n        self.shorten_cond_schedule = self.num_timesteps_cond > 1\n        if self.shorten_cond_schedule:\n            self.make_cond_schedule()\n\n    def instantiate_first_stage(self, config):\n        model = instantiate_from_config(config)\n        self.first_stage_model = model.eval()\n        self.first_stage_model.train = disabled_train\n        for param in self.first_stage_model.parameters():\n            param.requires_grad = False\n\n    def instantiate_cond_stage(self, config):\n        if not self.cond_stage_trainable:\n            if config == \"__is_first_stage__\":\n                print(\"Using first stage also as cond stage.\")\n                self.cond_stage_model = self.first_stage_model\n            elif config == \"__is_unconditional__\":\n                print(f\"Training {self.__class__.__name__} as an unconditional model.\")\n                self.cond_stage_model = None\n                # self.be_unconditional = True\n            else:\n                model = instantiate_from_config(config)\n                self.cond_stage_model = model.eval()\n                self.cond_stage_model.train = disabled_train\n                for param in self.cond_stage_model.parameters():\n                    param.requires_grad = False\n        else:\n            assert config != '__is_first_stage__'\n            assert config != '__is_unconditional__'\n            model = instantiate_from_config(config)\n            self.cond_stage_model = model\n\n    def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):\n        denoise_row = []\n        for zd in tqdm(samples, desc=desc):\n            denoise_row.append(self.decode_first_stage(zd.to(self.device),\n                                                            force_not_quantize=force_no_decoder_quantization))\n        n_imgs_per_row = len(denoise_row)\n        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W\n        denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')\n        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')\n        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)\n        return denoise_grid\n\n    def get_first_stage_encoding(self, encoder_posterior):\n        if isinstance(encoder_posterior, DiagonalGaussianDistribution):\n            z = encoder_posterior.sample()\n        elif isinstance(encoder_posterior, torch.Tensor):\n            z = encoder_posterior\n        else:\n            raise NotImplementedError(f\"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented\")\n        return self.scale_factor * z\n\n    def get_learned_conditioning(self, c):\n        if self.cond_stage_forward is None:\n            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):\n                c = self.cond_stage_model.encode(c)\n                if isinstance(c, DiagonalGaussianDistribution):\n                    c = c.mode()\n            else:\n                c = self.cond_stage_model(c)\n        else:\n            assert hasattr(self.cond_stage_model, self.cond_stage_forward)\n            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)\n        return c\n\n    def meshgrid(self, h, w):\n        y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)\n        x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)\n\n        arr = torch.cat([y, x], dim=-1)\n        return arr\n\n    def delta_border(self, h, w):\n        \"\"\"\n        :param h: height\n        :param w: width\n        :return: normalized distance to image border,\n         wtith min distance = 0 at border and max dist = 0.5 at image center\n        \"\"\"\n        lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)\n        arr = self.meshgrid(h, w) / lower_right_corner\n        dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]\n        dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]\n        edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]\n        return edge_dist\n\n    def get_weighting(self, h, w, Ly, Lx, device):\n        weighting = self.delta_border(h, w)\n        weighting = torch.clip(weighting, self.split_input_params[\"clip_min_weight\"],\n                               self.split_input_params[\"clip_max_weight\"], )\n        weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)\n\n        if self.split_input_params[\"tie_braker\"]:\n            L_weighting = self.delta_border(Ly, Lx)\n            L_weighting = torch.clip(L_weighting,\n                                     self.split_input_params[\"clip_min_tie_weight\"],\n                                     self.split_input_params[\"clip_max_tie_weight\"])\n\n            L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)\n            weighting = weighting * L_weighting\n        return weighting\n\n    def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):  # todo load once not every time, shorten code\n        \"\"\"\n        :param x: img of size (bs, c, h, w)\n        :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])\n        \"\"\"\n        bs, nc, h, w = x.shape\n\n        # number of crops in image\n        Ly = (h - kernel_size[0]) // stride[0] + 1\n        Lx = (w - kernel_size[1]) // stride[1] + 1\n\n        if uf == 1 and df == 1:\n            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)\n            unfold = torch.nn.Unfold(**fold_params)\n\n            fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)\n\n            weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)\n            normalization = fold(weighting).view(1, 1, h, w)  # normalizes the overlap\n            weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))\n\n        elif uf > 1 and df == 1:\n            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)\n            unfold = torch.nn.Unfold(**fold_params)\n\n            fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),\n                                dilation=1, padding=0,\n                                stride=(stride[0] * uf, stride[1] * uf))\n            fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)\n\n            weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)\n            normalization = fold(weighting).view(1, 1, h * uf, w * uf)  # normalizes the overlap\n            weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))\n\n        elif df > 1 and uf == 1:\n            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)\n            unfold = torch.nn.Unfold(**fold_params)\n\n            fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),\n                                dilation=1, padding=0,\n                                stride=(stride[0] // df, stride[1] // df))\n            fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)\n\n            weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)\n            normalization = fold(weighting).view(1, 1, h // df, w // df)  # normalizes the overlap\n            weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))\n\n        else:\n            raise NotImplementedError\n\n        return fold, unfold, normalization, weighting\n\n    \n    @torch.no_grad()\n    def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,\n                  cond_key=None, return_original_cond=False, bs=None, uncond=0.05):\n        x = super().get_input(batch, k)\n        T = batch['T'].to(memory_format=torch.contiguous_format).float()\n        \n        if bs is not None:\n            x = x[:bs]\n            T = T[:bs].to(self.device)\n\n        x = x.to(self.device)\n        encoder_posterior = self.encode_first_stage(x)\n        z = self.get_first_stage_encoding(encoder_posterior).detach()\n        cond_key = cond_key or self.cond_stage_key\n        xc = super().get_input(batch, cond_key).to(self.device)\n        if bs is not None:\n            xc = xc[:bs]\n        cond = {}\n\n        # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.\n        random = torch.rand(x.size(0), device=x.device)\n        prompt_mask = rearrange(random < 2 * uncond, \"n -> n 1 1\")\n        input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), \"n -> n 1 1 1\")\n        null_prompt = self.get_learned_conditioning([\"\"])\n\n        # z.shape: [8, 4, 64, 64]; c.shape: [8, 1, 768]\n        # print('=========== xc shape ===========', xc.shape)\n        with torch.enable_grad():\n            clip_emb = self.get_learned_conditioning(xc).detach()\n            null_prompt = self.get_learned_conditioning([\"\"]).detach()\n            cond[\"c_crossattn\"] = [self.cc_projection(torch.cat([torch.where(prompt_mask, null_prompt, clip_emb), T[:, None, :]], dim=-1))]\n        cond[\"c_concat\"] = [input_mask * self.encode_first_stage((xc.to(self.device))).mode().detach()]\n        out = [z, cond]\n        if return_first_stage_outputs:\n            xrec = self.decode_first_stage(z)\n            out.extend([x, xrec])\n        if return_original_cond:\n            out.append(xc)\n        return out\n\n    # @torch.no_grad()\n    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):\n        if predict_cids:\n            if z.dim() == 4:\n                z = torch.argmax(z.exp(), dim=1).long()\n            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)\n            z = rearrange(z, 'b h w c -> b c h w').contiguous()\n\n        z = 1. / self.scale_factor * z\n\n        if hasattr(self, \"split_input_params\"):\n            if self.split_input_params[\"patch_distributed_vq\"]:\n                ks = self.split_input_params[\"ks\"]  # eg. (128, 128)\n                stride = self.split_input_params[\"stride\"]  # eg. (64, 64)\n                uf = self.split_input_params[\"vqf\"]\n                bs, nc, h, w = z.shape\n                if ks[0] > h or ks[1] > w:\n                    ks = (min(ks[0], h), min(ks[1], w))\n                    print(\"reducing Kernel\")\n\n                if stride[0] > h or stride[1] > w:\n                    stride = (min(stride[0], h), min(stride[1], w))\n                    print(\"reducing stride\")\n\n                fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)\n\n                z = unfold(z)  # (bn, nc * prod(**ks), L)\n                # 1. Reshape to img shape\n                z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n\n                # 2. apply model loop over last dim\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i],\n                                                                 force_not_quantize=predict_cids or force_not_quantize)\n                                   for i in range(z.shape[-1])]\n                else:\n\n                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i])\n                                   for i in range(z.shape[-1])]\n\n                o = torch.stack(output_list, axis=-1)  # # (bn, nc, ks[0], ks[1], L)\n                o = o * weighting\n                # Reverse 1. reshape to img shape\n                o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)\n                # stitch crops together\n                decoded = fold(o)\n                decoded = decoded / normalization  # norm is shape (1, 1, h, w)\n                return decoded\n            else:\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)\n                else:\n                    return self.first_stage_model.decode(z)\n\n        else:\n            if isinstance(self.first_stage_model, VQModelInterface):\n                return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)\n            else:\n                return self.first_stage_model.decode(z)\n\n    # @torch.no_grad() # wasted two hours to find this bug... why no grad here!\n    def encode_first_stage(self, x):\n        if hasattr(self, \"split_input_params\"):\n            if self.split_input_params[\"patch_distributed_vq\"]:\n                ks = self.split_input_params[\"ks\"]  # eg. (128, 128)\n                stride = self.split_input_params[\"stride\"]  # eg. (64, 64)\n                df = self.split_input_params[\"vqf\"]\n                self.split_input_params['original_image_size'] = x.shape[-2:]\n                bs, nc, h, w = x.shape\n                if ks[0] > h or ks[1] > w:\n                    ks = (min(ks[0], h), min(ks[1], w))\n                    print(\"reducing Kernel\")\n\n                if stride[0] > h or stride[1] > w:\n                    stride = (min(stride[0], h), min(stride[1], w))\n                    print(\"reducing stride\")\n\n                fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)\n                z = unfold(x)  # (bn, nc * prod(**ks), L)\n                # Reshape to img shape\n                z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n\n                output_list = [self.first_stage_model.encode(z[:, :, :, :, i])\n                               for i in range(z.shape[-1])]\n\n                o = torch.stack(output_list, axis=-1)\n                o = o * weighting\n\n                # Reverse reshape to img shape\n                o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)\n                # stitch crops together\n                decoded = fold(o)\n                decoded = decoded / normalization\n                return decoded\n\n            else:\n                return self.first_stage_model.encode(x)\n        else:\n            return self.first_stage_model.encode(x)\n\n    def shared_step(self, batch, **kwargs):\n        x, c = self.get_input(batch, self.first_stage_key)\n        loss = self(x, c)\n        return loss\n\n    def forward(self, x, c, *args, **kwargs):\n        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()\n        if self.model.conditioning_key is not None:\n            assert c is not None\n            # if self.cond_stage_trainable:\n            #     c = self.get_learned_conditioning(c)\n            if self.shorten_cond_schedule:  # TODO: drop this option\n                tc = self.cond_ids[t].to(self.device)\n                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))\n        return self.p_losses(x, c, t, *args, **kwargs)\n\n    def _rescale_annotations(self, bboxes, crop_coordinates):  # TODO: move to dataset\n        def rescale_bbox(bbox):\n            x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])\n            y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])\n            w = min(bbox[2] / crop_coordinates[2], 1 - x0)\n            h = min(bbox[3] / crop_coordinates[3], 1 - y0)\n            return x0, y0, w, h\n\n        return [rescale_bbox(b) for b in bboxes]\n\n    def apply_model(self, x_noisy, t, cond, return_ids=False):\n\n        if isinstance(cond, dict):\n            # hybrid case, cond is exptected to be a dict\n            pass\n        else:\n            if not isinstance(cond, list):\n                cond = [cond]\n            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'\n            cond = {key: cond}\n\n        if hasattr(self, \"split_input_params\"):\n            assert len(cond) == 1  # todo can only deal with one conditioning atm\n            assert not return_ids\n            ks = self.split_input_params[\"ks\"]  # eg. (128, 128)\n            stride = self.split_input_params[\"stride\"]  # eg. (64, 64)\n\n            h, w = x_noisy.shape[-2:]\n\n            fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)\n\n            z = unfold(x_noisy)  # (bn, nc * prod(**ks), L)\n            # Reshape to img shape\n            z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n            z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]\n\n            if self.cond_stage_key in [\"image\", \"LR_image\", \"segmentation\",\n                                       'bbox_img'] and self.model.conditioning_key:  # todo check for completeness\n                c_key = next(iter(cond.keys()))  # get key\n                c = next(iter(cond.values()))  # get value\n                assert (len(c) == 1)  # todo extend to list with more than one elem\n                c = c[0]  # get element\n\n                c = unfold(c)\n                c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n\n                cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]\n\n            elif self.cond_stage_key == 'coordinates_bbox':\n                assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'\n\n                # assuming padding of unfold is always 0 and its dilation is always 1\n                n_patches_per_row = int((w - ks[0]) / stride[0] + 1)\n                full_img_h, full_img_w = self.split_input_params['original_image_size']\n                # as we are operating on latents, we need the factor from the original image size to the\n                # spatial latent size to properly rescale the crops for regenerating the bbox annotations\n                num_downs = self.first_stage_model.encoder.num_resolutions - 1\n                rescale_latent = 2 ** (num_downs)\n\n                # get top left postions of patches as conforming for the bbbox tokenizer, therefore we\n                # need to rescale the tl patch coordinates to be in between (0,1)\n                tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,\n                                         rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)\n                                        for patch_nr in range(z.shape[-1])]\n\n                # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)\n                patch_limits = [(x_tl, y_tl,\n                                 rescale_latent * ks[0] / full_img_w,\n                                 rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]\n                # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]\n\n                # tokenize crop coordinates for the bounding boxes of the respective patches\n                patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)\n                                      for bbox in patch_limits]  # list of length l with tensors of shape (1, 2)\n                # cut tknzd crop position from conditioning\n                assert isinstance(cond, dict), 'cond must be dict to be fed into model'\n                cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)\n\n                adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])\n                adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')\n                adapted_cond = self.get_learned_conditioning(adapted_cond)\n                adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])\n\n                cond_list = [{'c_crossattn': [e]} for e in adapted_cond]\n\n            else:\n                cond_list = [cond for i in range(z.shape[-1])]  # Todo make this more efficient\n\n            # apply model by loop over crops\n            output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]\n            assert not isinstance(output_list[0],\n                                  tuple)  # todo cant deal with multiple model outputs check this never happens\n\n            o = torch.stack(output_list, axis=-1)\n            o = o * weighting\n            # Reverse reshape to img shape\n            o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)\n            # stitch crops together\n            x_recon = fold(o) / normalization\n\n        else:\n            x_recon = self.model(x_noisy, t, **cond)\n\n        if isinstance(x_recon, tuple) and not return_ids:\n            return x_recon[0]\n        else:\n            return x_recon\n\n    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):\n        return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \\\n               extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n\n    def _prior_bpd(self, x_start):\n        \"\"\"\n        Get the prior KL term for the variational lower-bound, measured in\n        bits-per-dim.\n        This term can't be optimized, as it only depends on the encoder.\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :return: a batch of [N] KL values (in bits), one per batch element.\n        \"\"\"\n        batch_size = x_start.shape[0]\n        t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)\n        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)\n        kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)\n        return mean_flat(kl_prior) / np.log(2.0)\n\n    def p_losses(self, x_start, cond, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n        model_output = self.apply_model(x_noisy, t, cond)\n\n        loss_dict = {}\n        prefix = 'train' if self.training else 'val'\n\n        if self.parameterization == \"x0\":\n            target = x_start\n        elif self.parameterization == \"eps\":\n            target = noise\n        else:\n            raise NotImplementedError()\n\n        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])\n        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})\n\n        logvar_t = self.logvar[t].to(self.device)\n        loss = loss_simple / torch.exp(logvar_t) + logvar_t\n        # loss = loss_simple / torch.exp(self.logvar) + self.logvar\n        if self.learn_logvar:\n            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})\n            loss_dict.update({'logvar': self.logvar.data.mean()})\n\n        loss = self.l_simple_weight * loss.mean()\n\n        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))\n        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()\n        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})\n        loss += (self.original_elbo_weight * loss_vlb)\n        loss_dict.update({f'{prefix}/loss': loss})\n\n        return loss, loss_dict\n\n    def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,\n                        return_x0=False, score_corrector=None, corrector_kwargs=None):\n        t_in = t\n        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)\n\n        if score_corrector is not None:\n            assert self.parameterization == \"eps\"\n            model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)\n\n        if return_codebook_ids:\n            model_out, logits = model_out\n\n        if self.parameterization == \"eps\":\n            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)\n        elif self.parameterization == \"x0\":\n            x_recon = model_out\n        else:\n            raise NotImplementedError()\n\n        if clip_denoised:\n            x_recon.clamp_(-1., 1.)\n        if quantize_denoised:\n            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)\n        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)\n        if return_codebook_ids:\n            return model_mean, posterior_variance, posterior_log_variance, logits\n        elif return_x0:\n            return model_mean, posterior_variance, posterior_log_variance, x_recon\n        else:\n            return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.no_grad()\n    def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,\n                 return_codebook_ids=False, quantize_denoised=False, return_x0=False,\n                 temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n        b, *_, device = *x.shape, x.device\n        outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,\n                                       return_codebook_ids=return_codebook_ids,\n                                       quantize_denoised=quantize_denoised,\n                                       return_x0=return_x0,\n                                       score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)\n        if return_codebook_ids:\n            raise DeprecationWarning(\"Support dropped.\")\n            model_mean, _, model_log_variance, logits = outputs\n        elif return_x0:\n            model_mean, _, model_log_variance, x0 = outputs\n        else:\n            model_mean, _, model_log_variance = outputs\n\n        noise = noise_like(x.shape, device, repeat_noise) * temperature\n        if noise_dropout > 0.:\n            noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n        # no noise when t == 0\n        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n\n        if return_codebook_ids:\n            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)\n        if return_x0:\n            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0\n        else:\n            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n\n    @torch.no_grad()\n    def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,\n                              img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,\n                              score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,\n                              log_every_t=None):\n        if not log_every_t:\n            log_every_t = self.log_every_t\n        timesteps = self.num_timesteps\n        if batch_size is not None:\n            b = batch_size if batch_size is not None else shape[0]\n            shape = [batch_size] + list(shape)\n        else:\n            b = batch_size = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=self.device)\n        else:\n            img = x_T\n        intermediates = []\n        if cond is not None:\n            if isinstance(cond, dict):\n                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else\n                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}\n            else:\n                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]\n\n        if start_T is not None:\n            timesteps = min(timesteps, start_T)\n        iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',\n                        total=timesteps) if verbose else reversed(\n            range(0, timesteps))\n        if type(temperature) == float:\n            temperature = [temperature] * timesteps\n\n        for i in iterator:\n            ts = torch.full((b,), i, device=self.device, dtype=torch.long)\n            if self.shorten_cond_schedule:\n                assert self.model.conditioning_key != 'hybrid'\n                tc = self.cond_ids[ts].to(cond.device)\n                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))\n\n            img, x0_partial = self.p_sample(img, cond, ts,\n                                            clip_denoised=self.clip_denoised,\n                                            quantize_denoised=quantize_denoised, return_x0=True,\n                                            temperature=temperature[i], noise_dropout=noise_dropout,\n                                            score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)\n            if mask is not None:\n                assert x0 is not None\n                img_orig = self.q_sample(x0, ts)\n                img = img_orig * mask + (1. - mask) * img\n\n            if i % log_every_t == 0 or i == timesteps - 1:\n                intermediates.append(x0_partial)\n            if callback: callback(i)\n            if img_callback: img_callback(img, i)\n        return img, intermediates\n\n    @torch.no_grad()\n    def p_sample_loop(self, cond, shape, return_intermediates=False,\n                      x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,\n                      mask=None, x0=None, img_callback=None, start_T=None,\n                      log_every_t=None):\n\n        if not log_every_t:\n            log_every_t = self.log_every_t\n        device = self.betas.device\n        b = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=device)\n        else:\n            img = x_T\n\n        intermediates = [img]\n        if timesteps is None:\n            timesteps = self.num_timesteps\n\n        if start_T is not None:\n            timesteps = min(timesteps, start_T)\n        iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(\n            range(0, timesteps))\n\n        if mask is not None:\n            assert x0 is not None\n            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match\n\n        for i in iterator:\n            ts = torch.full((b,), i, device=device, dtype=torch.long)\n            if self.shorten_cond_schedule:\n                assert self.model.conditioning_key != 'hybrid'\n                tc = self.cond_ids[ts].to(cond.device)\n                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))\n\n            img = self.p_sample(img, cond, ts,\n                                clip_denoised=self.clip_denoised,\n                                quantize_denoised=quantize_denoised)\n            if mask is not None:\n                img_orig = self.q_sample(x0, ts)\n                img = img_orig * mask + (1. - mask) * img\n\n            if i % log_every_t == 0 or i == timesteps - 1:\n                intermediates.append(img)\n            if callback: callback(i)\n            if img_callback: img_callback(img, i)\n\n        if return_intermediates:\n            return img, intermediates\n        return img\n\n    @torch.no_grad()\n    def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,\n               verbose=True, timesteps=None, quantize_denoised=False,\n               mask=None, x0=None, shape=None,**kwargs):\n        if shape is None:\n            shape = (batch_size, self.channels, self.image_size, self.image_size)\n        if cond is not None:\n            if isinstance(cond, dict):\n                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else\n                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}\n            else:\n                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]\n        return self.p_sample_loop(cond,\n                                  shape,\n                                  return_intermediates=return_intermediates, x_T=x_T,\n                                  verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,\n                                  mask=mask, x0=x0)\n\n    @torch.no_grad()\n    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):\n        if ddim:\n            ddim_sampler = DDIMSampler(self)\n            shape = (self.channels, self.image_size, self.image_size)\n            samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,\n                                                         shape, cond, verbose=False, **kwargs)\n\n        else:\n            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,\n                                                 return_intermediates=True, **kwargs)\n\n        return samples, intermediates\n\n    @torch.no_grad()\n    def get_unconditional_conditioning(self, batch_size, null_label=None, image_size=512):\n        if null_label is not None:\n            xc = null_label\n            if isinstance(xc, ListConfig):\n                xc = list(xc)\n            if isinstance(xc, dict) or isinstance(xc, list):\n                c = self.get_learned_conditioning(xc)\n            else:\n                if hasattr(xc, \"to\"):\n                    xc = xc.to(self.device)\n                c = self.get_learned_conditioning(xc)\n        else:\n            # todo: get null label from cond_stage_model\n            raise NotImplementedError()\n        c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)\n        cond = {}\n        cond[\"c_crossattn\"] = [c]\n        cond[\"c_concat\"] = [torch.zeros([batch_size, 4, image_size // 8, image_size // 8]).to(self.device)]\n        return cond\n\n    @torch.no_grad()\n    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,\n                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,\n                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,\n                   use_ema_scope=True,\n                   **kwargs):\n        ema_scope = self.ema_scope if use_ema_scope else nullcontext\n        use_ddim = ddim_steps is not None\n\n        log = dict()\n        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,\n                                           return_first_stage_outputs=True,\n                                           force_c_encode=True,\n                                           return_original_cond=True,\n                                           bs=N)\n        N = min(x.shape[0], N)\n        n_row = min(x.shape[0], n_row)\n        log[\"inputs\"] = x\n        log[\"reconstruction\"] = xrec\n        if self.model.conditioning_key is not None:\n            if hasattr(self.cond_stage_model, \"decode\"):\n                xc = self.cond_stage_model.decode(c)\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key in [\"caption\", \"txt\"]:\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key == 'class_label':\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[\"human_label\"], size=x.shape[2]//25)\n                log['conditioning'] = xc\n            elif isimage(xc):\n                log[\"conditioning\"] = xc\n            if ismap(xc):\n                log[\"original_conditioning\"] = self.to_rgb(xc)\n\n        if plot_diffusion_rows:\n            # get diffusion row\n            diffusion_row = list()\n            z_start = z[:n_row]\n            for t in range(self.num_timesteps):\n                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:\n                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)\n                    t = t.to(self.device).long()\n                    noise = torch.randn_like(z_start)\n                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)\n                    diffusion_row.append(self.decode_first_stage(z_noisy))\n\n            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W\n            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')\n            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')\n            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])\n            log[\"diffusion_row\"] = diffusion_grid\n\n        if sample:\n            # get denoise row\n            with ema_scope(\"Sampling\"):\n                samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,\n                                                         ddim_steps=ddim_steps,eta=ddim_eta)\n                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)\n            x_samples = self.decode_first_stage(samples)\n            log[\"samples\"] = x_samples\n            if plot_denoise_rows:\n                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)\n                log[\"denoise_row\"] = denoise_grid\n\n            if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(\n                    self.first_stage_model, IdentityFirstStage):\n                # also display when quantizing x0 while sampling\n                with ema_scope(\"Plotting Quantized Denoised\"):\n                    samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,\n                                                             ddim_steps=ddim_steps,eta=ddim_eta,\n                                                             quantize_denoised=True)\n                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,\n                    #                                      quantize_denoised=True)\n                x_samples = self.decode_first_stage(samples.to(self.device))\n                log[\"samples_x0_quantized\"] = x_samples\n\n        if unconditional_guidance_scale > 1.0:\n            uc = self.get_unconditional_conditioning(N, unconditional_guidance_label, image_size=x.shape[-1])\n            # uc = torch.zeros_like(c)\n            with ema_scope(\"Sampling with classifier-free guidance\"):\n                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,\n                                                 ddim_steps=ddim_steps, eta=ddim_eta,\n                                                 unconditional_guidance_scale=unconditional_guidance_scale,\n                                                 unconditional_conditioning=uc,\n                                                 )\n                x_samples_cfg = self.decode_first_stage(samples_cfg)\n                log[f\"samples_cfg_scale_{unconditional_guidance_scale:.2f}\"] = x_samples_cfg\n\n        if inpaint:\n            # make a simple center square\n            b, h, w = z.shape[0], z.shape[2], z.shape[3]\n            mask = torch.ones(N, h, w).to(self.device)\n            # zeros will be filled in\n            mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.\n            mask = mask[:, None, ...]\n            with ema_scope(\"Plotting Inpaint\"):\n\n                samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,\n                                            ddim_steps=ddim_steps, x0=z[:N], mask=mask)\n            x_samples = self.decode_first_stage(samples.to(self.device))\n            log[\"samples_inpainting\"] = x_samples\n            log[\"mask\"] = mask\n\n            # outpaint\n            mask = 1. - mask\n            with ema_scope(\"Plotting Outpaint\"):\n                samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,\n                                            ddim_steps=ddim_steps, x0=z[:N], mask=mask)\n            x_samples = self.decode_first_stage(samples.to(self.device))\n            log[\"samples_outpainting\"] = x_samples\n\n        if plot_progressive_rows:\n            with ema_scope(\"Plotting Progressives\"):\n                img, progressives = self.progressive_denoising(c,\n                                                               shape=(self.channels, self.image_size, self.image_size),\n                                                               batch_size=N)\n            prog_row = self._get_denoise_row_from_list(progressives, desc=\"Progressive Generation\")\n            log[\"progressive_row\"] = prog_row\n\n        if return_keys:\n            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:\n                return log\n            else:\n                return {key: log[key] for key in return_keys}\n        return log\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n        params = []\n        if self.unet_trainable == \"attn\":\n            print(\"Training only unet attention layers\")\n            for n, m in self.model.named_modules():\n                if isinstance(m, CrossAttention) and n.endswith('attn2'):\n                    params.extend(m.parameters())\n        if self.unet_trainable == \"conv_in\":\n            print(\"Training only unet input conv layers\")\n            params = list(self.model.diffusion_model.input_blocks[0][0].parameters())\n        elif self.unet_trainable is True or self.unet_trainable == \"all\":\n            print(\"Training the full unet\")\n            params = list(self.model.parameters())\n        else:\n            raise ValueError(f\"Unrecognised setting for unet_trainable: {self.unet_trainable}\")\n\n        if self.cond_stage_trainable:\n            print(f\"{self.__class__.__name__}: Also optimizing conditioner params!\")\n            params = params + list(self.cond_stage_model.parameters())\n        if self.learn_logvar:\n            print('Diffusion model optimizing logvar')\n            params.append(self.logvar)\n\n        if self.cc_projection is not None:\n            params = params + list(self.cc_projection.parameters())\n            print('========== optimizing for cc projection weight ==========')\n\n        opt = torch.optim.AdamW([{\"params\": self.model.parameters(), \"lr\": lr},\n                                {\"params\": self.cc_projection.parameters(), \"lr\": 10. * lr}], lr=lr)\n        if self.use_scheduler:\n            assert 'target' in self.scheduler_config\n            scheduler = instantiate_from_config(self.scheduler_config)\n\n            print(\"Setting up LambdaLR scheduler...\")\n            scheduler = [\n                {\n                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),\n                    'interval': 'step',\n                    'frequency': 1\n                }]\n            return [opt], scheduler\n        return opt\n\n    @torch.no_grad()\n    def to_rgb(self, x):\n        x = x.float()\n        if not hasattr(self, \"colorize\"):\n            self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)\n        x = nn.functional.conv2d(x, weight=self.colorize)\n        x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.\n        return x\n\n\nclass DiffusionWrapper(pl.LightningModule):\n    def __init__(self, diff_model_config, conditioning_key):\n        super().__init__()\n        self.diffusion_model = instantiate_from_config(diff_model_config)\n        self.conditioning_key = conditioning_key\n        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm']\n\n    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):\n        if self.conditioning_key is None:\n            out = self.diffusion_model(x, t)\n        elif self.conditioning_key == 'concat':\n            xc = torch.cat([x] + c_concat, dim=1)\n            out = self.diffusion_model(xc, t)\n        elif self.conditioning_key == 'crossattn':\n            # c_crossattn dimension:  torch.Size([8, 1, 768]) 1\n            # cc dimension:  torch.Size([8, 1, 768]\n            cc = torch.cat(c_crossattn, 1)\n            out = self.diffusion_model(x, t, context=cc)\n        elif self.conditioning_key == 'hybrid':\n            xc = torch.cat([x] + c_concat, dim=1)\n            cc = torch.cat(c_crossattn, 1)\n            out = self.diffusion_model(xc, t, context=cc)\n        elif self.conditioning_key == 'hybrid-adm':\n            assert c_adm is not None\n            xc = torch.cat([x] + c_concat, dim=1)\n            cc = torch.cat(c_crossattn, 1)\n            out = self.diffusion_model(xc, t, context=cc, y=c_adm)\n        elif self.conditioning_key == 'adm':\n            cc = c_crossattn[0]\n            out = self.diffusion_model(x, t, y=cc)\n        else:\n            raise NotImplementedError()\n\n        return out\n\n\nclass LatentUpscaleDiffusion(LatentDiffusion):\n    def __init__(self, *args, low_scale_config, low_scale_key=\"LR\", **kwargs):\n        super().__init__(*args, **kwargs)\n        # assumes that neither the cond_stage nor the low_scale_model contain trainable params\n        assert not self.cond_stage_trainable\n        self.instantiate_low_stage(low_scale_config)\n        self.low_scale_key = low_scale_key\n\n    def instantiate_low_stage(self, config):\n        model = instantiate_from_config(config)\n        self.low_scale_model = model.eval()\n        self.low_scale_model.train = disabled_train\n        for param in self.low_scale_model.parameters():\n            param.requires_grad = False\n\n    @torch.no_grad()\n    def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):\n        if not log_mode:\n            z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)\n        else:\n            z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,\n                                                  force_c_encode=True, return_original_cond=True, bs=bs)\n        x_low = batch[self.low_scale_key][:bs]\n        x_low = rearrange(x_low, 'b h w c -> b c h w')\n        x_low = x_low.to(memory_format=torch.contiguous_format).float()\n        zx, noise_level = self.low_scale_model(x_low)\n        all_conds = {\"c_concat\": [zx], \"c_crossattn\": [c], \"c_adm\": noise_level}\n        #import pudb; pu.db\n        if log_mode:\n            # TODO: maybe disable if too expensive\n            interpretability = False\n            if interpretability:\n                zx = zx[:, :, ::2, ::2]\n            x_low_rec = self.low_scale_model.decode(zx)\n            return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level\n        return z, all_conds\n\n    @torch.no_grad()\n    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,\n                   plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,\n                   unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,\n                   **kwargs):\n        ema_scope = self.ema_scope if use_ema_scope else nullcontext\n        use_ddim = ddim_steps is not None\n\n        log = dict()\n        z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,\n                                                                          log_mode=True)\n        N = min(x.shape[0], N)\n        n_row = min(x.shape[0], n_row)\n        log[\"inputs\"] = x\n        log[\"reconstruction\"] = xrec\n        log[\"x_lr\"] = x_low\n        log[f\"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}\"] = x_low_rec\n        if self.model.conditioning_key is not None:\n            if hasattr(self.cond_stage_model, \"decode\"):\n                xc = self.cond_stage_model.decode(c)\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key in [\"caption\", \"txt\"]:\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key == 'class_label':\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[\"human_label\"], size=x.shape[2]//25)\n                log['conditioning'] = xc\n            elif isimage(xc):\n                log[\"conditioning\"] = xc\n            if ismap(xc):\n                log[\"original_conditioning\"] = self.to_rgb(xc)\n\n        if plot_diffusion_rows:\n            # get diffusion row\n            diffusion_row = list()\n            z_start = z[:n_row]\n            for t in range(self.num_timesteps):\n                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:\n                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)\n                    t = t.to(self.device).long()\n                    noise = torch.randn_like(z_start)\n                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)\n                    diffusion_row.append(self.decode_first_stage(z_noisy))\n\n            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W\n            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')\n            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')\n            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])\n            log[\"diffusion_row\"] = diffusion_grid\n\n        if sample:\n            # get denoise row\n            with ema_scope(\"Sampling\"):\n                samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,\n                                                         ddim_steps=ddim_steps, eta=ddim_eta)\n                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)\n            x_samples = self.decode_first_stage(samples)\n            log[\"samples\"] = x_samples\n            if plot_denoise_rows:\n                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)\n                log[\"denoise_row\"] = denoise_grid\n\n        if unconditional_guidance_scale > 1.0:\n            uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)\n            # TODO explore better \"unconditional\" choices for the other keys\n            # maybe guide away from empty text label and highest noise level and maximally degraded zx?\n            uc = dict()\n            for k in c:\n                if k == \"c_crossattn\":\n                    assert isinstance(c[k], list) and len(c[k]) == 1\n                    uc[k] = [uc_tmp]\n                elif k == \"c_adm\":  # todo: only run with text-based guidance?\n                    assert isinstance(c[k], torch.Tensor)\n                    uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level\n                elif isinstance(c[k], list):\n                    uc[k] = [c[k][i] for i in range(len(c[k]))]\n                else:\n                    uc[k] = c[k]\n\n            with ema_scope(\"Sampling with classifier-free guidance\"):\n                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,\n                                                 ddim_steps=ddim_steps, eta=ddim_eta,\n                                                 unconditional_guidance_scale=unconditional_guidance_scale,\n                                                 unconditional_conditioning=uc,\n                                                 )\n                x_samples_cfg = self.decode_first_stage(samples_cfg)\n                log[f\"samples_cfg_scale_{unconditional_guidance_scale:.2f}\"] = x_samples_cfg\n\n        if plot_progressive_rows:\n            with ema_scope(\"Plotting Progressives\"):\n                img, progressives = self.progressive_denoising(c,\n                                                               shape=(self.channels, self.image_size, self.image_size),\n                                                               batch_size=N)\n            prog_row = self._get_denoise_row_from_list(progressives, desc=\"Progressive Generation\")\n            log[\"progressive_row\"] = prog_row\n\n        return log\n\n\nclass LatentInpaintDiffusion(LatentDiffusion):\n    \"\"\"\n    can either run as pure inpainting model (only concat mode) or with mixed conditionings,\n    e.g. mask as concat and text via cross-attn.\n    To disable finetuning mode, set finetune_keys to None\n     \"\"\"\n    def __init__(self,\n                 finetune_keys=(\"model.diffusion_model.input_blocks.0.0.weight\",\n                                \"model_ema.diffusion_modelinput_blocks00weight\"\n                                ),\n                 concat_keys=(\"mask\", \"masked_image\"),\n                 masked_image_key=\"masked_image\",\n                 keep_finetune_dims=4,  # if model was trained without concat mode before and we would like to keep these channels\n                 c_concat_log_start=None, # to log reconstruction of c_concat codes\n                 c_concat_log_end=None,\n                 *args, **kwargs\n                 ):\n        ckpt_path = kwargs.pop(\"ckpt_path\", None)\n        ignore_keys = kwargs.pop(\"ignore_keys\", list())\n        super().__init__(*args, **kwargs)\n        self.masked_image_key = masked_image_key\n        assert self.masked_image_key in concat_keys\n        self.finetune_keys = finetune_keys\n        self.concat_keys = concat_keys\n        self.keep_dims = keep_finetune_dims\n        self.c_concat_log_start = c_concat_log_start\n        self.c_concat_log_end = c_concat_log_end\n        if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'\n        if exists(ckpt_path):\n            self.init_from_ckpt(ckpt_path, ignore_keys)\n\n    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):\n        sd = torch.load(path, map_location=\"cpu\")\n        if \"state_dict\" in list(sd.keys()):\n            sd = sd[\"state_dict\"]\n        keys = list(sd.keys())\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del sd[k]\n\n            # make it explicit, finetune by including extra input channels\n            if exists(self.finetune_keys) and k in self.finetune_keys:\n                new_entry = None\n                for name, param in self.named_parameters():\n                    if name in self.finetune_keys:\n                        print(f\"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only\")\n                        new_entry = torch.zeros_like(param)  # zero init\n                assert exists(new_entry), 'did not find matching parameter to modify'\n                new_entry[:, :self.keep_dims, ...] = sd[k]\n                sd[k] = new_entry\n\n        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False)\n        print(f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\")\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n        if len(unexpected) > 0:\n            print(f\"Unexpected Keys: {unexpected}\")\n\n    @torch.no_grad()\n    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):\n        # note: restricted to non-trainable encoders currently\n        assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'\n        z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,\n                                              force_c_encode=True, return_original_cond=True, bs=bs)\n\n        assert exists(self.concat_keys)\n        c_cat = list()\n        for ck in self.concat_keys:\n            cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()\n            if bs is not None:\n                cc = cc[:bs]\n                cc = cc.to(self.device)\n            bchw = z.shape\n            if ck != self.masked_image_key:\n                cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])\n            else:\n                cc = self.get_first_stage_encoding(self.encode_first_stage(cc))\n            c_cat.append(cc)\n        c_cat = torch.cat(c_cat, dim=1)\n        all_conds = {\"c_concat\": [c_cat], \"c_crossattn\": [c]}\n        if return_first_stage_outputs:\n            return z, all_conds, x, xrec, xc\n        return z, all_conds\n\n    @torch.no_grad()\n    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,\n                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,\n                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,\n                   use_ema_scope=True,\n                   **kwargs):\n        ema_scope = self.ema_scope if use_ema_scope else nullcontext\n        use_ddim = ddim_steps is not None\n\n        log = dict()\n        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)\n        c_cat, c = c[\"c_concat\"][0], c[\"c_crossattn\"][0]\n        N = min(x.shape[0], N)\n        n_row = min(x.shape[0], n_row)\n        log[\"inputs\"] = x\n        log[\"reconstruction\"] = xrec\n        if self.model.conditioning_key is not None:\n            if hasattr(self.cond_stage_model, \"decode\"):\n                xc = self.cond_stage_model.decode(c)\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key in [\"caption\", \"txt\"]:\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key == 'class_label':\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[\"human_label\"], size=x.shape[2] // 25)\n                log['conditioning'] = xc\n            elif isimage(xc):\n                log[\"conditioning\"] = xc\n            if ismap(xc):\n                log[\"original_conditioning\"] = self.to_rgb(xc)\n\n        if not (self.c_concat_log_start is None and self.c_concat_log_end is None):\n            log[\"c_concat_decoded\"] = self.decode_first_stage(c_cat[:,self.c_concat_log_start:self.c_concat_log_end])\n\n        if plot_diffusion_rows:\n            # get diffusion row\n            diffusion_row = list()\n            z_start = z[:n_row]\n            for t in range(self.num_timesteps):\n                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:\n                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)\n                    t = t.to(self.device).long()\n                    noise = torch.randn_like(z_start)\n                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)\n                    diffusion_row.append(self.decode_first_stage(z_noisy))\n\n            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W\n            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')\n            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')\n            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])\n            log[\"diffusion_row\"] = diffusion_grid\n\n        if sample:\n            # get denoise row\n            with ema_scope(\"Sampling\"):\n                samples, z_denoise_row = self.sample_log(cond={\"c_concat\": [c_cat], \"c_crossattn\": [c]},\n                                                         batch_size=N, ddim=use_ddim,\n                                                         ddim_steps=ddim_steps, eta=ddim_eta)\n                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)\n            x_samples = self.decode_first_stage(samples)\n            log[\"samples\"] = x_samples\n            if plot_denoise_rows:\n                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)\n                log[\"denoise_row\"] = denoise_grid\n\n        if unconditional_guidance_scale > 1.0:\n            uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)\n            uc_cat = c_cat\n            uc_full = {\"c_concat\": [uc_cat], \"c_crossattn\": [uc_cross]}\n            with ema_scope(\"Sampling with classifier-free guidance\"):\n                samples_cfg, _ = self.sample_log(cond={\"c_concat\": [c_cat], \"c_crossattn\": [c]},\n                                                 batch_size=N, ddim=use_ddim,\n                                                 ddim_steps=ddim_steps, eta=ddim_eta,\n                                                 unconditional_guidance_scale=unconditional_guidance_scale,\n                                                 unconditional_conditioning=uc_full,\n                                                 )\n                x_samples_cfg = self.decode_first_stage(samples_cfg)\n                log[f\"samples_cfg_scale_{unconditional_guidance_scale:.2f}\"] = x_samples_cfg\n\n        log[\"masked_image\"] = rearrange(batch[\"masked_image\"],\n                                        'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()\n        return log\n\n\nclass Layout2ImgDiffusion(LatentDiffusion):\n    # TODO: move all layout-specific hacks to this class\n    def __init__(self, cond_stage_key, *args, **kwargs):\n        assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key=\"coordinates_bbox\"'\n        super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)\n\n    def log_images(self, batch, N=8, *args, **kwargs):\n        logs = super().log_images(batch=batch, N=N, *args, **kwargs)\n\n        key = 'train' if self.training else 'validation'\n        dset = self.trainer.datamodule.datasets[key]\n        mapper = dset.conditional_builders[self.cond_stage_key]\n\n        bbox_imgs = []\n        map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))\n        for tknzd_bbox in batch[self.cond_stage_key][:N]:\n            bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))\n            bbox_imgs.append(bboximg)\n\n        cond_img = torch.stack(bbox_imgs, dim=0)\n        logs['bbox_image'] = cond_img\n        return logs\n\n\nclass SimpleUpscaleDiffusion(LatentDiffusion):\n    def __init__(self, *args, low_scale_key=\"LR\", **kwargs):\n        super().__init__(*args, **kwargs)\n        # assumes that neither the cond_stage nor the low_scale_model contain trainable params\n        assert not self.cond_stage_trainable\n        self.low_scale_key = low_scale_key\n\n    @torch.no_grad()\n    def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):\n        if not log_mode:\n            z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)\n        else:\n            z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,\n                                                  force_c_encode=True, return_original_cond=True, bs=bs)\n        x_low = batch[self.low_scale_key][:bs]\n        x_low = rearrange(x_low, 'b h w c -> b c h w')\n        x_low = x_low.to(memory_format=torch.contiguous_format).float()\n\n        encoder_posterior = self.encode_first_stage(x_low)\n        zx = self.get_first_stage_encoding(encoder_posterior).detach()\n        all_conds = {\"c_concat\": [zx], \"c_crossattn\": [c]}\n\n        if log_mode:\n            # TODO: maybe disable if too expensive\n            interpretability = False\n            if interpretability:\n                zx = zx[:, :, ::2, ::2]\n            return z, all_conds, x, xrec, xc, x_low\n        return z, all_conds\n\n    @torch.no_grad()\n    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,\n                   plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,\n                   unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,\n                   **kwargs):\n        ema_scope = self.ema_scope if use_ema_scope else nullcontext\n        use_ddim = ddim_steps is not None\n\n        log = dict()\n        z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True)\n        N = min(x.shape[0], N)\n        n_row = min(x.shape[0], n_row)\n        log[\"inputs\"] = x\n        log[\"reconstruction\"] = xrec\n        log[\"x_lr\"] = x_low\n\n        if self.model.conditioning_key is not None:\n            if hasattr(self.cond_stage_model, \"decode\"):\n                xc = self.cond_stage_model.decode(c)\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key in [\"caption\", \"txt\"]:\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key == 'class_label':\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[\"human_label\"], size=x.shape[2]//25)\n                log['conditioning'] = xc\n            elif isimage(xc):\n                log[\"conditioning\"] = xc\n            if ismap(xc):\n                log[\"original_conditioning\"] = self.to_rgb(xc)\n\n        if sample:\n            # get denoise row\n            with ema_scope(\"Sampling\"):\n                samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,\n                                                         ddim_steps=ddim_steps, eta=ddim_eta)\n                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)\n            x_samples = self.decode_first_stage(samples)\n            log[\"samples\"] = x_samples\n\n        if unconditional_guidance_scale > 1.0:\n            uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)\n            uc = dict()\n            for k in c:\n                if k == \"c_crossattn\":\n                    assert isinstance(c[k], list) and len(c[k]) == 1\n                    uc[k] = [uc_tmp]\n                elif isinstance(c[k], list):\n                    uc[k] = [c[k][i] for i in range(len(c[k]))]\n                else:\n                    uc[k] = c[k]\n\n            with ema_scope(\"Sampling with classifier-free guidance\"):\n                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,\n                                                 ddim_steps=ddim_steps, eta=ddim_eta,\n                                                 unconditional_guidance_scale=unconditional_guidance_scale,\n                                                 unconditional_conditioning=uc,\n                                                 )\n                x_samples_cfg = self.decode_first_stage(samples_cfg)\n                log[f\"samples_cfg_scale_{unconditional_guidance_scale:.2f}\"] = x_samples_cfg\n        return log\n\nclass MultiCatFrameDiffusion(LatentDiffusion):\n    def __init__(self, *args, low_scale_key=\"LR\", **kwargs):\n        super().__init__(*args, **kwargs)\n        # assumes that neither the cond_stage nor the low_scale_model contain trainable params\n        assert not self.cond_stage_trainable\n        self.low_scale_key = low_scale_key\n\n    @torch.no_grad()\n    def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):\n        n = 2\n        if not log_mode:\n            z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)\n        else:\n            z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,\n                                                  force_c_encode=True, return_original_cond=True, bs=bs)\n        cat_conds = batch[self.low_scale_key][:bs]\n        cats = []\n        for i in range(n):\n            x_low = cat_conds[:,:,:,3*i:3*(i+1)]\n            x_low = rearrange(x_low, 'b h w c -> b c h w')\n            x_low = x_low.to(memory_format=torch.contiguous_format).float()\n            encoder_posterior = self.encode_first_stage(x_low)\n            zx = self.get_first_stage_encoding(encoder_posterior).detach()\n            cats.append(zx)\n\n        all_conds = {\"c_concat\": [torch.cat(cats, dim=1)], \"c_crossattn\": [c]}\n\n        if log_mode:\n            # TODO: maybe disable if too expensive\n            interpretability = False\n            if interpretability:\n                zx = zx[:, :, ::2, ::2]\n            return z, all_conds, x, xrec, xc, x_low\n        return z, all_conds\n\n    @torch.no_grad()\n    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,\n                   plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,\n                   unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,\n                   **kwargs):\n        ema_scope = self.ema_scope if use_ema_scope else nullcontext\n        use_ddim = ddim_steps is not None\n\n        log = dict()\n        z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True)\n        N = min(x.shape[0], N)\n        n_row = min(x.shape[0], n_row)\n        log[\"inputs\"] = x\n        log[\"reconstruction\"] = xrec\n        log[\"x_lr\"] = x_low\n\n        if self.model.conditioning_key is not None:\n            if hasattr(self.cond_stage_model, \"decode\"):\n                xc = self.cond_stage_model.decode(c)\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key in [\"caption\", \"txt\"]:\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key == 'class_label':\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[\"human_label\"], size=x.shape[2]//25)\n                log['conditioning'] = xc\n            elif isimage(xc):\n                log[\"conditioning\"] = xc\n            if ismap(xc):\n                log[\"original_conditioning\"] = self.to_rgb(xc)\n\n        if sample:\n            # get denoise row\n            with ema_scope(\"Sampling\"):\n                samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,\n                                                         ddim_steps=ddim_steps, eta=ddim_eta)\n                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)\n            x_samples = self.decode_first_stage(samples)\n            log[\"samples\"] = x_samples\n\n        if unconditional_guidance_scale > 1.0:\n            uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)\n            uc = dict()\n            for k in c:\n                if k == \"c_crossattn\":\n                    assert isinstance(c[k], list) and len(c[k]) == 1\n                    uc[k] = [uc_tmp]\n                elif isinstance(c[k], list):\n                    uc[k] = [c[k][i] for i in range(len(c[k]))]\n                else:\n                    uc[k] = c[k]\n\n            with ema_scope(\"Sampling with classifier-free guidance\"):\n                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,\n                                                 ddim_steps=ddim_steps, eta=ddim_eta,\n                                                 unconditional_guidance_scale=unconditional_guidance_scale,\n                                                 unconditional_conditioning=uc,\n                                                 )\n                x_samples_cfg = self.decode_first_stage(samples_cfg)\n                log[f\"samples_cfg_scale_{unconditional_guidance_scale:.2f}\"] = x_samples_cfg\n        return log\n"
  },
  {
    "path": "ldm/models/diffusion/plms.py",
    "content": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\n\nfrom ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like\nfrom ldm.models.diffusion.sampling_util import norm_thresholding\n\n\nclass PLMSSampler(object):\n    def __init__(self, model, schedule=\"linear\", **kwargs):\n        super().__init__()\n        self.model = model\n        self.ddpm_num_timesteps = model.num_timesteps\n        self.schedule = schedule\n\n    def register_buffer(self, name, attr):\n        if type(attr) == torch.Tensor:\n            if attr.device != torch.device(\"cuda\"):\n                attr = attr.to(torch.device(\"cuda\"))\n        setattr(self, name, attr)\n\n    def make_schedule(self, ddim_num_steps, ddim_discretize=\"uniform\", ddim_eta=0., verbose=True):\n        if ddim_eta != 0:\n            raise ValueError('ddim_eta must be 0 for PLMS')\n        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,\n                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)\n        alphas_cumprod = self.model.alphas_cumprod\n        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'\n        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)\n\n        self.register_buffer('betas', to_torch(self.model.betas))\n        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))\n        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))\n        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))\n        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))\n        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))\n\n        # ddim sampling parameters\n        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),\n                                                                                   ddim_timesteps=self.ddim_timesteps,\n                                                                                   eta=ddim_eta,verbose=verbose)\n        self.register_buffer('ddim_sigmas', ddim_sigmas)\n        self.register_buffer('ddim_alphas', ddim_alphas)\n        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))\n        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (\n                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))\n        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)\n\n    @torch.no_grad()\n    def sample(self,\n               S,\n               batch_size,\n               shape,\n               conditioning=None,\n               callback=None,\n               normals_sequence=None,\n               img_callback=None,\n               quantize_x0=False,\n               eta=0.,\n               mask=None,\n               x0=None,\n               temperature=1.,\n               noise_dropout=0.,\n               score_corrector=None,\n               corrector_kwargs=None,\n               verbose=True,\n               x_T=None,\n               log_every_t=100,\n               unconditional_guidance_scale=1.,\n               unconditional_conditioning=None,\n               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...\n               dynamic_threshold=None,\n               **kwargs\n               ):\n        if conditioning is not None:\n            if isinstance(conditioning, dict):\n                ctmp = conditioning[list(conditioning.keys())[0]]\n                while isinstance(ctmp, list): ctmp = ctmp[0]\n                cbs = ctmp.shape[0]\n                if cbs != batch_size:\n                    print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n            else:\n                if conditioning.shape[0] != batch_size:\n                    print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n\n        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n        # sampling\n        C, H, W = shape\n        size = (batch_size, C, H, W)\n        print(f'Data shape for PLMS sampling is {size}')\n\n        samples, intermediates = self.plms_sampling(conditioning, size,\n                                                    callback=callback,\n                                                    img_callback=img_callback,\n                                                    quantize_denoised=quantize_x0,\n                                                    mask=mask, x0=x0,\n                                                    ddim_use_original_steps=False,\n                                                    noise_dropout=noise_dropout,\n                                                    temperature=temperature,\n                                                    score_corrector=score_corrector,\n                                                    corrector_kwargs=corrector_kwargs,\n                                                    x_T=x_T,\n                                                    log_every_t=log_every_t,\n                                                    unconditional_guidance_scale=unconditional_guidance_scale,\n                                                    unconditional_conditioning=unconditional_conditioning,\n                                                    dynamic_threshold=dynamic_threshold,\n                                                    )\n        return samples, intermediates\n\n    @torch.no_grad()\n    def plms_sampling(self, cond, shape,\n                      x_T=None, ddim_use_original_steps=False,\n                      callback=None, timesteps=None, quantize_denoised=False,\n                      mask=None, x0=None, img_callback=None, log_every_t=100,\n                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,\n                      unconditional_guidance_scale=1., unconditional_conditioning=None,\n                      dynamic_threshold=None):\n        device = self.model.betas.device\n        b = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=device)\n        else:\n            img = x_T\n\n        if timesteps is None:\n            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps\n        elif timesteps is not None and not ddim_use_original_steps:\n            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1\n            timesteps = self.ddim_timesteps[:subset_end]\n\n        intermediates = {'x_inter': [img], 'pred_x0': [img]}\n        time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)\n        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n        print(f\"Running PLMS Sampling with {total_steps} timesteps\")\n\n        iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)\n        old_eps = []\n\n        for i, step in enumerate(iterator):\n            index = total_steps - i - 1\n            ts = torch.full((b,), step, device=device, dtype=torch.long)\n            ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)\n\n            if mask is not None:\n                assert x0 is not None\n                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?\n                img = img_orig * mask + (1. - mask) * img\n\n            outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,\n                                      quantize_denoised=quantize_denoised, temperature=temperature,\n                                      noise_dropout=noise_dropout, score_corrector=score_corrector,\n                                      corrector_kwargs=corrector_kwargs,\n                                      unconditional_guidance_scale=unconditional_guidance_scale,\n                                      unconditional_conditioning=unconditional_conditioning,\n                                      old_eps=old_eps, t_next=ts_next,\n                                      dynamic_threshold=dynamic_threshold)\n            img, pred_x0, e_t = outs\n            old_eps.append(e_t)\n            if len(old_eps) >= 4:\n                old_eps.pop(0)\n            if callback: callback(i)\n            if img_callback: img_callback(pred_x0, i)\n\n            if index % log_every_t == 0 or index == total_steps - 1:\n                intermediates['x_inter'].append(img)\n                intermediates['pred_x0'].append(pred_x0)\n\n        return img, intermediates\n\n    @torch.no_grad()\n    def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,\n                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,\n                      unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,\n                      dynamic_threshold=None):\n        b, *_, device = *x.shape, x.device\n\n        def get_model_output(x, t):\n            if unconditional_conditioning is None or unconditional_guidance_scale == 1.:\n                e_t = self.model.apply_model(x, t, c)\n            else:\n                x_in = torch.cat([x] * 2)\n                t_in = torch.cat([t] * 2)\n                if isinstance(c, dict):\n                    assert isinstance(unconditional_conditioning, dict)\n                    c_in = dict()\n                    for k in c:\n                        if isinstance(c[k], list):\n                            c_in[k] = [torch.cat([\n                                unconditional_conditioning[k][i],\n                                c[k][i]]) for i in range(len(c[k]))]\n                        else:\n                            c_in[k] = torch.cat([\n                                    unconditional_conditioning[k],\n                                    c[k]])\n                else:\n                    c_in = torch.cat([unconditional_conditioning, c])\n                e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)\n                e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)\n\n            if score_corrector is not None:\n                assert self.model.parameterization == \"eps\"\n                e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)\n\n            return e_t\n\n        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas\n        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev\n        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas\n        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas\n\n        def get_x_prev_and_pred_x0(e_t, index):\n            # select parameters corresponding to the currently considered timestep\n            a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n            a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n            sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n            sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)\n\n            # current prediction for x_0\n            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n            if quantize_denoised:\n                pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n            if dynamic_threshold is not None:\n                pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)\n            # direction pointing to x_t\n            dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t\n            noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n            if noise_dropout > 0.:\n                noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n            x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n            return x_prev, pred_x0\n\n        e_t = get_model_output(x, t)\n        if len(old_eps) == 0:\n            # Pseudo Improved Euler (2nd order)\n            x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)\n            e_t_next = get_model_output(x_prev, t_next)\n            e_t_prime = (e_t + e_t_next) / 2\n        elif len(old_eps) == 1:\n            # 2nd order Pseudo Linear Multistep (Adams-Bashforth)\n            e_t_prime = (3 * e_t - old_eps[-1]) / 2\n        elif len(old_eps) == 2:\n            # 3nd order Pseudo Linear Multistep (Adams-Bashforth)\n            e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12\n        elif len(old_eps) >= 3:\n            # 4nd order Pseudo Linear Multistep (Adams-Bashforth)\n            e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24\n\n        x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)\n\n        return x_prev, pred_x0, e_t\n"
  },
  {
    "path": "ldm/models/diffusion/sampling_util.py",
    "content": "import torch\nimport numpy as np\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\n    From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')\n    return x[(...,) + (None,) * dims_to_append]\n\n\ndef renorm_thresholding(x0, value):\n    # renorm\n    pred_max = x0.max()\n    pred_min = x0.min()\n    pred_x0 = (x0 - pred_min) / (pred_max - pred_min)  # 0 ... 1\n    pred_x0 = 2 * pred_x0 - 1.  # -1 ... 1\n\n    s = torch.quantile(\n        rearrange(pred_x0, 'b ... -> b (...)').abs(),\n        value,\n        dim=-1\n    )\n    s.clamp_(min=1.0)\n    s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))\n\n    # clip by threshold\n    # pred_x0 = pred_x0.clamp(-s, s) / s  # needs newer pytorch  # TODO bring back to pure-gpu with min/max\n\n    # temporary hack: numpy on cpu\n    pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()\n    pred_x0 = torch.tensor(pred_x0).to(self.model.device)\n\n    # re.renorm\n    pred_x0 = (pred_x0 + 1.) / 2.  # 0 ... 1\n    pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min  # orig range\n    return pred_x0\n\n\ndef norm_thresholding(x0, value):\n    s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)\n    return x0 * (value / s)\n\n\ndef spatial_norm_thresholding(x0, value):\n    # b c h w\n    s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)\n    return x0 * (value / s)"
  },
  {
    "path": "ldm/modules/attention.py",
    "content": "from inspect import isfunction\nimport math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfrom einops import rearrange, repeat\n\nfrom ldm.modules.diffusionmodules.util import checkpoint\n\n\ndef exists(val):\n    return val is not None\n\n\ndef uniq(arr):\n    return{el: True for el in arr}.keys()\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef max_neg_value(t):\n    return -torch.finfo(t.dtype).max\n\n\ndef init_(tensor):\n    dim = tensor.shape[-1]\n    std = 1 / math.sqrt(dim)\n    tensor.uniform_(-std, std)\n    return tensor\n\n\n# feedforward\nclass GEGLU(nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def forward(self, x):\n        x, gate = self.proj(x).chunk(2, dim=-1)\n        return x * F.gelu(gate)\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = default(dim_out, dim)\n        project_in = nn.Sequential(\n            nn.Linear(dim, inner_dim),\n            nn.GELU()\n        ) if not glu else GEGLU(dim, inner_dim)\n\n        self.net = nn.Sequential(\n            project_in,\n            nn.Dropout(dropout),\n            nn.Linear(inner_dim, dim_out)\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef Normalize(in_channels):\n    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)\n\n\nclass LinearAttention(nn.Module):\n    def __init__(self, dim, heads=4, dim_head=32):\n        super().__init__()\n        self.heads = heads\n        hidden_dim = dim_head * heads\n        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)\n        self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n\n    def forward(self, x):\n        b, c, h, w = x.shape\n        qkv = self.to_qkv(x)\n        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)\n        k = k.softmax(dim=-1)  \n        context = torch.einsum('bhdn,bhen->bhde', k, v)\n        out = torch.einsum('bhde,bhdn->bhen', context, q)\n        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)\n        return self.to_out(out)\n\n\nclass SpatialSelfAttention(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.k = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.v = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.proj_out = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=1,\n                                        stride=1,\n                                        padding=0)\n\n    def forward(self, x):\n        h_ = x\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b,c,h,w = q.shape\n        q = rearrange(q, 'b c h w -> b (h w) c')\n        k = rearrange(k, 'b c h w -> b c (h w)')\n        w_ = torch.einsum('bij,bjk->bik', q, k)\n\n        w_ = w_ * (int(c)**(-0.5))\n        w_ = torch.nn.functional.softmax(w_, dim=2)\n\n        # attend to values\n        v = rearrange(v, 'b c h w -> b c (h w)')\n        w_ = rearrange(w_, 'b i j -> b j i')\n        h_ = torch.einsum('bij,bjk->bik', v, w_)\n        h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)\n        h_ = self.proj_out(h_)\n\n        return x+h_\n\n\nclass CrossAttention(nn.Module):\n    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):\n        super().__init__()\n        inner_dim = dim_head * heads\n        context_dim = default(context_dim, query_dim)\n\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)\n        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)\n        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)\n\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, query_dim),\n            nn.Dropout(dropout)\n        )\n\n    def forward(self, x, context=None, mask=None):\n        h = self.heads\n\n        q = self.to_q(x)\n        context = default(context, x)\n        k = self.to_k(context)\n        v = self.to_v(context)\n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))\n\n        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale\n\n        if exists(mask):\n            mask = rearrange(mask, 'b ... -> b (...)')\n            max_neg_value = -torch.finfo(sim.dtype).max\n            mask = repeat(mask, 'b j -> (b h) () j', h=h)\n            sim.masked_fill_(~mask, max_neg_value)\n\n        # attention, what we cannot get enough of\n        attn = sim.softmax(dim=-1)\n\n        out = einsum('b i j, b j d -> b i d', attn, v)\n        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)\n        return self.to_out(out)\n\n\nclass BasicTransformerBlock(nn.Module):\n    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,\n                 disable_self_attn=False):\n        super().__init__()\n        self.disable_self_attn = disable_self_attn\n        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,\n                                    context_dim=context_dim if self.disable_self_attn else None)  # is a self-attention if not self.disable_self_attn\n        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)\n        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,\n                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n        self.norm3 = nn.LayerNorm(dim)\n        self.checkpoint = checkpoint\n\n    def forward(self, x, context=None):\n        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)\n\n    def _forward(self, x, context=None):\n        x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x\n        x = self.attn2(self.norm2(x), context=context) + x\n        x = self.ff(self.norm3(x)) + x\n        return x\n\n\nclass SpatialTransformer(nn.Module):\n    \"\"\"\n    Transformer block for image-like data.\n    First, project the input (aka embedding)\n    and reshape to b, t, d.\n    Then apply standard transformer action.\n    Finally, reshape to image\n    \"\"\"\n    def __init__(self, in_channels, n_heads, d_head,\n                 depth=1, dropout=0., context_dim=None,\n                 disable_self_attn=False):\n        super().__init__()\n        self.in_channels = in_channels\n        inner_dim = n_heads * d_head\n        self.norm = Normalize(in_channels)\n\n        self.proj_in = nn.Conv2d(in_channels,\n                                 inner_dim,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n\n        self.transformer_blocks = nn.ModuleList(\n            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,\n                                   disable_self_attn=disable_self_attn)\n                for d in range(depth)]\n        )\n\n        self.proj_out = zero_module(nn.Conv2d(inner_dim,\n                                              in_channels,\n                                              kernel_size=1,\n                                              stride=1,\n                                              padding=0))\n\n    def forward(self, x, context=None):\n        # note: if no context is given, cross-attention defaults to self-attention\n        b, c, h, w = x.shape\n        x_in = x\n        x = self.norm(x)\n        x = self.proj_in(x)\n        x = rearrange(x, 'b c h w -> b (h w) c').contiguous()\n        for block in self.transformer_blocks:\n            x = block(x, context=context)\n        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()\n        x = self.proj_out(x)\n        return x + x_in\n"
  },
  {
    "path": "ldm/modules/diffusionmodules/__init__.py",
    "content": ""
  },
  {
    "path": "ldm/modules/diffusionmodules/model.py",
    "content": "# pytorch_diffusion + derived encoder decoder\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom einops import rearrange\n\nfrom ldm.util import instantiate_from_config\nfrom ldm.modules.attention import LinearAttention\n\n\ndef get_timestep_embedding(timesteps, embedding_dim):\n    \"\"\"\n    This matches the implementation in Denoising Diffusion Probabilistic Models:\n    From Fairseq.\n    Build sinusoidal embeddings.\n    This matches the implementation in tensor2tensor, but differs slightly\n    from the description in Section 3.5 of \"Attention Is All You Need\".\n    \"\"\"\n    assert len(timesteps.shape) == 1\n\n    half_dim = embedding_dim // 2\n    emb = math.log(10000) / (half_dim - 1)\n    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)\n    emb = emb.to(device=timesteps.device)\n    emb = timesteps.float()[:, None] * emb[None, :]\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n    if embedding_dim % 2 == 1:  # zero pad\n        emb = torch.nn.functional.pad(emb, (0,1,0,0))\n    return emb\n\n\ndef nonlinearity(x):\n    # swish\n    return x*torch.sigmoid(x)\n\n\ndef Normalize(in_channels, num_groups=32):\n    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)\n\n\nclass Upsample(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            self.conv = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode=\"nearest\")\n        if self.with_conv:\n            x = self.conv(x)\n        return x\n\n\nclass Downsample(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            # no asymmetric padding in torch conv, must do it ourselves\n            self.conv = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=3,\n                                        stride=2,\n                                        padding=0)\n\n    def forward(self, x):\n        if self.with_conv:\n            pad = (0,1,0,1)\n            x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n            x = self.conv(x)\n        else:\n            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)\n        return x\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,\n                 dropout, temb_channels=512):\n        super().__init__()\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        self.use_conv_shortcut = conv_shortcut\n\n        self.norm1 = Normalize(in_channels)\n        self.conv1 = torch.nn.Conv2d(in_channels,\n                                     out_channels,\n                                     kernel_size=3,\n                                     stride=1,\n                                     padding=1)\n        if temb_channels > 0:\n            self.temb_proj = torch.nn.Linear(temb_channels,\n                                             out_channels)\n        self.norm2 = Normalize(out_channels)\n        self.dropout = torch.nn.Dropout(dropout)\n        self.conv2 = torch.nn.Conv2d(out_channels,\n                                     out_channels,\n                                     kernel_size=3,\n                                     stride=1,\n                                     padding=1)\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                self.conv_shortcut = torch.nn.Conv2d(in_channels,\n                                                     out_channels,\n                                                     kernel_size=3,\n                                                     stride=1,\n                                                     padding=1)\n            else:\n                self.nin_shortcut = torch.nn.Conv2d(in_channels,\n                                                    out_channels,\n                                                    kernel_size=1,\n                                                    stride=1,\n                                                    padding=0)\n\n    def forward(self, x, temb):\n        h = x\n        h = self.norm1(h)\n        h = nonlinearity(h)\n        h = self.conv1(h)\n\n        if temb is not None:\n            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]\n\n        h = self.norm2(h)\n        h = nonlinearity(h)\n        h = self.dropout(h)\n        h = self.conv2(h)\n\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                x = self.conv_shortcut(x)\n            else:\n                x = self.nin_shortcut(x)\n\n        return x+h\n\n\nclass LinAttnBlock(LinearAttention):\n    \"\"\"to match AttnBlock usage\"\"\"\n    def __init__(self, in_channels):\n        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)\n\n\nclass AttnBlock(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.k = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.v = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.proj_out = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=1,\n                                        stride=1,\n                                        padding=0)\n\n\n    def forward(self, x):\n        h_ = x\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b,c,h,w = q.shape\n        q = q.reshape(b,c,h*w)\n        q = q.permute(0,2,1)   # b,hw,c\n        k = k.reshape(b,c,h*w) # b,c,hw\n        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]\n        w_ = w_ * (int(c)**(-0.5))\n        w_ = torch.nn.functional.softmax(w_, dim=2)\n\n        # attend to values\n        v = v.reshape(b,c,h*w)\n        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)\n        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]\n        h_ = h_.reshape(b,c,h,w)\n\n        h_ = self.proj_out(h_)\n\n        return x+h_\n\n\ndef make_attn(in_channels, attn_type=\"vanilla\"):\n    assert attn_type in [\"vanilla\", \"linear\", \"none\"], f'attn_type {attn_type} unknown'\n    print(f\"making attention of type '{attn_type}' with {in_channels} in_channels\")\n    if attn_type == \"vanilla\":\n        return AttnBlock(in_channels)\n    elif attn_type == \"none\":\n        return nn.Identity(in_channels)\n    else:\n        return LinAttnBlock(in_channels)\n\n\nclass Model(nn.Module):\n    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,\n                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,\n                 resolution, use_timestep=True, use_linear_attn=False, attn_type=\"vanilla\"):\n        super().__init__()\n        if use_linear_attn: attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = self.ch*4\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n\n        self.use_timestep = use_timestep\n        if self.use_timestep:\n            # timestep embedding\n            self.temb = nn.Module()\n            self.temb.dense = nn.ModuleList([\n                torch.nn.Linear(self.ch,\n                                self.temb_ch),\n                torch.nn.Linear(self.temb_ch,\n                                self.temb_ch),\n            ])\n\n        # downsampling\n        self.conv_in = torch.nn.Conv2d(in_channels,\n                                       self.ch,\n                                       kernel_size=3,\n                                       stride=1,\n                                       padding=1)\n\n        curr_res = resolution\n        in_ch_mult = (1,)+tuple(ch_mult)\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = ch*in_ch_mult[i_level]\n            block_out = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions-1:\n                down.downsample = Downsample(block_in, resamp_with_conv)\n                curr_res = curr_res // 2\n            self.down.append(down)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n\n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch*ch_mult[i_level]\n            skip_in = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks+1):\n                if i_block == self.num_res_blocks:\n                    skip_in = ch*in_ch_mult[i_level]\n                block.append(ResnetBlock(in_channels=block_in+skip_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                up.upsample = Upsample(block_in, resamp_with_conv)\n                curr_res = curr_res * 2\n            self.up.insert(0, up) # prepend to get consistent order\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        out_ch,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x, t=None, context=None):\n        #assert x.shape[2] == x.shape[3] == self.resolution\n        if context is not None:\n            # assume aligned context, cat along channel axis\n            x = torch.cat((x, context), dim=1)\n        if self.use_timestep:\n            # timestep embedding\n            assert t is not None\n            temb = get_timestep_embedding(t, self.ch)\n            temb = self.temb.dense[0](temb)\n            temb = nonlinearity(temb)\n            temb = self.temb.dense[1](temb)\n        else:\n            temb = None\n\n        # downsampling\n        hs = [self.conv_in(x)]\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](hs[-1], temb)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n                hs.append(h)\n            if i_level != self.num_resolutions-1:\n                hs.append(self.down[i_level].downsample(hs[-1]))\n\n        # middle\n        h = hs[-1]\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks+1):\n                h = self.up[i_level].block[i_block](\n                    torch.cat([h, hs.pop()], dim=1), temb)\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h)\n            if i_level != 0:\n                h = self.up[i_level].upsample(h)\n\n        # end\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n    def get_last_layer(self):\n        return self.conv_out.weight\n\n\nclass Encoder(nn.Module):\n    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,\n                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,\n                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type=\"vanilla\",\n                 **ignore_kwargs):\n        super().__init__()\n        if use_linear_attn: attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n\n        # downsampling\n        self.conv_in = torch.nn.Conv2d(in_channels,\n                                       self.ch,\n                                       kernel_size=3,\n                                       stride=1,\n                                       padding=1)\n\n        curr_res = resolution\n        in_ch_mult = (1,)+tuple(ch_mult)\n        self.in_ch_mult = in_ch_mult\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = ch*in_ch_mult[i_level]\n            block_out = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions-1:\n                down.downsample = Downsample(block_in, resamp_with_conv)\n                curr_res = curr_res // 2\n            self.down.append(down)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        2*z_channels if double_z else z_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        # timestep embedding\n        temb = None\n\n        # downsampling\n        hs = [self.conv_in(x)]\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](hs[-1], temb)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n                hs.append(h)\n            if i_level != self.num_resolutions-1:\n                hs.append(self.down[i_level].downsample(hs[-1]))\n\n        # middle\n        h = hs[-1]\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # end\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n\nclass Decoder(nn.Module):\n    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,\n                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,\n                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,\n                 attn_type=\"vanilla\", **ignorekwargs):\n        super().__init__()\n        if use_linear_attn: attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n        self.give_pre_end = give_pre_end\n        self.tanh_out = tanh_out\n\n        # compute in_ch_mult, block_in and curr_res at lowest res\n        in_ch_mult = (1,)+tuple(ch_mult)\n        block_in = ch*ch_mult[self.num_resolutions-1]\n        curr_res = resolution // 2**(self.num_resolutions-1)\n        self.z_shape = (1,z_channels,curr_res,curr_res)\n        print(\"Working with z of shape {} = {} dimensions.\".format(\n            self.z_shape, np.prod(self.z_shape)))\n\n        # z to block_in\n        self.conv_in = torch.nn.Conv2d(z_channels,\n                                       block_in,\n                                       kernel_size=3,\n                                       stride=1,\n                                       padding=1)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n\n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks+1):\n                block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                up.upsample = Upsample(block_in, resamp_with_conv)\n                curr_res = curr_res * 2\n            self.up.insert(0, up) # prepend to get consistent order\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        out_ch,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, z):\n        #assert z.shape[1:] == self.z_shape[1:]\n        self.last_z_shape = z.shape\n\n        # timestep embedding\n        temb = None\n\n        # z to block_in\n        h = self.conv_in(z)\n\n        # middle\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks+1):\n                h = self.up[i_level].block[i_block](h, temb)\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h)\n            if i_level != 0:\n                h = self.up[i_level].upsample(h)\n\n        # end\n        if self.give_pre_end:\n            return h\n\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        if self.tanh_out:\n            h = torch.tanh(h)\n        return h\n\n\nclass SimpleDecoder(nn.Module):\n    def __init__(self, in_channels, out_channels, *args, **kwargs):\n        super().__init__()\n        self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),\n                                     ResnetBlock(in_channels=in_channels,\n                                                 out_channels=2 * in_channels,\n                                                 temb_channels=0, dropout=0.0),\n                                     ResnetBlock(in_channels=2 * in_channels,\n                                                out_channels=4 * in_channels,\n                                                temb_channels=0, dropout=0.0),\n                                     ResnetBlock(in_channels=4 * in_channels,\n                                                out_channels=2 * in_channels,\n                                                temb_channels=0, dropout=0.0),\n                                     nn.Conv2d(2*in_channels, in_channels, 1),\n                                     Upsample(in_channels, with_conv=True)])\n        # end\n        self.norm_out = Normalize(in_channels)\n        self.conv_out = torch.nn.Conv2d(in_channels,\n                                        out_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        for i, layer in enumerate(self.model):\n            if i in [1,2,3]:\n                x = layer(x, None)\n            else:\n                x = layer(x)\n\n        h = self.norm_out(x)\n        h = nonlinearity(h)\n        x = self.conv_out(h)\n        return x\n\n\nclass UpsampleDecoder(nn.Module):\n    def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,\n                 ch_mult=(2,2), dropout=0.0):\n        super().__init__()\n        # upsampling\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        block_in = in_channels\n        curr_res = resolution // 2 ** (self.num_resolutions - 1)\n        self.res_blocks = nn.ModuleList()\n        self.upsample_blocks = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            res_block = []\n            block_out = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks + 1):\n                res_block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n            self.res_blocks.append(nn.ModuleList(res_block))\n            if i_level != self.num_resolutions - 1:\n                self.upsample_blocks.append(Upsample(block_in, True))\n                curr_res = curr_res * 2\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        out_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        # upsampling\n        h = x\n        for k, i_level in enumerate(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks + 1):\n                h = self.res_blocks[i_level][i_block](h, None)\n            if i_level != self.num_resolutions - 1:\n                h = self.upsample_blocks[k](h)\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n\nclass LatentRescaler(nn.Module):\n    def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):\n        super().__init__()\n        # residual block, interpolate, residual block\n        self.factor = factor\n        self.conv_in = nn.Conv2d(in_channels,\n                                 mid_channels,\n                                 kernel_size=3,\n                                 stride=1,\n                                 padding=1)\n        self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,\n                                                     out_channels=mid_channels,\n                                                     temb_channels=0,\n                                                     dropout=0.0) for _ in range(depth)])\n        self.attn = AttnBlock(mid_channels)\n        self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,\n                                                     out_channels=mid_channels,\n                                                     temb_channels=0,\n                                                     dropout=0.0) for _ in range(depth)])\n\n        self.conv_out = nn.Conv2d(mid_channels,\n                                  out_channels,\n                                  kernel_size=1,\n                                  )\n\n    def forward(self, x):\n        x = self.conv_in(x)\n        for block in self.res_block1:\n            x = block(x, None)\n        x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))\n        x = self.attn(x)\n        for block in self.res_block2:\n            x = block(x, None)\n        x = self.conv_out(x)\n        return x\n\n\nclass MergedRescaleEncoder(nn.Module):\n    def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,\n                 attn_resolutions, dropout=0.0, resamp_with_conv=True,\n                 ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):\n        super().__init__()\n        intermediate_chn = ch * ch_mult[-1]\n        self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,\n                               z_channels=intermediate_chn, double_z=False, resolution=resolution,\n                               attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,\n                               out_ch=None)\n        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,\n                                       mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)\n\n    def forward(self, x):\n        x = self.encoder(x)\n        x = self.rescaler(x)\n        return x\n\n\nclass MergedRescaleDecoder(nn.Module):\n    def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),\n                 dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):\n        super().__init__()\n        tmp_chn = z_channels*ch_mult[-1]\n        self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,\n                               resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,\n                               ch_mult=ch_mult, resolution=resolution, ch=ch)\n        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,\n                                       out_channels=tmp_chn, depth=rescale_module_depth)\n\n    def forward(self, x):\n        x = self.rescaler(x)\n        x = self.decoder(x)\n        return x\n\n\nclass Upsampler(nn.Module):\n    def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):\n        super().__init__()\n        assert out_size >= in_size\n        num_blocks = int(np.log2(out_size//in_size))+1\n        factor_up = 1.+ (out_size % in_size)\n        print(f\"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}\")\n        self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,\n                                       out_channels=in_channels)\n        self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,\n                               attn_resolutions=[], in_channels=None, ch=in_channels,\n                               ch_mult=[ch_mult for _ in range(num_blocks)])\n\n    def forward(self, x):\n        x = self.rescaler(x)\n        x = self.decoder(x)\n        return x\n\n\nclass Resize(nn.Module):\n    def __init__(self, in_channels=None, learned=False, mode=\"bilinear\"):\n        super().__init__()\n        self.with_conv = learned\n        self.mode = mode\n        if self.with_conv:\n            print(f\"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode\")\n            raise NotImplementedError()\n            assert in_channels is not None\n            # no asymmetric padding in torch conv, must do it ourselves\n            self.conv = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=4,\n                                        stride=2,\n                                        padding=1)\n\n    def forward(self, x, scale_factor=1.0):\n        if scale_factor==1.0:\n            return x\n        else:\n            x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)\n        return x\n\nclass FirstStagePostProcessor(nn.Module):\n\n    def __init__(self, ch_mult:list, in_channels,\n                 pretrained_model:nn.Module=None,\n                 reshape=False,\n                 n_channels=None,\n                 dropout=0.,\n                 pretrained_config=None):\n        super().__init__()\n        if pretrained_config is None:\n            assert pretrained_model is not None, 'Either \"pretrained_model\" or \"pretrained_config\" must not be None'\n            self.pretrained_model = pretrained_model\n        else:\n            assert pretrained_config is not None, 'Either \"pretrained_model\" or \"pretrained_config\" must not be None'\n            self.instantiate_pretrained(pretrained_config)\n\n        self.do_reshape = reshape\n\n        if n_channels is None:\n            n_channels = self.pretrained_model.encoder.ch\n\n        self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)\n        self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,\n                            stride=1,padding=1)\n\n        blocks = []\n        downs = []\n        ch_in = n_channels\n        for m in ch_mult:\n            blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))\n            ch_in = m * n_channels\n            downs.append(Downsample(ch_in, with_conv=False))\n\n        self.model = nn.ModuleList(blocks)\n        self.downsampler = nn.ModuleList(downs)\n\n\n    def instantiate_pretrained(self, config):\n        model = instantiate_from_config(config)\n        self.pretrained_model = model.eval()\n        # self.pretrained_model.train = False\n        for param in self.pretrained_model.parameters():\n            param.requires_grad = False\n\n\n    @torch.no_grad()\n    def encode_with_pretrained(self,x):\n        c = self.pretrained_model.encode(x)\n        if isinstance(c, DiagonalGaussianDistribution):\n            c = c.mode()\n        return  c\n\n    def forward(self,x):\n        z_fs = self.encode_with_pretrained(x)\n        z = self.proj_norm(z_fs)\n        z = self.proj(z)\n        z = nonlinearity(z)\n\n        for submodel, downmodel in zip(self.model,self.downsampler):\n            z = submodel(z,temb=None)\n            z = downmodel(z)\n\n        if self.do_reshape:\n            z = rearrange(z,'b c h w -> b (h w) c')\n        return z\n\n"
  },
  {
    "path": "ldm/modules/diffusionmodules/openaimodel.py",
    "content": "from abc import abstractmethod\nfrom functools import partial\nimport math\nfrom typing import Iterable\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom ldm.modules.diffusionmodules.util import (\n    checkpoint,\n    conv_nd,\n    linear,\n    avg_pool_nd,\n    zero_module,\n    normalization,\n    timestep_embedding,\n)\nfrom ldm.modules.attention import SpatialTransformer\nfrom ldm.util import exists\n\n\n# dummy replace\ndef convert_module_to_f16(x):\n    pass\n\ndef convert_module_to_f32(x):\n    pass\n\n\n## go\nclass AttentionPool2d(nn.Module):\n    \"\"\"\n    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py\n    \"\"\"\n\n    def __init__(\n        self,\n        spacial_dim: int,\n        embed_dim: int,\n        num_heads_channels: int,\n        output_dim: int = None,\n    ):\n        super().__init__()\n        self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)\n        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)\n        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)\n        self.num_heads = embed_dim // num_heads_channels\n        self.attention = QKVAttention(self.num_heads)\n\n    def forward(self, x):\n        b, c, *_spatial = x.shape\n        x = x.reshape(b, c, -1)  # NC(HW)\n        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)\n        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)\n        x = self.qkv_proj(x)\n        x = self.attention(x)\n        x = self.c_proj(x)\n        return x[:, :, 0]\n\n\nclass TimestepBlock(nn.Module):\n    \"\"\"\n    Any module where forward() takes timestep embeddings as a second argument.\n    \"\"\"\n\n    @abstractmethod\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the module to `x` given `emb` timestep embeddings.\n        \"\"\"\n\n\nclass TimestepEmbedSequential(nn.Sequential, TimestepBlock):\n    \"\"\"\n    A sequential module that passes timestep embeddings to the children that\n    support it as an extra input.\n    \"\"\"\n\n    def forward(self, x, emb, context=None):\n        for layer in self:\n            if isinstance(layer, TimestepBlock):\n                x = layer(x, emb)\n            elif isinstance(layer, SpatialTransformer):\n                x = layer(x, context)\n            else:\n                x = layer(x)\n        return x\n\n\nclass Upsample(nn.Module):\n    \"\"\"\n    An upsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 upsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        if use_conv:\n            self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        if self.dims == 3:\n            x = F.interpolate(\n                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode=\"nearest\"\n            )\n        else:\n            x = F.interpolate(x, scale_factor=2, mode=\"nearest\")\n        if self.use_conv:\n            x = self.conv(x)\n        return x\n\nclass TransposedUpsample(nn.Module):\n    'Learned 2x upsampling without padding'\n    def __init__(self, channels, out_channels=None, ks=5):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n\n        self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)\n\n    def forward(self,x):\n        return self.up(x)\n\n\nclass Downsample(nn.Module):\n    \"\"\"\n    A downsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 downsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        stride = 2 if dims != 3 else (1, 2, 2)\n        if use_conv:\n            self.op = conv_nd(\n                dims, self.channels, self.out_channels, 3, stride=stride, padding=padding\n            )\n        else:\n            assert self.channels == self.out_channels\n            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        return self.op(x)\n\n\nclass ResBlock(TimestepBlock):\n    \"\"\"\n    A residual block that can optionally change the number of channels.\n    :param channels: the number of input channels.\n    :param emb_channels: the number of timestep embedding channels.\n    :param dropout: the rate of dropout.\n    :param out_channels: if specified, the number of out channels.\n    :param use_conv: if True and out_channels is specified, use a spatial\n        convolution instead of a smaller 1x1 convolution to change the\n        channels in the skip connection.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param use_checkpoint: if True, use gradient checkpointing on this module.\n    :param up: if True, use this block for upsampling.\n    :param down: if True, use this block for downsampling.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels,\n        emb_channels,\n        dropout,\n        out_channels=None,\n        use_conv=False,\n        use_scale_shift_norm=False,\n        dims=2,\n        use_checkpoint=False,\n        up=False,\n        down=False,\n    ):\n        super().__init__()\n        self.channels = channels\n        self.emb_channels = emb_channels\n        self.dropout = dropout\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.use_checkpoint = use_checkpoint\n        self.use_scale_shift_norm = use_scale_shift_norm\n\n        self.in_layers = nn.Sequential(\n            normalization(channels),\n            nn.SiLU(),\n            conv_nd(dims, channels, self.out_channels, 3, padding=1),\n        )\n\n        self.updown = up or down\n\n        if up:\n            self.h_upd = Upsample(channels, False, dims)\n            self.x_upd = Upsample(channels, False, dims)\n        elif down:\n            self.h_upd = Downsample(channels, False, dims)\n            self.x_upd = Downsample(channels, False, dims)\n        else:\n            self.h_upd = self.x_upd = nn.Identity()\n\n        self.emb_layers = nn.Sequential(\n            nn.SiLU(),\n            linear(\n                emb_channels,\n                2 * self.out_channels if use_scale_shift_norm else self.out_channels,\n            ),\n        )\n        self.out_layers = nn.Sequential(\n            normalization(self.out_channels),\n            nn.SiLU(),\n            nn.Dropout(p=dropout),\n            zero_module(\n                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)\n            ),\n        )\n\n        if self.out_channels == channels:\n            self.skip_connection = nn.Identity()\n        elif use_conv:\n            self.skip_connection = conv_nd(\n                dims, channels, self.out_channels, 3, padding=1\n            )\n        else:\n            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)\n\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the block to a Tensor, conditioned on a timestep embedding.\n        :param x: an [N x C x ...] Tensor of features.\n        :param emb: an [N x emb_channels] Tensor of timestep embeddings.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        return checkpoint(\n            self._forward, (x, emb), self.parameters(), self.use_checkpoint\n        )\n\n\n    def _forward(self, x, emb):\n        if self.updown:\n            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]\n            h = in_rest(x)\n            h = self.h_upd(h)\n            x = self.x_upd(x)\n            h = in_conv(h)\n        else:\n            h = self.in_layers(x)\n        emb_out = self.emb_layers(emb).type(h.dtype)\n        while len(emb_out.shape) < len(h.shape):\n            emb_out = emb_out[..., None]\n        if self.use_scale_shift_norm:\n            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]\n            scale, shift = th.chunk(emb_out, 2, dim=1)\n            h = out_norm(h) * (1 + scale) + shift\n            h = out_rest(h)\n        else:\n            h = h + emb_out\n            h = self.out_layers(h)\n        return self.skip_connection(x) + h\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"\n    An attention block that allows spatial positions to attend to each other.\n    Originally ported from here, but adapted to the N-d case.\n    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels,\n        num_heads=1,\n        num_head_channels=-1,\n        use_checkpoint=False,\n        use_new_attention_order=False,\n    ):\n        super().__init__()\n        self.channels = channels\n        if num_head_channels == -1:\n            self.num_heads = num_heads\n        else:\n            assert (\n                channels % num_head_channels == 0\n            ), f\"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}\"\n            self.num_heads = channels // num_head_channels\n        self.use_checkpoint = use_checkpoint\n        self.norm = normalization(channels)\n        self.qkv = conv_nd(1, channels, channels * 3, 1)\n        if use_new_attention_order:\n            # split qkv before split heads\n            self.attention = QKVAttention(self.num_heads)\n        else:\n            # split heads before split qkv\n            self.attention = QKVAttentionLegacy(self.num_heads)\n\n        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))\n\n    def forward(self, x):\n        return checkpoint(self._forward, (x,), self.parameters(), True)   # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!\n        #return pt_checkpoint(self._forward, x)  # pytorch\n\n    def _forward(self, x):\n        b, c, *spatial = x.shape\n        x = x.reshape(b, c, -1)\n        qkv = self.qkv(self.norm(x))\n        h = self.attention(qkv)\n        h = self.proj_out(h)\n        return (x + h).reshape(b, c, *spatial)\n\n\ndef count_flops_attn(model, _x, y):\n    \"\"\"\n    A counter for the `thop` package to count the operations in an\n    attention operation.\n    Meant to be used like:\n        macs, params = thop.profile(\n            model,\n            inputs=(inputs, timestamps),\n            custom_ops={QKVAttention: QKVAttention.count_flops},\n        )\n    \"\"\"\n    b, c, *spatial = y[0].shape\n    num_spatial = int(np.prod(spatial))\n    # We perform two matmuls with the same number of ops.\n    # The first computes the weight matrix, the second computes\n    # the combination of the value vectors.\n    matmul_ops = 2 * b * (num_spatial ** 2) * c\n    model.total_ops += th.DoubleTensor([matmul_ops])\n\n\nclass QKVAttentionLegacy(nn.Module):\n    \"\"\"\n    A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping\n    \"\"\"\n\n    def __init__(self, n_heads):\n        super().__init__()\n        self.n_heads = n_heads\n\n    def forward(self, qkv):\n        \"\"\"\n        Apply QKV attention.\n        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x (H * C) x T] tensor after attention.\n        \"\"\"\n        bs, width, length = qkv.shape\n        assert width % (3 * self.n_heads) == 0\n        ch = width // (3 * self.n_heads)\n        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        weight = th.einsum(\n            \"bct,bcs->bts\", q * scale, k * scale\n        )  # More stable with f16 than dividing afterwards\n        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n        a = th.einsum(\"bts,bcs->bct\", weight, v)\n        return a.reshape(bs, -1, length)\n\n    @staticmethod\n    def count_flops(model, _x, y):\n        return count_flops_attn(model, _x, y)\n\n\nclass QKVAttention(nn.Module):\n    \"\"\"\n    A module which performs QKV attention and splits in a different order.\n    \"\"\"\n\n    def __init__(self, n_heads):\n        super().__init__()\n        self.n_heads = n_heads\n\n    def forward(self, qkv):\n        \"\"\"\n        Apply QKV attention.\n        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x (H * C) x T] tensor after attention.\n        \"\"\"\n        bs, width, length = qkv.shape\n        assert width % (3 * self.n_heads) == 0\n        ch = width // (3 * self.n_heads)\n        q, k, v = qkv.chunk(3, dim=1)\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        weight = th.einsum(\n            \"bct,bcs->bts\",\n            (q * scale).view(bs * self.n_heads, ch, length),\n            (k * scale).view(bs * self.n_heads, ch, length),\n        )  # More stable with f16 than dividing afterwards\n        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n        a = th.einsum(\"bts,bcs->bct\", weight, v.reshape(bs * self.n_heads, ch, length))\n        return a.reshape(bs, -1, length)\n\n    @staticmethod\n    def count_flops(model, _x, y):\n        return count_flops_attn(model, _x, y)\n\n\nclass UNetModel(nn.Module):\n    \"\"\"\n    The full UNet model with attention and timestep embedding.\n    :param in_channels: channels in the input Tensor.\n    :param model_channels: base channel count for the model.\n    :param out_channels: channels in the output Tensor.\n    :param num_res_blocks: number of residual blocks per downsample.\n    :param attention_resolutions: a collection of downsample rates at which\n        attention will take place. May be a set, list, or tuple.\n        For example, if this contains 4, then at 4x downsampling, attention\n        will be used.\n    :param dropout: the dropout probability.\n    :param channel_mult: channel multiplier for each level of the UNet.\n    :param conv_resample: if True, use learned convolutions for upsampling and\n        downsampling.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param num_classes: if specified (as an int), then this model will be\n        class-conditional with `num_classes` classes.\n    :param use_checkpoint: use gradient checkpointing to reduce memory usage.\n    :param num_heads: the number of attention heads in each attention layer.\n    :param num_heads_channels: if specified, ignore num_heads and instead use\n                               a fixed channel width per attention head.\n    :param num_heads_upsample: works with num_heads to set a different number\n                               of heads for upsampling. Deprecated.\n    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.\n    :param resblock_updown: use residual blocks for up/downsampling.\n    :param use_new_attention_order: use a different attention pattern for potentially\n                                    increased efficiency.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_size,\n        in_channels,\n        model_channels,\n        out_channels,\n        num_res_blocks,\n        attention_resolutions,\n        dropout=0,\n        channel_mult=(1, 2, 4, 8),\n        conv_resample=True,\n        dims=2,\n        num_classes=None,\n        use_checkpoint=False,\n        use_fp16=False,\n        num_heads=-1,\n        num_head_channels=-1,\n        num_heads_upsample=-1,\n        use_scale_shift_norm=False,\n        resblock_updown=False,\n        use_new_attention_order=False,\n        use_spatial_transformer=False,    # custom transformer support\n        transformer_depth=1,              # custom transformer support\n        context_dim=None,                 # custom transformer support\n        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model\n        legacy=True,\n        disable_self_attentions=None,\n        num_attention_blocks=None\n    ):\n        super().__init__()\n        if use_spatial_transformer:\n            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'\n\n        if context_dim is not None:\n            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'\n            from omegaconf.listconfig import ListConfig\n            if type(context_dim) == ListConfig:\n                context_dim = list(context_dim)\n\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        if num_heads == -1:\n            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'\n\n        if num_head_channels == -1:\n            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'\n\n        self.image_size = image_size\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        if isinstance(num_res_blocks, int):\n            self.num_res_blocks = len(channel_mult) * [num_res_blocks]\n        else:\n            if len(num_res_blocks) != len(channel_mult):\n                raise ValueError(\"provide num_res_blocks either as an int (globally constant) or \"\n                                 \"as a list/tuple (per-level) with the same length as channel_mult\")\n            self.num_res_blocks = num_res_blocks\n        #self.num_res_blocks = num_res_blocks\n        if disable_self_attentions is not None:\n            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not\n            assert len(disable_self_attentions) == len(channel_mult)\n        if num_attention_blocks is not None:\n            assert len(num_attention_blocks) == len(self.num_res_blocks)\n            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))\n            print(f\"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. \"\n                  f\"This option has LESS priority than attention_resolutions {attention_resolutions}, \"\n                  f\"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, \"\n                  f\"attention will still not be set.\")  # todo: convert to warning\n\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.num_classes = num_classes\n        self.use_checkpoint = use_checkpoint\n        self.dtype = th.float16 if use_fp16 else th.float32\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n        self.predict_codebook_ids = n_embed is not None\n\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            linear(model_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, time_embed_dim),\n        )\n\n        if self.num_classes is not None:\n            self.label_emb = nn.Embedding(num_classes, time_embed_dim)\n\n        self.input_blocks = nn.ModuleList(\n            [\n                TimestepEmbedSequential(\n                    conv_nd(dims, in_channels, model_channels, 3, padding=1)\n                )\n            ]\n        )\n        self._feature_size = model_channels\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n        for level, mult in enumerate(channel_mult):\n            for nr in range(self.num_res_blocks[level]):\n                layers = [\n                    ResBlock(\n                        ch,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=mult * model_channels,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n                    if legacy:\n                        #num_heads = 1\n                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels\n                    if exists(disable_self_attentions):\n                        disabled_sa = disable_self_attentions[level]\n                    else:\n                        disabled_sa = False\n\n                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:\n                        layers.append(\n                            AttentionBlock(\n                                ch,\n                                use_checkpoint=use_checkpoint,\n                                num_heads=num_heads,\n                                num_head_channels=dim_head,\n                                use_new_attention_order=use_new_attention_order,\n                            ) if not use_spatial_transformer else SpatialTransformer(\n                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,\n                                disable_self_attn=disabled_sa\n                            )\n                        )\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                out_ch = ch\n                self.input_blocks.append(\n                    TimestepEmbedSequential(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                        )\n                        if resblock_updown\n                        else Downsample(\n                            ch, conv_resample, dims=dims, out_channels=out_ch\n                        )\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n                ds *= 2\n                self._feature_size += ch\n\n        if num_head_channels == -1:\n            dim_head = ch // num_heads\n        else:\n            num_heads = ch // num_head_channels\n            dim_head = num_head_channels\n        if legacy:\n            #num_heads = 1\n            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels\n        self.middle_block = TimestepEmbedSequential(\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n            AttentionBlock(\n                ch,\n                use_checkpoint=use_checkpoint,\n                num_heads=num_heads,\n                num_head_channels=dim_head,\n                use_new_attention_order=use_new_attention_order,\n            ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn\n                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim\n                        ),\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n        )\n        self._feature_size += ch\n\n        self.output_blocks = nn.ModuleList([])\n        for level, mult in list(enumerate(channel_mult))[::-1]:\n            for i in range(self.num_res_blocks[level] + 1):\n                ich = input_block_chans.pop()\n                layers = [\n                    ResBlock(\n                        ch + ich,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=model_channels * mult,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = model_channels * mult\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n                    if legacy:\n                        #num_heads = 1\n                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels\n                    if exists(disable_self_attentions):\n                        disabled_sa = disable_self_attentions[level]\n                    else:\n                        disabled_sa = False\n\n                    if not exists(num_attention_blocks) or i < num_attention_blocks[level]:\n                        layers.append(\n                            AttentionBlock(\n                                ch,\n                                use_checkpoint=use_checkpoint,\n                                num_heads=num_heads_upsample,\n                                num_head_channels=dim_head,\n                                use_new_attention_order=use_new_attention_order,\n                            ) if not use_spatial_transformer else SpatialTransformer(\n                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,\n                                disable_self_attn=disabled_sa\n                            )\n                        )\n                if level and i == self.num_res_blocks[level]:\n                    out_ch = ch\n                    layers.append(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            up=True,\n                        )\n                        if resblock_updown\n                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)\n                    )\n                    ds //= 2\n                self.output_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n\n        self.out = nn.Sequential(\n            normalization(ch),\n            nn.SiLU(),\n            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),\n        )\n        if self.predict_codebook_ids:\n            self.id_predictor = nn.Sequential(\n            normalization(ch),\n            conv_nd(dims, model_channels, n_embed, 1),\n            #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits\n        )\n\n    def convert_to_fp16(self):\n        \"\"\"\n        Convert the torso of the model to float16.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f16)\n        self.middle_block.apply(convert_module_to_f16)\n        self.output_blocks.apply(convert_module_to_f16)\n\n    def convert_to_fp32(self):\n        \"\"\"\n        Convert the torso of the model to float32.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f32)\n        self.middle_block.apply(convert_module_to_f32)\n        self.output_blocks.apply(convert_module_to_f32)\n\n    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):\n        \"\"\"\n        Apply the model to an input batch.\n        :param x: an [N x C x ...] Tensor of inputs.\n        :param timesteps: a 1-D batch of timesteps.\n        :param context: conditioning plugged in via crossattn\n        :param y: an [N] Tensor of labels, if class-conditional.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        assert (y is not None) == (\n            self.num_classes is not None\n        ), \"must specify y if and only if the model is class-conditional\"\n        hs = []\n        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)\n        emb = self.time_embed(t_emb)\n\n        if self.num_classes is not None:\n            assert y.shape == (x.shape[0],)\n            emb = emb + self.label_emb(y)\n\n        h = x.type(self.dtype)\n        for module in self.input_blocks:\n            h = module(h, emb, context)\n            hs.append(h)\n        h = self.middle_block(h, emb, context)\n        for module in self.output_blocks:\n            h = th.cat([h, hs.pop()], dim=1)\n            h = module(h, emb, context)\n        h = h.type(x.dtype)\n        if self.predict_codebook_ids:\n            return self.id_predictor(h)\n        else:\n            return self.out(h)\n\n\nclass EncoderUNetModel(nn.Module):\n    \"\"\"\n    The half UNet model with attention and timestep embedding.\n    For usage, see UNet.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_size,\n        in_channels,\n        model_channels,\n        out_channels,\n        num_res_blocks,\n        attention_resolutions,\n        dropout=0,\n        channel_mult=(1, 2, 4, 8),\n        conv_resample=True,\n        dims=2,\n        use_checkpoint=False,\n        use_fp16=False,\n        num_heads=1,\n        num_head_channels=-1,\n        num_heads_upsample=-1,\n        use_scale_shift_norm=False,\n        resblock_updown=False,\n        use_new_attention_order=False,\n        pool=\"adaptive\",\n        *args,\n        **kwargs\n    ):\n        super().__init__()\n\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.use_checkpoint = use_checkpoint\n        self.dtype = th.float16 if use_fp16 else th.float32\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            linear(model_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, time_embed_dim),\n        )\n\n        self.input_blocks = nn.ModuleList(\n            [\n                TimestepEmbedSequential(\n                    conv_nd(dims, in_channels, model_channels, 3, padding=1)\n                )\n            ]\n        )\n        self._feature_size = model_channels\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n        for level, mult in enumerate(channel_mult):\n            for _ in range(num_res_blocks):\n                layers = [\n                    ResBlock(\n                        ch,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=mult * model_channels,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    layers.append(\n                        AttentionBlock(\n                            ch,\n                            use_checkpoint=use_checkpoint,\n                            num_heads=num_heads,\n                            num_head_channels=num_head_channels,\n                            use_new_attention_order=use_new_attention_order,\n                        )\n                    )\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                out_ch = ch\n                self.input_blocks.append(\n                    TimestepEmbedSequential(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                        )\n                        if resblock_updown\n                        else Downsample(\n                            ch, conv_resample, dims=dims, out_channels=out_ch\n                        )\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n                ds *= 2\n                self._feature_size += ch\n\n        self.middle_block = TimestepEmbedSequential(\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n            AttentionBlock(\n                ch,\n                use_checkpoint=use_checkpoint,\n                num_heads=num_heads,\n                num_head_channels=num_head_channels,\n                use_new_attention_order=use_new_attention_order,\n            ),\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n        )\n        self._feature_size += ch\n        self.pool = pool\n        if pool == \"adaptive\":\n            self.out = nn.Sequential(\n                normalization(ch),\n                nn.SiLU(),\n                nn.AdaptiveAvgPool2d((1, 1)),\n                zero_module(conv_nd(dims, ch, out_channels, 1)),\n                nn.Flatten(),\n            )\n        elif pool == \"attention\":\n            assert num_head_channels != -1\n            self.out = nn.Sequential(\n                normalization(ch),\n                nn.SiLU(),\n                AttentionPool2d(\n                    (image_size // ds), ch, num_head_channels, out_channels\n                ),\n            )\n        elif pool == \"spatial\":\n            self.out = nn.Sequential(\n                nn.Linear(self._feature_size, 2048),\n                nn.ReLU(),\n                nn.Linear(2048, self.out_channels),\n            )\n        elif pool == \"spatial_v2\":\n            self.out = nn.Sequential(\n                nn.Linear(self._feature_size, 2048),\n                normalization(2048),\n                nn.SiLU(),\n                nn.Linear(2048, self.out_channels),\n            )\n        else:\n            raise NotImplementedError(f\"Unexpected {pool} pooling\")\n\n    def convert_to_fp16(self):\n        \"\"\"\n        Convert the torso of the model to float16.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f16)\n        self.middle_block.apply(convert_module_to_f16)\n\n    def convert_to_fp32(self):\n        \"\"\"\n        Convert the torso of the model to float32.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f32)\n        self.middle_block.apply(convert_module_to_f32)\n\n    def forward(self, x, timesteps):\n        \"\"\"\n        Apply the model to an input batch.\n        :param x: an [N x C x ...] Tensor of inputs.\n        :param timesteps: a 1-D batch of timesteps.\n        :return: an [N x K] Tensor of outputs.\n        \"\"\"\n        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))\n\n        results = []\n        h = x.type(self.dtype)\n        for module in self.input_blocks:\n            h = module(h, emb)\n            if self.pool.startswith(\"spatial\"):\n                results.append(h.type(x.dtype).mean(dim=(2, 3)))\n        h = self.middle_block(h, emb)\n        if self.pool.startswith(\"spatial\"):\n            results.append(h.type(x.dtype).mean(dim=(2, 3)))\n            h = th.cat(results, axis=-1)\n            return self.out(h)\n        else:\n            h = h.type(x.dtype)\n            return self.out(h)\n\n"
  },
  {
    "path": "ldm/modules/diffusionmodules/util.py",
    "content": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py\n# and\n# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py\n#\n# thanks!\n\n\nimport os\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom einops import repeat\n\nfrom ldm.util import instantiate_from_config\n\n\ndef make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):\n    if schedule == \"linear\":\n        betas = (\n                torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2\n        )\n\n    elif schedule == \"cosine\":\n        timesteps = (\n                torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s\n        )\n        alphas = timesteps / (1 + cosine_s) * np.pi / 2\n        alphas = torch.cos(alphas).pow(2)\n        alphas = alphas / alphas[0]\n        betas = 1 - alphas[1:] / alphas[:-1]\n        betas = np.clip(betas, a_min=0, a_max=0.999)\n\n    elif schedule == \"sqrt_linear\":\n        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)\n    elif schedule == \"sqrt\":\n        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5\n    else:\n        raise ValueError(f\"schedule '{schedule}' unknown.\")\n    return betas.numpy()\n\n\ndef make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):\n    if ddim_discr_method == 'uniform':\n        c = num_ddpm_timesteps // num_ddim_timesteps\n        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))\n    elif ddim_discr_method == 'quad':\n        ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)\n    else:\n        raise NotImplementedError(f'There is no ddim discretization method called \"{ddim_discr_method}\"')\n\n    # assert ddim_timesteps.shape[0] == num_ddim_timesteps\n    # add one to get the final alpha values right (the ones from first scale to data during sampling)\n    steps_out = ddim_timesteps + 1\n    if verbose:\n        print(f'Selected timesteps for ddim sampler: {steps_out}')\n    return steps_out\n\n\ndef make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):\n    # select alphas for computing the variance schedule\n    alphas = alphacums[ddim_timesteps]\n    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())\n\n    # according the the formula provided in https://arxiv.org/abs/2010.02502\n    sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))\n    if verbose:\n        print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')\n        print(f'For the chosen value of eta, which is {eta}, '\n              f'this results in the following sigma_t schedule for ddim sampler {sigmas}')\n    return sigmas, alphas, alphas_prev\n\n\ndef betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):\n    \"\"\"\n    Create a beta schedule that discretizes the given alpha_t_bar function,\n    which defines the cumulative product of (1-beta) over time from t = [0,1].\n    :param num_diffusion_timesteps: the number of betas to produce.\n    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and\n                      produces the cumulative product of (1-beta) up to that\n                      part of the diffusion process.\n    :param max_beta: the maximum beta to use; use values lower than 1 to\n                     prevent singularities.\n    \"\"\"\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))\n    return np.array(betas)\n\n\ndef extract_into_tensor(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef checkpoint(func, inputs, params, flag):\n    \"\"\"\n    Evaluate a function without caching intermediate activations, allowing for\n    reduced memory at the expense of extra compute in the backward pass.\n    :param func: the function to evaluate.\n    :param inputs: the argument sequence to pass to `func`.\n    :param params: a sequence of parameters `func` depends on but does not\n                   explicitly take as arguments.\n    :param flag: if False, disable gradient checkpointing.\n    \"\"\"\n    if flag:\n        args = tuple(inputs) + tuple(params)\n        return CheckpointFunction.apply(func, len(inputs), *args)\n    else:\n        return func(*inputs)\n\n\nclass CheckpointFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, run_function, length, *args):\n        ctx.run_function = run_function\n        ctx.input_tensors = list(args[:length])\n        ctx.input_params = list(args[length:])\n\n        with torch.no_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        return output_tensors\n\n    @staticmethod\n    def backward(ctx, *output_grads):\n        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]\n        with torch.enable_grad():\n            # Fixes a bug where the first op in run_function modifies the\n            # Tensor storage in place, which is not allowed for detach()'d\n            # Tensors.\n            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]\n            output_tensors = ctx.run_function(*shallow_copies)\n        input_grads = torch.autograd.grad(\n            output_tensors,\n            ctx.input_tensors + ctx.input_params,\n            output_grads,\n            allow_unused=True,\n        )\n        del ctx.input_tensors\n        del ctx.input_params\n        del output_tensors\n        return (None, None) + input_grads\n\n\ndef timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    if not repeat_only:\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n        ).to(device=timesteps.device)\n        args = timesteps[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n    else:\n        embedding = repeat(timesteps, 'b -> b d', d=dim)\n    return embedding\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef scale_module(module, scale):\n    \"\"\"\n    Scale the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().mul_(scale)\n    return module\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef normalization(channels):\n    \"\"\"\n    Make a standard normalization layer.\n    :param channels: number of input channels.\n    :return: an nn.Module for normalization.\n    \"\"\"\n    return GroupNorm32(32, channels)\n\n\n# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.\nclass SiLU(nn.Module):\n    def forward(self, x):\n        return x * torch.sigmoid(x)\n\n\nclass GroupNorm32(nn.GroupNorm):\n    def forward(self, x):\n        return super().forward(x.float()).type(x.dtype)\n\ndef conv_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D convolution module.\n    \"\"\"\n    if dims == 1:\n        return nn.Conv1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.Conv2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.Conv3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\ndef linear(*args, **kwargs):\n    \"\"\"\n    Create a linear module.\n    \"\"\"\n    return nn.Linear(*args, **kwargs)\n\n\ndef avg_pool_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D average pooling module.\n    \"\"\"\n    if dims == 1:\n        return nn.AvgPool1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.AvgPool2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.AvgPool3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\nclass HybridConditioner(nn.Module):\n\n    def __init__(self, c_concat_config, c_crossattn_config):\n        super().__init__()\n        self.concat_conditioner = instantiate_from_config(c_concat_config)\n        self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)\n\n    def forward(self, c_concat, c_crossattn):\n        c_concat = self.concat_conditioner(c_concat)\n        c_crossattn = self.crossattn_conditioner(c_crossattn)\n        return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}\n\n\ndef noise_like(shape, device, repeat=False):\n    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))\n    noise = lambda: torch.randn(shape, device=device)\n    return repeat_noise() if repeat else noise()"
  },
  {
    "path": "ldm/modules/distributions/__init__.py",
    "content": ""
  },
  {
    "path": "ldm/modules/distributions/distributions.py",
    "content": "import torch\nimport numpy as np\n\n\nclass AbstractDistribution:\n    def sample(self):\n        raise NotImplementedError()\n\n    def mode(self):\n        raise NotImplementedError()\n\n\nclass DiracDistribution(AbstractDistribution):\n    def __init__(self, value):\n        self.value = value\n\n    def sample(self):\n        return self.value\n\n    def mode(self):\n        return self.value\n\n\nclass DiagonalGaussianDistribution(object):\n    def __init__(self, parameters, deterministic=False):\n        self.parameters = parameters\n        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)\n        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)\n        self.deterministic = deterministic\n        self.std = torch.exp(0.5 * self.logvar)\n        self.var = torch.exp(self.logvar)\n        if self.deterministic:\n            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)\n\n    def sample(self):\n        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)\n        return x\n\n    def kl(self, other=None):\n        if self.deterministic:\n            return torch.Tensor([0.])\n        else:\n            if other is None:\n                return 0.5 * torch.sum(torch.pow(self.mean, 2)\n                                       + self.var - 1.0 - self.logvar,\n                                       dim=[1, 2, 3])\n            else:\n                return 0.5 * torch.sum(\n                    torch.pow(self.mean - other.mean, 2) / other.var\n                    + self.var / other.var - 1.0 - self.logvar + other.logvar,\n                    dim=[1, 2, 3])\n\n    def nll(self, sample, dims=[1,2,3]):\n        if self.deterministic:\n            return torch.Tensor([0.])\n        logtwopi = np.log(2.0 * np.pi)\n        return 0.5 * torch.sum(\n            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,\n            dim=dims)\n\n    def mode(self):\n        return self.mean\n\n\ndef normal_kl(mean1, logvar1, mean2, logvar2):\n    \"\"\"\n    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12\n    Compute the KL divergence between two gaussians.\n    Shapes are automatically broadcasted, so batches can be compared to\n    scalars, among other use cases.\n    \"\"\"\n    tensor = None\n    for obj in (mean1, logvar1, mean2, logvar2):\n        if isinstance(obj, torch.Tensor):\n            tensor = obj\n            break\n    assert tensor is not None, \"at least one argument must be a Tensor\"\n\n    # Force variances to be Tensors. Broadcasting helps convert scalars to\n    # Tensors, but it does not work for torch.exp().\n    logvar1, logvar2 = [\n        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)\n        for x in (logvar1, logvar2)\n    ]\n\n    return 0.5 * (\n        -1.0\n        + logvar2\n        - logvar1\n        + torch.exp(logvar1 - logvar2)\n        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)\n    )\n"
  },
  {
    "path": "ldm/modules/ema.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n    def __init__(self, model, decay=0.9999, use_num_upates=True):\n        super().__init__()\n        if decay < 0.0 or decay > 1.0:\n            raise ValueError('Decay must be between 0 and 1')\n\n        self.m_name2s_name = {}\n        self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))\n        self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates\n                             else torch.tensor(-1,dtype=torch.int))\n\n        for name, p in model.named_parameters():\n            if p.requires_grad:\n                #remove as '.'-character is not allowed in buffers\n                s_name = name.replace('.','')\n                self.m_name2s_name.update({name:s_name})\n                self.register_buffer(s_name,p.clone().detach().data)\n\n        self.collected_params = []\n\n    def forward(self,model):\n        decay = self.decay\n\n        if self.num_updates >= 0:\n            self.num_updates += 1\n            decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))\n\n        one_minus_decay = 1.0 - decay\n\n        with torch.no_grad():\n            m_param = dict(model.named_parameters())\n            shadow_params = dict(self.named_buffers())\n\n            for key in m_param:\n                if m_param[key].requires_grad:\n                    sname = self.m_name2s_name[key]\n                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])\n                    shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))\n                else:\n                    assert not key in self.m_name2s_name\n\n    def copy_to(self, model):\n        m_param = dict(model.named_parameters())\n        shadow_params = dict(self.named_buffers())\n        for key in m_param:\n            if m_param[key].requires_grad:\n                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)\n            else:\n                assert not key in self.m_name2s_name\n\n    def store(self, parameters):\n        \"\"\"\n        Save the current parameters for restoring later.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            temporarily stored.\n        \"\"\"\n        self.collected_params = [param.clone() for param in parameters]\n\n    def restore(self, parameters):\n        \"\"\"\n        Restore the parameters stored with the `store` method.\n        Useful to validate the model with EMA parameters without affecting the\n        original optimization process. Store the parameters before the\n        `copy_to` method. After validation (or model saving), use this to\n        restore the former parameters.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            updated with the stored parameters.\n        \"\"\"\n        for c_param, param in zip(self.collected_params, parameters):\n            param.data.copy_(c_param.data)\n"
  },
  {
    "path": "ldm/modules/encoders/__init__.py",
    "content": ""
  },
  {
    "path": "ldm/modules/encoders/modules.py",
    "content": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom functools import partial\nimport kornia\n\nfrom ldm.modules.x_transformer import Encoder, TransformerWrapper  # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test\nfrom ldm.util import default\nimport clip\n\n\nclass AbstractEncoder(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def encode(self, *args, **kwargs):\n        raise NotImplementedError\n\nclass IdentityEncoder(AbstractEncoder):\n\n    def encode(self, x):\n        return x\n\nclass FaceClipEncoder(AbstractEncoder):\n    def __init__(self, augment=True, retreival_key=None):\n        super().__init__()\n        self.encoder = FrozenCLIPImageEmbedder()\n        self.augment = augment\n        self.retreival_key = retreival_key\n\n    def forward(self, img):\n        encodings = []\n        with torch.no_grad():\n            x_offset = 125\n            if self.retreival_key:\n                # Assumes retrieved image are packed into the second half of channels\n                face = img[:,3:,190:440,x_offset:(512-x_offset)]\n                other = img[:,:3,...].clone()\n            else:\n                face = img[:,:,190:440,x_offset:(512-x_offset)]\n                other = img.clone()\n\n            if self.augment:\n                face = K.RandomHorizontalFlip()(face)\n\n            other[:,:,190:440,x_offset:(512-x_offset)] *= 0\n            encodings = [\n                self.encoder.encode(face),\n                self.encoder.encode(other),\n            ]\n\n        return torch.cat(encodings, dim=1)\n\n    def encode(self, img):\n        if isinstance(img, list):\n            # Uncondition\n            return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)\n\n        return self(img)\n\nclass FaceIdClipEncoder(AbstractEncoder):\n    def __init__(self):\n        super().__init__()\n        self.encoder = FrozenCLIPImageEmbedder()\n        for p in self.encoder.parameters():\n            p.requires_grad = False\n        self.id = FrozenFaceEncoder(\"/home/jpinkney/code/stable-diffusion/model_ir_se50.pth\", augment=True)\n\n    def forward(self, img):\n        encodings = []\n        with torch.no_grad():\n            face = kornia.geometry.resize(img, (256, 256),\n                            interpolation='bilinear', align_corners=True)\n\n            other = img.clone()\n            other[:,:,184:452,122:396] *= 0\n            encodings = [\n                self.id.encode(face),\n                self.encoder.encode(other),\n            ]\n\n        return torch.cat(encodings, dim=1)\n\n    def encode(self, img):\n        if isinstance(img, list):\n            # Uncondition\n            return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)\n\n        return self(img)\n\nclass ClassEmbedder(nn.Module):\n    def __init__(self, embed_dim, n_classes=1000, key='class'):\n        super().__init__()\n        self.key = key\n        self.embedding = nn.Embedding(n_classes, embed_dim)\n\n    def forward(self, batch, key=None):\n        if key is None:\n            key = self.key\n        # this is for use in crossattn\n        c = batch[key][:, None]\n        c = self.embedding(c)\n        return c\n\n\nclass TransformerEmbedder(AbstractEncoder):\n    \"\"\"Some transformer encoder layers\"\"\"\n    def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device=\"cuda\"):\n        super().__init__()\n        self.device = device\n        self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,\n                                              attn_layers=Encoder(dim=n_embed, depth=n_layer))\n\n    def forward(self, tokens):\n        tokens = tokens.to(self.device)  # meh\n        z = self.transformer(tokens, return_embeddings=True)\n        return z\n\n    def encode(self, x):\n        return self(x)\n\n\nclass BERTTokenizer(AbstractEncoder):\n    \"\"\" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)\"\"\"\n    def __init__(self, device=\"cuda\", vq_interface=True, max_length=77):\n        super().__init__()\n        from transformers import BertTokenizerFast  # TODO: add to reuquirements\n        self.tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-uncased\")\n        self.device = device\n        self.vq_interface = vq_interface\n        self.max_length = max_length\n\n    def forward(self, text):\n        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,\n                                        return_overflowing_tokens=False, padding=\"max_length\", return_tensors=\"pt\")\n        tokens = batch_encoding[\"input_ids\"].to(self.device)\n        return tokens\n\n    @torch.no_grad()\n    def encode(self, text):\n        tokens = self(text)\n        if not self.vq_interface:\n            return tokens\n        return None, None, [None, None, tokens]\n\n    def decode(self, text):\n        return text\n\n\nclass BERTEmbedder(AbstractEncoder):\n    \"\"\"Uses the BERT tokenizr model and add some transformer encoder layers\"\"\"\n    def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,\n                 device=\"cuda\",use_tokenizer=True, embedding_dropout=0.0):\n        super().__init__()\n        self.use_tknz_fn = use_tokenizer\n        if self.use_tknz_fn:\n            self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)\n        self.device = device\n        self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,\n                                              attn_layers=Encoder(dim=n_embed, depth=n_layer),\n                                              emb_dropout=embedding_dropout)\n\n    def forward(self, text):\n        if self.use_tknz_fn:\n            tokens = self.tknz_fn(text)#.to(self.device)\n        else:\n            tokens = text\n        z = self.transformer(tokens, return_embeddings=True)\n        return z\n\n    def encode(self, text):\n        # output of length 77\n        return self(text)\n\n\nfrom transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\nclass FrozenT5Embedder(AbstractEncoder):\n    \"\"\"Uses the T5 transformer encoder for text\"\"\"\n    def __init__(self, version=\"google/t5-v1_1-large\", device=\"cuda\", max_length=77):  # others are google/t5-v1_1-xl and google/t5-v1_1-xxl\n        super().__init__()\n        self.tokenizer = T5Tokenizer.from_pretrained(version)\n        self.transformer = T5EncoderModel.from_pretrained(version)\n        self.device = device\n        self.max_length = max_length   # TODO: typical value?\n        self.freeze()\n\n    def freeze(self):\n        self.transformer = self.transformer.eval()\n        #self.train = disabled_train\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, text):\n        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,\n                                        return_overflowing_tokens=False, padding=\"max_length\", return_tensors=\"pt\")\n        tokens = batch_encoding[\"input_ids\"].to(self.device)\n        outputs = self.transformer(input_ids=tokens)\n\n        z = outputs.last_hidden_state\n        return z\n\n    def encode(self, text):\n        return self(text)\n\nfrom ldm.thirdp.psp.id_loss import IDFeatures\nimport kornia.augmentation as K\n\nclass FrozenFaceEncoder(AbstractEncoder):\n    def __init__(self, model_path, augment=False):\n        super().__init__()\n        self.loss_fn = IDFeatures(model_path)\n        # face encoder is frozen\n        for p in self.loss_fn.parameters():\n            p.requires_grad = False\n        # Mapper is trainable\n        self.mapper = torch.nn.Linear(512, 768)\n        p = 0.25\n        if augment:\n            self.augment = K.AugmentationSequential(\n                K.RandomHorizontalFlip(p=0.5),\n                K.RandomEqualize(p=p),\n                # K.RandomPlanckianJitter(p=p),\n                # K.RandomPlasmaBrightness(p=p),\n                # K.RandomPlasmaContrast(p=p),\n                # K.ColorJiggle(0.02, 0.2, 0.2, p=p),\n            )\n        else:\n            self.augment = False\n\n    def forward(self, img):\n        if isinstance(img, list):\n            # Uncondition\n            return torch.zeros((1, 1, 768), device=self.mapper.weight.device)\n\n        if self.augment is not None:\n            # Transforms require 0-1\n            img = self.augment((img + 1)/2)\n            img = 2*img - 1\n\n        feat = self.loss_fn(img, crop=True)\n        feat = self.mapper(feat.unsqueeze(1))\n        return feat\n\n    def encode(self, img):\n        return self(img)\n\nclass FrozenCLIPEmbedder(AbstractEncoder):\n    \"\"\"Uses the CLIP transformer encoder for text (from huggingface)\"\"\"\n    def __init__(self, version=\"openai/clip-vit-large-patch14\", device=\"cuda\", max_length=77):  # clip-vit-base-patch32\n        super().__init__()\n        self.tokenizer = CLIPTokenizer.from_pretrained(version)\n        self.transformer = CLIPTextModel.from_pretrained(version)\n        self.device = device\n        self.max_length = max_length   # TODO: typical value?\n        self.freeze()\n\n    def freeze(self):\n        self.transformer = self.transformer.eval()\n        #self.train = disabled_train\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, text):\n        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,\n                                        return_overflowing_tokens=False, padding=\"max_length\", return_tensors=\"pt\")\n        tokens = batch_encoding[\"input_ids\"].to(self.device)\n        outputs = self.transformer(input_ids=tokens)\n\n        z = outputs.last_hidden_state\n        return z\n\n    def encode(self, text):\n        return self(text)\n\nimport torch.nn.functional as F\nfrom transformers import CLIPVisionModel\nclass ClipImageProjector(AbstractEncoder):\n    \"\"\"\n        Uses the CLIP image encoder.\n        \"\"\"\n    def __init__(self, version=\"openai/clip-vit-large-patch14\", max_length=77):  # clip-vit-base-patch32\n        super().__init__()\n        self.model = CLIPVisionModel.from_pretrained(version)\n        self.model.train()\n        self.max_length = max_length   # TODO: typical value?\n        self.antialias = True\n        self.mapper = torch.nn.Linear(1024, 768)\n        self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)\n        self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)\n        null_cond = self.get_null_cond(version, max_length)\n        self.register_buffer('null_cond', null_cond)\n\n    @torch.no_grad()\n    def get_null_cond(self, version, max_length):\n        device = self.mean.device\n        embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)\n        null_cond = embedder([\"\"])\n        return null_cond\n\n    def preprocess(self, x):\n        # Expects inputs in the range -1, 1\n        x = kornia.geometry.resize(x, (224, 224),\n                                   interpolation='bicubic',align_corners=True,\n                                   antialias=self.antialias)\n        x = (x + 1.) / 2.\n        # renormalize according to clip\n        x = kornia.enhance.normalize(x, self.mean, self.std)\n        return x\n\n    def forward(self, x):\n        if isinstance(x, list):\n            return self.null_cond\n        # x is assumed to be in range [-1,1]\n        x = self.preprocess(x)\n        outputs = self.model(pixel_values=x)\n        last_hidden_state = outputs.last_hidden_state\n        last_hidden_state = self.mapper(last_hidden_state)\n        return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0])\n\n    def encode(self, im):\n        return self(im)\n\nclass ProjectedFrozenCLIPEmbedder(AbstractEncoder):\n    def __init__(self, version=\"openai/clip-vit-large-patch14\", device=\"cuda\", max_length=77):  # clip-vit-base-patch32\n        super().__init__()\n        self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)\n        self.projection = torch.nn.Linear(768, 768)\n\n    def forward(self, text):\n        z = self.embedder(text)\n        return self.projection(z)\n\n    def encode(self, text):\n        return self(text)\n\nclass FrozenCLIPImageEmbedder(AbstractEncoder):\n    \"\"\"\n        Uses the CLIP image encoder.\n        Not actually frozen... If you want that set cond_stage_trainable=False in cfg\n        \"\"\"\n    def __init__(\n            self,\n            model='ViT-L/14',\n            jit=False,\n            device='cpu',\n            antialias=False,\n        ):\n        super().__init__()\n        self.model, _ = clip.load(name=model, device=device, jit=jit)\n        # We don't use the text part so delete it\n        del self.model.transformer\n        self.antialias = antialias\n        self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)\n        self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)\n\n    def preprocess(self, x):\n        # Expects inputs in the range -1, 1\n        x = kornia.geometry.resize(x, (224, 224),\n                                   interpolation='bicubic',align_corners=True,\n                                   antialias=self.antialias)\n        x = (x + 1.) / 2.\n        # renormalize according to clip\n        x = kornia.enhance.normalize(x, self.mean, self.std)\n        return x\n\n    def forward(self, x):\n        # x is assumed to be in range [-1,1]\n        if isinstance(x, list):\n            # [\"\"] denotes condition dropout for ucg\n            device = self.model.visual.conv1.weight.device\n            return torch.zeros(1, 768, device=device)\n        return self.model.encode_image(self.preprocess(x)).float()\n\n    def encode(self, im):\n        return self(im).unsqueeze(1)\n\nfrom torchvision import transforms\nimport random\n\nclass FrozenCLIPImageMutliEmbedder(AbstractEncoder):\n    \"\"\"\n        Uses the CLIP image encoder.\n        Not actually frozen... If you want that set cond_stage_trainable=False in cfg\n        \"\"\"\n    def __init__(\n            self,\n            model='ViT-L/14',\n            jit=False,\n            device='cpu',\n            antialias=True,\n            max_crops=5,\n        ):\n        super().__init__()\n        self.model, _ = clip.load(name=model, device=device, jit=jit)\n        # We don't use the text part so delete it\n        del self.model.transformer\n        self.antialias = antialias\n        self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)\n        self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)\n        self.max_crops = max_crops\n\n    def preprocess(self, x):\n\n        # Expects inputs in the range -1, 1\n        randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1))\n        max_crops = self.max_crops\n        patches = []\n        crops = [randcrop(x) for _ in range(max_crops)]\n        patches.extend(crops)\n        x = torch.cat(patches, dim=0)\n        x = (x + 1.) / 2.\n        # renormalize according to clip\n        x = kornia.enhance.normalize(x, self.mean, self.std)\n        return x\n\n    def forward(self, x):\n        # x is assumed to be in range [-1,1]\n        if isinstance(x, list):\n            # [\"\"] denotes condition dropout for ucg\n            device = self.model.visual.conv1.weight.device\n            return torch.zeros(1, self.max_crops, 768, device=device)\n        batch_tokens = []\n        for im in x:\n            patches = self.preprocess(im.unsqueeze(0))\n            tokens = self.model.encode_image(patches).float()\n            for t in tokens:\n                if random.random() < 0.1:\n                    t *= 0\n            batch_tokens.append(tokens.unsqueeze(0))\n\n        return torch.cat(batch_tokens, dim=0)\n\n    def encode(self, im):\n        return self(im)\n\nclass SpatialRescaler(nn.Module):\n    def __init__(self,\n                 n_stages=1,\n                 method='bilinear',\n                 multiplier=0.5,\n                 in_channels=3,\n                 out_channels=None,\n                 bias=False):\n        super().__init__()\n        self.n_stages = n_stages\n        assert self.n_stages >= 0\n        assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']\n        self.multiplier = multiplier\n        self.interpolator = partial(torch.nn.functional.interpolate, mode=method)\n        self.remap_output = out_channels is not None\n        if self.remap_output:\n            print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')\n            self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)\n\n    def forward(self,x):\n        for stage in range(self.n_stages):\n            x = self.interpolator(x, scale_factor=self.multiplier)\n\n\n        if self.remap_output:\n            x = self.channel_mapper(x)\n        return x\n\n    def encode(self, x):\n        return self(x)\n\n\nfrom ldm.util import instantiate_from_config\nfrom ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like\n\n\nclass LowScaleEncoder(nn.Module):\n    def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64,\n                 scale_factor=1.0):\n        super().__init__()\n        self.max_noise_level = max_noise_level\n        self.model = instantiate_from_config(model_config)\n        self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,\n                                                            linear_end=linear_end)\n        self.out_size = output_size\n        self.scale_factor = scale_factor\n\n    def register_schedule(self, beta_schedule=\"linear\", timesteps=1000,\n                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):\n        betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,\n                                   cosine_s=cosine_s)\n        alphas = 1. - betas\n        alphas_cumprod = np.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])\n\n        timesteps, = betas.shape\n        self.num_timesteps = int(timesteps)\n        self.linear_start = linear_start\n        self.linear_end = linear_end\n        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'\n\n        to_torch = partial(torch.tensor, dtype=torch.float32)\n\n        self.register_buffer('betas', to_torch(betas))\n        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))\n        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))\n        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))\n        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))\n        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))\n\n    def q_sample(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +\n                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)\n\n    def forward(self, x):\n        z = self.model.encode(x).sample()\n        z = z * self.scale_factor\n        noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()\n        z = self.q_sample(z, noise_level)\n        if self.out_size is not None:\n            z = torch.nn.functional.interpolate(z, size=self.out_size, mode=\"nearest\")  # TODO: experiment with mode\n        # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)\n        return z, noise_level\n\n    def decode(self, z):\n        z = z / self.scale_factor\n        return self.model.decode(z)\n\n\nif __name__ == \"__main__\":\n    from ldm.util import count_params\n    sentences = [\"a hedgehog drinking a whiskey\", \"der mond ist aufgegangen\", \"Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'\"]\n    model = FrozenT5Embedder(version=\"google/t5-v1_1-xl\").cuda()\n    count_params(model, True)\n    z = model(sentences)\n    print(z.shape)\n\n    model = FrozenCLIPEmbedder().cuda()\n    count_params(model, True)\n    z = model(sentences)\n    print(z.shape)\n\n    print(\"done.\")\n"
  },
  {
    "path": "ldm/modules/evaluate/adm_evaluator.py",
    "content": "import argparse\nimport io\nimport os\nimport random\nimport warnings\nimport zipfile\nfrom abc import ABC, abstractmethod\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom multiprocessing import cpu_count\nfrom multiprocessing.pool import ThreadPool\nfrom typing import Iterable, Optional, Tuple\nimport yaml\n\nimport numpy as np\nimport requests\nimport tensorflow.compat.v1 as tf\nfrom scipy import linalg\nfrom tqdm.auto import tqdm\n\nINCEPTION_V3_URL = \"https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb\"\nINCEPTION_V3_PATH = \"classify_image_graph_def.pb\"\n\nFID_POOL_NAME = \"pool_3:0\"\nFID_SPATIAL_NAME = \"mixed_6/conv:0\"\n\nREQUIREMENTS = f\"This script has the following requirements: \\n\" \\\n               'tensorflow-gpu>=2.0' + \"\\n\" + 'scipy' + \"\\n\" + \"requests\" + \"\\n\" + \"tqdm\"\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--ref_batch\", help=\"path to reference batch npz file\")\n    parser.add_argument(\"--sample_batch\", help=\"path to sample batch npz file\")\n    args = parser.parse_args()\n\n    config = tf.ConfigProto(\n        allow_soft_placement=True  # allows DecodeJpeg to run on CPU in Inception graph\n    )\n    config.gpu_options.allow_growth = True\n    evaluator = Evaluator(tf.Session(config=config))\n\n    print(\"warming up TensorFlow...\")\n    # This will cause TF to print a bunch of verbose stuff now rather\n    # than after the next print(), to help prevent confusion.\n    evaluator.warmup()\n\n    print(\"computing reference batch activations...\")\n    ref_acts = evaluator.read_activations(args.ref_batch)\n    print(\"computing/reading reference batch statistics...\")\n    ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)\n\n    print(\"computing sample batch activations...\")\n    sample_acts = evaluator.read_activations(args.sample_batch)\n    print(\"computing/reading sample batch statistics...\")\n    sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)\n\n    print(\"Computing evaluations...\")\n    is_ = evaluator.compute_inception_score(sample_acts[0])\n    print(\"Inception Score:\", is_)\n    fid = sample_stats.frechet_distance(ref_stats)\n    print(\"FID:\", fid)\n    sfid = sample_stats_spatial.frechet_distance(ref_stats_spatial)\n    print(\"sFID:\", sfid)\n    prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])\n    print(\"Precision:\", prec)\n    print(\"Recall:\", recall)\n\n    savepath = '/'.join(args.sample_batch.split('/')[:-1])\n    results_file = os.path.join(savepath,'evaluation_metrics.yaml')\n    print(f'Saving evaluation results to \"{results_file}\"')\n\n    results = {\n        'IS': is_,\n        'FID': fid,\n        'sFID': sfid,\n        'Precision:':prec,\n        'Recall': recall\n    }\n\n    with open(results_file, 'w') as f:\n        yaml.dump(results, f, default_flow_style=False)\n\nclass InvalidFIDException(Exception):\n    pass\n\n\nclass FIDStatistics:\n    def __init__(self, mu: np.ndarray, sigma: np.ndarray):\n        self.mu = mu\n        self.sigma = sigma\n\n    def frechet_distance(self, other, eps=1e-6):\n        \"\"\"\n        Compute the Frechet distance between two sets of statistics.\n        \"\"\"\n        # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132\n        mu1, sigma1 = self.mu, self.sigma\n        mu2, sigma2 = other.mu, other.sigma\n\n        mu1 = np.atleast_1d(mu1)\n        mu2 = np.atleast_1d(mu2)\n\n        sigma1 = np.atleast_2d(sigma1)\n        sigma2 = np.atleast_2d(sigma2)\n\n        assert (\n            mu1.shape == mu2.shape\n        ), f\"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}\"\n        assert (\n            sigma1.shape == sigma2.shape\n        ), f\"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}\"\n\n        diff = mu1 - mu2\n\n        # product might be almost singular\n        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n        if not np.isfinite(covmean).all():\n            msg = (\n                \"fid calculation produces singular product; adding %s to diagonal of cov estimates\"\n                % eps\n            )\n            warnings.warn(msg)\n            offset = np.eye(sigma1.shape[0]) * eps\n            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n\n        # numerical error might give slight imaginary component\n        if np.iscomplexobj(covmean):\n            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n                m = np.max(np.abs(covmean.imag))\n                raise ValueError(\"Imaginary component {}\".format(m))\n            covmean = covmean.real\n\n        tr_covmean = np.trace(covmean)\n\n        return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean\n\n\nclass Evaluator:\n    def __init__(\n        self,\n        session,\n        batch_size=64,\n        softmax_batch_size=512,\n    ):\n        self.sess = session\n        self.batch_size = batch_size\n        self.softmax_batch_size = softmax_batch_size\n        self.manifold_estimator = ManifoldEstimator(session)\n        with self.sess.graph.as_default():\n            self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])\n            self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])\n            self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)\n            self.softmax = _create_softmax_graph(self.softmax_input)\n\n    def warmup(self):\n        self.compute_activations(np.zeros([1, 8, 64, 64, 3]))\n\n    def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:\n        with open_npz_array(npz_path, \"arr_0\") as reader:\n            return self.compute_activations(reader.read_batches(self.batch_size))\n\n    def compute_activations(self, batches: Iterable[np.ndarray],silent=False) -> Tuple[np.ndarray, np.ndarray]:\n        \"\"\"\n        Compute image features for downstream evals.\n\n        :param batches: a iterator over NHWC numpy arrays in [0, 255].\n        :return: a tuple of numpy arrays of shape [N x X], where X is a feature\n                 dimension. The tuple is (pool_3, spatial).\n        \"\"\"\n        preds = []\n        spatial_preds = []\n        it = batches if silent else tqdm(batches)\n        for batch in it:\n            batch = batch.astype(np.float32)\n            pred, spatial_pred = self.sess.run(\n                [self.pool_features, self.spatial_features], {self.image_input: batch}\n            )\n            preds.append(pred.reshape([pred.shape[0], -1]))\n            spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))\n        return (\n            np.concatenate(preds, axis=0),\n            np.concatenate(spatial_preds, axis=0),\n        )\n\n    def read_statistics(\n        self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]\n    ) -> Tuple[FIDStatistics, FIDStatistics]:\n        obj = np.load(npz_path)\n        if \"mu\" in list(obj.keys()):\n            return FIDStatistics(obj[\"mu\"], obj[\"sigma\"]), FIDStatistics(\n                obj[\"mu_s\"], obj[\"sigma_s\"]\n            )\n        return tuple(self.compute_statistics(x) for x in activations)\n\n    def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:\n        mu = np.mean(activations, axis=0)\n        sigma = np.cov(activations, rowvar=False)\n        return FIDStatistics(mu, sigma)\n\n    def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:\n        softmax_out = []\n        for i in range(0, len(activations), self.softmax_batch_size):\n            acts = activations[i : i + self.softmax_batch_size]\n            softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))\n        preds = np.concatenate(softmax_out, axis=0)\n        # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46\n        scores = []\n        for i in range(0, len(preds), split_size):\n            part = preds[i : i + split_size]\n            kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))\n            kl = np.mean(np.sum(kl, 1))\n            scores.append(np.exp(kl))\n        return float(np.mean(scores))\n\n    def compute_prec_recall(\n        self, activations_ref: np.ndarray, activations_sample: np.ndarray\n    ) -> Tuple[float, float]:\n        radii_1 = self.manifold_estimator.manifold_radii(activations_ref)\n        radii_2 = self.manifold_estimator.manifold_radii(activations_sample)\n        pr = self.manifold_estimator.evaluate_pr(\n            activations_ref, radii_1, activations_sample, radii_2\n        )\n        return (float(pr[0][0]), float(pr[1][0]))\n\n\nclass ManifoldEstimator:\n    \"\"\"\n    A helper for comparing manifolds of feature vectors.\n\n    Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57\n    \"\"\"\n\n    def __init__(\n        self,\n        session,\n        row_batch_size=10000,\n        col_batch_size=10000,\n        nhood_sizes=(3,),\n        clamp_to_percentile=None,\n        eps=1e-5,\n    ):\n        \"\"\"\n        Estimate the manifold of given feature vectors.\n\n        :param session: the TensorFlow session.\n        :param row_batch_size: row batch size to compute pairwise distances\n                               (parameter to trade-off between memory usage and performance).\n        :param col_batch_size: column batch size to compute pairwise distances.\n        :param nhood_sizes: number of neighbors used to estimate the manifold.\n        :param clamp_to_percentile: prune hyperspheres that have radius larger than\n                                    the given percentile.\n        :param eps: small number for numerical stability.\n        \"\"\"\n        self.distance_block = DistanceBlock(session)\n        self.row_batch_size = row_batch_size\n        self.col_batch_size = col_batch_size\n        self.nhood_sizes = nhood_sizes\n        self.num_nhoods = len(nhood_sizes)\n        self.clamp_to_percentile = clamp_to_percentile\n        self.eps = eps\n\n    def warmup(self):\n        feats, radii = (\n            np.zeros([1, 2048], dtype=np.float32),\n            np.zeros([1, 1], dtype=np.float32),\n        )\n        self.evaluate_pr(feats, radii, feats, radii)\n\n    def manifold_radii(self, features: np.ndarray) -> np.ndarray:\n        num_images = len(features)\n\n        # Estimate manifold of features by calculating distances to k-NN of each sample.\n        radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)\n        distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)\n        seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)\n\n        for begin1 in range(0, num_images, self.row_batch_size):\n            end1 = min(begin1 + self.row_batch_size, num_images)\n            row_batch = features[begin1:end1]\n\n            for begin2 in range(0, num_images, self.col_batch_size):\n                end2 = min(begin2 + self.col_batch_size, num_images)\n                col_batch = features[begin2:end2]\n\n                # Compute distances between batches.\n                distance_batch[\n                    0 : end1 - begin1, begin2:end2\n                ] = self.distance_block.pairwise_distances(row_batch, col_batch)\n\n            # Find the k-nearest neighbor from the current batch.\n            radii[begin1:end1, :] = np.concatenate(\n                [\n                    x[:, self.nhood_sizes]\n                    for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)\n                ],\n                axis=0,\n            )\n\n        if self.clamp_to_percentile is not None:\n            max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)\n            radii[radii > max_distances] = 0\n        return radii\n\n    def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):\n        \"\"\"\n        Evaluate if new feature vectors are at the manifold.\n        \"\"\"\n        num_eval_images = eval_features.shape[0]\n        num_ref_images = radii.shape[0]\n        distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)\n        batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)\n        max_realism_score = np.zeros([num_eval_images], dtype=np.float32)\n        nearest_indices = np.zeros([num_eval_images], dtype=np.int32)\n\n        for begin1 in range(0, num_eval_images, self.row_batch_size):\n            end1 = min(begin1 + self.row_batch_size, num_eval_images)\n            feature_batch = eval_features[begin1:end1]\n\n            for begin2 in range(0, num_ref_images, self.col_batch_size):\n                end2 = min(begin2 + self.col_batch_size, num_ref_images)\n                ref_batch = features[begin2:end2]\n\n                distance_batch[\n                    0 : end1 - begin1, begin2:end2\n                ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)\n\n            # From the minibatch of new feature vectors, determine if they are in the estimated manifold.\n            # If a feature vector is inside a hypersphere of some reference sample, then\n            # the new sample lies at the estimated manifold.\n            # The radii of the hyperspheres are determined from distances of neighborhood size k.\n            samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii\n            batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)\n\n            max_realism_score[begin1:end1] = np.max(\n                radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1\n            )\n            nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)\n\n        return {\n            \"fraction\": float(np.mean(batch_predictions)),\n            \"batch_predictions\": batch_predictions,\n            \"max_realisim_score\": max_realism_score,\n            \"nearest_indices\": nearest_indices,\n        }\n\n    def evaluate_pr(\n        self,\n        features_1: np.ndarray,\n        radii_1: np.ndarray,\n        features_2: np.ndarray,\n        radii_2: np.ndarray,\n    ) -> Tuple[np.ndarray, np.ndarray]:\n        \"\"\"\n        Evaluate precision and recall efficiently.\n\n        :param features_1: [N1 x D] feature vectors for reference batch.\n        :param radii_1: [N1 x K1] radii for reference vectors.\n        :param features_2: [N2 x D] feature vectors for the other batch.\n        :param radii_2: [N x K2] radii for other vectors.\n        :return: a tuple of arrays for (precision, recall):\n                 - precision: an np.ndarray of length K1\n                 - recall: an np.ndarray of length K2\n        \"\"\"\n        features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)\n        features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)\n        for begin_1 in range(0, len(features_1), self.row_batch_size):\n            end_1 = begin_1 + self.row_batch_size\n            batch_1 = features_1[begin_1:end_1]\n            for begin_2 in range(0, len(features_2), self.col_batch_size):\n                end_2 = begin_2 + self.col_batch_size\n                batch_2 = features_2[begin_2:end_2]\n                batch_1_in, batch_2_in = self.distance_block.less_thans(\n                    batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]\n                )\n                features_1_status[begin_1:end_1] |= batch_1_in\n                features_2_status[begin_2:end_2] |= batch_2_in\n        return (\n            np.mean(features_2_status.astype(np.float64), axis=0),\n            np.mean(features_1_status.astype(np.float64), axis=0),\n        )\n\n\nclass DistanceBlock:\n    \"\"\"\n    Calculate pairwise distances between vectors.\n\n    Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34\n    \"\"\"\n\n    def __init__(self, session):\n        self.session = session\n\n        # Initialize TF graph to calculate pairwise distances.\n        with session.graph.as_default():\n            self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])\n            self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])\n            distance_block_16 = _batch_pairwise_distances(\n                tf.cast(self._features_batch1, tf.float16),\n                tf.cast(self._features_batch2, tf.float16),\n            )\n            self.distance_block = tf.cond(\n                tf.reduce_all(tf.math.is_finite(distance_block_16)),\n                lambda: tf.cast(distance_block_16, tf.float32),\n                lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),\n            )\n\n            # Extra logic for less thans.\n            self._radii1 = tf.placeholder(tf.float32, shape=[None, None])\n            self._radii2 = tf.placeholder(tf.float32, shape=[None, None])\n            dist32 = tf.cast(self.distance_block, tf.float32)[..., None]\n            self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)\n            self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)\n\n    def pairwise_distances(self, U, V):\n        \"\"\"\n        Evaluate pairwise distances between two batches of feature vectors.\n        \"\"\"\n        return self.session.run(\n            self.distance_block,\n            feed_dict={self._features_batch1: U, self._features_batch2: V},\n        )\n\n    def less_thans(self, batch_1, radii_1, batch_2, radii_2):\n        return self.session.run(\n            [self._batch_1_in, self._batch_2_in],\n            feed_dict={\n                self._features_batch1: batch_1,\n                self._features_batch2: batch_2,\n                self._radii1: radii_1,\n                self._radii2: radii_2,\n            },\n        )\n\n\ndef _batch_pairwise_distances(U, V):\n    \"\"\"\n    Compute pairwise distances between two batches of feature vectors.\n    \"\"\"\n    with tf.variable_scope(\"pairwise_dist_block\"):\n        # Squared norms of each row in U and V.\n        norm_u = tf.reduce_sum(tf.square(U), 1)\n        norm_v = tf.reduce_sum(tf.square(V), 1)\n\n        # norm_u as a column and norm_v as a row vectors.\n        norm_u = tf.reshape(norm_u, [-1, 1])\n        norm_v = tf.reshape(norm_v, [1, -1])\n\n        # Pairwise squared Euclidean distances.\n        D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)\n\n    return D\n\n\nclass NpzArrayReader(ABC):\n    @abstractmethod\n    def read_batch(self, batch_size: int) -> Optional[np.ndarray]:\n        pass\n\n    @abstractmethod\n    def remaining(self) -> int:\n        pass\n\n    def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:\n        def gen_fn():\n            while True:\n                batch = self.read_batch(batch_size)\n                if batch is None:\n                    break\n                yield batch\n\n        rem = self.remaining()\n        num_batches = rem // batch_size + int(rem % batch_size != 0)\n        return BatchIterator(gen_fn, num_batches)\n\n\nclass BatchIterator:\n    def __init__(self, gen_fn, length):\n        self.gen_fn = gen_fn\n        self.length = length\n\n    def __len__(self):\n        return self.length\n\n    def __iter__(self):\n        return self.gen_fn()\n\n\nclass StreamingNpzArrayReader(NpzArrayReader):\n    def __init__(self, arr_f, shape, dtype):\n        self.arr_f = arr_f\n        self.shape = shape\n        self.dtype = dtype\n        self.idx = 0\n\n    def read_batch(self, batch_size: int) -> Optional[np.ndarray]:\n        if self.idx >= self.shape[0]:\n            return None\n\n        bs = min(batch_size, self.shape[0] - self.idx)\n        self.idx += bs\n\n        if self.dtype.itemsize == 0:\n            return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)\n\n        read_count = bs * np.prod(self.shape[1:])\n        read_size = int(read_count * self.dtype.itemsize)\n        data = _read_bytes(self.arr_f, read_size, \"array data\")\n        return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])\n\n    def remaining(self) -> int:\n        return max(0, self.shape[0] - self.idx)\n\n\nclass MemoryNpzArrayReader(NpzArrayReader):\n    def __init__(self, arr):\n        self.arr = arr\n        self.idx = 0\n\n    @classmethod\n    def load(cls, path: str, arr_name: str):\n        with open(path, \"rb\") as f:\n            arr = np.load(f)[arr_name]\n        return cls(arr)\n\n    def read_batch(self, batch_size: int) -> Optional[np.ndarray]:\n        if self.idx >= self.arr.shape[0]:\n            return None\n\n        res = self.arr[self.idx : self.idx + batch_size]\n        self.idx += batch_size\n        return res\n\n    def remaining(self) -> int:\n        return max(0, self.arr.shape[0] - self.idx)\n\n\n@contextmanager\ndef open_npz_array(path: str, arr_name: str) -> NpzArrayReader:\n    with _open_npy_file(path, arr_name) as arr_f:\n        version = np.lib.format.read_magic(arr_f)\n        if version == (1, 0):\n            header = np.lib.format.read_array_header_1_0(arr_f)\n        elif version == (2, 0):\n            header = np.lib.format.read_array_header_2_0(arr_f)\n        else:\n            yield MemoryNpzArrayReader.load(path, arr_name)\n            return\n        shape, fortran, dtype = header\n        if fortran or dtype.hasobject:\n            yield MemoryNpzArrayReader.load(path, arr_name)\n        else:\n            yield StreamingNpzArrayReader(arr_f, shape, dtype)\n\n\ndef _read_bytes(fp, size, error_template=\"ran out of data\"):\n    \"\"\"\n    Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886\n\n    Read from file-like object until size bytes are read.\n    Raises ValueError if not EOF is encountered before size bytes are read.\n    Non-blocking objects only supported if they derive from io objects.\n    Required as e.g. ZipExtFile in python 2.6 can return less data than\n    requested.\n    \"\"\"\n    data = bytes()\n    while True:\n        # io files (default in python3) return None or raise on\n        # would-block, python2 file will truncate, probably nothing can be\n        # done about that.  note that regular files can't be non-blocking\n        try:\n            r = fp.read(size - len(data))\n            data += r\n            if len(r) == 0 or len(data) == size:\n                break\n        except io.BlockingIOError:\n            pass\n    if len(data) != size:\n        msg = \"EOF: reading %s, expected %d bytes got %d\"\n        raise ValueError(msg % (error_template, size, len(data)))\n    else:\n        return data\n\n\n@contextmanager\ndef _open_npy_file(path: str, arr_name: str):\n    with open(path, \"rb\") as f:\n        with zipfile.ZipFile(f, \"r\") as zip_f:\n            if f\"{arr_name}.npy\" not in zip_f.namelist():\n                raise ValueError(f\"missing {arr_name} in npz file\")\n            with zip_f.open(f\"{arr_name}.npy\", \"r\") as arr_f:\n                yield arr_f\n\n\ndef _download_inception_model():\n    if os.path.exists(INCEPTION_V3_PATH):\n        return\n    print(\"downloading InceptionV3 model...\")\n    with requests.get(INCEPTION_V3_URL, stream=True) as r:\n        r.raise_for_status()\n        tmp_path = INCEPTION_V3_PATH + \".tmp\"\n        with open(tmp_path, \"wb\") as f:\n            for chunk in tqdm(r.iter_content(chunk_size=8192)):\n                f.write(chunk)\n        os.rename(tmp_path, INCEPTION_V3_PATH)\n\n\ndef _create_feature_graph(input_batch):\n    _download_inception_model()\n    prefix = f\"{random.randrange(2**32)}_{random.randrange(2**32)}\"\n    with open(INCEPTION_V3_PATH, \"rb\") as f:\n        graph_def = tf.GraphDef()\n        graph_def.ParseFromString(f.read())\n    pool3, spatial = tf.import_graph_def(\n        graph_def,\n        input_map={f\"ExpandDims:0\": input_batch},\n        return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],\n        name=prefix,\n    )\n    _update_shapes(pool3)\n    spatial = spatial[..., :7]\n    return pool3, spatial\n\n\ndef _create_softmax_graph(input_batch):\n    _download_inception_model()\n    prefix = f\"{random.randrange(2**32)}_{random.randrange(2**32)}\"\n    with open(INCEPTION_V3_PATH, \"rb\") as f:\n        graph_def = tf.GraphDef()\n        graph_def.ParseFromString(f.read())\n    (matmul,) = tf.import_graph_def(\n        graph_def, return_elements=[f\"softmax/logits/MatMul\"], name=prefix\n    )\n    w = matmul.inputs[1]\n    logits = tf.matmul(input_batch, w)\n    return tf.nn.softmax(logits)\n\n\ndef _update_shapes(pool3):\n    # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63\n    ops = pool3.graph.get_operations()\n    for op in ops:\n        for o in op.outputs:\n            shape = o.get_shape()\n            if shape._dims is not None:  # pylint: disable=protected-access\n                # shape = [s.value for s in shape] TF 1.x\n                shape = [s for s in shape]  # TF 2.x\n                new_shape = []\n                for j, s in enumerate(shape):\n                    if s == 1 and j == 0:\n                        new_shape.append(None)\n                    else:\n                        new_shape.append(s)\n                o.__dict__[\"_shape_val\"] = tf.TensorShape(new_shape)\n    return pool3\n\n\ndef _numpy_partition(arr, kth, **kwargs):\n    num_workers = min(cpu_count(), len(arr))\n    chunk_size = len(arr) // num_workers\n    extra = len(arr) % num_workers\n\n    start_idx = 0\n    batches = []\n    for i in range(num_workers):\n        size = chunk_size + (1 if i < extra else 0)\n        batches.append(arr[start_idx : start_idx + size])\n        start_idx += size\n\n    with ThreadPool(num_workers) as pool:\n        return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))\n\n\nif __name__ == \"__main__\":\n    print(REQUIREMENTS)\n    main()\n"
  },
  {
    "path": "ldm/modules/evaluate/evaluate_perceptualsim.py",
    "content": "import argparse\nimport glob\nimport os\nfrom tqdm import tqdm\nfrom collections import namedtuple\n\nimport numpy as np\nimport torch\nimport torchvision.transforms as transforms\nfrom torchvision import models\nfrom PIL import Image\n\nfrom ldm.modules.evaluate.ssim import ssim\n\n\ntransform = transforms.Compose([transforms.ToTensor()])\n\ndef normalize_tensor(in_feat, eps=1e-10):\n    norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1)).view(\n        in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3]\n    )\n    return in_feat / (norm_factor.expand_as(in_feat) + eps)\n\n\ndef cos_sim(in0, in1):\n    in0_norm = normalize_tensor(in0)\n    in1_norm = normalize_tensor(in1)\n    N = in0.size()[0]\n    X = in0.size()[2]\n    Y = in0.size()[3]\n\n    return torch.mean(\n        torch.mean(\n            torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2\n        ).view(N, 1, 1, Y),\n        dim=3,\n    ).view(N)\n\n\nclass squeezenet(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(squeezenet, self).__init__()\n        pretrained_features = models.squeezenet1_1(\n            pretrained=pretrained\n        ).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.slice6 = torch.nn.Sequential()\n        self.slice7 = torch.nn.Sequential()\n        self.N_slices = 7\n        for x in range(2):\n            self.slice1.add_module(str(x), pretrained_features[x])\n        for x in range(2, 5):\n            self.slice2.add_module(str(x), pretrained_features[x])\n        for x in range(5, 8):\n            self.slice3.add_module(str(x), pretrained_features[x])\n        for x in range(8, 10):\n            self.slice4.add_module(str(x), pretrained_features[x])\n        for x in range(10, 11):\n            self.slice5.add_module(str(x), pretrained_features[x])\n        for x in range(11, 12):\n            self.slice6.add_module(str(x), pretrained_features[x])\n        for x in range(12, 13):\n            self.slice7.add_module(str(x), pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1 = h\n        h = self.slice2(h)\n        h_relu2 = h\n        h = self.slice3(h)\n        h_relu3 = h\n        h = self.slice4(h)\n        h_relu4 = h\n        h = self.slice5(h)\n        h_relu5 = h\n        h = self.slice6(h)\n        h_relu6 = h\n        h = self.slice7(h)\n        h_relu7 = h\n        vgg_outputs = namedtuple(\n            \"SqueezeOutputs\",\n            [\"relu1\", \"relu2\", \"relu3\", \"relu4\", \"relu5\", \"relu6\", \"relu7\"],\n        )\n        out = vgg_outputs(\n            h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7\n        )\n\n        return out\n\n\nclass alexnet(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(alexnet, self).__init__()\n        alexnet_pretrained_features = models.alexnet(\n            pretrained=pretrained\n        ).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(2):\n            self.slice1.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(2, 5):\n            self.slice2.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(5, 8):\n            self.slice3.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(8, 10):\n            self.slice4.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(10, 12):\n            self.slice5.add_module(str(x), alexnet_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1 = h\n        h = self.slice2(h)\n        h_relu2 = h\n        h = self.slice3(h)\n        h_relu3 = h\n        h = self.slice4(h)\n        h_relu4 = h\n        h = self.slice5(h)\n        h_relu5 = h\n        alexnet_outputs = namedtuple(\n            \"AlexnetOutputs\", [\"relu1\", \"relu2\", \"relu3\", \"relu4\", \"relu5\"]\n        )\n        out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)\n\n        return out\n\n\nclass vgg16(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(vgg16, self).__init__()\n        vgg_pretrained_features = models.vgg16(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(4):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(4, 9):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(9, 16):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(16, 23):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(23, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1_2 = h\n        h = self.slice2(h)\n        h_relu2_2 = h\n        h = self.slice3(h)\n        h_relu3_3 = h\n        h = self.slice4(h)\n        h_relu4_3 = h\n        h = self.slice5(h)\n        h_relu5_3 = h\n        vgg_outputs = namedtuple(\n            \"VggOutputs\",\n            [\"relu1_2\", \"relu2_2\", \"relu3_3\", \"relu4_3\", \"relu5_3\"],\n        )\n        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)\n\n        return out\n\n\nclass resnet(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True, num=18):\n        super(resnet, self).__init__()\n        if num == 18:\n            self.net = models.resnet18(pretrained=pretrained)\n        elif num == 34:\n            self.net = models.resnet34(pretrained=pretrained)\n        elif num == 50:\n            self.net = models.resnet50(pretrained=pretrained)\n        elif num == 101:\n            self.net = models.resnet101(pretrained=pretrained)\n        elif num == 152:\n            self.net = models.resnet152(pretrained=pretrained)\n        self.N_slices = 5\n\n        self.conv1 = self.net.conv1\n        self.bn1 = self.net.bn1\n        self.relu = self.net.relu\n        self.maxpool = self.net.maxpool\n        self.layer1 = self.net.layer1\n        self.layer2 = self.net.layer2\n        self.layer3 = self.net.layer3\n        self.layer4 = self.net.layer4\n\n    def forward(self, X):\n        h = self.conv1(X)\n        h = self.bn1(h)\n        h = self.relu(h)\n        h_relu1 = h\n        h = self.maxpool(h)\n        h = self.layer1(h)\n        h_conv2 = h\n        h = self.layer2(h)\n        h_conv3 = h\n        h = self.layer3(h)\n        h_conv4 = h\n        h = self.layer4(h)\n        h_conv5 = h\n\n        outputs = namedtuple(\n            \"Outputs\", [\"relu1\", \"conv2\", \"conv3\", \"conv4\", \"conv5\"]\n        )\n        out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)\n\n        return out\n\n# Off-the-shelf deep network\nclass PNet(torch.nn.Module):\n    \"\"\"Pre-trained network with all channels equally weighted by default\"\"\"\n\n    def __init__(self, pnet_type=\"vgg\", pnet_rand=False, use_gpu=True):\n        super(PNet, self).__init__()\n\n        self.use_gpu = use_gpu\n\n        self.pnet_type = pnet_type\n        self.pnet_rand = pnet_rand\n\n        self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1)\n        self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1)\n\n        if self.pnet_type in [\"vgg\", \"vgg16\"]:\n            self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False)\n        elif self.pnet_type == \"alex\":\n            self.net = alexnet(\n                pretrained=not self.pnet_rand, requires_grad=False\n            )\n        elif self.pnet_type[:-2] == \"resnet\":\n            self.net = resnet(\n                pretrained=not self.pnet_rand,\n                requires_grad=False,\n                num=int(self.pnet_type[-2:]),\n            )\n        elif self.pnet_type == \"squeeze\":\n            self.net = squeezenet(\n                pretrained=not self.pnet_rand, requires_grad=False\n            )\n\n        self.L = self.net.N_slices\n\n        if use_gpu:\n            self.net.cuda()\n            self.shift = self.shift.cuda()\n            self.scale = self.scale.cuda()\n\n    def forward(self, in0, in1, retPerLayer=False):\n        in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)\n        in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)\n\n        outs0 = self.net.forward(in0_sc)\n        outs1 = self.net.forward(in1_sc)\n\n        if retPerLayer:\n            all_scores = []\n        for (kk, out0) in enumerate(outs0):\n            cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk])\n            if kk == 0:\n                val = 1.0 * cur_score\n            else:\n                val = val + cur_score\n            if retPerLayer:\n                all_scores += [cur_score]\n\n        if retPerLayer:\n            return (val, all_scores)\n        else:\n            return val\n\n\n\n\n# The SSIM metric\ndef ssim_metric(img1, img2, mask=None):\n    return ssim(img1, img2, mask=mask, size_average=False)\n\n\n# The PSNR metric\ndef psnr(img1, img2, mask=None,reshape=False):\n    b = img1.size(0)\n    if not (mask is None):\n        b = img1.size(0)\n        mse_err = (img1 - img2).pow(2) * mask\n        if reshape:\n            mse_err = mse_err.reshape(b, -1).sum(dim=1) / (\n                    3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1)\n            )\n        else:\n            mse_err = mse_err.view(b, -1).sum(dim=1) / (\n                3 * mask.view(b, -1).sum(dim=1).clamp(min=1)\n            )\n    else:\n        if reshape:\n            mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1)\n        else:\n            mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1)\n\n    psnr = 10 * (1 / mse_err).log10()\n    return psnr\n\n\n# The perceptual similarity metric\ndef perceptual_sim(img1, img2, vgg16):\n    # First extract features\n    dist = vgg16(img1 * 2 - 1, img2 * 2 - 1)\n\n    return dist\n\ndef load_img(img_name, size=None):\n    try:\n        img = Image.open(img_name)\n\n        if type(size) == int:\n            img = img.resize((size, size))\n        elif size is not None:\n            img = img.resize((size[1], size[0]))\n\n        img = transform(img).cuda()\n        img = img.unsqueeze(0)\n    except Exception as e:\n        print(\"Failed at loading %s \" % img_name)\n        print(e)\n        img = torch.zeros(1, 3, 256, 256).cuda()\n        raise\n    return img\n\n\ndef compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other):\n\n    # Load VGG16 for feature similarity\n    vgg16 = PNet().to(\"cuda\")\n    vgg16.eval()\n    vgg16.cuda()\n\n    values_percsim = []\n    values_ssim = []\n    values_psnr = []\n    folders = os.listdir(folder)\n    for i, f in tqdm(enumerate(sorted(folders))):\n        pred_imgs = glob.glob(folder + f + \"/\" + pred_img)\n        tgt_imgs = glob.glob(folder + f + \"/\" + tgt_img)\n        assert len(tgt_imgs) == 1\n\n        perc_sim = 10000\n        ssim_sim = -10\n        psnr_sim = -10\n        for p_img in pred_imgs:\n            t_img = load_img(tgt_imgs[0])\n            p_img = load_img(p_img, size=t_img.shape[2:])\n            t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()\n            perc_sim = min(perc_sim, t_perc_sim)\n\n            ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())\n            psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())\n\n        values_percsim += [perc_sim]\n        values_ssim += [ssim_sim]\n        values_psnr += [psnr_sim]\n\n    if take_every_other:\n        n_valuespercsim = []\n        n_valuesssim = []\n        n_valuespsnr = []\n        for i in range(0, len(values_percsim) // 2):\n            n_valuespercsim += [\n                min(values_percsim[2 * i], values_percsim[2 * i + 1])\n            ]\n            n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]\n            n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]\n\n        values_percsim = n_valuespercsim\n        values_ssim = n_valuesssim\n        values_psnr = n_valuespsnr\n\n    avg_percsim = np.mean(np.array(values_percsim))\n    std_percsim = np.std(np.array(values_percsim))\n\n    avg_psnr = np.mean(np.array(values_psnr))\n    std_psnr = np.std(np.array(values_psnr))\n\n    avg_ssim = np.mean(np.array(values_ssim))\n    std_ssim = np.std(np.array(values_ssim))\n\n    return {\n        \"Perceptual similarity\": (avg_percsim, std_percsim),\n        \"PSNR\": (avg_psnr, std_psnr),\n        \"SSIM\": (avg_ssim, std_ssim),\n    }\n\n\ndef compute_perceptual_similarity_from_list(pred_imgs_list, tgt_imgs_list,\n                                            take_every_other,\n                                            simple_format=True):\n\n    # Load VGG16 for feature similarity\n    vgg16 = PNet().to(\"cuda\")\n    vgg16.eval()\n    vgg16.cuda()\n\n    values_percsim = []\n    values_ssim = []\n    values_psnr = []\n    equal_count = 0\n    ambig_count = 0\n    for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):\n        pred_imgs = pred_imgs_list[i]\n        tgt_imgs = [tgt_img]\n        assert len(tgt_imgs) == 1\n\n        if type(pred_imgs) != list:\n            pred_imgs = [pred_imgs]\n\n        perc_sim = 10000\n        ssim_sim = -10\n        psnr_sim = -10\n        assert len(pred_imgs)>0\n        for p_img in pred_imgs:\n            t_img = load_img(tgt_imgs[0])\n            p_img = load_img(p_img, size=t_img.shape[2:])\n            t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()\n            perc_sim = min(perc_sim, t_perc_sim)\n\n            ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())\n            psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())\n\n        values_percsim += [perc_sim]\n        values_ssim += [ssim_sim]\n        if psnr_sim != np.float(\"inf\"):\n            values_psnr += [psnr_sim]\n        else:\n            if torch.allclose(p_img, t_img):\n                equal_count += 1\n                print(\"{} equal src and wrp images.\".format(equal_count))\n            else:\n                ambig_count += 1\n                print(\"{} ambiguous src and wrp images.\".format(ambig_count))\n\n    if take_every_other:\n        n_valuespercsim = []\n        n_valuesssim = []\n        n_valuespsnr = []\n        for i in range(0, len(values_percsim) // 2):\n            n_valuespercsim += [\n                min(values_percsim[2 * i], values_percsim[2 * i + 1])\n            ]\n            n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]\n            n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]\n\n        values_percsim = n_valuespercsim\n        values_ssim = n_valuesssim\n        values_psnr = n_valuespsnr\n\n    avg_percsim = np.mean(np.array(values_percsim))\n    std_percsim = np.std(np.array(values_percsim))\n\n    avg_psnr = np.mean(np.array(values_psnr))\n    std_psnr = np.std(np.array(values_psnr))\n\n    avg_ssim = np.mean(np.array(values_ssim))\n    std_ssim = np.std(np.array(values_ssim))\n\n    if simple_format:\n        # just to make yaml formatting readable\n        return {\n            \"Perceptual similarity\": [float(avg_percsim), float(std_percsim)],\n            \"PSNR\": [float(avg_psnr), float(std_psnr)],\n            \"SSIM\": [float(avg_ssim), float(std_ssim)],\n        }\n    else:\n        return {\n            \"Perceptual similarity\": (avg_percsim, std_percsim),\n            \"PSNR\": (avg_psnr, std_psnr),\n            \"SSIM\": (avg_ssim, std_ssim),\n        }\n\n\ndef compute_perceptual_similarity_from_list_topk(pred_imgs_list, tgt_imgs_list,\n                                                 take_every_other, resize=False):\n\n    # Load VGG16 for feature similarity\n    vgg16 = PNet().to(\"cuda\")\n    vgg16.eval()\n    vgg16.cuda()\n\n    values_percsim = []\n    values_ssim = []\n    values_psnr = []\n    individual_percsim = []\n    individual_ssim = []\n    individual_psnr = []\n    for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):\n        pred_imgs = pred_imgs_list[i]\n        tgt_imgs = [tgt_img]\n        assert len(tgt_imgs) == 1\n\n        if type(pred_imgs) != list:\n            assert False\n            pred_imgs = [pred_imgs]\n\n        perc_sim = 10000\n        ssim_sim = -10\n        psnr_sim = -10\n        sample_percsim = list()\n        sample_ssim = list()\n        sample_psnr = list()\n        for p_img in pred_imgs:\n            if resize:\n                t_img = load_img(tgt_imgs[0], size=(256,256))\n            else:\n                t_img = load_img(tgt_imgs[0])\n            p_img = load_img(p_img, size=t_img.shape[2:])\n\n            t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()\n            sample_percsim.append(t_perc_sim)\n            perc_sim = min(perc_sim, t_perc_sim)\n\n            t_ssim = ssim_metric(p_img, t_img).item()\n            sample_ssim.append(t_ssim)\n            ssim_sim = max(ssim_sim, t_ssim)\n\n            t_psnr = psnr(p_img, t_img).item()\n            sample_psnr.append(t_psnr)\n            psnr_sim = max(psnr_sim, t_psnr)\n\n        values_percsim += [perc_sim]\n        values_ssim += [ssim_sim]\n        values_psnr += [psnr_sim]\n        individual_percsim.append(sample_percsim)\n        individual_ssim.append(sample_ssim)\n        individual_psnr.append(sample_psnr)\n\n    if take_every_other:\n        assert False, \"Do this later, after specifying topk to get proper results\"\n        n_valuespercsim = []\n        n_valuesssim = []\n        n_valuespsnr = []\n        for i in range(0, len(values_percsim) // 2):\n            n_valuespercsim += [\n                min(values_percsim[2 * i], values_percsim[2 * i + 1])\n            ]\n            n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]\n            n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]\n\n        values_percsim = n_valuespercsim\n        values_ssim = n_valuesssim\n        values_psnr = n_valuespsnr\n\n    avg_percsim = np.mean(np.array(values_percsim))\n    std_percsim = np.std(np.array(values_percsim))\n\n    avg_psnr = np.mean(np.array(values_psnr))\n    std_psnr = np.std(np.array(values_psnr))\n\n    avg_ssim = np.mean(np.array(values_ssim))\n    std_ssim = np.std(np.array(values_ssim))\n\n    individual_percsim = np.array(individual_percsim)\n    individual_psnr = np.array(individual_psnr)\n    individual_ssim = np.array(individual_ssim)\n\n    return {\n        \"avg_of_best\": {\n            \"Perceptual similarity\": [float(avg_percsim), float(std_percsim)],\n            \"PSNR\": [float(avg_psnr), float(std_psnr)],\n            \"SSIM\": [float(avg_ssim), float(std_ssim)],\n        },\n        \"individual\": {\n            \"PSIM\": individual_percsim,\n            \"PSNR\": individual_psnr,\n            \"SSIM\": individual_ssim,\n        }\n    }\n\n\nif __name__ == \"__main__\":\n    args = argparse.ArgumentParser()\n    args.add_argument(\"--folder\", type=str, default=\"\")\n    args.add_argument(\"--pred_image\", type=str, default=\"\")\n    args.add_argument(\"--target_image\", type=str, default=\"\")\n    args.add_argument(\"--take_every_other\", action=\"store_true\", default=False)\n    args.add_argument(\"--output_file\", type=str, default=\"\")\n\n    opts = args.parse_args()\n\n    folder = opts.folder\n    pred_img = opts.pred_image\n    tgt_img = opts.target_image\n\n    results = compute_perceptual_similarity(\n        folder, pred_img, tgt_img, opts.take_every_other\n    )\n\n    f = open(opts.output_file, 'w')\n    for key in results:\n        print(\"%s for %s: \\n\" % (key, opts.folder))\n        print(\n            \"\\t {:0.4f} | {:0.4f} \\n\".format(results[key][0], results[key][1])\n        )\n\n        f.write(\"%s for %s: \\n\" % (key, opts.folder))\n        f.write(\n            \"\\t {:0.4f} | {:0.4f} \\n\".format(results[key][0], results[key][1])\n        )\n\n    f.close()\n"
  },
  {
    "path": "ldm/modules/evaluate/frechet_video_distance.py",
    "content": "# coding=utf-8\n# Copyright 2022 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python2, python3\n\"\"\"Minimal Reference implementation for the Frechet Video Distance (FVD).\n\nFVD is a metric for the quality of video generation models. It is inspired by\nthe FID (Frechet Inception Distance) used for images, but uses a different\nembedding to be better suitable for videos.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n\nimport six\nimport tensorflow.compat.v1 as tf\nimport tensorflow_gan as tfgan\nimport tensorflow_hub as hub\n\n\ndef preprocess(videos, target_resolution):\n  \"\"\"Runs some preprocessing on the videos for I3D model.\n\n  Args:\n    videos: <T>[batch_size, num_frames, height, width, depth] The videos to be\n      preprocessed. We don't care about the specific dtype of the videos, it can\n      be anything that tf.image.resize_bilinear accepts. Values are expected to\n      be in the range 0-255.\n    target_resolution: (width, height): target video resolution\n\n  Returns:\n    videos: <float32>[batch_size, num_frames, height, width, depth]\n  \"\"\"\n  videos_shape = list(videos.shape)\n  all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])\n  resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)\n  target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]\n  output_videos = tf.reshape(resized_videos, target_shape)\n  scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1\n  return scaled_videos\n\n\ndef _is_in_graph(tensor_name):\n  \"\"\"Checks whether a given tensor does exists in the graph.\"\"\"\n  try:\n    tf.get_default_graph().get_tensor_by_name(tensor_name)\n  except KeyError:\n    return False\n  return True\n\n\ndef create_id3_embedding(videos,warmup=False,batch_size=16):\n  \"\"\"Embeds the given videos using the Inflated 3D Convolution ne   twork.\n\n  Downloads the graph of the I3D from tf.hub and adds it to the graph on the\n  first call.\n\n  Args:\n    videos: <float32>[batch_size, num_frames, height=224, width=224, depth=3].\n      Expected range is [-1, 1].\n\n  Returns:\n    embedding: <float32>[batch_size, embedding_size]. embedding_size depends\n               on the model used.\n\n  Raises:\n    ValueError: when a provided embedding_layer is not supported.\n  \"\"\"\n\n  # batch_size = 16\n  module_spec = \"https://tfhub.dev/deepmind/i3d-kinetics-400/1\"\n\n\n  # Making sure that we import the graph separately for\n  # each different input video tensor.\n  module_name = \"fvd_kinetics-400_id3_module_\" + six.ensure_str(\n      videos.name).replace(\":\", \"_\")\n\n\n\n  assert_ops = [\n      tf.Assert(\n          tf.reduce_max(videos) <= 1.001,\n          [\"max value in frame is > 1\", videos]),\n      tf.Assert(\n          tf.reduce_min(videos) >= -1.001,\n          [\"min value in frame is < -1\", videos]),\n      tf.assert_equal(\n          tf.shape(videos)[0],\n          batch_size, [\"invalid frame batch size: \",\n                       tf.shape(videos)],\n          summarize=6),\n  ]\n  with tf.control_dependencies(assert_ops):\n    videos = tf.identity(videos)\n\n  module_scope = \"%s_apply_default/\" % module_name\n\n  # To check whether the module has already been loaded into the graph, we look\n  # for a given tensor name. If this tensor name exists, we assume the function\n  # has been called before and the graph was imported. Otherwise we import it.\n  # Note: in theory, the tensor could exist, but have wrong shapes.\n  # This will happen if create_id3_embedding is called with a frames_placehoder\n  # of wrong size/batch size, because even though that will throw a tf.Assert\n  # on graph-execution time, it will insert the tensor (with wrong shape) into\n  # the graph. This is why we need the following assert.\n  if warmup:\n      video_batch_size = int(videos.shape[0])\n      assert video_batch_size in [batch_size, -1, None], f\"Invalid batch size {video_batch_size}\"\n  tensor_name = module_scope + \"RGB/inception_i3d/Mean:0\"\n  if not _is_in_graph(tensor_name):\n    i3d_model = hub.Module(module_spec, name=module_name)\n    i3d_model(videos)\n\n  # gets the kinetics-i3d-400-logits layer\n  tensor_name = module_scope + \"RGB/inception_i3d/Mean:0\"\n  tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)\n  return tensor\n\n\ndef calculate_fvd(real_activations,\n                  generated_activations):\n  \"\"\"Returns a list of ops that compute metrics as funcs of activations.\n\n  Args:\n    real_activations: <float32>[num_samples, embedding_size]\n    generated_activations: <float32>[num_samples, embedding_size]\n\n  Returns:\n    A scalar that contains the requested FVD.\n  \"\"\"\n  return tfgan.eval.frechet_classifier_distance_from_activations(\n      real_activations, generated_activations)\n"
  },
  {
    "path": "ldm/modules/evaluate/ssim.py",
    "content": "# MIT Licence\n\n# Methods to predict the SSIM, taken from\n# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py\n\nfrom math import exp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor(\n        [\n            exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2))\n            for x in range(window_size)\n        ]\n    )\n    return gauss / gauss.sum()\n\n\ndef create_window(window_size, channel):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = Variable(\n        _2D_window.expand(channel, 1, window_size, window_size).contiguous()\n    )\n    return window\n\n\ndef _ssim(\n    img1, img2, window, window_size, channel, mask=None, size_average=True\n):\n    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)\n    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = (\n        F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel)\n        - mu1_sq\n    )\n    sigma2_sq = (\n        F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel)\n        - mu2_sq\n    )\n    sigma12 = (\n        F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)\n        - mu1_mu2\n    )\n\n    C1 = (0.01) ** 2\n    C2 = (0.03) ** 2\n\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (\n        (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)\n    )\n\n    if not (mask is None):\n        b = mask.size(0)\n        ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask\n        ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(\n            dim=1\n        ).clamp(min=1)\n        return ssim_map\n\n    import pdb\n\n    pdb.set_trace\n\n    if size_average:\n        return ssim_map.mean()\n    else:\n        return ssim_map.mean(1).mean(1).mean(1)\n\n\nclass SSIM(torch.nn.Module):\n    def __init__(self, window_size=11, size_average=True):\n        super(SSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.channel = 1\n        self.window = create_window(window_size, self.channel)\n\n    def forward(self, img1, img2, mask=None):\n        (_, channel, _, _) = img1.size()\n\n        if (\n            channel == self.channel\n            and self.window.data.type() == img1.data.type()\n        ):\n            window = self.window\n        else:\n            window = create_window(self.window_size, channel)\n\n            if img1.is_cuda:\n                window = window.cuda(img1.get_device())\n            window = window.type_as(img1)\n\n            self.window = window\n            self.channel = channel\n\n        return _ssim(\n            img1,\n            img2,\n            window,\n            self.window_size,\n            channel,\n            mask,\n            self.size_average,\n        )\n\n\ndef ssim(img1, img2, window_size=11, mask=None, size_average=True):\n    (_, channel, _, _) = img1.size()\n    window = create_window(window_size, channel)\n\n    if img1.is_cuda:\n        window = window.cuda(img1.get_device())\n    window = window.type_as(img1)\n\n    return _ssim(img1, img2, window, window_size, channel, mask, size_average)\n"
  },
  {
    "path": "ldm/modules/evaluate/torch_frechet_video_distance.py",
    "content": "# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks!\nimport os\nimport numpy as np\nimport io\nimport re\nimport requests\nimport html\nimport hashlib\nimport urllib\nimport urllib.request\nimport scipy.linalg\nimport multiprocessing as mp\nimport glob\n\n\nfrom tqdm import tqdm\nfrom typing import Any, List, Tuple, Union, Dict, Callable\n\nfrom torchvision.io import read_video\nimport torch; torch.set_grad_enabled(False)\nfrom einops import rearrange\n\nfrom nitro.util import isvideo\n\ndef compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float:\n    print('Calculate frechet distance...')\n    m = np.square(mu_sample - mu_ref).sum()\n    s, _ = scipy.linalg.sqrtm(np.dot(sigma_sample, sigma_ref), disp=False) # pylint: disable=no-member\n    fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2))\n\n    return float(fid)\n\n\ndef compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n    mu = feats.mean(axis=0) # [d]\n    sigma = np.cov(feats, rowvar=False) # [d, d]\n\n    return mu, sigma\n\n\ndef open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any:\n    \"\"\"Download the given URL and return a binary-mode file object to access the data.\"\"\"\n    assert num_attempts >= 1\n\n    # Doesn't look like an URL scheme so interpret it as a local filename.\n    if not re.match('^[a-z]+://', url):\n        return url if return_filename else open(url, \"rb\")\n\n    # Handle file URLs.  This code handles unusual file:// patterns that\n    # arise on Windows:\n    #\n    # file:///c:/foo.txt\n    #\n    # which would translate to a local '/c:/foo.txt' filename that's\n    # invalid.  Drop the forward slash for such pathnames.\n    #\n    # If you touch this code path, you should test it on both Linux and\n    # Windows.\n    #\n    # Some internet resources suggest using urllib.request.url2pathname() but\n    # but that converts forward slashes to backslashes and this causes\n    # its own set of problems.\n    if url.startswith('file://'):\n        filename = urllib.parse.urlparse(url).path\n        if re.match(r'^/[a-zA-Z]:', filename):\n            filename = filename[1:]\n        return filename if return_filename else open(filename, \"rb\")\n\n    url_md5 = hashlib.md5(url.encode(\"utf-8\")).hexdigest()\n\n    # Download.\n    url_name = None\n    url_data = None\n    with requests.Session() as session:\n        if verbose:\n            print(\"Downloading %s ...\" % url, end=\"\", flush=True)\n        for attempts_left in reversed(range(num_attempts)):\n            try:\n                with session.get(url) as res:\n                    res.raise_for_status()\n                    if len(res.content) == 0:\n                        raise IOError(\"No data received\")\n\n                    if len(res.content) < 8192:\n                        content_str = res.content.decode(\"utf-8\")\n                        if \"download_warning\" in res.headers.get(\"Set-Cookie\", \"\"):\n                            links = [html.unescape(link) for link in content_str.split('\"') if \"export=download\" in link]\n                            if len(links) == 1:\n                                url = requests.compat.urljoin(url, links[0])\n                                raise IOError(\"Google Drive virus checker nag\")\n                        if \"Google Drive - Quota exceeded\" in content_str:\n                            raise IOError(\"Google Drive download quota exceeded -- please try again later\")\n\n                    match = re.search(r'filename=\"([^\"]*)\"', res.headers.get(\"Content-Disposition\", \"\"))\n                    url_name = match[1] if match else url\n                    url_data = res.content\n                    if verbose:\n                        print(\" done\")\n                    break\n            except KeyboardInterrupt:\n                raise\n            except:\n                if not attempts_left:\n                    if verbose:\n                        print(\" failed\")\n                    raise\n                if verbose:\n                    print(\".\", end=\"\", flush=True)\n\n    # Return data as file object.\n    assert not return_filename\n    return io.BytesIO(url_data)\n\ndef load_video(ip):\n    vid, *_ = read_video(ip)\n    vid = rearrange(vid, 't h w c -> t c h w').to(torch.uint8)\n    return vid\n\ndef get_data_from_str(input_str,nprc = None):\n    assert os.path.isdir(input_str), f'Specified input folder \"{input_str}\" is not a directory'\n    vid_filelist = glob.glob(os.path.join(input_str,'*.mp4'))\n    print(f'Found {len(vid_filelist)} videos in dir {input_str}')\n\n    if nprc is None:\n        try:\n            nprc = mp.cpu_count()\n        except NotImplementedError:\n            print('WARNING: cpu_count() not avlailable, using only 1 cpu for video loading')\n            nprc = 1\n\n    pool = mp.Pool(processes=nprc)\n\n    vids = []\n    for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'):\n        vids.append(v)\n\n\n    vids = torch.stack(vids,dim=0).float()\n\n    return vids\n\ndef get_stats(stats):\n    assert os.path.isfile(stats) and stats.endswith('.npz'), f'no stats found under {stats}'\n\n    print(f'Using precomputed statistics under {stats}')\n    stats = np.load(stats)\n    stats = {key: stats[key] for key in stats.files}\n\n    return stats\n\n\n\n\n@torch.no_grad()\ndef compute_fvd(ref_input, sample_input, bs=32,\n                ref_stats=None,\n                sample_stats=None,\n                nprc_load=None):\n\n\n\n    calc_stats = ref_stats is None or sample_stats is None\n\n    if calc_stats:\n\n        only_ref = sample_stats is not None\n        only_sample = ref_stats is not None\n\n\n        if isinstance(ref_input,str) and not only_sample:\n            ref_input = get_data_from_str(ref_input,nprc_load)\n\n        if isinstance(sample_input, str) and not only_ref:\n            sample_input = get_data_from_str(sample_input, nprc_load)\n\n        stats = compute_statistics(sample_input,ref_input,\n                                        device='cuda' if torch.cuda.is_available() else 'cpu',\n                                        bs=bs,\n                                        only_ref=only_ref,\n                                        only_sample=only_sample)\n\n        if only_ref:\n            stats.update(get_stats(sample_stats))\n        elif only_sample:\n            stats.update(get_stats(ref_stats))\n\n\n\n    else:\n        stats = get_stats(sample_stats)\n        stats.update(get_stats(ref_stats))\n\n    fvd = compute_frechet_distance(**stats)\n\n    return {'FVD' : fvd,}\n\n\n@torch.no_grad()\ndef compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only_ref=False,only_sample=False) -> Dict:\n    detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1'\n    detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer.\n\n    with open_url(detector_url, verbose=False) as f:\n        detector = torch.jit.load(f).eval().to(device)\n\n\n\n    assert not (only_sample and only_ref), 'only_ref and only_sample arguments are mutually exclusive'\n\n    ref_embed, sample_embed = [], []\n\n    info = f'Computing I3D activations for FVD score with batch size {bs}'\n\n    if only_ref:\n\n        if not isvideo(videos_real):\n            # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]\n            videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()\n            print(videos_real.shape)\n\n        if videos_real.shape[0] % bs == 0:\n            n_secs = videos_real.shape[0] // bs\n        else:\n            n_secs = videos_real.shape[0] // bs + 1\n\n        videos_real = torch.tensor_split(videos_real, n_secs, dim=0)\n\n        for ref_v in tqdm(videos_real, total=len(videos_real),desc=info):\n\n            feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()\n            ref_embed.append(feats_ref)\n\n    elif only_sample:\n\n        if not isvideo(videos_fake):\n            # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]\n            videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()\n            print(videos_fake.shape)\n\n        if videos_fake.shape[0] % bs == 0:\n            n_secs = videos_fake.shape[0] // bs\n        else:\n            n_secs = videos_fake.shape[0] // bs + 1\n\n        videos_real = torch.tensor_split(videos_real, n_secs, dim=0)\n\n        for sample_v in tqdm(videos_fake, total=len(videos_real),desc=info):\n            feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()\n            sample_embed.append(feats_sample)\n\n\n    else:\n\n        if not isvideo(videos_real):\n            # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]\n            videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()\n\n        if not isvideo(videos_fake):\n            videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()\n\n        if videos_fake.shape[0] % bs == 0:\n            n_secs = videos_fake.shape[0] // bs\n        else:\n            n_secs = videos_fake.shape[0] // bs + 1\n\n        videos_real = torch.tensor_split(videos_real, n_secs, dim=0)\n        videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0)\n\n        for ref_v, sample_v in tqdm(zip(videos_real,videos_fake),total=len(videos_fake),desc=info):\n            # print(ref_v.shape)\n            # ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)\n            # sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)\n\n\n            feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()\n            feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()\n            sample_embed.append(feats_sample)\n            ref_embed.append(feats_ref)\n\n    out = dict()\n    if len(sample_embed) > 0:\n        sample_embed = np.concatenate(sample_embed,axis=0)\n        mu_sample, sigma_sample = compute_stats(sample_embed)\n        out.update({'mu_sample': mu_sample,\n                    'sigma_sample': sigma_sample})\n\n    if len(ref_embed) > 0:\n        ref_embed = np.concatenate(ref_embed,axis=0)\n        mu_ref, sigma_ref = compute_stats(ref_embed)\n        out.update({'mu_ref': mu_ref,\n                    'sigma_ref': sigma_ref})\n\n\n    return out\n"
  },
  {
    "path": "ldm/modules/image_degradation/__init__.py",
    "content": "from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr\nfrom ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light\n"
  },
  {
    "path": "ldm/modules/image_degradation/bsrgan.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\n# --------------------------------------------\n# Super-Resolution\n# --------------------------------------------\n#\n# Kai Zhang (cskaizhang@gmail.com)\n# https://github.com/cszn\n# From 2019/03--2021/08\n# --------------------------------------------\n\"\"\"\n\nimport numpy as np\nimport cv2\nimport torch\n\nfrom functools import partial\nimport random\nfrom scipy import ndimage\nimport scipy\nimport scipy.stats as ss\nfrom scipy.interpolate import interp2d\nfrom scipy.linalg import orth\nimport albumentations\n\nimport ldm.modules.image_degradation.utils_image as util\n\n\ndef modcrop_np(img, sf):\n    '''\n    Args:\n        img: numpy image, WxH or WxHxC\n        sf: scale factor\n    Return:\n        cropped image\n    '''\n    w, h = img.shape[:2]\n    im = np.copy(img)\n    return im[:w - w % sf, :h - h % sf, ...]\n\n\n\"\"\"\n# --------------------------------------------\n# anisotropic Gaussian kernels\n# --------------------------------------------\n\"\"\"\n\n\ndef analytic_kernel(k):\n    \"\"\"Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)\"\"\"\n    k_size = k.shape[0]\n    # Calculate the big kernels size\n    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))\n    # Loop over the small kernel to fill the big one\n    for r in range(k_size):\n        for c in range(k_size):\n            big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k\n    # Crop the edges of the big kernel to ignore very small values and increase run time of SR\n    crop = k_size // 2\n    cropped_big_k = big_k[crop:-crop, crop:-crop]\n    # Normalize to 1\n    return cropped_big_k / cropped_big_k.sum()\n\n\ndef anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):\n    \"\"\" generate an anisotropic Gaussian kernel\n    Args:\n        ksize : e.g., 15, kernel size\n        theta : [0,  pi], rotation angle range\n        l1    : [0.1,50], scaling of eigenvalues\n        l2    : [0.1,l1], scaling of eigenvalues\n        If l1 = l2, will get an isotropic Gaussian kernel.\n    Returns:\n        k     : kernel\n    \"\"\"\n\n    v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))\n    V = np.array([[v[0], v[1]], [v[1], -v[0]]])\n    D = np.array([[l1, 0], [0, l2]])\n    Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))\n    k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)\n\n    return k\n\n\ndef gm_blur_kernel(mean, cov, size=15):\n    center = size / 2.0 + 0.5\n    k = np.zeros([size, size])\n    for y in range(size):\n        for x in range(size):\n            cy = y - center + 1\n            cx = x - center + 1\n            k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)\n\n    k = k / np.sum(k)\n    return k\n\n\ndef shift_pixel(x, sf, upper_left=True):\n    \"\"\"shift pixel for super-resolution with different scale factors\n    Args:\n        x: WxHxC or WxH\n        sf: scale factor\n        upper_left: shift direction\n    \"\"\"\n    h, w = x.shape[:2]\n    shift = (sf - 1) * 0.5\n    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)\n    if upper_left:\n        x1 = xv + shift\n        y1 = yv + shift\n    else:\n        x1 = xv - shift\n        y1 = yv - shift\n\n    x1 = np.clip(x1, 0, w - 1)\n    y1 = np.clip(y1, 0, h - 1)\n\n    if x.ndim == 2:\n        x = interp2d(xv, yv, x)(x1, y1)\n    if x.ndim == 3:\n        for i in range(x.shape[-1]):\n            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)\n\n    return x\n\n\ndef blur(x, k):\n    '''\n    x: image, NxcxHxW\n    k: kernel, Nx1xhxw\n    '''\n    n, c = x.shape[:2]\n    p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2\n    x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')\n    k = k.repeat(1, c, 1, 1)\n    k = k.view(-1, 1, k.shape[2], k.shape[3])\n    x = x.view(1, -1, x.shape[2], x.shape[3])\n    x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)\n    x = x.view(n, c, x.shape[2], x.shape[3])\n\n    return x\n\n\ndef gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):\n    \"\"\"\"\n    # modified version of https://github.com/assafshocher/BlindSR_dataset_generator\n    # Kai Zhang\n    # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var\n    # max_var = 2.5 * sf\n    \"\"\"\n    # Set random eigen-vals (lambdas) and angle (theta) for COV matrix\n    lambda_1 = min_var + np.random.rand() * (max_var - min_var)\n    lambda_2 = min_var + np.random.rand() * (max_var - min_var)\n    theta = np.random.rand() * np.pi  # random theta\n    noise = -noise_level + np.random.rand(*k_size) * noise_level * 2\n\n    # Set COV matrix using Lambdas and Theta\n    LAMBDA = np.diag([lambda_1, lambda_2])\n    Q = np.array([[np.cos(theta), -np.sin(theta)],\n                  [np.sin(theta), np.cos(theta)]])\n    SIGMA = Q @ LAMBDA @ Q.T\n    INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]\n\n    # Set expectation position (shifting kernel for aligned image)\n    MU = k_size // 2 - 0.5 * (scale_factor - 1)  # - 0.5 * (scale_factor - k_size % 2)\n    MU = MU[None, None, :, None]\n\n    # Create meshgrid for Gaussian\n    [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))\n    Z = np.stack([X, Y], 2)[:, :, :, None]\n\n    # Calcualte Gaussian for every pixel of the kernel\n    ZZ = Z - MU\n    ZZ_t = ZZ.transpose(0, 1, 3, 2)\n    raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)\n\n    # shift the kernel so it will be centered\n    # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)\n\n    # Normalize the kernel and return\n    # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)\n    kernel = raw_kernel / np.sum(raw_kernel)\n    return kernel\n\n\ndef fspecial_gaussian(hsize, sigma):\n    hsize = [hsize, hsize]\n    siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]\n    std = sigma\n    [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))\n    arg = -(x * x + y * y) / (2 * std * std)\n    h = np.exp(arg)\n    h[h < scipy.finfo(float).eps * h.max()] = 0\n    sumh = h.sum()\n    if sumh != 0:\n        h = h / sumh\n    return h\n\n\ndef fspecial_laplacian(alpha):\n    alpha = max([0, min([alpha, 1])])\n    h1 = alpha / (alpha + 1)\n    h2 = (1 - alpha) / (alpha + 1)\n    h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]\n    h = np.array(h)\n    return h\n\n\ndef fspecial(filter_type, *args, **kwargs):\n    '''\n    python code from:\n    https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py\n    '''\n    if filter_type == 'gaussian':\n        return fspecial_gaussian(*args, **kwargs)\n    if filter_type == 'laplacian':\n        return fspecial_laplacian(*args, **kwargs)\n\n\n\"\"\"\n# --------------------------------------------\n# degradation models\n# --------------------------------------------\n\"\"\"\n\n\ndef bicubic_degradation(x, sf=3):\n    '''\n    Args:\n        x: HxWxC image, [0, 1]\n        sf: down-scale factor\n    Return:\n        bicubicly downsampled LR image\n    '''\n    x = util.imresize_np(x, scale=1 / sf)\n    return x\n\n\ndef srmd_degradation(x, k, sf=3):\n    ''' blur + bicubic downsampling\n    Args:\n        x: HxWxC image, [0, 1]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    Reference:\n        @inproceedings{zhang2018learning,\n          title={Learning a single convolutional super-resolution network for multiple degradations},\n          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},\n          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},\n          pages={3262--3271},\n          year={2018}\n        }\n    '''\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')  # 'nearest' | 'mirror'\n    x = bicubic_degradation(x, sf=sf)\n    return x\n\n\ndef dpsr_degradation(x, k, sf=3):\n    ''' bicubic downsampling + blur\n    Args:\n        x: HxWxC image, [0, 1]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    Reference:\n        @inproceedings{zhang2019deep,\n          title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},\n          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},\n          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},\n          pages={1671--1681},\n          year={2019}\n        }\n    '''\n    x = bicubic_degradation(x, sf=sf)\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')\n    return x\n\n\ndef classical_degradation(x, k, sf=3):\n    ''' blur + downsampling\n    Args:\n        x: HxWxC image, [0, 1]/[0, 255]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    '''\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')\n    # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))\n    st = 0\n    return x[st::sf, st::sf, ...]\n\n\ndef add_sharpening(img, weight=0.5, radius=50, threshold=10):\n    \"\"\"USM sharpening. borrowed from real-ESRGAN\n    Input image: I; Blurry image: B.\n    1. K = I + weight * (I - B)\n    2. Mask = 1 if abs(I - B) > threshold, else: 0\n    3. Blur mask:\n    4. Out = Mask * K + (1 - Mask) * I\n    Args:\n        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].\n        weight (float): Sharp weight. Default: 1.\n        radius (float): Kernel size of Gaussian blur. Default: 50.\n        threshold (int):\n    \"\"\"\n    if radius % 2 == 0:\n        radius += 1\n    blur = cv2.GaussianBlur(img, (radius, radius), 0)\n    residual = img - blur\n    mask = np.abs(residual) * 255 > threshold\n    mask = mask.astype('float32')\n    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)\n\n    K = img + weight * residual\n    K = np.clip(K, 0, 1)\n    return soft_mask * K + (1 - soft_mask) * img\n\n\ndef add_blur(img, sf=4):\n    wd2 = 4.0 + sf\n    wd = 2.0 + 0.2 * sf\n    if random.random() < 0.5:\n        l1 = wd2 * random.random()\n        l2 = wd2 * random.random()\n        k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)\n    else:\n        k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())\n    img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')\n\n    return img\n\n\ndef add_resize(img, sf=4):\n    rnum = np.random.rand()\n    if rnum > 0.8:  # up\n        sf1 = random.uniform(1, 2)\n    elif rnum < 0.7:  # down\n        sf1 = random.uniform(0.5 / sf, 1)\n    else:\n        sf1 = 1.0\n    img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))\n    img = np.clip(img, 0.0, 1.0)\n\n    return img\n\n\n# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):\n#     noise_level = random.randint(noise_level1, noise_level2)\n#     rnum = np.random.rand()\n#     if rnum > 0.6:  # add color Gaussian noise\n#         img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n#     elif rnum < 0.4:  # add grayscale Gaussian noise\n#         img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n#     else:  # add  noise\n#         L = noise_level2 / 255.\n#         D = np.diag(np.random.rand(3))\n#         U = orth(np.random.rand(3, 3))\n#         conv = np.dot(np.dot(np.transpose(U), D), U)\n#         img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n#     img = np.clip(img, 0.0, 1.0)\n#     return img\n\ndef add_Gaussian_noise(img, noise_level1=2, noise_level2=25):\n    noise_level = random.randint(noise_level1, noise_level2)\n    rnum = np.random.rand()\n    if rnum > 0.6:  # add color Gaussian noise\n        img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n    elif rnum < 0.4:  # add grayscale Gaussian noise\n        img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n    else:  # add  noise\n        L = noise_level2 / 255.\n        D = np.diag(np.random.rand(3))\n        U = orth(np.random.rand(3, 3))\n        conv = np.dot(np.dot(np.transpose(U), D), U)\n        img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_speckle_noise(img, noise_level1=2, noise_level2=25):\n    noise_level = random.randint(noise_level1, noise_level2)\n    img = np.clip(img, 0.0, 1.0)\n    rnum = random.random()\n    if rnum > 0.6:\n        img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n    elif rnum < 0.4:\n        img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n    else:\n        L = noise_level2 / 255.\n        D = np.diag(np.random.rand(3))\n        U = orth(np.random.rand(3, 3))\n        conv = np.dot(np.dot(np.transpose(U), D), U)\n        img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_Poisson_noise(img):\n    img = np.clip((img * 255.0).round(), 0, 255) / 255.\n    vals = 10 ** (2 * random.random() + 2.0)  # [2, 4]\n    if random.random() < 0.5:\n        img = np.random.poisson(img * vals).astype(np.float32) / vals\n    else:\n        img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])\n        img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.\n        noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray\n        img += noise_gray[:, :, np.newaxis]\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_JPEG_noise(img):\n    quality_factor = random.randint(30, 95)\n    img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)\n    result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])\n    img = cv2.imdecode(encimg, 1)\n    img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)\n    return img\n\n\ndef random_crop(lq, hq, sf=4, lq_patchsize=64):\n    h, w = lq.shape[:2]\n    rnd_h = random.randint(0, h - lq_patchsize)\n    rnd_w = random.randint(0, w - lq_patchsize)\n    lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]\n\n    rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)\n    hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]\n    return lq, hq\n\n\ndef degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):\n    \"\"\"\n    This is the degradation model of BSRGAN from the paper\n    \"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution\"\n    ----------\n    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)\n    sf: scale factor\n    isp_model: camera ISP model\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25\n    sf_ori = sf\n\n    h1, w1 = img.shape[:2]\n    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop\n    h, w = img.shape[:2]\n\n    if h < lq_patchsize * sf or w < lq_patchsize * sf:\n        raise ValueError(f'img size ({h1}X{w1}) is too small!')\n\n    hq = img.copy()\n\n    if sf == 4 and random.random() < scale2_prob:  # downsample1\n        if np.random.rand() < 0.5:\n            img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),\n                             interpolation=random.choice([1, 2, 3]))\n        else:\n            img = util.imresize_np(img, 1 / 2, True)\n        img = np.clip(img, 0.0, 1.0)\n        sf = 2\n\n    shuffle_order = random.sample(range(7), 7)\n    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)\n    if idx1 > idx2:  # keep downsample3 last\n        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]\n\n    for i in shuffle_order:\n\n        if i == 0:\n            img = add_blur(img, sf=sf)\n\n        elif i == 1:\n            img = add_blur(img, sf=sf)\n\n        elif i == 2:\n            a, b = img.shape[1], img.shape[0]\n            # downsample2\n            if random.random() < 0.75:\n                sf1 = random.uniform(1, 2 * sf)\n                img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),\n                                 interpolation=random.choice([1, 2, 3]))\n            else:\n                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))\n                k_shifted = shift_pixel(k, sf)\n                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel\n                img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')\n                img = img[0::sf, 0::sf, ...]  # nearest downsampling\n            img = np.clip(img, 0.0, 1.0)\n\n        elif i == 3:\n            # downsample3\n            img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))\n            img = np.clip(img, 0.0, 1.0)\n\n        elif i == 4:\n            # add Gaussian noise\n            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)\n\n        elif i == 5:\n            # add JPEG noise\n            if random.random() < jpeg_prob:\n                img = add_JPEG_noise(img)\n\n        elif i == 6:\n            # add processed camera sensor noise\n            if random.random() < isp_prob and isp_model is not None:\n                with torch.no_grad():\n                    img, hq = isp_model.forward(img.copy(), hq)\n\n    # add final JPEG compression noise\n    img = add_JPEG_noise(img)\n\n    # random crop\n    img, hq = random_crop(img, hq, sf_ori, lq_patchsize)\n\n    return img, hq\n\n\n# todo no isp_model?\ndef degradation_bsrgan_variant(image, sf=4, isp_model=None):\n    \"\"\"\n    This is the degradation model of BSRGAN from the paper\n    \"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution\"\n    ----------\n    sf: scale factor\n    isp_model: camera ISP model\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n    image = util.uint2single(image)\n    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25\n    sf_ori = sf\n\n    h1, w1 = image.shape[:2]\n    image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop\n    h, w = image.shape[:2]\n\n    hq = image.copy()\n\n    if sf == 4 and random.random() < scale2_prob:  # downsample1\n        if np.random.rand() < 0.5:\n            image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),\n                               interpolation=random.choice([1, 2, 3]))\n        else:\n            image = util.imresize_np(image, 1 / 2, True)\n        image = np.clip(image, 0.0, 1.0)\n        sf = 2\n\n    shuffle_order = random.sample(range(7), 7)\n    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)\n    if idx1 > idx2:  # keep downsample3 last\n        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]\n\n    for i in shuffle_order:\n\n        if i == 0:\n            image = add_blur(image, sf=sf)\n\n        elif i == 1:\n            image = add_blur(image, sf=sf)\n\n        elif i == 2:\n            a, b = image.shape[1], image.shape[0]\n            # downsample2\n            if random.random() < 0.75:\n                sf1 = random.uniform(1, 2 * sf)\n                image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),\n                                   interpolation=random.choice([1, 2, 3]))\n            else:\n                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))\n                k_shifted = shift_pixel(k, sf)\n                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel\n                image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')\n                image = image[0::sf, 0::sf, ...]  # nearest downsampling\n            image = np.clip(image, 0.0, 1.0)\n\n        elif i == 3:\n            # downsample3\n            image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))\n            image = np.clip(image, 0.0, 1.0)\n\n        elif i == 4:\n            # add Gaussian noise\n            image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)\n\n        elif i == 5:\n            # add JPEG noise\n            if random.random() < jpeg_prob:\n                image = add_JPEG_noise(image)\n\n        # elif i == 6:\n        #     # add processed camera sensor noise\n        #     if random.random() < isp_prob and isp_model is not None:\n        #         with torch.no_grad():\n        #             img, hq = isp_model.forward(img.copy(), hq)\n\n    # add final JPEG compression noise\n    image = add_JPEG_noise(image)\n    image = util.single2uint(image)\n    example = {\"image\":image}\n    return example\n\n\n# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...\ndef degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):\n    \"\"\"\n    This is an extended degradation model by combining\n    the degradation models of BSRGAN and Real-ESRGAN\n    ----------\n    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)\n    sf: scale factor\n    use_shuffle: the degradation shuffle\n    use_sharp: sharpening the img\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n\n    h1, w1 = img.shape[:2]\n    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop\n    h, w = img.shape[:2]\n\n    if h < lq_patchsize * sf or w < lq_patchsize * sf:\n        raise ValueError(f'img size ({h1}X{w1}) is too small!')\n\n    if use_sharp:\n        img = add_sharpening(img)\n    hq = img.copy()\n\n    if random.random() < shuffle_prob:\n        shuffle_order = random.sample(range(13), 13)\n    else:\n        shuffle_order = list(range(13))\n        # local shuffle for noise, JPEG is always the last one\n        shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))\n        shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))\n\n    poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1\n\n    for i in shuffle_order:\n        if i == 0:\n            img = add_blur(img, sf=sf)\n        elif i == 1:\n            img = add_resize(img, sf=sf)\n        elif i == 2:\n            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)\n        elif i == 3:\n            if random.random() < poisson_prob:\n                img = add_Poisson_noise(img)\n        elif i == 4:\n            if random.random() < speckle_prob:\n                img = add_speckle_noise(img)\n        elif i == 5:\n            if random.random() < isp_prob and isp_model is not None:\n                with torch.no_grad():\n                    img, hq = isp_model.forward(img.copy(), hq)\n        elif i == 6:\n            img = add_JPEG_noise(img)\n        elif i == 7:\n            img = add_blur(img, sf=sf)\n        elif i == 8:\n            img = add_resize(img, sf=sf)\n        elif i == 9:\n            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)\n        elif i == 10:\n            if random.random() < poisson_prob:\n                img = add_Poisson_noise(img)\n        elif i == 11:\n            if random.random() < speckle_prob:\n                img = add_speckle_noise(img)\n        elif i == 12:\n            if random.random() < isp_prob and isp_model is not None:\n                with torch.no_grad():\n                    img, hq = isp_model.forward(img.copy(), hq)\n        else:\n            print('check the shuffle!')\n\n    # resize to desired size\n    img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),\n                     interpolation=random.choice([1, 2, 3]))\n\n    # add final JPEG compression noise\n    img = add_JPEG_noise(img)\n\n    # random crop\n    img, hq = random_crop(img, hq, sf, lq_patchsize)\n\n    return img, hq\n\n\nif __name__ == '__main__':\n\tprint(\"hey\")\n\timg = util.imread_uint('utils/test.png', 3)\n\tprint(img)\n\timg = util.uint2single(img)\n\tprint(img)\n\timg = img[:448, :448]\n\th = img.shape[0] // 4\n\tprint(\"resizing to\", h)\n\tsf = 4\n\tdeg_fn = partial(degradation_bsrgan_variant, sf=sf)\n\tfor i in range(20):\n\t\tprint(i)\n\t\timg_lq = deg_fn(img)\n\t\tprint(img_lq)\n\t\timg_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)[\"image\"]\n\t\tprint(img_lq.shape)\n\t\tprint(\"bicubic\", img_lq_bicubic.shape)\n\t\tprint(img_hq.shape)\n\t\tlq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),\n\t\t                        interpolation=0)\n\t\tlq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),\n\t\t                        interpolation=0)\n\t\timg_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)\n\t\tutil.imsave(img_concat, str(i) + '.png')\n\n\n"
  },
  {
    "path": "ldm/modules/image_degradation/bsrgan_light.py",
    "content": "# -*- coding: utf-8 -*-\nimport numpy as np\nimport cv2\nimport torch\n\nfrom functools import partial\nimport random\nfrom scipy import ndimage\nimport scipy\nimport scipy.stats as ss\nfrom scipy.interpolate import interp2d\nfrom scipy.linalg import orth\nimport albumentations\n\nimport ldm.modules.image_degradation.utils_image as util\n\n\"\"\"\n# --------------------------------------------\n# Super-Resolution\n# --------------------------------------------\n#\n# Kai Zhang (cskaizhang@gmail.com)\n# https://github.com/cszn\n# From 2019/03--2021/08\n# --------------------------------------------\n\"\"\"\n\n\ndef modcrop_np(img, sf):\n    '''\n    Args:\n        img: numpy image, WxH or WxHxC\n        sf: scale factor\n    Return:\n        cropped image\n    '''\n    w, h = img.shape[:2]\n    im = np.copy(img)\n    return im[:w - w % sf, :h - h % sf, ...]\n\n\n\"\"\"\n# --------------------------------------------\n# anisotropic Gaussian kernels\n# --------------------------------------------\n\"\"\"\n\n\ndef analytic_kernel(k):\n    \"\"\"Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)\"\"\"\n    k_size = k.shape[0]\n    # Calculate the big kernels size\n    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))\n    # Loop over the small kernel to fill the big one\n    for r in range(k_size):\n        for c in range(k_size):\n            big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k\n    # Crop the edges of the big kernel to ignore very small values and increase run time of SR\n    crop = k_size // 2\n    cropped_big_k = big_k[crop:-crop, crop:-crop]\n    # Normalize to 1\n    return cropped_big_k / cropped_big_k.sum()\n\n\ndef anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):\n    \"\"\" generate an anisotropic Gaussian kernel\n    Args:\n        ksize : e.g., 15, kernel size\n        theta : [0,  pi], rotation angle range\n        l1    : [0.1,50], scaling of eigenvalues\n        l2    : [0.1,l1], scaling of eigenvalues\n        If l1 = l2, will get an isotropic Gaussian kernel.\n    Returns:\n        k     : kernel\n    \"\"\"\n\n    v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))\n    V = np.array([[v[0], v[1]], [v[1], -v[0]]])\n    D = np.array([[l1, 0], [0, l2]])\n    Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))\n    k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)\n\n    return k\n\n\ndef gm_blur_kernel(mean, cov, size=15):\n    center = size / 2.0 + 0.5\n    k = np.zeros([size, size])\n    for y in range(size):\n        for x in range(size):\n            cy = y - center + 1\n            cx = x - center + 1\n            k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)\n\n    k = k / np.sum(k)\n    return k\n\n\ndef shift_pixel(x, sf, upper_left=True):\n    \"\"\"shift pixel for super-resolution with different scale factors\n    Args:\n        x: WxHxC or WxH\n        sf: scale factor\n        upper_left: shift direction\n    \"\"\"\n    h, w = x.shape[:2]\n    shift = (sf - 1) * 0.5\n    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)\n    if upper_left:\n        x1 = xv + shift\n        y1 = yv + shift\n    else:\n        x1 = xv - shift\n        y1 = yv - shift\n\n    x1 = np.clip(x1, 0, w - 1)\n    y1 = np.clip(y1, 0, h - 1)\n\n    if x.ndim == 2:\n        x = interp2d(xv, yv, x)(x1, y1)\n    if x.ndim == 3:\n        for i in range(x.shape[-1]):\n            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)\n\n    return x\n\n\ndef blur(x, k):\n    '''\n    x: image, NxcxHxW\n    k: kernel, Nx1xhxw\n    '''\n    n, c = x.shape[:2]\n    p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2\n    x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')\n    k = k.repeat(1, c, 1, 1)\n    k = k.view(-1, 1, k.shape[2], k.shape[3])\n    x = x.view(1, -1, x.shape[2], x.shape[3])\n    x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)\n    x = x.view(n, c, x.shape[2], x.shape[3])\n\n    return x\n\n\ndef gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):\n    \"\"\"\"\n    # modified version of https://github.com/assafshocher/BlindSR_dataset_generator\n    # Kai Zhang\n    # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var\n    # max_var = 2.5 * sf\n    \"\"\"\n    # Set random eigen-vals (lambdas) and angle (theta) for COV matrix\n    lambda_1 = min_var + np.random.rand() * (max_var - min_var)\n    lambda_2 = min_var + np.random.rand() * (max_var - min_var)\n    theta = np.random.rand() * np.pi  # random theta\n    noise = -noise_level + np.random.rand(*k_size) * noise_level * 2\n\n    # Set COV matrix using Lambdas and Theta\n    LAMBDA = np.diag([lambda_1, lambda_2])\n    Q = np.array([[np.cos(theta), -np.sin(theta)],\n                  [np.sin(theta), np.cos(theta)]])\n    SIGMA = Q @ LAMBDA @ Q.T\n    INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]\n\n    # Set expectation position (shifting kernel for aligned image)\n    MU = k_size // 2 - 0.5 * (scale_factor - 1)  # - 0.5 * (scale_factor - k_size % 2)\n    MU = MU[None, None, :, None]\n\n    # Create meshgrid for Gaussian\n    [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))\n    Z = np.stack([X, Y], 2)[:, :, :, None]\n\n    # Calcualte Gaussian for every pixel of the kernel\n    ZZ = Z - MU\n    ZZ_t = ZZ.transpose(0, 1, 3, 2)\n    raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)\n\n    # shift the kernel so it will be centered\n    # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)\n\n    # Normalize the kernel and return\n    # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)\n    kernel = raw_kernel / np.sum(raw_kernel)\n    return kernel\n\n\ndef fspecial_gaussian(hsize, sigma):\n    hsize = [hsize, hsize]\n    siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]\n    std = sigma\n    [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))\n    arg = -(x * x + y * y) / (2 * std * std)\n    h = np.exp(arg)\n    h[h < scipy.finfo(float).eps * h.max()] = 0\n    sumh = h.sum()\n    if sumh != 0:\n        h = h / sumh\n    return h\n\n\ndef fspecial_laplacian(alpha):\n    alpha = max([0, min([alpha, 1])])\n    h1 = alpha / (alpha + 1)\n    h2 = (1 - alpha) / (alpha + 1)\n    h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]\n    h = np.array(h)\n    return h\n\n\ndef fspecial(filter_type, *args, **kwargs):\n    '''\n    python code from:\n    https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py\n    '''\n    if filter_type == 'gaussian':\n        return fspecial_gaussian(*args, **kwargs)\n    if filter_type == 'laplacian':\n        return fspecial_laplacian(*args, **kwargs)\n\n\n\"\"\"\n# --------------------------------------------\n# degradation models\n# --------------------------------------------\n\"\"\"\n\n\ndef bicubic_degradation(x, sf=3):\n    '''\n    Args:\n        x: HxWxC image, [0, 1]\n        sf: down-scale factor\n    Return:\n        bicubicly downsampled LR image\n    '''\n    x = util.imresize_np(x, scale=1 / sf)\n    return x\n\n\ndef srmd_degradation(x, k, sf=3):\n    ''' blur + bicubic downsampling\n    Args:\n        x: HxWxC image, [0, 1]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    Reference:\n        @inproceedings{zhang2018learning,\n          title={Learning a single convolutional super-resolution network for multiple degradations},\n          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},\n          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},\n          pages={3262--3271},\n          year={2018}\n        }\n    '''\n    x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')  # 'nearest' | 'mirror'\n    x = bicubic_degradation(x, sf=sf)\n    return x\n\n\ndef dpsr_degradation(x, k, sf=3):\n    ''' bicubic downsampling + blur\n    Args:\n        x: HxWxC image, [0, 1]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    Reference:\n        @inproceedings{zhang2019deep,\n          title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},\n          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},\n          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},\n          pages={1671--1681},\n          year={2019}\n        }\n    '''\n    x = bicubic_degradation(x, sf=sf)\n    x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')\n    return x\n\n\ndef classical_degradation(x, k, sf=3):\n    ''' blur + downsampling\n    Args:\n        x: HxWxC image, [0, 1]/[0, 255]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    '''\n    x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')\n    # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))\n    st = 0\n    return x[st::sf, st::sf, ...]\n\n\ndef add_sharpening(img, weight=0.5, radius=50, threshold=10):\n    \"\"\"USM sharpening. borrowed from real-ESRGAN\n    Input image: I; Blurry image: B.\n    1. K = I + weight * (I - B)\n    2. Mask = 1 if abs(I - B) > threshold, else: 0\n    3. Blur mask:\n    4. Out = Mask * K + (1 - Mask) * I\n    Args:\n        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].\n        weight (float): Sharp weight. Default: 1.\n        radius (float): Kernel size of Gaussian blur. Default: 50.\n        threshold (int):\n    \"\"\"\n    if radius % 2 == 0:\n        radius += 1\n    blur = cv2.GaussianBlur(img, (radius, radius), 0)\n    residual = img - blur\n    mask = np.abs(residual) * 255 > threshold\n    mask = mask.astype('float32')\n    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)\n\n    K = img + weight * residual\n    K = np.clip(K, 0, 1)\n    return soft_mask * K + (1 - soft_mask) * img\n\n\ndef add_blur(img, sf=4):\n    wd2 = 4.0 + sf\n    wd = 2.0 + 0.2 * sf\n\n    wd2 = wd2/4\n    wd = wd/4\n\n    if random.random() < 0.5:\n        l1 = wd2 * random.random()\n        l2 = wd2 * random.random()\n        k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)\n    else:\n        k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())\n    img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')\n\n    return img\n\n\ndef add_resize(img, sf=4):\n    rnum = np.random.rand()\n    if rnum > 0.8:  # up\n        sf1 = random.uniform(1, 2)\n    elif rnum < 0.7:  # down\n        sf1 = random.uniform(0.5 / sf, 1)\n    else:\n        sf1 = 1.0\n    img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))\n    img = np.clip(img, 0.0, 1.0)\n\n    return img\n\n\n# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):\n#     noise_level = random.randint(noise_level1, noise_level2)\n#     rnum = np.random.rand()\n#     if rnum > 0.6:  # add color Gaussian noise\n#         img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n#     elif rnum < 0.4:  # add grayscale Gaussian noise\n#         img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n#     else:  # add  noise\n#         L = noise_level2 / 255.\n#         D = np.diag(np.random.rand(3))\n#         U = orth(np.random.rand(3, 3))\n#         conv = np.dot(np.dot(np.transpose(U), D), U)\n#         img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n#     img = np.clip(img, 0.0, 1.0)\n#     return img\n\ndef add_Gaussian_noise(img, noise_level1=2, noise_level2=25):\n    noise_level = random.randint(noise_level1, noise_level2)\n    rnum = np.random.rand()\n    if rnum > 0.6:  # add color Gaussian noise\n        img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n    elif rnum < 0.4:  # add grayscale Gaussian noise\n        img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n    else:  # add  noise\n        L = noise_level2 / 255.\n        D = np.diag(np.random.rand(3))\n        U = orth(np.random.rand(3, 3))\n        conv = np.dot(np.dot(np.transpose(U), D), U)\n        img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_speckle_noise(img, noise_level1=2, noise_level2=25):\n    noise_level = random.randint(noise_level1, noise_level2)\n    img = np.clip(img, 0.0, 1.0)\n    rnum = random.random()\n    if rnum > 0.6:\n        img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n    elif rnum < 0.4:\n        img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n    else:\n        L = noise_level2 / 255.\n        D = np.diag(np.random.rand(3))\n        U = orth(np.random.rand(3, 3))\n        conv = np.dot(np.dot(np.transpose(U), D), U)\n        img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_Poisson_noise(img):\n    img = np.clip((img * 255.0).round(), 0, 255) / 255.\n    vals = 10 ** (2 * random.random() + 2.0)  # [2, 4]\n    if random.random() < 0.5:\n        img = np.random.poisson(img * vals).astype(np.float32) / vals\n    else:\n        img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])\n        img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.\n        noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray\n        img += noise_gray[:, :, np.newaxis]\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_JPEG_noise(img):\n    quality_factor = random.randint(80, 95)\n    img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)\n    result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])\n    img = cv2.imdecode(encimg, 1)\n    img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)\n    return img\n\n\ndef random_crop(lq, hq, sf=4, lq_patchsize=64):\n    h, w = lq.shape[:2]\n    rnd_h = random.randint(0, h - lq_patchsize)\n    rnd_w = random.randint(0, w - lq_patchsize)\n    lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]\n\n    rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)\n    hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]\n    return lq, hq\n\n\ndef degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):\n    \"\"\"\n    This is the degradation model of BSRGAN from the paper\n    \"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution\"\n    ----------\n    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)\n    sf: scale factor\n    isp_model: camera ISP model\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25\n    sf_ori = sf\n\n    h1, w1 = img.shape[:2]\n    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop\n    h, w = img.shape[:2]\n\n    if h < lq_patchsize * sf or w < lq_patchsize * sf:\n        raise ValueError(f'img size ({h1}X{w1}) is too small!')\n\n    hq = img.copy()\n\n    if sf == 4 and random.random() < scale2_prob:  # downsample1\n        if np.random.rand() < 0.5:\n            img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),\n                             interpolation=random.choice([1, 2, 3]))\n        else:\n            img = util.imresize_np(img, 1 / 2, True)\n        img = np.clip(img, 0.0, 1.0)\n        sf = 2\n\n    shuffle_order = random.sample(range(7), 7)\n    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)\n    if idx1 > idx2:  # keep downsample3 last\n        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]\n\n    for i in shuffle_order:\n\n        if i == 0:\n            img = add_blur(img, sf=sf)\n\n        elif i == 1:\n            img = add_blur(img, sf=sf)\n\n        elif i == 2:\n            a, b = img.shape[1], img.shape[0]\n            # downsample2\n            if random.random() < 0.75:\n                sf1 = random.uniform(1, 2 * sf)\n                img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),\n                                 interpolation=random.choice([1, 2, 3]))\n            else:\n                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))\n                k_shifted = shift_pixel(k, sf)\n                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel\n                img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')\n                img = img[0::sf, 0::sf, ...]  # nearest downsampling\n            img = np.clip(img, 0.0, 1.0)\n\n        elif i == 3:\n            # downsample3\n            img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))\n            img = np.clip(img, 0.0, 1.0)\n\n        elif i == 4:\n            # add Gaussian noise\n            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)\n\n        elif i == 5:\n            # add JPEG noise\n            if random.random() < jpeg_prob:\n                img = add_JPEG_noise(img)\n\n        elif i == 6:\n            # add processed camera sensor noise\n            if random.random() < isp_prob and isp_model is not None:\n                with torch.no_grad():\n                    img, hq = isp_model.forward(img.copy(), hq)\n\n    # add final JPEG compression noise\n    img = add_JPEG_noise(img)\n\n    # random crop\n    img, hq = random_crop(img, hq, sf_ori, lq_patchsize)\n\n    return img, hq\n\n\n# todo no isp_model?\ndef degradation_bsrgan_variant(image, sf=4, isp_model=None):\n    \"\"\"\n    This is the degradation model of BSRGAN from the paper\n    \"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution\"\n    ----------\n    sf: scale factor\n    isp_model: camera ISP model\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n    image = util.uint2single(image)\n    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25\n    sf_ori = sf\n\n    h1, w1 = image.shape[:2]\n    image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop\n    h, w = image.shape[:2]\n\n    hq = image.copy()\n\n    if sf == 4 and random.random() < scale2_prob:  # downsample1\n        if np.random.rand() < 0.5:\n            image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),\n                               interpolation=random.choice([1, 2, 3]))\n        else:\n            image = util.imresize_np(image, 1 / 2, True)\n        image = np.clip(image, 0.0, 1.0)\n        sf = 2\n\n    shuffle_order = random.sample(range(7), 7)\n    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)\n    if idx1 > idx2:  # keep downsample3 last\n        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]\n\n    for i in shuffle_order:\n\n        if i == 0:\n            image = add_blur(image, sf=sf)\n\n        # elif i == 1:\n        #     image = add_blur(image, sf=sf)\n\n        if i == 0:\n            pass\n\n        elif i == 2:\n            a, b = image.shape[1], image.shape[0]\n            # downsample2\n            if random.random() < 0.8:\n                sf1 = random.uniform(1, 2 * sf)\n                image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),\n                                   interpolation=random.choice([1, 2, 3]))\n            else:\n                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))\n                k_shifted = shift_pixel(k, sf)\n                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel\n                image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')\n                image = image[0::sf, 0::sf, ...]  # nearest downsampling\n\n            image = np.clip(image, 0.0, 1.0)\n\n        elif i == 3:\n            # downsample3\n            image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))\n            image = np.clip(image, 0.0, 1.0)\n\n        elif i == 4:\n            # add Gaussian noise\n            image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)\n\n        elif i == 5:\n            # add JPEG noise\n            if random.random() < jpeg_prob:\n                image = add_JPEG_noise(image)\n        #\n        # elif i == 6:\n        #     # add processed camera sensor noise\n        #     if random.random() < isp_prob and isp_model is not None:\n        #         with torch.no_grad():\n        #             img, hq = isp_model.forward(img.copy(), hq)\n\n    # add final JPEG compression noise\n    image = add_JPEG_noise(image)\n    image = util.single2uint(image)\n    example = {\"image\": image}\n    return example\n\n\n\n\nif __name__ == '__main__':\n    print(\"hey\")\n    img = util.imread_uint('utils/test.png', 3)\n    img = img[:448, :448]\n    h = img.shape[0] // 4\n    print(\"resizing to\", h)\n    sf = 4\n    deg_fn = partial(degradation_bsrgan_variant, sf=sf)\n    for i in range(20):\n        print(i)\n        img_hq = img\n        img_lq = deg_fn(img)[\"image\"]\n        img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)\n        print(img_lq)\n        img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[\"image\"]\n        print(img_lq.shape)\n        print(\"bicubic\", img_lq_bicubic.shape)\n        print(img_hq.shape)\n        lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),\n                                interpolation=0)\n        lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),\n                                        (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),\n                                        interpolation=0)\n        img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)\n        util.imsave(img_concat, str(i) + '.png')\n"
  },
  {
    "path": "ldm/modules/image_degradation/utils_image.py",
    "content": "import os\nimport math\nimport random\nimport numpy as np\nimport torch\nimport cv2\nfrom torchvision.utils import make_grid\nfrom datetime import datetime\n#import matplotlib.pyplot as plt   # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py\n\n\nos.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n\n\n'''\n# --------------------------------------------\n# Kai Zhang (github: https://github.com/cszn)\n# 03/Mar/2019\n# --------------------------------------------\n# https://github.com/twhui/SRGAN-pyTorch\n# https://github.com/xinntao/BasicSR\n# --------------------------------------------\n'''\n\n\nIMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']\n\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\n\ndef get_timestamp():\n    return datetime.now().strftime('%y%m%d-%H%M%S')\n\n\ndef imshow(x, title=None, cbar=False, figsize=None):\n    plt.figure(figsize=figsize)\n    plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')\n    if title:\n        plt.title(title)\n    if cbar:\n        plt.colorbar()\n    plt.show()\n\n\ndef surf(Z, cmap='rainbow', figsize=None):\n    plt.figure(figsize=figsize)\n    ax3 = plt.axes(projection='3d')\n\n    w, h = Z.shape[:2]\n    xx = np.arange(0,w,1)\n    yy = np.arange(0,h,1)\n    X, Y = np.meshgrid(xx, yy)\n    ax3.plot_surface(X,Y,Z,cmap=cmap)\n    #ax3.contour(X,Y,Z, zdim='z',offset=-2，cmap=cmap)\n    plt.show()\n\n\n'''\n# --------------------------------------------\n# get image pathes\n# --------------------------------------------\n'''\n\n\ndef get_image_paths(dataroot):\n    paths = None  # return None if dataroot is None\n    if dataroot is not None:\n        paths = sorted(_get_paths_from_images(dataroot))\n    return paths\n\n\ndef _get_paths_from_images(path):\n    assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)\n    images = []\n    for dirpath, _, fnames in sorted(os.walk(path)):\n        for fname in sorted(fnames):\n            if is_image_file(fname):\n                img_path = os.path.join(dirpath, fname)\n                images.append(img_path)\n    assert images, '{:s} has no valid image file'.format(path)\n    return images\n\n\n'''\n# --------------------------------------------\n# split large images into small images \n# --------------------------------------------\n'''\n\n\ndef patches_from_image(img, p_size=512, p_overlap=64, p_max=800):\n    w, h = img.shape[:2]\n    patches = []\n    if w > p_max and h > p_max:\n        w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))\n        h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))\n        w1.append(w-p_size)\n        h1.append(h-p_size)\n#        print(w1)\n#        print(h1)\n        for i in w1:\n            for j in h1:\n                patches.append(img[i:i+p_size, j:j+p_size,:])\n    else:\n        patches.append(img)\n\n    return patches\n\n\ndef imssave(imgs, img_path):\n    \"\"\"\n    imgs: list, N images of size WxHxC\n    \"\"\"\n    img_name, ext = os.path.splitext(os.path.basename(img_path))\n\n    for i, img in enumerate(imgs):\n        if img.ndim == 3:\n            img = img[:, :, [2, 1, 0]]\n        new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')\n        cv2.imwrite(new_path, img)\n\n\ndef split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):\n    \"\"\"\n    split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),\n    and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)\n    will be splitted.\n    Args:\n        original_dataroot:\n        taget_dataroot:\n        p_size: size of small images\n        p_overlap: patch size in training is a good choice\n        p_max: images with smaller size than (p_max)x(p_max) keep unchanged.\n    \"\"\"\n    paths = get_image_paths(original_dataroot)\n    for img_path in paths:\n        # img_name, ext = os.path.splitext(os.path.basename(img_path))\n        img = imread_uint(img_path, n_channels=n_channels)\n        patches = patches_from_image(img, p_size, p_overlap, p_max)\n        imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))\n        #if original_dataroot == taget_dataroot:\n        #del img_path\n\n'''\n# --------------------------------------------\n# makedir\n# --------------------------------------------\n'''\n\n\ndef mkdir(path):\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n\ndef mkdirs(paths):\n    if isinstance(paths, str):\n        mkdir(paths)\n    else:\n        for path in paths:\n            mkdir(path)\n\n\ndef mkdir_and_rename(path):\n    if os.path.exists(path):\n        new_name = path + '_archived_' + get_timestamp()\n        print('Path already exists. Rename it to [{:s}]'.format(new_name))\n        os.rename(path, new_name)\n    os.makedirs(path)\n\n\n'''\n# --------------------------------------------\n# read image from path\n# opencv is fast, but read BGR numpy image\n# --------------------------------------------\n'''\n\n\n# --------------------------------------------\n# get uint8 image of size HxWxn_channles (RGB)\n# --------------------------------------------\ndef imread_uint(path, n_channels=3):\n    #  input: path\n    # output: HxWx3(RGB or GGG), or HxWx1 (G)\n    if n_channels == 1:\n        img = cv2.imread(path, 0)  # cv2.IMREAD_GRAYSCALE\n        img = np.expand_dims(img, axis=2)  # HxWx1\n    elif n_channels == 3:\n        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGR or G\n        if img.ndim == 2:\n            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # GGG\n        else:\n            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # RGB\n    return img\n\n\n# --------------------------------------------\n# matlab's imwrite\n# --------------------------------------------\ndef imsave(img, img_path):\n    img = np.squeeze(img)\n    if img.ndim == 3:\n        img = img[:, :, [2, 1, 0]]\n    cv2.imwrite(img_path, img)\n\ndef imwrite(img, img_path):\n    img = np.squeeze(img)\n    if img.ndim == 3:\n        img = img[:, :, [2, 1, 0]]\n    cv2.imwrite(img_path, img)\n\n\n\n# --------------------------------------------\n# get single image of size HxWxn_channles (BGR)\n# --------------------------------------------\ndef read_img(path):\n    # read image by cv2\n    # return: Numpy float32, HWC, BGR, [0,1]\n    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # cv2.IMREAD_GRAYSCALE\n    img = img.astype(np.float32) / 255.\n    if img.ndim == 2:\n        img = np.expand_dims(img, axis=2)\n    # some images have 4 channels\n    if img.shape[2] > 3:\n        img = img[:, :, :3]\n    return img\n\n\n'''\n# --------------------------------------------\n# image format conversion\n# --------------------------------------------\n# numpy(single) <--->  numpy(unit)\n# numpy(single) <--->  tensor\n# numpy(unit)   <--->  tensor\n# --------------------------------------------\n'''\n\n\n# --------------------------------------------\n# numpy(single) [0, 1] <--->  numpy(unit)\n# --------------------------------------------\n\n\ndef uint2single(img):\n\n    return np.float32(img/255.)\n\n\ndef single2uint(img):\n\n    return np.uint8((img.clip(0, 1)*255.).round())\n\n\ndef uint162single(img):\n\n    return np.float32(img/65535.)\n\n\ndef single2uint16(img):\n\n    return np.uint16((img.clip(0, 1)*65535.).round())\n\n\n# --------------------------------------------\n# numpy(unit) (HxWxC or HxW) <--->  tensor\n# --------------------------------------------\n\n\n# convert uint to 4-dimensional torch tensor\ndef uint2tensor4(img):\n    if img.ndim == 2:\n        img = np.expand_dims(img, axis=2)\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)\n\n\n# convert uint to 3-dimensional torch tensor\ndef uint2tensor3(img):\n    if img.ndim == 2:\n        img = np.expand_dims(img, axis=2)\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)\n\n\n# convert 2/3/4-dimensional torch tensor to uint\ndef tensor2uint(img):\n    img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()\n    if img.ndim == 3:\n        img = np.transpose(img, (1, 2, 0))\n    return np.uint8((img*255.0).round())\n\n\n# --------------------------------------------\n# numpy(single) (HxWxC) <--->  tensor\n# --------------------------------------------\n\n\n# convert single (HxWxC) to 3-dimensional torch tensor\ndef single2tensor3(img):\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()\n\n\n# convert single (HxWxC) to 4-dimensional torch tensor\ndef single2tensor4(img):\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)\n\n\n# convert torch tensor to single\ndef tensor2single(img):\n    img = img.data.squeeze().float().cpu().numpy()\n    if img.ndim == 3:\n        img = np.transpose(img, (1, 2, 0))\n\n    return img\n\n# convert torch tensor to single\ndef tensor2single3(img):\n    img = img.data.squeeze().float().cpu().numpy()\n    if img.ndim == 3:\n        img = np.transpose(img, (1, 2, 0))\n    elif img.ndim == 2:\n        img = np.expand_dims(img, axis=2)\n    return img\n\n\ndef single2tensor5(img):\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)\n\n\ndef single32tensor5(img):\n    return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)\n\n\ndef single42tensor4(img):\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()\n\n\n# from skimage.io import imread, imsave\ndef tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):\n    '''\n    Converts a torch Tensor into an image Numpy array of BGR channel order\n    Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order\n    Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)\n    '''\n    tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # squeeze first, then clamp\n    tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0,1]\n    n_dim = tensor.dim()\n    if n_dim == 4:\n        n_img = len(tensor)\n        img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()\n        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR\n    elif n_dim == 3:\n        img_np = tensor.numpy()\n        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR\n    elif n_dim == 2:\n        img_np = tensor.numpy()\n    else:\n        raise TypeError(\n            'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))\n    if out_type == np.uint8:\n        img_np = (img_np * 255.0).round()\n        # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.\n    return img_np.astype(out_type)\n\n\n'''\n# --------------------------------------------\n# Augmentation, flipe and/or rotate\n# --------------------------------------------\n# The following two are enough.\n# (1) augmet_img: numpy image of WxHxC or WxH\n# (2) augment_img_tensor4: tensor image 1xCxWxH\n# --------------------------------------------\n'''\n\n\ndef augment_img(img, mode=0):\n    '''Kai Zhang (github: https://github.com/cszn)\n    '''\n    if mode == 0:\n        return img\n    elif mode == 1:\n        return np.flipud(np.rot90(img))\n    elif mode == 2:\n        return np.flipud(img)\n    elif mode == 3:\n        return np.rot90(img, k=3)\n    elif mode == 4:\n        return np.flipud(np.rot90(img, k=2))\n    elif mode == 5:\n        return np.rot90(img)\n    elif mode == 6:\n        return np.rot90(img, k=2)\n    elif mode == 7:\n        return np.flipud(np.rot90(img, k=3))\n\n\ndef augment_img_tensor4(img, mode=0):\n    '''Kai Zhang (github: https://github.com/cszn)\n    '''\n    if mode == 0:\n        return img\n    elif mode == 1:\n        return img.rot90(1, [2, 3]).flip([2])\n    elif mode == 2:\n        return img.flip([2])\n    elif mode == 3:\n        return img.rot90(3, [2, 3])\n    elif mode == 4:\n        return img.rot90(2, [2, 3]).flip([2])\n    elif mode == 5:\n        return img.rot90(1, [2, 3])\n    elif mode == 6:\n        return img.rot90(2, [2, 3])\n    elif mode == 7:\n        return img.rot90(3, [2, 3]).flip([2])\n\n\ndef augment_img_tensor(img, mode=0):\n    '''Kai Zhang (github: https://github.com/cszn)\n    '''\n    img_size = img.size()\n    img_np = img.data.cpu().numpy()\n    if len(img_size) == 3:\n        img_np = np.transpose(img_np, (1, 2, 0))\n    elif len(img_size) == 4:\n        img_np = np.transpose(img_np, (2, 3, 1, 0))\n    img_np = augment_img(img_np, mode=mode)\n    img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))\n    if len(img_size) == 3:\n        img_tensor = img_tensor.permute(2, 0, 1)\n    elif len(img_size) == 4:\n        img_tensor = img_tensor.permute(3, 2, 0, 1)\n\n    return img_tensor.type_as(img)\n\n\ndef augment_img_np3(img, mode=0):\n    if mode == 0:\n        return img\n    elif mode == 1:\n        return img.transpose(1, 0, 2)\n    elif mode == 2:\n        return img[::-1, :, :]\n    elif mode == 3:\n        img = img[::-1, :, :]\n        img = img.transpose(1, 0, 2)\n        return img\n    elif mode == 4:\n        return img[:, ::-1, :]\n    elif mode == 5:\n        img = img[:, ::-1, :]\n        img = img.transpose(1, 0, 2)\n        return img\n    elif mode == 6:\n        img = img[:, ::-1, :]\n        img = img[::-1, :, :]\n        return img\n    elif mode == 7:\n        img = img[:, ::-1, :]\n        img = img[::-1, :, :]\n        img = img.transpose(1, 0, 2)\n        return img\n\n\ndef augment_imgs(img_list, hflip=True, rot=True):\n    # horizontal flip OR rotate\n    hflip = hflip and random.random() < 0.5\n    vflip = rot and random.random() < 0.5\n    rot90 = rot and random.random() < 0.5\n\n    def _augment(img):\n        if hflip:\n            img = img[:, ::-1, :]\n        if vflip:\n            img = img[::-1, :, :]\n        if rot90:\n            img = img.transpose(1, 0, 2)\n        return img\n\n    return [_augment(img) for img in img_list]\n\n\n'''\n# --------------------------------------------\n# modcrop and shave\n# --------------------------------------------\n'''\n\n\ndef modcrop(img_in, scale):\n    # img_in: Numpy, HWC or HW\n    img = np.copy(img_in)\n    if img.ndim == 2:\n        H, W = img.shape\n        H_r, W_r = H % scale, W % scale\n        img = img[:H - H_r, :W - W_r]\n    elif img.ndim == 3:\n        H, W, C = img.shape\n        H_r, W_r = H % scale, W % scale\n        img = img[:H - H_r, :W - W_r, :]\n    else:\n        raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))\n    return img\n\n\ndef shave(img_in, border=0):\n    # img_in: Numpy, HWC or HW\n    img = np.copy(img_in)\n    h, w = img.shape[:2]\n    img = img[border:h-border, border:w-border]\n    return img\n\n\n'''\n# --------------------------------------------\n# image processing process on numpy image\n# channel_convert(in_c, tar_type, img_list):\n# rgb2ycbcr(img, only_y=True):\n# bgr2ycbcr(img, only_y=True):\n# ycbcr2rgb(img):\n# --------------------------------------------\n'''\n\n\ndef rgb2ycbcr(img, only_y=True):\n    '''same as matlab rgb2ycbcr\n    only_y: only return Y channel\n    Input:\n        uint8, [0, 255]\n        float, [0, 1]\n    '''\n    in_img_type = img.dtype\n    img.astype(np.float32)\n    if in_img_type != np.uint8:\n        img *= 255.\n    # convert\n    if only_y:\n        rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0\n    else:\n        rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],\n                              [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]\n    if in_img_type == np.uint8:\n        rlt = rlt.round()\n    else:\n        rlt /= 255.\n    return rlt.astype(in_img_type)\n\n\ndef ycbcr2rgb(img):\n    '''same as matlab ycbcr2rgb\n    Input:\n        uint8, [0, 255]\n        float, [0, 1]\n    '''\n    in_img_type = img.dtype\n    img.astype(np.float32)\n    if in_img_type != np.uint8:\n        img *= 255.\n    # convert\n    rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],\n                          [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]\n    if in_img_type == np.uint8:\n        rlt = rlt.round()\n    else:\n        rlt /= 255.\n    return rlt.astype(in_img_type)\n\n\ndef bgr2ycbcr(img, only_y=True):\n    '''bgr version of rgb2ycbcr\n    only_y: only return Y channel\n    Input:\n        uint8, [0, 255]\n        float, [0, 1]\n    '''\n    in_img_type = img.dtype\n    img.astype(np.float32)\n    if in_img_type != np.uint8:\n        img *= 255.\n    # convert\n    if only_y:\n        rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0\n    else:\n        rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],\n                              [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]\n    if in_img_type == np.uint8:\n        rlt = rlt.round()\n    else:\n        rlt /= 255.\n    return rlt.astype(in_img_type)\n\n\ndef channel_convert(in_c, tar_type, img_list):\n    # conversion among BGR, gray and y\n    if in_c == 3 and tar_type == 'gray':  # BGR to gray\n        gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]\n        return [np.expand_dims(img, axis=2) for img in gray_list]\n    elif in_c == 3 and tar_type == 'y':  # BGR to y\n        y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]\n        return [np.expand_dims(img, axis=2) for img in y_list]\n    elif in_c == 1 and tar_type == 'RGB':  # gray/y to BGR\n        return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]\n    else:\n        return img_list\n\n\n'''\n# --------------------------------------------\n# metric, PSNR and SSIM\n# --------------------------------------------\n'''\n\n\n# --------------------------------------------\n# PSNR\n# --------------------------------------------\ndef calculate_psnr(img1, img2, border=0):\n    # img1 and img2 have range [0, 255]\n    #img1 = img1.squeeze()\n    #img2 = img2.squeeze()\n    if not img1.shape == img2.shape:\n        raise ValueError('Input images must have the same dimensions.')\n    h, w = img1.shape[:2]\n    img1 = img1[border:h-border, border:w-border]\n    img2 = img2[border:h-border, border:w-border]\n\n    img1 = img1.astype(np.float64)\n    img2 = img2.astype(np.float64)\n    mse = np.mean((img1 - img2)**2)\n    if mse == 0:\n        return float('inf')\n    return 20 * math.log10(255.0 / math.sqrt(mse))\n\n\n# --------------------------------------------\n# SSIM\n# --------------------------------------------\ndef calculate_ssim(img1, img2, border=0):\n    '''calculate SSIM\n    the same outputs as MATLAB's\n    img1, img2: [0, 255]\n    '''\n    #img1 = img1.squeeze()\n    #img2 = img2.squeeze()\n    if not img1.shape == img2.shape:\n        raise ValueError('Input images must have the same dimensions.')\n    h, w = img1.shape[:2]\n    img1 = img1[border:h-border, border:w-border]\n    img2 = img2[border:h-border, border:w-border]\n\n    if img1.ndim == 2:\n        return ssim(img1, img2)\n    elif img1.ndim == 3:\n        if img1.shape[2] == 3:\n            ssims = []\n            for i in range(3):\n                ssims.append(ssim(img1[:,:,i], img2[:,:,i]))\n            return np.array(ssims).mean()\n        elif img1.shape[2] == 1:\n            return ssim(np.squeeze(img1), np.squeeze(img2))\n    else:\n        raise ValueError('Wrong input image dimensions.')\n\n\ndef ssim(img1, img2):\n    C1 = (0.01 * 255)**2\n    C2 = (0.03 * 255)**2\n\n    img1 = img1.astype(np.float64)\n    img2 = img2.astype(np.float64)\n    kernel = cv2.getGaussianKernel(11, 1.5)\n    window = np.outer(kernel, kernel.transpose())\n\n    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid\n    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]\n    mu1_sq = mu1**2\n    mu2_sq = mu2**2\n    mu1_mu2 = mu1 * mu2\n    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq\n    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq\n    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2\n\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *\n                                                            (sigma1_sq + sigma2_sq + C2))\n    return ssim_map.mean()\n\n\n'''\n# --------------------------------------------\n# matlab's bicubic imresize (numpy and torch) [0, 1]\n# --------------------------------------------\n'''\n\n\n# matlab 'imresize' function, now only support 'bicubic'\ndef cubic(x):\n    absx = torch.abs(x)\n    absx2 = absx**2\n    absx3 = absx**3\n    return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \\\n        (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))\n\n\ndef calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):\n    if (scale < 1) and (antialiasing):\n        # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width\n        kernel_width = kernel_width / scale\n\n    # Output-space coordinates\n    x = torch.linspace(1, out_length, out_length)\n\n    # Input-space coordinates. Calculate the inverse mapping such that 0.5\n    # in output space maps to 0.5 in input space, and 0.5+scale in output\n    # space maps to 1.5 in input space.\n    u = x / scale + 0.5 * (1 - 1 / scale)\n\n    # What is the left-most pixel that can be involved in the computation?\n    left = torch.floor(u - kernel_width / 2)\n\n    # What is the maximum number of pixels that can be involved in the\n    # computation?  Note: it's OK to use an extra pixel here; if the\n    # corresponding weights are all zero, it will be eliminated at the end\n    # of this function.\n    P = math.ceil(kernel_width) + 2\n\n    # The indices of the input pixels involved in computing the k-th output\n    # pixel are in row k of the indices matrix.\n    indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(\n        1, P).expand(out_length, P)\n\n    # The weights used to compute the k-th output pixel are in row k of the\n    # weights matrix.\n    distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices\n    # apply cubic kernel\n    if (scale < 1) and (antialiasing):\n        weights = scale * cubic(distance_to_center * scale)\n    else:\n        weights = cubic(distance_to_center)\n    # Normalize the weights matrix so that each row sums to 1.\n    weights_sum = torch.sum(weights, 1).view(out_length, 1)\n    weights = weights / weights_sum.expand(out_length, P)\n\n    # If a column in weights is all zero, get rid of it. only consider the first and last column.\n    weights_zero_tmp = torch.sum((weights == 0), 0)\n    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):\n        indices = indices.narrow(1, 1, P - 2)\n        weights = weights.narrow(1, 1, P - 2)\n    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):\n        indices = indices.narrow(1, 0, P - 2)\n        weights = weights.narrow(1, 0, P - 2)\n    weights = weights.contiguous()\n    indices = indices.contiguous()\n    sym_len_s = -indices.min() + 1\n    sym_len_e = indices.max() - in_length\n    indices = indices + sym_len_s - 1\n    return weights, indices, int(sym_len_s), int(sym_len_e)\n\n\n# --------------------------------------------\n# imresize for tensor image [0, 1]\n# --------------------------------------------\ndef imresize(img, scale, antialiasing=True):\n    # Now the scale should be the same for H and W\n    # input: img: pytorch tensor, CHW or HW [0,1]\n    # output: CHW or HW [0,1] w/o round\n    need_squeeze = True if img.dim() == 2 else False\n    if need_squeeze:\n        img.unsqueeze_(0)\n    in_C, in_H, in_W = img.size()\n    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)\n    kernel_width = 4\n    kernel = 'cubic'\n\n    # Return the desired dimension order for performing the resize.  The\n    # strategy is to perform the resize first along the dimension with the\n    # smallest scale factor.\n    # Now we do not support this.\n\n    # get weights and indices\n    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(\n        in_H, out_H, scale, kernel, kernel_width, antialiasing)\n    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(\n        in_W, out_W, scale, kernel, kernel_width, antialiasing)\n    # process H dimension\n    # symmetric copying\n    img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)\n    img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)\n\n    sym_patch = img[:, :sym_len_Hs, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)\n\n    sym_patch = img[:, -sym_len_He:, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)\n\n    out_1 = torch.FloatTensor(in_C, out_H, in_W)\n    kernel_width = weights_H.size(1)\n    for i in range(out_H):\n        idx = int(indices_H[i][0])\n        for j in range(out_C):\n            out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])\n\n    # process W dimension\n    # symmetric copying\n    out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)\n    out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)\n\n    sym_patch = out_1[:, :, :sym_len_Ws]\n    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(2, inv_idx)\n    out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)\n\n    sym_patch = out_1[:, :, -sym_len_We:]\n    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(2, inv_idx)\n    out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)\n\n    out_2 = torch.FloatTensor(in_C, out_H, out_W)\n    kernel_width = weights_W.size(1)\n    for i in range(out_W):\n        idx = int(indices_W[i][0])\n        for j in range(out_C):\n            out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])\n    if need_squeeze:\n        out_2.squeeze_()\n    return out_2\n\n\n# --------------------------------------------\n# imresize for numpy image [0, 1]\n# --------------------------------------------\ndef imresize_np(img, scale, antialiasing=True):\n    # Now the scale should be the same for H and W\n    # input: img: Numpy, HWC or HW [0,1]\n    # output: HWC or HW [0,1] w/o round\n    img = torch.from_numpy(img)\n    need_squeeze = True if img.dim() == 2 else False\n    if need_squeeze:\n        img.unsqueeze_(2)\n\n    in_H, in_W, in_C = img.size()\n    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)\n    kernel_width = 4\n    kernel = 'cubic'\n\n    # Return the desired dimension order for performing the resize.  The\n    # strategy is to perform the resize first along the dimension with the\n    # smallest scale factor.\n    # Now we do not support this.\n\n    # get weights and indices\n    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(\n        in_H, out_H, scale, kernel, kernel_width, antialiasing)\n    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(\n        in_W, out_W, scale, kernel, kernel_width, antialiasing)\n    # process H dimension\n    # symmetric copying\n    img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)\n    img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)\n\n    sym_patch = img[:sym_len_Hs, :, :]\n    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(0, inv_idx)\n    img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)\n\n    sym_patch = img[-sym_len_He:, :, :]\n    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(0, inv_idx)\n    img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)\n\n    out_1 = torch.FloatTensor(out_H, in_W, in_C)\n    kernel_width = weights_H.size(1)\n    for i in range(out_H):\n        idx = int(indices_H[i][0])\n        for j in range(out_C):\n            out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])\n\n    # process W dimension\n    # symmetric copying\n    out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)\n    out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)\n\n    sym_patch = out_1[:, :sym_len_Ws, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)\n\n    sym_patch = out_1[:, -sym_len_We:, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)\n\n    out_2 = torch.FloatTensor(out_H, out_W, in_C)\n    kernel_width = weights_W.size(1)\n    for i in range(out_W):\n        idx = int(indices_W[i][0])\n        for j in range(out_C):\n            out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])\n    if need_squeeze:\n        out_2.squeeze_()\n\n    return out_2.numpy()\n\n\nif __name__ == '__main__':\n    print('---')\n#    img = imread_uint('test.bmp', 3)\n#    img = uint2single(img)\n#    img_bicubic = imresize_np(img, 1/4)"
  },
  {
    "path": "ldm/modules/losses/__init__.py",
    "content": "from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator"
  },
  {
    "path": "ldm/modules/losses/contperceptual.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom taming.modules.losses.vqperceptual import *  # TODO: taming dependency yes/no?\n\n\nclass LPIPSWithDiscriminator(nn.Module):\n    def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,\n                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,\n                 perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,\n                 disc_loss=\"hinge\"):\n\n        super().__init__()\n        assert disc_loss in [\"hinge\", \"vanilla\"]\n        self.kl_weight = kl_weight\n        self.pixel_weight = pixelloss_weight\n        self.perceptual_loss = LPIPS().eval()\n        self.perceptual_weight = perceptual_weight\n        # output log variance\n        self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)\n\n        self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,\n                                                 n_layers=disc_num_layers,\n                                                 use_actnorm=use_actnorm\n                                                 ).apply(weights_init)\n        self.discriminator_iter_start = disc_start\n        self.disc_loss = hinge_d_loss if disc_loss == \"hinge\" else vanilla_d_loss\n        self.disc_factor = disc_factor\n        self.discriminator_weight = disc_weight\n        self.disc_conditional = disc_conditional\n\n    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):\n        if last_layer is not None:\n            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]\n            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]\n        else:\n            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]\n            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]\n\n        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()\n        d_weight = d_weight * self.discriminator_weight\n        return d_weight\n\n    def forward(self, inputs, reconstructions, posteriors, optimizer_idx,\n                global_step, last_layer=None, cond=None, split=\"train\",\n                weights=None):\n        rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())\n        if self.perceptual_weight > 0:\n            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())\n            rec_loss = rec_loss + self.perceptual_weight * p_loss\n\n        nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar\n        weighted_nll_loss = nll_loss\n        if weights is not None:\n            weighted_nll_loss = weights*nll_loss\n        weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]\n        nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]\n        kl_loss = posteriors.kl()\n        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]\n\n        # now the GAN part\n        if optimizer_idx == 0:\n            # generator update\n            if cond is None:\n                assert not self.disc_conditional\n                logits_fake = self.discriminator(reconstructions.contiguous())\n            else:\n                assert self.disc_conditional\n                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))\n            g_loss = -torch.mean(logits_fake)\n\n            if self.disc_factor > 0.0:\n                try:\n                    d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)\n                except RuntimeError:\n                    assert not self.training\n                    d_weight = torch.tensor(0.0)\n            else:\n                d_weight = torch.tensor(0.0)\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss\n\n            log = {\"{}/total_loss\".format(split): loss.clone().detach().mean(), \"{}/logvar\".format(split): self.logvar.detach(),\n                   \"{}/kl_loss\".format(split): kl_loss.detach().mean(), \"{}/nll_loss\".format(split): nll_loss.detach().mean(),\n                   \"{}/rec_loss\".format(split): rec_loss.detach().mean(),\n                   \"{}/d_weight\".format(split): d_weight.detach(),\n                   \"{}/disc_factor\".format(split): torch.tensor(disc_factor),\n                   \"{}/g_loss\".format(split): g_loss.detach().mean(),\n                   }\n            return loss, log\n\n        if optimizer_idx == 1:\n            # second pass for discriminator update\n            if cond is None:\n                logits_real = self.discriminator(inputs.contiguous().detach())\n                logits_fake = self.discriminator(reconstructions.contiguous().detach())\n            else:\n                logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))\n                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)\n\n            log = {\"{}/disc_loss\".format(split): d_loss.clone().detach().mean(),\n                   \"{}/logits_real\".format(split): logits_real.detach().mean(),\n                   \"{}/logits_fake\".format(split): logits_fake.detach().mean()\n                   }\n            return d_loss, log\n\n"
  },
  {
    "path": "ldm/modules/losses/vqperceptual.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import repeat\n\nfrom taming.modules.discriminator.model import NLayerDiscriminator, weights_init\nfrom taming.modules.losses.lpips import LPIPS\nfrom taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss\n\n\ndef hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):\n    assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]\n    loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])\n    loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])\n    loss_real = (weights * loss_real).sum() / weights.sum()\n    loss_fake = (weights * loss_fake).sum() / weights.sum()\n    d_loss = 0.5 * (loss_real + loss_fake)\n    return d_loss\n\ndef adopt_weight(weight, global_step, threshold=0, value=0.):\n    if global_step < threshold:\n        weight = value\n    return weight\n\n\ndef measure_perplexity(predicted_indices, n_embed):\n    # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py\n    # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally\n    encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)\n    avg_probs = encodings.mean(0)\n    perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()\n    cluster_use = torch.sum(avg_probs > 0)\n    return perplexity, cluster_use\n\ndef l1(x, y):\n    return torch.abs(x-y)\n\n\ndef l2(x, y):\n    return torch.pow((x-y), 2)\n\n\nclass VQLPIPSWithDiscriminator(nn.Module):\n    def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,\n                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,\n                 perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,\n                 disc_ndf=64, disc_loss=\"hinge\", n_classes=None, perceptual_loss=\"lpips\",\n                 pixel_loss=\"l1\"):\n        super().__init__()\n        assert disc_loss in [\"hinge\", \"vanilla\"]\n        assert perceptual_loss in [\"lpips\", \"clips\", \"dists\"]\n        assert pixel_loss in [\"l1\", \"l2\"]\n        self.codebook_weight = codebook_weight\n        self.pixel_weight = pixelloss_weight\n        if perceptual_loss == \"lpips\":\n            print(f\"{self.__class__.__name__}: Running with LPIPS.\")\n            self.perceptual_loss = LPIPS().eval()\n        else:\n            raise ValueError(f\"Unknown perceptual loss: >> {perceptual_loss} <<\")\n        self.perceptual_weight = perceptual_weight\n\n        if pixel_loss == \"l1\":\n            self.pixel_loss = l1\n        else:\n            self.pixel_loss = l2\n\n        self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,\n                                                 n_layers=disc_num_layers,\n                                                 use_actnorm=use_actnorm,\n                                                 ndf=disc_ndf\n                                                 ).apply(weights_init)\n        self.discriminator_iter_start = disc_start\n        if disc_loss == \"hinge\":\n            self.disc_loss = hinge_d_loss\n        elif disc_loss == \"vanilla\":\n            self.disc_loss = vanilla_d_loss\n        else:\n            raise ValueError(f\"Unknown GAN loss '{disc_loss}'.\")\n        print(f\"VQLPIPSWithDiscriminator running with {disc_loss} loss.\")\n        self.disc_factor = disc_factor\n        self.discriminator_weight = disc_weight\n        self.disc_conditional = disc_conditional\n        self.n_classes = n_classes\n\n    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):\n        if last_layer is not None:\n            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]\n            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]\n        else:\n            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]\n            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]\n\n        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()\n        d_weight = d_weight * self.discriminator_weight\n        return d_weight\n\n    def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,\n                global_step, last_layer=None, cond=None, split=\"train\", predicted_indices=None):\n        if not exists(codebook_loss):\n            codebook_loss = torch.tensor([0.]).to(inputs.device)\n        #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())\n        rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())\n        if self.perceptual_weight > 0:\n            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())\n            rec_loss = rec_loss + self.perceptual_weight * p_loss\n        else:\n            p_loss = torch.tensor([0.0])\n\n        nll_loss = rec_loss\n        #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]\n        nll_loss = torch.mean(nll_loss)\n\n        # now the GAN part\n        if optimizer_idx == 0:\n            # generator update\n            if cond is None:\n                assert not self.disc_conditional\n                logits_fake = self.discriminator(reconstructions.contiguous())\n            else:\n                assert self.disc_conditional\n                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))\n            g_loss = -torch.mean(logits_fake)\n\n            try:\n                d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)\n            except RuntimeError:\n                assert not self.training\n                d_weight = torch.tensor(0.0)\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()\n\n            log = {\"{}/total_loss\".format(split): loss.clone().detach().mean(),\n                   \"{}/quant_loss\".format(split): codebook_loss.detach().mean(),\n                   \"{}/nll_loss\".format(split): nll_loss.detach().mean(),\n                   \"{}/rec_loss\".format(split): rec_loss.detach().mean(),\n                   \"{}/p_loss\".format(split): p_loss.detach().mean(),\n                   \"{}/d_weight\".format(split): d_weight.detach(),\n                   \"{}/disc_factor\".format(split): torch.tensor(disc_factor),\n                   \"{}/g_loss\".format(split): g_loss.detach().mean(),\n                   }\n            if predicted_indices is not None:\n                assert self.n_classes is not None\n                with torch.no_grad():\n                    perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)\n                log[f\"{split}/perplexity\"] = perplexity\n                log[f\"{split}/cluster_usage\"] = cluster_usage\n            return loss, log\n\n        if optimizer_idx == 1:\n            # second pass for discriminator update\n            if cond is None:\n                logits_real = self.discriminator(inputs.contiguous().detach())\n                logits_fake = self.discriminator(reconstructions.contiguous().detach())\n            else:\n                logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))\n                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)\n\n            log = {\"{}/disc_loss\".format(split): d_loss.clone().detach().mean(),\n                   \"{}/logits_real\".format(split): logits_real.detach().mean(),\n                   \"{}/logits_fake\".format(split): logits_fake.detach().mean()\n                   }\n            return d_loss, log\n"
  },
  {
    "path": "ldm/modules/x_transformer.py",
    "content": "\"\"\"shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers\"\"\"\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom functools import partial\nfrom inspect import isfunction\nfrom collections import namedtuple\nfrom einops import rearrange, repeat, reduce\n\n# constants\n\nDEFAULT_DIM_HEAD = 64\n\nIntermediates = namedtuple('Intermediates', [\n    'pre_softmax_attn',\n    'post_softmax_attn'\n])\n\nLayerIntermediates = namedtuple('Intermediates', [\n    'hiddens',\n    'attn_intermediates'\n])\n\n\nclass AbsolutePositionalEmbedding(nn.Module):\n    def __init__(self, dim, max_seq_len):\n        super().__init__()\n        self.emb = nn.Embedding(max_seq_len, dim)\n        self.init_()\n\n    def init_(self):\n        nn.init.normal_(self.emb.weight, std=0.02)\n\n    def forward(self, x):\n        n = torch.arange(x.shape[1], device=x.device)\n        return self.emb(n)[None, :, :]\n\n\nclass FixedPositionalEmbedding(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))\n        self.register_buffer('inv_freq', inv_freq)\n\n    def forward(self, x, seq_dim=1, offset=0):\n        t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset\n        sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)\n        emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)\n        return emb[None, :, :]\n\n\n# helpers\n\ndef exists(val):\n    return val is not None\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef always(val):\n    def inner(*args, **kwargs):\n        return val\n    return inner\n\n\ndef not_equals(val):\n    def inner(x):\n        return x != val\n    return inner\n\n\ndef equals(val):\n    def inner(x):\n        return x == val\n    return inner\n\n\ndef max_neg_value(tensor):\n    return -torch.finfo(tensor.dtype).max\n\n\n# keyword argument helpers\n\ndef pick_and_pop(keys, d):\n    values = list(map(lambda key: d.pop(key), keys))\n    return dict(zip(keys, values))\n\n\ndef group_dict_by_key(cond, d):\n    return_val = [dict(), dict()]\n    for key in d.keys():\n        match = bool(cond(key))\n        ind = int(not match)\n        return_val[ind][key] = d[key]\n    return (*return_val,)\n\n\ndef string_begins_with(prefix, str):\n    return str.startswith(prefix)\n\n\ndef group_by_key_prefix(prefix, d):\n    return group_dict_by_key(partial(string_begins_with, prefix), d)\n\n\ndef groupby_prefix_and_trim(prefix, d):\n    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)\n    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))\n    return kwargs_without_prefix, kwargs\n\n\n# classes\nclass Scale(nn.Module):\n    def __init__(self, value, fn):\n        super().__init__()\n        self.value = value\n        self.fn = fn\n\n    def forward(self, x, **kwargs):\n        x, *rest = self.fn(x, **kwargs)\n        return (x * self.value, *rest)\n\n\nclass Rezero(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n        self.g = nn.Parameter(torch.zeros(1))\n\n    def forward(self, x, **kwargs):\n        x, *rest = self.fn(x, **kwargs)\n        return (x * self.g, *rest)\n\n\nclass ScaleNorm(nn.Module):\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.scale = dim ** -0.5\n        self.eps = eps\n        self.g = nn.Parameter(torch.ones(1))\n\n    def forward(self, x):\n        norm = torch.norm(x, dim=-1, keepdim=True) * self.scale\n        return x / norm.clamp(min=self.eps) * self.g\n\n\nclass RMSNorm(nn.Module):\n    def __init__(self, dim, eps=1e-8):\n        super().__init__()\n        self.scale = dim ** -0.5\n        self.eps = eps\n        self.g = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        norm = torch.norm(x, dim=-1, keepdim=True) * self.scale\n        return x / norm.clamp(min=self.eps) * self.g\n\n\nclass Residual(nn.Module):\n    def forward(self, x, residual):\n        return x + residual\n\n\nclass GRUGating(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.gru = nn.GRUCell(dim, dim)\n\n    def forward(self, x, residual):\n        gated_output = self.gru(\n            rearrange(x, 'b n d -> (b n) d'),\n            rearrange(residual, 'b n d -> (b n) d')\n        )\n\n        return gated_output.reshape_as(x)\n\n\n# feedforward\n\nclass GEGLU(nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def forward(self, x):\n        x, gate = self.proj(x).chunk(2, dim=-1)\n        return x * F.gelu(gate)\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = default(dim_out, dim)\n        project_in = nn.Sequential(\n            nn.Linear(dim, inner_dim),\n            nn.GELU()\n        ) if not glu else GEGLU(dim, inner_dim)\n\n        self.net = nn.Sequential(\n            project_in,\n            nn.Dropout(dropout),\n            nn.Linear(inner_dim, dim_out)\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\n# attention.\nclass Attention(nn.Module):\n    def __init__(\n            self,\n            dim,\n            dim_head=DEFAULT_DIM_HEAD,\n            heads=8,\n            causal=False,\n            mask=None,\n            talking_heads=False,\n            sparse_topk=None,\n            use_entmax15=False,\n            num_mem_kv=0,\n            dropout=0.,\n            on_attn=False\n    ):\n        super().__init__()\n        if use_entmax15:\n            raise NotImplementedError(\"Check out entmax activation instead of softmax activation!\")\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        self.causal = causal\n        self.mask = mask\n\n        inner_dim = dim_head * heads\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_k = nn.Linear(dim, inner_dim, bias=False)\n        self.to_v = nn.Linear(dim, inner_dim, bias=False)\n        self.dropout = nn.Dropout(dropout)\n\n        # talking heads\n        self.talking_heads = talking_heads\n        if talking_heads:\n            self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))\n            self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))\n\n        # explicit topk sparse attention\n        self.sparse_topk = sparse_topk\n\n        # entmax\n        #self.attn_fn = entmax15 if use_entmax15 else F.softmax\n        self.attn_fn = F.softmax\n\n        # add memory key / values\n        self.num_mem_kv = num_mem_kv\n        if num_mem_kv > 0:\n            self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))\n            self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))\n\n        # attention on attention\n        self.attn_on_attn = on_attn\n        self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)\n\n    def forward(\n            self,\n            x,\n            context=None,\n            mask=None,\n            context_mask=None,\n            rel_pos=None,\n            sinusoidal_emb=None,\n            prev_attn=None,\n            mem=None\n    ):\n        b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device\n        kv_input = default(context, x)\n\n        q_input = x\n        k_input = kv_input\n        v_input = kv_input\n\n        if exists(mem):\n            k_input = torch.cat((mem, k_input), dim=-2)\n            v_input = torch.cat((mem, v_input), dim=-2)\n\n        if exists(sinusoidal_emb):\n            # in shortformer, the query would start at a position offset depending on the past cached memory\n            offset = k_input.shape[-2] - q_input.shape[-2]\n            q_input = q_input + sinusoidal_emb(q_input, offset=offset)\n            k_input = k_input + sinusoidal_emb(k_input)\n\n        q = self.to_q(q_input)\n        k = self.to_k(k_input)\n        v = self.to_v(v_input)\n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))\n\n        input_mask = None\n        if any(map(exists, (mask, context_mask))):\n            q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())\n            k_mask = q_mask if not exists(context) else context_mask\n            k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())\n            q_mask = rearrange(q_mask, 'b i -> b () i ()')\n            k_mask = rearrange(k_mask, 'b j -> b () () j')\n            input_mask = q_mask * k_mask\n\n        if self.num_mem_kv > 0:\n            mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))\n            k = torch.cat((mem_k, k), dim=-2)\n            v = torch.cat((mem_v, v), dim=-2)\n            if exists(input_mask):\n                input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)\n\n        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale\n        mask_value = max_neg_value(dots)\n\n        if exists(prev_attn):\n            dots = dots + prev_attn\n\n        pre_softmax_attn = dots\n\n        if talking_heads:\n            dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()\n\n        if exists(rel_pos):\n            dots = rel_pos(dots)\n\n        if exists(input_mask):\n            dots.masked_fill_(~input_mask, mask_value)\n            del input_mask\n\n        if self.causal:\n            i, j = dots.shape[-2:]\n            r = torch.arange(i, device=device)\n            mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')\n            mask = F.pad(mask, (j - i, 0), value=False)\n            dots.masked_fill_(mask, mask_value)\n            del mask\n\n        if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:\n            top, _ = dots.topk(self.sparse_topk, dim=-1)\n            vk = top[..., -1].unsqueeze(-1).expand_as(dots)\n            mask = dots < vk\n            dots.masked_fill_(mask, mask_value)\n            del mask\n\n        attn = self.attn_fn(dots, dim=-1)\n        post_softmax_attn = attn\n\n        attn = self.dropout(attn)\n\n        if talking_heads:\n            attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()\n\n        out = einsum('b h i j, b h j d -> b h i d', attn, v)\n        out = rearrange(out, 'b h n d -> b n (h d)')\n\n        intermediates = Intermediates(\n            pre_softmax_attn=pre_softmax_attn,\n            post_softmax_attn=post_softmax_attn\n        )\n\n        return self.to_out(out), intermediates\n\n\nclass AttentionLayers(nn.Module):\n    def __init__(\n            self,\n            dim,\n            depth,\n            heads=8,\n            causal=False,\n            cross_attend=False,\n            only_cross=False,\n            use_scalenorm=False,\n            use_rmsnorm=False,\n            use_rezero=False,\n            rel_pos_num_buckets=32,\n            rel_pos_max_distance=128,\n            position_infused_attn=False,\n            custom_layers=None,\n            sandwich_coef=None,\n            par_ratio=None,\n            residual_attn=False,\n            cross_residual_attn=False,\n            macaron=False,\n            pre_norm=True,\n            gate_residual=False,\n            **kwargs\n    ):\n        super().__init__()\n        ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)\n        attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)\n\n        dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)\n\n        self.dim = dim\n        self.depth = depth\n        self.layers = nn.ModuleList([])\n\n        self.has_pos_emb = position_infused_attn\n        self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None\n        self.rotary_pos_emb = always(None)\n\n        assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'\n        self.rel_pos = None\n\n        self.pre_norm = pre_norm\n\n        self.residual_attn = residual_attn\n        self.cross_residual_attn = cross_residual_attn\n\n        norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm\n        norm_class = RMSNorm if use_rmsnorm else norm_class\n        norm_fn = partial(norm_class, dim)\n\n        norm_fn = nn.Identity if use_rezero else norm_fn\n        branch_fn = Rezero if use_rezero else None\n\n        if cross_attend and not only_cross:\n            default_block = ('a', 'c', 'f')\n        elif cross_attend and only_cross:\n            default_block = ('c', 'f')\n        else:\n            default_block = ('a', 'f')\n\n        if macaron:\n            default_block = ('f',) + default_block\n\n        if exists(custom_layers):\n            layer_types = custom_layers\n        elif exists(par_ratio):\n            par_depth = depth * len(default_block)\n            assert 1 < par_ratio <= par_depth, 'par ratio out of range'\n            default_block = tuple(filter(not_equals('f'), default_block))\n            par_attn = par_depth // par_ratio\n            depth_cut = par_depth * 2 // 3  # 2 / 3 attention layer cutoff suggested by PAR paper\n            par_width = (depth_cut + depth_cut // par_attn) // par_attn\n            assert len(default_block) <= par_width, 'default block is too large for par_ratio'\n            par_block = default_block + ('f',) * (par_width - len(default_block))\n            par_head = par_block * par_attn\n            layer_types = par_head + ('f',) * (par_depth - len(par_head))\n        elif exists(sandwich_coef):\n            assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'\n            layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef\n        else:\n            layer_types = default_block * depth\n\n        self.layer_types = layer_types\n        self.num_attn_layers = len(list(filter(equals('a'), layer_types)))\n\n        for layer_type in self.layer_types:\n            if layer_type == 'a':\n                layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)\n            elif layer_type == 'c':\n                layer = Attention(dim, heads=heads, **attn_kwargs)\n            elif layer_type == 'f':\n                layer = FeedForward(dim, **ff_kwargs)\n                layer = layer if not macaron else Scale(0.5, layer)\n            else:\n                raise Exception(f'invalid layer type {layer_type}')\n\n            if isinstance(layer, Attention) and exists(branch_fn):\n                layer = branch_fn(layer)\n\n            if gate_residual:\n                residual_fn = GRUGating(dim)\n            else:\n                residual_fn = Residual()\n\n            self.layers.append(nn.ModuleList([\n                norm_fn(),\n                layer,\n                residual_fn\n            ]))\n\n    def forward(\n            self,\n            x,\n            context=None,\n            mask=None,\n            context_mask=None,\n            mems=None,\n            return_hiddens=False\n    ):\n        hiddens = []\n        intermediates = []\n        prev_attn = None\n        prev_cross_attn = None\n\n        mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers\n\n        for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):\n            is_last = ind == (len(self.layers) - 1)\n\n            if layer_type == 'a':\n                hiddens.append(x)\n                layer_mem = mems.pop(0)\n\n            residual = x\n\n            if self.pre_norm:\n                x = norm(x)\n\n            if layer_type == 'a':\n                out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,\n                                   prev_attn=prev_attn, mem=layer_mem)\n            elif layer_type == 'c':\n                out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)\n            elif layer_type == 'f':\n                out = block(x)\n\n            x = residual_fn(out, residual)\n\n            if layer_type in ('a', 'c'):\n                intermediates.append(inter)\n\n            if layer_type == 'a' and self.residual_attn:\n                prev_attn = inter.pre_softmax_attn\n            elif layer_type == 'c' and self.cross_residual_attn:\n                prev_cross_attn = inter.pre_softmax_attn\n\n            if not self.pre_norm and not is_last:\n                x = norm(x)\n\n        if return_hiddens:\n            intermediates = LayerIntermediates(\n                hiddens=hiddens,\n                attn_intermediates=intermediates\n            )\n\n            return x, intermediates\n\n        return x\n\n\nclass Encoder(AttentionLayers):\n    def __init__(self, **kwargs):\n        assert 'causal' not in kwargs, 'cannot set causality on encoder'\n        super().__init__(causal=False, **kwargs)\n\n\n\nclass TransformerWrapper(nn.Module):\n    def __init__(\n            self,\n            *,\n            num_tokens,\n            max_seq_len,\n            attn_layers,\n            emb_dim=None,\n            max_mem_len=0.,\n            emb_dropout=0.,\n            num_memory_tokens=None,\n            tie_embedding=False,\n            use_pos_emb=True\n    ):\n        super().__init__()\n        assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'\n\n        dim = attn_layers.dim\n        emb_dim = default(emb_dim, dim)\n\n        self.max_seq_len = max_seq_len\n        self.max_mem_len = max_mem_len\n        self.num_tokens = num_tokens\n\n        self.token_emb = nn.Embedding(num_tokens, emb_dim)\n        self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (\n                    use_pos_emb and not attn_layers.has_pos_emb) else always(0)\n        self.emb_dropout = nn.Dropout(emb_dropout)\n\n        self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()\n        self.attn_layers = attn_layers\n        self.norm = nn.LayerNorm(dim)\n\n        self.init_()\n\n        self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()\n\n        # memory tokens (like [cls]) from Memory Transformers paper\n        num_memory_tokens = default(num_memory_tokens, 0)\n        self.num_memory_tokens = num_memory_tokens\n        if num_memory_tokens > 0:\n            self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))\n\n            # let funnel encoder know number of memory tokens, if specified\n            if hasattr(attn_layers, 'num_memory_tokens'):\n                attn_layers.num_memory_tokens = num_memory_tokens\n\n    def init_(self):\n        nn.init.normal_(self.token_emb.weight, std=0.02)\n\n    def forward(\n            self,\n            x,\n            return_embeddings=False,\n            mask=None,\n            return_mems=False,\n            return_attn=False,\n            mems=None,\n            **kwargs\n    ):\n        b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens\n        x = self.token_emb(x)\n        x += self.pos_emb(x)\n        x = self.emb_dropout(x)\n\n        x = self.project_emb(x)\n\n        if num_mem > 0:\n            mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)\n            x = torch.cat((mem, x), dim=1)\n\n            # auto-handle masking after appending memory tokens\n            if exists(mask):\n                mask = F.pad(mask, (num_mem, 0), value=True)\n\n        x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)\n        x = self.norm(x)\n\n        mem, x = x[:, :num_mem], x[:, num_mem:]\n\n        out = self.to_logits(x) if not return_embeddings else x\n\n        if return_mems:\n            hiddens = intermediates.hiddens\n            new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens\n            new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))\n            return out, new_mems\n\n        if return_attn:\n            attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))\n            return out, attn_maps\n\n        return out\n\n"
  },
  {
    "path": "ldm/thirdp/psp/helpers.py",
    "content": "# https://github.com/eladrich/pixel2style2pixel\n\nfrom collections import namedtuple\nimport torch\nfrom torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module\n\n\"\"\"\nArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)\n\"\"\"\n\n\nclass Flatten(Module):\n\tdef forward(self, input):\n\t\treturn input.view(input.size(0), -1)\n\n\ndef l2_norm(input, axis=1):\n\tnorm = torch.norm(input, 2, axis, True)\n\toutput = torch.div(input, norm)\n\treturn output\n\n\nclass Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):\n\t\"\"\" A named tuple describing a ResNet block. \"\"\"\n\n\ndef get_block(in_channel, depth, num_units, stride=2):\n\treturn [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]\n\n\ndef get_blocks(num_layers):\n\tif num_layers == 50:\n\t\tblocks = [\n\t\t\tget_block(in_channel=64, depth=64, num_units=3),\n\t\t\tget_block(in_channel=64, depth=128, num_units=4),\n\t\t\tget_block(in_channel=128, depth=256, num_units=14),\n\t\t\tget_block(in_channel=256, depth=512, num_units=3)\n\t\t]\n\telif num_layers == 100:\n\t\tblocks = [\n\t\t\tget_block(in_channel=64, depth=64, num_units=3),\n\t\t\tget_block(in_channel=64, depth=128, num_units=13),\n\t\t\tget_block(in_channel=128, depth=256, num_units=30),\n\t\t\tget_block(in_channel=256, depth=512, num_units=3)\n\t\t]\n\telif num_layers == 152:\n\t\tblocks = [\n\t\t\tget_block(in_channel=64, depth=64, num_units=3),\n\t\t\tget_block(in_channel=64, depth=128, num_units=8),\n\t\t\tget_block(in_channel=128, depth=256, num_units=36),\n\t\t\tget_block(in_channel=256, depth=512, num_units=3)\n\t\t]\n\telse:\n\t\traise ValueError(\"Invalid number of layers: {}. Must be one of [50, 100, 152]\".format(num_layers))\n\treturn blocks\n\n\nclass SEModule(Module):\n\tdef __init__(self, channels, reduction):\n\t\tsuper(SEModule, self).__init__()\n\t\tself.avg_pool = AdaptiveAvgPool2d(1)\n\t\tself.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)\n\t\tself.relu = ReLU(inplace=True)\n\t\tself.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)\n\t\tself.sigmoid = Sigmoid()\n\n\tdef forward(self, x):\n\t\tmodule_input = x\n\t\tx = self.avg_pool(x)\n\t\tx = self.fc1(x)\n\t\tx = self.relu(x)\n\t\tx = self.fc2(x)\n\t\tx = self.sigmoid(x)\n\t\treturn module_input * x\n\n\nclass bottleneck_IR(Module):\n\tdef __init__(self, in_channel, depth, stride):\n\t\tsuper(bottleneck_IR, self).__init__()\n\t\tif in_channel == depth:\n\t\t\tself.shortcut_layer = MaxPool2d(1, stride)\n\t\telse:\n\t\t\tself.shortcut_layer = Sequential(\n\t\t\t\tConv2d(in_channel, depth, (1, 1), stride, bias=False),\n\t\t\t\tBatchNorm2d(depth)\n\t\t\t)\n\t\tself.res_layer = Sequential(\n\t\t\tBatchNorm2d(in_channel),\n\t\t\tConv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),\n\t\t\tConv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)\n\t\t)\n\n\tdef forward(self, x):\n\t\tshortcut = self.shortcut_layer(x)\n\t\tres = self.res_layer(x)\n\t\treturn res + shortcut\n\n\nclass bottleneck_IR_SE(Module):\n\tdef __init__(self, in_channel, depth, stride):\n\t\tsuper(bottleneck_IR_SE, self).__init__()\n\t\tif in_channel == depth:\n\t\t\tself.shortcut_layer = MaxPool2d(1, stride)\n\t\telse:\n\t\t\tself.shortcut_layer = Sequential(\n\t\t\t\tConv2d(in_channel, depth, (1, 1), stride, bias=False),\n\t\t\t\tBatchNorm2d(depth)\n\t\t\t)\n\t\tself.res_layer = Sequential(\n\t\t\tBatchNorm2d(in_channel),\n\t\t\tConv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),\n\t\t\tPReLU(depth),\n\t\t\tConv2d(depth, depth, (3, 3), stride, 1, bias=False),\n\t\t\tBatchNorm2d(depth),\n\t\t\tSEModule(depth, 16)\n\t\t)\n\n\tdef forward(self, x):\n\t\tshortcut = self.shortcut_layer(x)\n\t\tres = self.res_layer(x)\n\t\treturn res + shortcut"
  },
  {
    "path": "ldm/thirdp/psp/id_loss.py",
    "content": "# https://github.com/eladrich/pixel2style2pixel\nimport torch\nfrom torch import nn\nfrom ldm.thirdp.psp.model_irse import Backbone\n\n\nclass IDFeatures(nn.Module):\n    def __init__(self, model_path):\n        super(IDFeatures, self).__init__()\n        print('Loading ResNet ArcFace')\n        self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')\n        self.facenet.load_state_dict(torch.load(model_path, map_location=\"cpu\"))\n        self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))\n        self.facenet.eval()\n\n    def forward(self, x, crop=False):\n        # Not sure of the image range here\n        if crop:\n            x = torch.nn.functional.interpolate(x, (256, 256), mode=\"area\")\n            x = x[:, :, 35:223, 32:220]\n        x = self.face_pool(x)\n        x_feats = self.facenet(x)\n        return x_feats\n"
  },
  {
    "path": "ldm/thirdp/psp/model_irse.py",
    "content": "# https://github.com/eladrich/pixel2style2pixel\n\nfrom torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module\nfrom ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm\n\n\"\"\"\nModified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)\n\"\"\"\n\n\nclass Backbone(Module):\n\tdef __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):\n\t\tsuper(Backbone, self).__init__()\n\t\tassert input_size in [112, 224], \"input_size should be 112 or 224\"\n\t\tassert num_layers in [50, 100, 152], \"num_layers should be 50, 100 or 152\"\n\t\tassert mode in ['ir', 'ir_se'], \"mode should be ir or ir_se\"\n\t\tblocks = get_blocks(num_layers)\n\t\tif mode == 'ir':\n\t\t\tunit_module = bottleneck_IR\n\t\telif mode == 'ir_se':\n\t\t\tunit_module = bottleneck_IR_SE\n\t\tself.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),\n\t\t\t\t\t\t\t\t\t  BatchNorm2d(64),\n\t\t\t\t\t\t\t\t\t  PReLU(64))\n\t\tif input_size == 112:\n\t\t\tself.output_layer = Sequential(BatchNorm2d(512),\n\t\t\t                               Dropout(drop_ratio),\n\t\t\t                               Flatten(),\n\t\t\t                               Linear(512 * 7 * 7, 512),\n\t\t\t                               BatchNorm1d(512, affine=affine))\n\t\telse:\n\t\t\tself.output_layer = Sequential(BatchNorm2d(512),\n\t\t\t                               Dropout(drop_ratio),\n\t\t\t                               Flatten(),\n\t\t\t                               Linear(512 * 14 * 14, 512),\n\t\t\t                               BatchNorm1d(512, affine=affine))\n\n\t\tmodules = []\n\t\tfor block in blocks:\n\t\t\tfor bottleneck in block:\n\t\t\t\tmodules.append(unit_module(bottleneck.in_channel,\n\t\t\t\t\t\t\t\t\t\t   bottleneck.depth,\n\t\t\t\t\t\t\t\t\t\t   bottleneck.stride))\n\t\tself.body = Sequential(*modules)\n\n\tdef forward(self, x):\n\t\tx = self.input_layer(x)\n\t\tx = self.body(x)\n\t\tx = self.output_layer(x)\n\t\treturn l2_norm(x)\n\n\ndef IR_50(input_size):\n\t\"\"\"Constructs a ir-50 model.\"\"\"\n\tmodel = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)\n\treturn model\n\n\ndef IR_101(input_size):\n\t\"\"\"Constructs a ir-101 model.\"\"\"\n\tmodel = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)\n\treturn model\n\n\ndef IR_152(input_size):\n\t\"\"\"Constructs a ir-152 model.\"\"\"\n\tmodel = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)\n\treturn model\n\n\ndef IR_SE_50(input_size):\n\t\"\"\"Constructs a ir_se-50 model.\"\"\"\n\tmodel = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)\n\treturn model\n\n\ndef IR_SE_101(input_size):\n\t\"\"\"Constructs a ir_se-101 model.\"\"\"\n\tmodel = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)\n\treturn model\n\n\ndef IR_SE_152(input_size):\n\t\"\"\"Constructs a ir_se-152 model.\"\"\"\n\tmodel = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)\n\treturn model"
  },
  {
    "path": "ldm/util.py",
    "content": "import importlib\n\nimport torchvision\nimport torch\nfrom torch import optim\nimport numpy as np\n\nfrom inspect import isfunction\nfrom PIL import Image, ImageDraw, ImageFont\n\nimport os\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom PIL import Image\nimport torch\nimport time\nimport cv2\n\nimport PIL\n\ndef pil_rectangle_crop(im):\n    width, height = im.size   # Get dimensions\n    \n    if width <= height:\n        left = 0\n        right = width\n        top = (height - width)/2\n        bottom = (height + width)/2\n    else:\n        \n        top = 0\n        bottom = height\n        left = (width - height) / 2\n        bottom = (width + height) / 2\n\n    # Crop the center of the image\n    im = im.crop((left, top, right, bottom))\n    return im\n\n\ndef log_txt_as_img(wh, xc, size=10):\n    # wh a tuple of (width, height)\n    # xc a list of captions to plot\n    b = len(xc)\n    txts = list()\n    for bi in range(b):\n        txt = Image.new(\"RGB\", wh, color=\"white\")\n        draw = ImageDraw.Draw(txt)\n        font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)\n        nc = int(40 * (wh[0] / 256))\n        lines = \"\\n\".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))\n\n        try:\n            draw.text((0, 0), lines, fill=\"black\", font=font)\n        except UnicodeEncodeError:\n            print(\"Cant encode string for logging. Skipping.\")\n\n        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0\n        txts.append(txt)\n    txts = np.stack(txts)\n    txts = torch.tensor(txts)\n    return txts\n\n\ndef ismap(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] > 3)\n\n\ndef isimage(x):\n    if not isinstance(x,torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)\n\n\ndef exists(x):\n    return x is not None\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef count_params(model, verbose=False):\n    total_params = sum(p.numel() for p in model.parameters())\n    if verbose:\n        print(f\"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.\")\n    return total_params\n\n\ndef instantiate_from_config(config):\n    if not \"target\" in config:\n        if config == '__is_first_stage__':\n            return None\n        elif config == \"__is_unconditional__\":\n            return None\n        raise KeyError(\"Expected key `target` to instantiate.\")\n    return get_obj_from_str(config[\"target\"])(**config.get(\"params\", dict()))\n\n\ndef get_obj_from_str(string, reload=False):\n    module, cls = string.rsplit(\".\", 1)\n    if reload:\n        module_imp = importlib.import_module(module)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=None), cls)\n\n\nclass AdamWwithEMAandWings(optim.Optimizer):\n    # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298\n    def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8,  # TODO: check hyperparameters before using\n                 weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999,   # ema decay to match previous code\n                 ema_power=1., param_names=()):\n        \"\"\"AdamW that saves EMA versions of the parameters.\"\"\"\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 0: {}\".format(betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 1: {}\".format(betas[1]))\n        if not 0.0 <= weight_decay:\n            raise ValueError(\"Invalid weight_decay value: {}\".format(weight_decay))\n        if not 0.0 <= ema_decay <= 1.0:\n            raise ValueError(\"Invalid ema_decay value: {}\".format(ema_decay))\n        defaults = dict(lr=lr, betas=betas, eps=eps,\n                        weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,\n                        ema_power=ema_power, param_names=param_names)\n        super().__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super().__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault('amsgrad', False)\n\n    @torch.no_grad()\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n        Args:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            with torch.enable_grad():\n                loss = closure()\n\n        for group in self.param_groups:\n            params_with_grad = []\n            grads = []\n            exp_avgs = []\n            exp_avg_sqs = []\n            ema_params_with_grad = []\n            state_sums = []\n            max_exp_avg_sqs = []\n            state_steps = []\n            amsgrad = group['amsgrad']\n            beta1, beta2 = group['betas']\n            ema_decay = group['ema_decay']\n            ema_power = group['ema_power']\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                params_with_grad.append(p)\n                if p.grad.is_sparse:\n                    raise RuntimeError('AdamW does not support sparse gradients')\n                grads.append(p.grad)\n\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state['step'] = 0\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)\n                    # Exponential moving average of squared gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)\n                    if amsgrad:\n                        # Maintains max of all exp. moving avg. of sq. grad. values\n                        state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)\n                    # Exponential moving average of parameter values\n                    state['param_exp_avg'] = p.detach().float().clone()\n\n                exp_avgs.append(state['exp_avg'])\n                exp_avg_sqs.append(state['exp_avg_sq'])\n                ema_params_with_grad.append(state['param_exp_avg'])\n\n                if amsgrad:\n                    max_exp_avg_sqs.append(state['max_exp_avg_sq'])\n\n                # update the steps for each param group update\n                state['step'] += 1\n                # record the step after step update\n                state_steps.append(state['step'])\n\n            optim._functional.adamw(params_with_grad,\n                    grads,\n                    exp_avgs,\n                    exp_avg_sqs,\n                    max_exp_avg_sqs,\n                    state_steps,\n                    amsgrad=amsgrad,\n                    beta1=beta1,\n                    beta2=beta2,\n                    lr=group['lr'],\n                    weight_decay=group['weight_decay'],\n                    eps=group['eps'],\n                    maximize=False)\n\n            cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)\n            for param, ema_param in zip(params_with_grad, ema_params_with_grad):\n                ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)\n\n        return loss"
  },
  {
    "path": "main.py",
    "content": "import torch\nimport argparse\nimport pandas as pd\nimport sys\n\nfrom nerf.provider import NeRFDataset\nfrom nerf.utils import *\n\n# torch.autograd.set_detect_anomaly(True)\n\nif __name__ == '__main__':\n    # See https://stackoverflow.com/questions/27433316/how-to-get-argparse-to-read-arguments-from-a-file-with-an-option-rather-than-pre\n    class LoadFromFile (argparse.Action):\n        def __call__ (self, parser, namespace, values, option_string = None):\n            with values as f:\n                # parse arguments in the file and store them in the target namespace\n                parser.parse_args(f.read().split(), namespace)\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--file', type=open, action=LoadFromFile, help=\"specify a file filled with more arguments\")\n    parser.add_argument('--text', default=None, help=\"text prompt\")\n    parser.add_argument('--negative', default='', type=str, help=\"negative text prompt\")\n    parser.add_argument('-O', action='store_true', help=\"equals --fp16 --cuda_ray\")\n    parser.add_argument('-O2', action='store_true', help=\"equals --backbone vanilla\")\n    parser.add_argument('--test', action='store_true', help=\"test mode\")\n    parser.add_argument('--six_views', action='store_true', help=\"six_views mode: save the images of the six views\")\n    parser.add_argument('--eval_interval', type=int, default=1, help=\"evaluate on the valid set every interval epochs\")\n    parser.add_argument('--test_interval', type=int, default=100, help=\"test on the test set every interval epochs\")\n    parser.add_argument('--workspace', type=str, default='workspace')\n    parser.add_argument('--seed', default=None)\n\n    parser.add_argument('--image', default=None, help=\"image prompt\")\n    parser.add_argument('--image_config', default=None, help=\"image config csv\")\n\n    parser.add_argument('--known_view_interval', type=int, default=4, help=\"train default view with RGB loss every & iters, only valid if --image is not None.\")\n\n    parser.add_argument('--IF', action='store_true', help=\"experimental: use DeepFloyd IF as the guidance model for nerf stage\")\n\n    parser.add_argument('--guidance', type=str, nargs='*', default=['SD'], help='guidance model')\n    parser.add_argument('--guidance_scale', type=float, default=100, help=\"diffusion model classifier-free guidance scale\")\n\n    parser.add_argument('--save_mesh', action='store_true', help=\"export an obj mesh with texture\")\n    parser.add_argument('--mcubes_resolution', type=int, default=256, help=\"mcubes resolution for extracting mesh\")\n    parser.add_argument('--decimate_target', type=int, default=5e4, help=\"target face number for mesh decimation\")\n\n    parser.add_argument('--dmtet', action='store_true', help=\"use dmtet finetuning\")\n    parser.add_argument('--tet_grid_size', type=int, default=128, help=\"tet grid size\")\n    parser.add_argument('--init_with', type=str, default='', help=\"ckpt to init dmtet\")\n    parser.add_argument('--lock_geo', action='store_true', help=\"disable dmtet to learn geometry\")\n\n    ## Perp-Neg options\n    parser.add_argument('--perpneg', action='store_true', help=\"use perp_neg\")\n    parser.add_argument('--negative_w', type=float, default=-2, help=\"The scale of the weights of negative prompts. A larger value will help to avoid the Janus problem, but may cause flat faces. Vary between 0 to -4, depending on the prompt\")\n    parser.add_argument('--front_decay_factor', type=float, default=2, help=\"decay factor for the front prompt\")\n    parser.add_argument('--side_decay_factor', type=float, default=10, help=\"decay factor for the side prompt\")\n\n    ### training options\n    parser.add_argument('--iters', type=int, default=10000, help=\"training iters\")\n    parser.add_argument('--lr', type=float, default=1e-3, help=\"max learning rate\")\n    parser.add_argument('--ckpt', type=str, default='latest', help=\"possible options are ['latest', 'scratch', 'best', 'latest_model']\")\n    parser.add_argument('--cuda_ray', action='store_true', help=\"use CUDA raymarching instead of pytorch\")\n    parser.add_argument('--taichi_ray', action='store_true', help=\"use taichi raymarching\")\n    parser.add_argument('--max_steps', type=int, default=1024, help=\"max num steps sampled per ray (only valid when using --cuda_ray)\")\n    parser.add_argument('--num_steps', type=int, default=64, help=\"num steps sampled per ray (only valid when not using --cuda_ray)\")\n    parser.add_argument('--upsample_steps', type=int, default=32, help=\"num steps up-sampled per ray (only valid when not using --cuda_ray)\")\n    parser.add_argument('--update_extra_interval', type=int, default=16, help=\"iter interval to update extra status (only valid when using --cuda_ray)\")\n    parser.add_argument('--max_ray_batch', type=int, default=4096, help=\"batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)\")\n    parser.add_argument('--latent_iter_ratio', type=float, default=0.2, help=\"training iters that only use albedo shading\")\n    parser.add_argument('--albedo_iter_ratio', type=float, default=0, help=\"training iters that only use albedo shading\")\n    parser.add_argument('--min_ambient_ratio', type=float, default=0.1, help=\"minimum ambient ratio to use in lambertian shading\")\n    parser.add_argument('--textureless_ratio', type=float, default=0.2, help=\"ratio of textureless shading\")\n    parser.add_argument('--jitter_pose', action='store_true', help=\"add jitters to the randomly sampled camera poses\")\n    parser.add_argument('--jitter_center', type=float, default=0.2, help=\"amount of jitter to add to sampled camera pose's center (camera location)\")\n    parser.add_argument('--jitter_target', type=float, default=0.2, help=\"amount of jitter to add to sampled camera pose's target (i.e. 'look-at')\")\n    parser.add_argument('--jitter_up', type=float, default=0.02, help=\"amount of jitter to add to sampled camera pose's up-axis (i.e. 'camera roll')\")\n    parser.add_argument('--uniform_sphere_rate', type=float, default=0, help=\"likelihood of sampling camera location uniformly on the sphere surface area\")\n    parser.add_argument('--grad_clip', type=float, default=-1, help=\"clip grad of all grad to this limit, negative value disables it\")\n    parser.add_argument('--grad_clip_rgb', type=float, default=-1, help=\"clip grad of rgb space grad to this limit, negative value disables it\")\n    # model options\n    parser.add_argument('--bg_radius', type=float, default=1.4, help=\"if positive, use a background model at sphere(bg_radius)\")\n    parser.add_argument('--density_activation', type=str, default='exp', choices=['softplus', 'exp'], help=\"density activation function\")\n    parser.add_argument('--density_thresh', type=float, default=10, help=\"threshold for density grid to be occupied\")\n    parser.add_argument('--blob_density', type=float, default=5, help=\"max (center) density for the density blob\")\n    parser.add_argument('--blob_radius', type=float, default=0.2, help=\"control the radius for the density blob\")\n    # network backbone\n    parser.add_argument('--backbone', type=str, default='grid', choices=['grid_tcnn', 'grid', 'vanilla', 'grid_taichi'], help=\"nerf backbone\")\n    parser.add_argument('--optim', type=str, default='adan', choices=['adan', 'adam'], help=\"optimizer\")\n    parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help=\"stable diffusion version\")\n    parser.add_argument('--hf_key', type=str, default=None, help=\"hugging face Stable diffusion model key\")\n    # try this if CUDA OOM\n    parser.add_argument('--fp16', action='store_true', help=\"use float16 for training\")\n    parser.add_argument('--vram_O', action='store_true', help=\"optimization for low VRAM usage\")\n    # rendering resolution in training, increase these for better quality / decrease these if CUDA OOM even if --vram_O enabled.\n    parser.add_argument('--w', type=int, default=64, help=\"render width for NeRF in training\")\n    parser.add_argument('--h', type=int, default=64, help=\"render height for NeRF in training\")\n    parser.add_argument('--known_view_scale', type=float, default=1.5, help=\"multiply --h/w by this for known view rendering\")\n    parser.add_argument('--known_view_noise_scale', type=float, default=2e-3, help=\"random camera noise added to rays_o and rays_d\")\n    parser.add_argument('--dmtet_reso_scale', type=float, default=8, help=\"multiply --h/w by this for dmtet finetuning\")\n    parser.add_argument('--batch_size', type=int, default=1, help=\"images to render per batch using NeRF\")\n\n    ### dataset options\n    parser.add_argument('--bound', type=float, default=1, help=\"assume the scene is bounded in box(-bound, bound)\")\n    parser.add_argument('--dt_gamma', type=float, default=0, help=\"dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)\")\n    parser.add_argument('--min_near', type=float, default=0.01, help=\"minimum near distance for camera\")\n\n    parser.add_argument('--radius_range', type=float, nargs='*', default=[3.0, 3.5], help=\"training camera radius range\")\n    parser.add_argument('--theta_range', type=float, nargs='*', default=[45, 105], help=\"training camera range along the polar angles (i.e. up and down). See advanced.md for details.\")\n    parser.add_argument('--phi_range', type=float, nargs='*', default=[-180, 180], help=\"training camera range along the azimuth angles (i.e. left and right). See advanced.md for details.\")\n    parser.add_argument('--fovy_range', type=float, nargs='*', default=[10, 30], help=\"training camera fovy range\")\n\n    parser.add_argument('--default_radius', type=float, default=3.2, help=\"radius for the default view\")\n    parser.add_argument('--default_polar', type=float, default=90, help=\"polar for the default view\")\n    parser.add_argument('--default_azimuth', type=float, default=0, help=\"azimuth for the default view\")\n    parser.add_argument('--default_fovy', type=float, default=20, help=\"fovy for the default view\")\n\n    parser.add_argument('--progressive_view', action='store_true', help=\"progressively expand view sampling range from default to full\")\n    parser.add_argument('--progressive_view_init_ratio', type=float, default=0.2, help=\"initial ratio of final range, used for progressive_view\")\n    \n    parser.add_argument('--progressive_level', action='store_true', help=\"progressively increase gridencoder's max_level\")\n\n    parser.add_argument('--angle_overhead', type=float, default=30, help=\"[0, angle_overhead] is the overhead region\")\n    parser.add_argument('--angle_front', type=float, default=60, help=\"[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.\")\n    parser.add_argument('--t_range', type=float, nargs='*', default=[0.02, 0.98], help=\"stable diffusion time steps range\")\n    parser.add_argument('--dont_override_stuff',action='store_true', help=\"Don't override t_range, etc.\")\n\n\n    ### regularizations\n    parser.add_argument('--lambda_entropy', type=float, default=1e-3, help=\"loss scale for alpha entropy\")\n    parser.add_argument('--lambda_opacity', type=float, default=0, help=\"loss scale for alpha value\")\n    parser.add_argument('--lambda_orient', type=float, default=1e-2, help=\"loss scale for orientation\")\n    parser.add_argument('--lambda_tv', type=float, default=0, help=\"loss scale for total variation\")\n    parser.add_argument('--lambda_wd', type=float, default=0, help=\"loss scale\")\n\n    parser.add_argument('--lambda_mesh_normal', type=float, default=0.5, help=\"loss scale for mesh normal smoothness\")\n    parser.add_argument('--lambda_mesh_laplacian', type=float, default=0.5, help=\"loss scale for mesh laplacian\")\n\n    parser.add_argument('--lambda_guidance', type=float, default=1, help=\"loss scale for SDS\")\n    parser.add_argument('--lambda_rgb', type=float, default=1000, help=\"loss scale for RGB\")\n    parser.add_argument('--lambda_mask', type=float, default=500, help=\"loss scale for mask (alpha)\")\n    parser.add_argument('--lambda_normal', type=float, default=0, help=\"loss scale for normal map\")\n    parser.add_argument('--lambda_depth', type=float, default=10, help=\"loss scale for relative depth\")\n    parser.add_argument('--lambda_2d_normal_smooth', type=float, default=0, help=\"loss scale for 2D normal image smoothness\")\n    parser.add_argument('--lambda_3d_normal_smooth', type=float, default=0, help=\"loss scale for 3D normal image smoothness\")\n\n    ### debugging options\n    parser.add_argument('--save_guidance', action='store_true', help=\"save images of the per-iteration NeRF renders, added noise, denoised (i.e. guidance), fully-denoised. Useful for debugging, but VERY SLOW and takes lots of memory!\")\n    parser.add_argument('--save_guidance_interval', type=int, default=10, help=\"save guidance every X step\")\n\n    ### GUI options\n    parser.add_argument('--gui', action='store_true', help=\"start a GUI\")\n    parser.add_argument('--W', type=int, default=800, help=\"GUI width\")\n    parser.add_argument('--H', type=int, default=800, help=\"GUI height\")\n    parser.add_argument('--radius', type=float, default=5, help=\"default GUI camera radius from center\")\n    parser.add_argument('--fovy', type=float, default=20, help=\"default GUI camera fovy\")\n    parser.add_argument('--light_theta', type=float, default=60, help=\"default GUI light direction in [0, 180], corresponding to elevation [90, -90]\")\n    parser.add_argument('--light_phi', type=float, default=0, help=\"default GUI light direction in [0, 360), azimuth\")\n    parser.add_argument('--max_spp', type=int, default=1, help=\"GUI rendering max sample per pixel\")\n\n    parser.add_argument('--zero123_config', type=str, default='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml', help=\"config file for zero123\")\n    parser.add_argument('--zero123_ckpt', type=str, default='pretrained/zero123/zero123-xl.ckpt', help=\"ckpt for zero123\")\n    parser.add_argument('--zero123_grad_scale', type=str, default='angle', help=\"whether to scale the gradients based on 'angle' or 'None'\")\n\n    parser.add_argument('--dataset_size_train', type=int, default=100, help=\"Length of train dataset i.e. # of iterations per epoch\")\n    parser.add_argument('--dataset_size_valid', type=int, default=8, help=\"# of frames to render in the turntable video in validation\")\n    parser.add_argument('--dataset_size_test', type=int, default=100, help=\"# of frames to render in the turntable video at test time\")\n\n    parser.add_argument('--exp_start_iter', type=int, default=None, help=\"start iter # for experiment, to calculate progressive_view and progressive_level\")\n    parser.add_argument('--exp_end_iter', type=int, default=None, help=\"end iter # for experiment, to calculate progressive_view and progressive_level\")\n\n    opt = parser.parse_args()\n\n    if opt.O:\n        opt.fp16 = True\n        opt.cuda_ray = True\n\n    elif opt.O2:\n        opt.fp16 = True\n        opt.backbone = 'vanilla'\n        opt.progressive_level = True\n\n    if opt.IF:\n        if 'SD' in opt.guidance:\n            opt.guidance.remove('SD')\n            opt.guidance.append('IF')\n        opt.latent_iter_ratio = 0 # must not do as_latent\n\n    opt.images, opt.ref_radii, opt.ref_polars, opt.ref_azimuths, opt.zero123_ws = [], [], [], [], []\n    opt.default_zero123_w = 1\n\n    opt.exp_start_iter = opt.exp_start_iter or 0\n    opt.exp_end_iter = opt.exp_end_iter or opt.iters\n\n    # parameters for image-conditioned generation\n    if opt.image is not None or opt.image_config is not None:\n\n        if opt.text is None:\n            # use zero123 guidance model when only providing image\n            opt.guidance = ['zero123']\n            if not opt.dont_override_stuff:\n                opt.fovy_range = [opt.default_fovy, opt.default_fovy] # fix fov as zero123 doesn't support changing fov\n                opt.guidance_scale = 5\n                opt.lambda_3d_normal_smooth = 10\n        else:\n            # use stable-diffusion when providing both text and image\n            opt.guidance = ['SD', 'clip']\n            \n            if not opt.dont_override_stuff:\n                opt.guidance_scale = 10\n                opt.t_range = [0.2, 0.6]\n                opt.known_view_interval = 2\n                opt.lambda_3d_normal_smooth = 20\n            opt.bg_radius = -1\n\n        # smoothness\n        opt.lambda_entropy = 1\n        opt.lambda_orient = 1\n\n        # latent warmup is not needed\n        opt.latent_iter_ratio = 0\n        if not opt.dont_override_stuff:\n            opt.albedo_iter_ratio = 0\n            \n            # make shape init more stable\n            opt.progressive_view = True\n            opt.progressive_level = True\n\n        if opt.image is not None:\n            opt.images += [opt.image]\n            opt.ref_radii += [opt.default_radius]\n            opt.ref_polars += [opt.default_polar]\n            opt.ref_azimuths += [opt.default_azimuth]\n            opt.zero123_ws += [opt.default_zero123_w]\n\n        if opt.image_config is not None:\n            # for multiview (zero123)\n            conf = pd.read_csv(opt.image_config, skipinitialspace=True)\n            opt.images += list(conf.image)\n            opt.ref_radii += list(conf.radius)\n            opt.ref_polars += list(conf.polar)\n            opt.ref_azimuths += list(conf.azimuth)\n            opt.zero123_ws += list(conf.zero123_weight)\n            if opt.image is None:\n                opt.default_radius = opt.ref_radii[0]\n                opt.default_polar = opt.ref_polars[0]\n                opt.default_azimuth = opt.ref_azimuths[0]\n                opt.default_zero123_w = opt.zero123_ws[0]\n\n    # reset to None\n    if len(opt.images) == 0:\n        opt.images = None\n\n    # default parameters for finetuning\n    if opt.dmtet:\n\n        opt.h = int(opt.h * opt.dmtet_reso_scale)\n        opt.w = int(opt.w * opt.dmtet_reso_scale)\n        opt.known_view_scale = 1\n\n        if not opt.dont_override_stuff:            \n            opt.t_range = [0.02, 0.50] # ref: magic3D\n\n        if opt.images is not None:\n\n            opt.lambda_normal = 0\n            opt.lambda_depth = 0\n\n            if opt.text is not None and not opt.dont_override_stuff:\n                opt.t_range = [0.20, 0.50]\n\n        # assume finetuning\n        opt.latent_iter_ratio = 0\n        opt.albedo_iter_ratio = 0\n        opt.progressive_view = False\n        # opt.progressive_level = False\n\n    # record full range for progressive view expansion\n    if opt.progressive_view:\n        if not opt.dont_override_stuff:\n            # disable as they disturb progressive view\n            opt.jitter_pose = False\n            \n        opt.uniform_sphere_rate = 0\n        # back up full range\n        opt.full_radius_range = opt.radius_range\n        opt.full_theta_range = opt.theta_range\n        opt.full_phi_range = opt.phi_range\n        opt.full_fovy_range = opt.fovy_range\n\n    if opt.backbone == 'vanilla':\n        from nerf.network import NeRFNetwork\n    elif opt.backbone == 'grid':\n        from nerf.network_grid import NeRFNetwork\n    elif opt.backbone == 'grid_tcnn':\n        from nerf.network_grid_tcnn import NeRFNetwork\n    elif opt.backbone == 'grid_taichi':\n        opt.cuda_ray = False\n        opt.taichi_ray = True\n        import taichi as ti\n        from nerf.network_grid_taichi import NeRFNetwork\n        taichi_half2_opt = True\n        taichi_init_args = {\"arch\": ti.cuda, \"device_memory_GB\": 4.0}\n        if taichi_half2_opt:\n            taichi_init_args[\"half2_vectorization\"] = True\n        ti.init(**taichi_init_args)\n    else:\n        raise NotImplementedError(f'--backbone {opt.backbone} is not implemented!')\n\n    print(opt)\n\n    if opt.seed is not None:\n        seed_everything(int(opt.seed))\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n    model = NeRFNetwork(opt).to(device)\n\n    if opt.dmtet and opt.init_with != '':\n        if opt.init_with.endswith('.pth'):\n            # load pretrained weights to init dmtet\n            state_dict = torch.load(opt.init_with, map_location=device)\n            model.load_state_dict(state_dict['model'], strict=False)\n            if opt.cuda_ray:\n                model.mean_density = state_dict['mean_density']\n            model.init_tet()\n        else:\n            # assume a mesh to init dmtet (experimental, not working well now!)\n            import trimesh\n            mesh = trimesh.load(opt.init_with, force='mesh', skip_material=True, process=False)\n            model.init_tet(mesh=mesh)\n\n    print(model)\n\n    if opt.six_views:\n        guidance = None # no need to load guidance model at test\n\n        trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)\n\n        test_loader = NeRFDataset(opt, device=device, type='six_views', H=opt.H, W=opt.W, size=6).dataloader(batch_size=1)\n        trainer.test(test_loader, write_video=False)\n\n        if opt.save_mesh:\n            trainer.save_mesh()\n\n    elif opt.test:\n        guidance = None # no need to load guidance model at test\n\n        trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)\n\n        if opt.gui:\n            from nerf.gui import NeRFGUI\n            gui = NeRFGUI(opt, trainer)\n            gui.render()\n\n        else:\n            test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader(batch_size=1)\n            trainer.test(test_loader)\n\n            if opt.save_mesh:\n                trainer.save_mesh()\n\n    else:\n\n        train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=opt.dataset_size_train * opt.batch_size).dataloader()\n\n        if opt.optim == 'adan':\n            from optimizer import Adan\n            # Adan usually requires a larger LR\n            optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False)\n        else: # adam\n            optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)\n\n        if opt.backbone == 'vanilla':\n            scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))\n        else:\n            scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1) # fixed\n            # scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))\n\n        guidance = nn.ModuleDict()\n\n        if 'SD' in opt.guidance:\n            from guidance.sd_utils import StableDiffusion\n            guidance['SD'] = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key, opt.t_range)\n\n        if 'IF' in opt.guidance:\n            from guidance.if_utils import IF\n            guidance['IF'] = IF(device, opt.vram_O, opt.t_range)\n\n        if 'zero123' in opt.guidance:\n            from guidance.zero123_utils import Zero123\n            guidance['zero123'] = Zero123(device=device, fp16=opt.fp16, config=opt.zero123_config, ckpt=opt.zero123_ckpt, vram_O=opt.vram_O, t_range=opt.t_range, opt=opt)\n\n        if 'clip' in opt.guidance:\n            from guidance.clip_utils import CLIP\n            guidance['clip'] = CLIP(device)\n\n        trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, scheduler_update_every_step=True)\n\n        trainer.default_view_data = train_loader._data.get_default_view_data()\n\n        if opt.gui:\n            from nerf.gui import NeRFGUI\n            gui = NeRFGUI(opt, trainer, train_loader)\n            gui.render()\n\n        else:\n            valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=opt.dataset_size_valid).dataloader(batch_size=1)\n            test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader(batch_size=1)\n\n            max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)\n            trainer.train(train_loader, valid_loader, test_loader, max_epoch)\n\n            if opt.save_mesh:\n                trainer.save_mesh()\n"
  },
  {
    "path": "meshutils.py",
    "content": "import numpy as np\nimport pymeshlab as pml\n\ndef poisson_mesh_reconstruction(points, normals=None):\n    # points/normals: [N, 3] np.ndarray\n\n    import open3d as o3d\n\n    pcd = o3d.geometry.PointCloud()\n    pcd.points = o3d.utility.Vector3dVector(points)\n\n    # outlier removal\n    pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)\n\n    # normals\n    if normals is None:\n        pcd.estimate_normals()\n    else:\n        pcd.normals = o3d.utility.Vector3dVector(normals[ind])\n\n    # visualize\n    o3d.visualization.draw_geometries([pcd], point_show_normal=False)\n    \n    mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9)\n    vertices_to_remove = densities < np.quantile(densities, 0.1)\n    mesh.remove_vertices_by_mask(vertices_to_remove)\n\n    # visualize\n    o3d.visualization.draw_geometries([mesh])\n\n    vertices = np.asarray(mesh.vertices)\n    triangles = np.asarray(mesh.triangles)\n\n    print(f'[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}')\n\n    return vertices, triangles\n    \n\ndef decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True):\n    # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.\n\n    _ori_vert_shape = verts.shape\n    _ori_face_shape = faces.shape\n\n    if backend == 'pyfqmr':\n        import pyfqmr\n        solver = pyfqmr.Simplify()\n        solver.setMesh(verts, faces)\n        solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)\n        verts, faces, normals = solver.getMesh()\n    else:\n        \n        m = pml.Mesh(verts, faces)\n        ms = pml.MeshSet()\n        ms.add_mesh(m, 'mesh') # will copy!\n\n        # filters\n        # ms.meshing_decimation_clustering(threshold=pml.Percentage(1))\n        ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement)\n\n        if remesh:\n            # ms.apply_coord_taubin_smoothing()\n            ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1))\n\n        # extract mesh\n        m = ms.current_mesh()\n        verts = m.vertex_matrix()\n        faces = m.face_matrix()\n\n    print(f'[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')\n\n    return verts, faces\n\n\ndef clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=True, remesh_size=0.01):\n    # verts: [N, 3]\n    # faces: [N, 3]\n\n    _ori_vert_shape = verts.shape\n    _ori_face_shape = faces.shape\n\n    m = pml.Mesh(verts, faces)\n    ms = pml.MeshSet()\n    ms.add_mesh(m, 'mesh') # will copy!\n\n    # filters\n    ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces\n\n    if v_pct > 0:\n        ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal\n\n    ms.meshing_remove_duplicate_faces() # faces defined by the same verts\n    ms.meshing_remove_null_faces() # faces with area == 0\n\n    if min_d > 0:\n        ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d))\n    \n    if min_f > 0:\n        ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)\n\n    if repair:\n        # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)\n        ms.meshing_repair_non_manifold_edges(method=0)\n        ms.meshing_repair_non_manifold_vertices(vertdispratio=0)\n    \n    if remesh:\n        # ms.apply_coord_taubin_smoothing()\n        ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.AbsoluteValue(remesh_size))\n\n    # extract mesh\n    m = ms.current_mesh()\n    verts = m.vertex_matrix()\n    faces = m.face_matrix()\n\n    print(f'[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')\n\n    return verts, faces    "
  },
  {
    "path": "nerf/gui.py",
    "content": "import math\nimport torch\nimport numpy as np\nimport dearpygui.dearpygui as dpg\nfrom scipy.spatial.transform import Rotation as R\n\nfrom nerf.utils import *\n\n\nclass OrbitCamera:\n    def __init__(self, W, H, r=2, fovy=60):\n        self.W = W\n        self.H = H\n        self.radius = r # camera distance from center\n        self.fovy = fovy # in degree\n        self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point\n        self.rot = R.from_matrix(np.eye(3))\n        self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!\n        self.near = 0.001\n        self.far = 1000\n\n    # pose\n    @property\n    def pose(self):\n        # first move camera to radius\n        res = np.eye(4, dtype=np.float32)\n        res[2, 3] = self.radius\n        # rotate\n        rot = np.eye(4, dtype=np.float32)\n        rot[:3, :3] = self.rot.as_matrix()\n        res = rot @ res\n        # translate\n        res[:3, 3] -= self.center\n        return res\n    \n    # intrinsics\n    @property\n    def intrinsics(self):\n        focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))\n        return np.array([focal, focal, self.W // 2, self.H // 2])\n\n    @property\n    def mvp(self):\n        focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))\n        projection = np.array([\n            [2*focal/self.W, 0, 0, 0], \n            [0, -2*focal/self.H, 0, 0],\n            [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],\n            [0, 0, -1, 0]\n        ], dtype=np.float32)\n\n        return projection @ np.linalg.inv(self.pose) # [4, 4]\n    \n    def orbit(self, dx, dy):\n        # rotate along camera up/side axis!\n        side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.\n        rotvec_x = self.up * np.deg2rad(-0.1 * dx)\n        rotvec_y = side * np.deg2rad(-0.1 * dy)\n        self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot\n\n    def scale(self, delta):\n        self.radius *= 1.1 ** (-delta)\n\n    def pan(self, dx, dy, dz=0):\n        # pan in camera coordinate system (careful on the sensitivity!)\n        self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, -dy, dz])\n\n\nclass NeRFGUI:\n    def __init__(self, opt, trainer, loader=None, debug=True):\n        self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.\n        self.W = opt.W\n        self.H = opt.H\n        self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)\n        self.debug = debug\n        self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg\n        self.training = False\n        self.step = 0 # training step \n\n        self.trainer = trainer\n        self.loader = loader\n        self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)\n        self.need_update = True # camera moved, should reset accumulation\n        self.spp = 1 # sample per pixel\n        self.light_dir = np.array([opt.light_theta, opt.light_phi])\n        self.ambient_ratio = 1.0\n        self.mode = 'image' # choose from ['image', 'depth']\n        self.shading = 'albedo'\n\n        self.dynamic_resolution = True if not self.opt.dmtet else False\n        self.downscale = 1\n        self.train_steps = 16\n\n        dpg.create_context()\n        self.register_dpg()\n        self.test_step()\n        \n\n    def __del__(self):\n        dpg.destroy_context()\n\n\n    def train_step(self):\n\n        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)\n        starter.record()\n\n        outputs = self.trainer.train_gui(self.loader, step=self.train_steps)\n\n        ender.record()\n        torch.cuda.synchronize()\n        t = starter.elapsed_time(ender)\n\n        self.step += self.train_steps\n        self.need_update = True\n\n        dpg.set_value(\"_log_train_time\", f'{t:.4f}ms ({int(1000/t)} FPS)')\n        dpg.set_value(\"_log_train_log\", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs[\"loss\"]:.4f}, lr = {outputs[\"lr\"]:.5f}')\n\n        # dynamic train steps\n        # max allowed train time per-frame is 500 ms\n        full_t = t / self.train_steps * 16\n        train_steps = min(16, max(4, int(16 * 500 / full_t)))\n        if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:\n            self.train_steps = train_steps\n\n\n    def prepare_buffer(self, outputs):\n        if self.mode == 'image':\n            return outputs['image'].astype(np.float32)\n        else:\n            depth = outputs['depth'].astype(np.float32)\n            depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6)\n            return np.expand_dims(depth, -1).repeat(3, -1)\n\n    \n    def test_step(self):\n\n        if self.need_update or self.spp < self.opt.max_spp:\n        \n            starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)\n            starter.record()\n\n            outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.cam.mvp, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading)\n\n            ender.record()\n            torch.cuda.synchronize()\n            t = starter.elapsed_time(ender)\n\n            # update dynamic resolution\n            if self.dynamic_resolution:\n                # max allowed infer time per-frame is 200 ms\n                full_t = t / (self.downscale ** 2)\n                downscale = min(1, max(1/4, math.sqrt(200 / full_t)))\n                if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:\n                    self.downscale = downscale\n\n            if self.need_update:\n                self.render_buffer = self.prepare_buffer(outputs)\n                self.spp = 1\n                self.need_update = False\n            else:\n                self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)\n                self.spp += 1\n\n            dpg.set_value(\"_log_infer_time\", f'{t:.4f}ms ({int(1000/t)} FPS)')\n            dpg.set_value(\"_log_resolution\", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')\n            dpg.set_value(\"_log_spp\", self.spp)\n            dpg.set_value(\"_texture\", self.render_buffer)\n\n        \n    def register_dpg(self):\n\n        ### register texture \n\n        with dpg.texture_registry(show=False):\n            dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag=\"_texture\")\n\n        ### register window\n\n        # the rendered image, as the primary window\n        with dpg.window(tag=\"_primary_window\", width=self.W, height=self.H):\n\n            # add the texture\n            dpg.add_image(\"_texture\")\n\n        dpg.set_primary_window(\"_primary_window\", True)\n\n        # control window\n        with dpg.window(label=\"Control\", tag=\"_control_window\", width=400, height=300):\n\n            # text prompt\n            if self.opt.text is not None:\n                dpg.add_text(\"text: \" + self.opt.text, tag=\"_log_prompt_text\")\n            \n            if self.opt.negative != '':\n                dpg.add_text(\"negative text: \" + self.opt.negative, tag=\"_log_prompt_negative_text\")\n\n            # button theme\n            with dpg.theme() as theme_button:\n                with dpg.theme_component(dpg.mvButton):\n                    dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))\n                    dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))\n                    dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))\n                    dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)\n                    dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)\n\n            # time\n            if not self.opt.test:\n                with dpg.group(horizontal=True):\n                    dpg.add_text(\"Train time: \")\n                    dpg.add_text(\"no data\", tag=\"_log_train_time\")                    \n\n            with dpg.group(horizontal=True):\n                dpg.add_text(\"Infer time: \")\n                dpg.add_text(\"no data\", tag=\"_log_infer_time\")\n            \n            with dpg.group(horizontal=True):\n                dpg.add_text(\"SPP: \")\n                dpg.add_text(\"1\", tag=\"_log_spp\")\n\n            # train button\n            if not self.opt.test:\n                with dpg.collapsing_header(label=\"Train\", default_open=True):\n                    with dpg.group(horizontal=True):\n                        dpg.add_text(\"Train: \")\n\n                        def callback_train(sender, app_data):\n                            if self.training:\n                                self.training = False\n                                dpg.configure_item(\"_button_train\", label=\"start\")\n                            else:\n                                self.training = True\n                                dpg.configure_item(\"_button_train\", label=\"stop\")\n\n                        dpg.add_button(label=\"start\", tag=\"_button_train\", callback=callback_train)\n                        dpg.bind_item_theme(\"_button_train\", theme_button)\n\n                        def callback_reset(sender, app_data):\n                            @torch.no_grad()\n                            def weight_reset(m: nn.Module):\n                                reset_parameters = getattr(m, \"reset_parameters\", None)\n                                if callable(reset_parameters):\n                                    m.reset_parameters()\n                            self.trainer.model.apply(fn=weight_reset)\n                            self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter\n                            self.need_update = True\n\n                        dpg.add_button(label=\"reset\", tag=\"_button_reset\", callback=callback_reset)\n                        dpg.bind_item_theme(\"_button_reset\", theme_button)\n\n\n                    with dpg.group(horizontal=True):\n                        dpg.add_text(\"Checkpoint: \")\n\n                        def callback_save(sender, app_data):\n                            self.trainer.save_checkpoint(full=True, best=False)\n                            dpg.set_value(\"_log_ckpt\", \"saved \" + os.path.basename(self.trainer.stats[\"checkpoints\"][-1]))\n                            self.trainer.epoch += 1 # use epoch to indicate different calls.\n\n                        dpg.add_button(label=\"save\", tag=\"_button_save\", callback=callback_save)\n                        dpg.bind_item_theme(\"_button_save\", theme_button)\n\n                        dpg.add_text(\"\", tag=\"_log_ckpt\")\n\n                    # save mesh\n                    with dpg.group(horizontal=True):\n                        dpg.add_text(\"Marching Cubes: \")\n\n                        def callback_mesh(sender, app_data):\n                            self.trainer.save_mesh()\n                            dpg.set_value(\"_log_mesh\", \"saved \" + f'{self.trainer.name}_{self.trainer.epoch}.ply')\n                            self.trainer.epoch += 1 # use epoch to indicate different calls.\n\n                        dpg.add_button(label=\"mesh\", tag=\"_button_mesh\", callback=callback_mesh)\n                        dpg.bind_item_theme(\"_button_mesh\", theme_button)\n\n                        dpg.add_text(\"\", tag=\"_log_mesh\")                        \n\n                    with dpg.group(horizontal=True):\n                        dpg.add_text(\"\", tag=\"_log_train_log\")\n\n            \n            # rendering options\n            with dpg.collapsing_header(label=\"Options\", default_open=True):\n\n                # dynamic rendering resolution\n                with dpg.group(horizontal=True):\n\n                    def callback_set_dynamic_resolution(sender, app_data):\n                        if self.dynamic_resolution:\n                            self.dynamic_resolution = False\n                            self.downscale = 1\n                        else:\n                            self.dynamic_resolution = True\n                        self.need_update = True\n\n                    dpg.add_checkbox(label=\"dynamic resolution\", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)\n                    dpg.add_text(f\"{self.W}x{self.H}\", tag=\"_log_resolution\")\n\n                # mode combo\n                def callback_change_mode(sender, app_data):\n                    self.mode = app_data\n                    self.need_update = True\n                \n                dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)\n\n                # bg_color picker\n                def callback_change_bg(sender, app_data):\n                    self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]\n                    self.need_update = True\n\n                dpg.add_color_edit((255, 255, 255), label=\"Background Color\", width=200, tag=\"_color_editor\", no_alpha=True, callback=callback_change_bg)\n\n                # fov slider\n                def callback_set_fovy(sender, app_data):\n                    self.cam.fovy = app_data\n                    self.need_update = True\n\n                dpg.add_slider_int(label=\"FoV (vertical)\", min_value=1, max_value=120, format=\"%d deg\", default_value=self.cam.fovy, callback=callback_set_fovy)\n\n                # dt_gamma slider\n                def callback_set_dt_gamma(sender, app_data):\n                    self.opt.dt_gamma = app_data\n                    self.need_update = True\n\n                dpg.add_slider_float(label=\"dt_gamma\", min_value=0, max_value=0.1, format=\"%.5f\", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)\n\n                # max_steps slider\n                def callback_set_max_steps(sender, app_data):\n                    self.opt.max_steps = app_data\n                    self.need_update = True\n\n                dpg.add_slider_int(label=\"max steps\", min_value=1, max_value=1024, format=\"%d\", default_value=self.opt.max_steps, callback=callback_set_max_steps)\n\n                # aabb slider\n                def callback_set_aabb(sender, app_data, user_data):\n                    # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)\n                    self.trainer.model.aabb_infer[user_data] = app_data\n\n                    # also change train aabb ? [better not...]\n                    #self.trainer.model.aabb_train[user_data] = app_data\n\n                    self.need_update = True\n\n                dpg.add_separator()\n                dpg.add_text(\"Axis-aligned bounding box:\")\n\n                with dpg.group(horizontal=True):\n                    dpg.add_slider_float(label=\"x\", width=150, min_value=-self.opt.bound, max_value=0, format=\"%.2f\", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)\n                    dpg.add_slider_float(label=\"\", width=150, min_value=0, max_value=self.opt.bound, format=\"%.2f\", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)\n\n                with dpg.group(horizontal=True):\n                    dpg.add_slider_float(label=\"y\", width=150, min_value=-self.opt.bound, max_value=0, format=\"%.2f\", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)\n                    dpg.add_slider_float(label=\"\", width=150, min_value=0, max_value=self.opt.bound, format=\"%.2f\", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)\n\n                with dpg.group(horizontal=True):\n                    dpg.add_slider_float(label=\"z\", width=150, min_value=-self.opt.bound, max_value=0, format=\"%.2f\", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)\n                    dpg.add_slider_float(label=\"\", width=150, min_value=0, max_value=self.opt.bound, format=\"%.2f\", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)\n\n                # light dir\n                def callback_set_light_dir(sender, app_data, user_data):\n                    self.light_dir[user_data] = app_data\n                    self.need_update = True\n\n                dpg.add_separator()\n                dpg.add_text(\"Plane Light Direction:\")\n\n                with dpg.group(horizontal=True):\n                    dpg.add_slider_float(label=\"theta\", min_value=0, max_value=180, format=\"%.2f\", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0)\n\n                with dpg.group(horizontal=True):\n                    dpg.add_slider_float(label=\"phi\", min_value=0, max_value=360, format=\"%.2f\", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1)\n\n                # ambient ratio\n                def callback_set_abm_ratio(sender, app_data):\n                    self.ambient_ratio = app_data\n                    self.need_update = True\n\n                dpg.add_slider_float(label=\"ambient\", min_value=0, max_value=1.0, format=\"%.5f\", default_value=self.ambient_ratio, callback=callback_set_abm_ratio)\n\n                # shading mode\n                def callback_change_shading(sender, app_data):\n                    self.shading = app_data\n                    self.need_update = True\n                \n                dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading)\n\n\n            # debug info\n            if self.debug:\n                with dpg.collapsing_header(label=\"Debug\"):\n                    # pose\n                    dpg.add_separator()\n                    dpg.add_text(\"Camera Pose:\")\n                    dpg.add_text(str(self.cam.pose), tag=\"_log_pose\")\n\n\n        ### register camera handler\n\n        def callback_camera_drag_rotate(sender, app_data):\n\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            dx = app_data[1]\n            dy = app_data[2]\n\n            self.cam.orbit(dx, dy)\n            self.need_update = True\n\n            if self.debug:\n                dpg.set_value(\"_log_pose\", str(self.cam.pose))\n\n\n        def callback_camera_wheel_scale(sender, app_data):\n\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            delta = app_data\n\n            self.cam.scale(delta)\n            self.need_update = True\n\n            if self.debug:\n                dpg.set_value(\"_log_pose\", str(self.cam.pose))\n\n\n        def callback_camera_drag_pan(sender, app_data):\n\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            dx = app_data[1]\n            dy = app_data[2]\n\n            self.cam.pan(dx, dy)\n            self.need_update = True\n\n            if self.debug:\n                dpg.set_value(\"_log_pose\", str(self.cam.pose))\n\n\n        with dpg.handler_registry():\n            dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)\n            dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)\n            dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Right, callback=callback_camera_drag_pan)\n\n        \n        dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False)\n        \n        # TODO: seems dearpygui doesn't support resizing texture...\n        # def callback_resize(sender, app_data):\n        #     self.W = app_data[0]\n        #     self.H = app_data[1]\n        #     # how to reload texture ???\n\n        # dpg.set_viewport_resize_callback(callback_resize)\n\n        ### global theme\n        with dpg.theme() as theme_no_padding:\n            with dpg.theme_component(dpg.mvAll):\n                # set all padding to 0 to avoid scroll bar\n                dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)\n                dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)\n                dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)\n        \n        dpg.bind_item_theme(\"_primary_window\", theme_no_padding)\n\n        dpg.setup_dearpygui()\n\n        #dpg.show_metrics()\n\n        dpg.show_viewport()\n\n\n    def render(self):\n\n        while dpg.is_dearpygui_running():\n            # update texture every frame\n            if self.training:\n                self.train_step()\n            self.test_step()\n            dpg.render_dearpygui_frame()"
  },
  {
    "path": "nerf/network.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom activation import trunc_exp\nfrom .renderer import NeRFRenderer\n\nimport numpy as np\nfrom encoding import get_encoder\n\nfrom .utils import safe_normalize\n\n# TODO: not sure about the details...\nclass ResBlock(nn.Module):\n    def __init__(self, dim_in, dim_out, bias=True):\n        super().__init__()\n        self.dim_in = dim_in\n        self.dim_out = dim_out\n\n        self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias)\n        self.norm = nn.LayerNorm(self.dim_out)\n        self.activation = nn.SiLU(inplace=True)\n\n        if self.dim_in != self.dim_out:\n            self.skip = nn.Linear(self.dim_in, self.dim_out, bias=False)\n        else:\n            self.skip = None\n\n    def forward(self, x):\n        # x: [B, C]\n        identity = x\n\n        out = self.dense(x)\n        out = self.norm(out)\n\n        if self.skip is not None:\n            identity = self.skip(identity)\n\n        out += identity\n        out = self.activation(out)\n\n        return out\n\nclass BasicBlock(nn.Module):\n    def __init__(self, dim_in, dim_out, bias=True):\n        super().__init__()\n        self.dim_in = dim_in\n        self.dim_out = dim_out\n\n        self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias)\n        self.activation = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        # x: [B, C]\n\n        out = self.dense(x)\n        out = self.activation(out)\n\n        return out    \n\nclass MLP(nn.Module):\n    def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True, block=BasicBlock):\n        super().__init__()\n        self.dim_in = dim_in\n        self.dim_out = dim_out\n        self.dim_hidden = dim_hidden\n        self.num_layers = num_layers\n\n        net = []\n        for l in range(num_layers):\n            if l == 0:\n                net.append(BasicBlock(self.dim_in, self.dim_hidden, bias=bias))\n            elif l != num_layers - 1:\n                net.append(block(self.dim_hidden, self.dim_hidden, bias=bias))\n            else:\n                net.append(nn.Linear(self.dim_hidden, self.dim_out, bias=bias))\n\n        self.net = nn.ModuleList(net)\n        \n    \n    def forward(self, x):\n\n        for l in range(self.num_layers):\n            x = self.net[l](x)\n            \n        return x\n\n\nclass NeRFNetwork(NeRFRenderer):\n    def __init__(self, \n                 opt,\n                 num_layers=5, # 5 in paper\n                 hidden_dim=64, # 128 in paper\n                 num_layers_bg=2, # 3 in paper\n                 hidden_dim_bg=32, # 64 in paper\n                 encoding='frequency_torch', # pure pytorch\n                 ):\n        \n        super().__init__(opt)\n\n        self.num_layers = num_layers\n        self.hidden_dim = hidden_dim\n        self.encoder, self.in_dim = get_encoder(encoding, input_dim=3, multires=12)\n        self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True, block=ResBlock)\n\n        self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else F.softplus\n\n        # background network\n        if self.opt.bg_radius > 0:\n            self.num_layers_bg = num_layers_bg   \n            self.hidden_dim_bg = hidden_dim_bg\n            self.encoder_bg, self.in_dim_bg = get_encoder(encoding, input_dim=3, multires=4)\n            self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)\n            \n        else:\n            self.bg_net = None\n\n    def common_forward(self, x):\n        # x: [N, 3], in [-bound, bound]\n\n        # sigma\n        enc = self.encoder(x, bound=self.bound, max_level=self.max_level)\n\n        h = self.sigma_net(enc)\n\n        sigma = self.density_activation(h[..., 0] + self.density_blob(x))\n        albedo = torch.sigmoid(h[..., 1:])\n\n        return sigma, albedo\n    \n    # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192\n    def finite_difference_normal(self, x, epsilon=1e-2):\n        # x: [N, 3]\n        dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))\n        dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))\n        dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))\n        dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))\n        dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))\n        dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))\n        \n        normal = torch.stack([\n            0.5 * (dx_pos - dx_neg) / epsilon, \n            0.5 * (dy_pos - dy_neg) / epsilon, \n            0.5 * (dz_pos - dz_neg) / epsilon\n        ], dim=-1)\n\n        return -normal\n    \n    def normal(self, x):\n    \n        with torch.enable_grad():\n            with torch.cuda.amp.autocast(enabled=False):\n                x.requires_grad_(True)\n                sigma, albedo = self.common_forward(x)\n                # query gradient\n                normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]\n        \n        # normal = self.finite_difference_normal(x)\n        normal = safe_normalize(normal)\n        normal = torch.nan_to_num(normal)\n\n        return normal\n        \n    def forward(self, x, d, l=None, ratio=1, shading='albedo'):\n        # x: [N, 3], in [-bound, bound]\n        # d: [N, 3], view direction, nomalized in [-1, 1]\n        # l: [3], plane light direction, nomalized in [-1, 1]\n        # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)\n\n        if shading == 'albedo':\n            # no need to query normal\n            sigma, color = self.common_forward(x)\n            normal = None\n        \n        else:\n            # query normal\n\n            # sigma, albedo = self.common_forward(x)\n            # normal = self.normal(x)\n        \n            with torch.enable_grad():\n                with torch.cuda.amp.autocast(enabled=False):\n                    x.requires_grad_(True)\n                    sigma, albedo = self.common_forward(x)\n                    # query gradient\n                    normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]\n            normal = safe_normalize(normal)\n            normal = torch.nan_to_num(normal)\n\n            # lambertian shading\n            lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]\n\n            if shading == 'textureless':\n                color = lambertian.unsqueeze(-1).repeat(1, 3)\n            elif shading == 'normal':\n                color = (normal + 1) / 2\n            else: # 'lambertian'\n                color = albedo * lambertian.unsqueeze(-1)\n            \n        return sigma, color, normal\n\n      \n    def density(self, x):\n        # x: [N, 3], in [-bound, bound]\n        \n        sigma, albedo = self.common_forward(x)\n        \n        return {\n            'sigma': sigma,\n            'albedo': albedo,\n        }\n\n\n    def background(self, d):\n\n        h = self.encoder_bg(d) # [N, C]\n        \n        h = self.bg_net(h)\n\n        # sigmoid activation for rgb\n        rgbs = torch.sigmoid(h)\n\n        return rgbs\n\n    # optimizer utils\n    def get_params(self, lr):\n\n        params = [\n            # {'params': self.encoder.parameters(), 'lr': lr * 10},\n            {'params': self.sigma_net.parameters(), 'lr': lr},\n        ]        \n\n        if self.opt.bg_radius > 0:\n            # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})\n            params.append({'params': self.bg_net.parameters(), 'lr': lr})\n        \n        if self.opt.dmtet and not self.opt.lock_geo:\n            params.append({'params': self.sdf, 'lr': lr})\n            params.append({'params': self.deform, 'lr': lr})\n\n        return params"
  },
  {
    "path": "nerf/network_grid.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom activation import trunc_exp, biased_softplus\nfrom .renderer import NeRFRenderer\n\nimport numpy as np\nfrom encoding import get_encoder\n\nfrom .utils import safe_normalize\n\nclass MLP(nn.Module):\n    def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):\n        super().__init__()\n        self.dim_in = dim_in\n        self.dim_out = dim_out\n        self.dim_hidden = dim_hidden\n        self.num_layers = num_layers\n\n        net = []\n        for l in range(num_layers):\n            net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))\n\n        self.net = nn.ModuleList(net)\n    \n    def forward(self, x):\n        for l in range(self.num_layers):\n            x = self.net[l](x)\n            if l != self.num_layers - 1:\n                x = F.relu(x, inplace=True)\n        return x\n\n\nclass NeRFNetwork(NeRFRenderer):\n    def __init__(self, \n                 opt,\n                 num_layers=3,\n                 hidden_dim=64,\n                 num_layers_bg=2,\n                 hidden_dim_bg=32,\n                 ):\n        \n        super().__init__(opt)\n\n        self.num_layers = num_layers\n        self.hidden_dim = hidden_dim\n\n        self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')\n\n        self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)\n        # self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)\n\n        self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus\n\n        # background network\n        if self.opt.bg_radius > 0:\n            self.num_layers_bg = num_layers_bg   \n            self.hidden_dim_bg = hidden_dim_bg\n            \n            # use a very simple network to avoid it learning the prompt...\n            self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)\n            self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)\n            \n        else:\n            self.bg_net = None\n\n    def common_forward(self, x):\n\n        # sigma\n        enc = self.encoder(x, bound=self.bound, max_level=self.max_level)\n\n        h = self.sigma_net(enc)\n\n        sigma = self.density_activation(h[..., 0] + self.density_blob(x))\n        albedo = torch.sigmoid(h[..., 1:])\n\n        return sigma, albedo\n    \n    # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192\n    def finite_difference_normal(self, x, epsilon=1e-2):\n        # x: [N, 3]\n        dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))\n        dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))\n        dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))\n        dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))\n        dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))\n        dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))\n        \n        normal = torch.stack([\n            0.5 * (dx_pos - dx_neg) / epsilon, \n            0.5 * (dy_pos - dy_neg) / epsilon, \n            0.5 * (dz_pos - dz_neg) / epsilon\n        ], dim=-1)\n\n        return -normal\n\n    def normal(self, x):\n        normal = self.finite_difference_normal(x)\n        normal = safe_normalize(normal)\n        normal = torch.nan_to_num(normal)\n        return normal\n    \n    def forward(self, x, d, l=None, ratio=1, shading='albedo'):\n        # x: [N, 3], in [-bound, bound]\n        # d: [N, 3], view direction, nomalized in [-1, 1]\n        # l: [3], plane light direction, nomalized in [-1, 1]\n        # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)\n\n        sigma, albedo = self.common_forward(x)\n\n        if shading == 'albedo':\n            normal = None\n            color = albedo\n        \n        else: # lambertian shading\n\n            # normal = self.normal_net(enc)\n            normal = self.normal(x)\n\n            lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]\n\n            if shading == 'textureless':\n                color = lambertian.unsqueeze(-1).repeat(1, 3)\n            elif shading == 'normal':\n                color = (normal + 1) / 2\n            else: # 'lambertian'\n                color = albedo * lambertian.unsqueeze(-1)\n            \n        return sigma, color, normal\n\n      \n    def density(self, x):\n        # x: [N, 3], in [-bound, bound]\n        \n        sigma, albedo = self.common_forward(x)\n        \n        return {\n            'sigma': sigma,\n            'albedo': albedo,\n        }\n\n\n    def background(self, d):\n\n        h = self.encoder_bg(d) # [N, C]\n        \n        h = self.bg_net(h)\n\n        # sigmoid activation for rgb\n        rgbs = torch.sigmoid(h)\n\n        return rgbs\n\n    # optimizer utils\n    def get_params(self, lr):\n\n        params = [\n            {'params': self.encoder.parameters(), 'lr': lr * 10},\n            {'params': self.sigma_net.parameters(), 'lr': lr},\n            # {'params': self.normal_net.parameters(), 'lr': lr},\n        ]        \n\n        if self.opt.bg_radius > 0:\n            # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})\n            params.append({'params': self.bg_net.parameters(), 'lr': lr})\n        \n        if self.opt.dmtet and not self.opt.lock_geo:\n            params.append({'params': self.sdf, 'lr': lr})\n            params.append({'params': self.deform, 'lr': lr})\n\n        return params"
  },
  {
    "path": "nerf/network_grid_taichi.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom activation import trunc_exp\nfrom .renderer import NeRFRenderer\n\nimport numpy as np\nfrom encoding import get_encoder\n\nfrom .utils import safe_normalize\n\nclass MLP(nn.Module):\n    def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):\n        super().__init__()\n        self.dim_in = dim_in\n        self.dim_out = dim_out\n        self.dim_hidden = dim_hidden\n        self.num_layers = num_layers\n\n        net = []\n        for l in range(num_layers):\n            net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))\n\n        self.net = nn.ModuleList(net)\n    \n    def forward(self, x):\n        for l in range(self.num_layers):\n            x = self.net[l](x)\n            if l != self.num_layers - 1:\n                x = F.relu(x, inplace=True)\n        return x\n\n\nclass NeRFNetwork(NeRFRenderer):\n    def __init__(self, \n                 opt,\n                 num_layers=2,\n                 hidden_dim=32,\n                 num_layers_bg=2,\n                 hidden_dim_bg=16,\n                 ):\n        \n        super().__init__(opt)\n\n        self.num_layers = num_layers\n        self.hidden_dim = hidden_dim\n\n        self.encoder, self.in_dim = get_encoder('hashgrid_taichi', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')\n\n        self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)\n        # self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)\n\n        self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else F.softplus\n\n        # background network\n        if self.opt.bg_radius > 0:\n            self.num_layers_bg = num_layers_bg\n            self.hidden_dim_bg = hidden_dim_bg\n            # use a very simple network to avoid it learning the prompt...\n            self.encoder_bg, self.in_dim_bg = get_encoder('frequency_torch', input_dim=3, multires=4) # TODO: freq encoder can be replaced by a Taichi implementation\n            self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)\n            \n        else:\n            self.bg_net = None\n\n    def common_forward(self, x):\n\n        # sigma\n        enc = self.encoder(x, bound=self.bound)\n\n        h = self.sigma_net(enc)\n\n        sigma = self.density_activation(h[..., 0] + self.density_blob(x))\n        albedo = torch.sigmoid(h[..., 1:])\n\n        return sigma, albedo\n    \n    # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192\n    def finite_difference_normal(self, x, epsilon=1e-2):\n        # x: [N, 3]\n        dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))\n        dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))\n        dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))\n        dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))\n        dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))\n        dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))\n        \n        normal = torch.stack([\n            0.5 * (dx_pos - dx_neg) / epsilon, \n            0.5 * (dy_pos - dy_neg) / epsilon, \n            0.5 * (dz_pos - dz_neg) / epsilon\n        ], dim=-1)\n\n        return -normal\n    \n    def normal(self, x):\n        normal = self.finite_difference_normal(x)\n        normal = safe_normalize(normal)\n        normal = torch.nan_to_num(normal)\n        return normal\n    \n    def forward(self, x, d, l=None, ratio=1, shading='albedo'):\n        # x: [N, 3], in [-bound, bound]\n        # d: [N, 3], view direction, nomalized in [-1, 1]\n        # l: [3], plane light direction, nomalized in [-1, 1]\n        # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)\n\n        sigma, albedo = self.common_forward(x)\n\n        if shading == 'albedo':\n            normal = None\n            color = albedo\n        \n        else: # lambertian shading\n            # normal = self.normal_net(enc)\n            normal = self.normal(x)\n\n            lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]\n\n            if shading == 'textureless':\n                color = lambertian.unsqueeze(-1).repeat(1, 3)\n            elif shading == 'normal':\n                color = (normal + 1) / 2\n            else: # 'lambertian'\n                color = albedo * lambertian.unsqueeze(-1)\n            \n        return sigma, color, normal\n\n      \n    def density(self, x):\n        # x: [N, 3], in [-bound, bound]\n        \n        sigma, albedo = self.common_forward(x)\n        \n        return {\n            'sigma': sigma,\n            'albedo': albedo,\n        }\n\n\n    def background(self, d):\n\n        h = self.encoder_bg(d) # [N, C]\n        \n        h = self.bg_net(h)\n\n        # sigmoid activation for rgb\n        rgbs = torch.sigmoid(h)\n\n        return rgbs\n\n    # optimizer utils\n    def get_params(self, lr):\n\n        params = [\n            {'params': self.encoder.parameters(), 'lr': lr * 10},\n            {'params': self.sigma_net.parameters(), 'lr': lr},\n            # {'params': self.normal_net.parameters(), 'lr': lr},\n        ]        \n\n        if self.opt.bg_radius > 0:\n            # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})\n            params.append({'params': self.bg_net.parameters(), 'lr': lr})\n        \n        if self.opt.dmtet and not self.opt.lock_geo:\n            params.append({'params': self.sdf, 'lr': lr})\n            params.append({'params': self.deform, 'lr': lr})\n\n        return params"
  },
  {
    "path": "nerf/network_grid_tcnn.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom activation import trunc_exp, biased_softplus\nfrom .renderer import NeRFRenderer\n\nimport numpy as np\nfrom encoding import get_encoder\n\nfrom .utils import safe_normalize\n\nimport tinycudann as tcnn\n\nclass MLP(nn.Module):\n    def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):\n        super().__init__()\n        self.dim_in = dim_in\n        self.dim_out = dim_out\n        self.dim_hidden = dim_hidden\n        self.num_layers = num_layers\n\n        net = []\n        for l in range(num_layers):\n            net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))\n\n        self.net = nn.ModuleList(net)\n    \n    def forward(self, x):\n        for l in range(self.num_layers):\n            x = self.net[l](x)\n            if l != self.num_layers - 1:\n                x = F.relu(x, inplace=True)\n        return x\n\n\nclass NeRFNetwork(NeRFRenderer):\n    def __init__(self, \n                 opt,\n                 num_layers=3,\n                 hidden_dim=64,\n                 num_layers_bg=2,\n                 hidden_dim_bg=32,\n                 ):\n        \n        super().__init__(opt)\n\n        self.num_layers = num_layers\n        self.hidden_dim = hidden_dim\n\n        self.encoder = tcnn.Encoding(\n            n_input_dims=3,\n            encoding_config={\n                \"otype\": \"HashGrid\",\n                \"n_levels\": 16,\n                \"n_features_per_level\": 2,\n                \"log2_hashmap_size\": 19,\n                \"base_resolution\": 16,\n                \"interpolation\": \"Smoothstep\",\n                \"per_level_scale\": np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1)),\n            },\n            dtype=torch.float32, # ENHANCE: default float16 seems unstable...\n        )\n        self.in_dim = self.encoder.n_output_dims\n        # use torch MLP, as tcnn MLP doesn't impl second-order derivative\n        self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)\n\n        self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus\n\n        # background network\n        if self.opt.bg_radius > 0:\n            self.num_layers_bg = num_layers_bg   \n            self.hidden_dim_bg = hidden_dim_bg\n            \n            # use a very simple network to avoid it learning the prompt...\n            self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)\n            self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)\n            \n        else:\n            self.bg_net = None\n\n    def common_forward(self, x):\n\n        # sigma\n        enc = self.encoder((x + self.bound) / (2 * self.bound)).float()\n        h = self.sigma_net(enc)\n\n        sigma = self.density_activation(h[..., 0] + self.density_blob(x))\n        albedo = torch.sigmoid(h[..., 1:])\n\n        return sigma, albedo\n    \n    def normal(self, x):\n    \n        with torch.enable_grad():\n            with torch.cuda.amp.autocast(enabled=False):\n                x.requires_grad_(True)\n                sigma, albedo = self.common_forward(x)\n                # query gradient\n                normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]\n        \n        # normal = self.finite_difference_normal(x)\n        normal = safe_normalize(normal)\n        normal = torch.nan_to_num(normal)\n\n        return normal\n    \n    def forward(self, x, d, l=None, ratio=1, shading='albedo'):\n        # x: [N, 3], in [-bound, bound]\n        # d: [N, 3], view direction, nomalized in [-1, 1]\n        # l: [3], plane light direction, nomalized in [-1, 1]\n        # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)\n\n\n        if shading == 'albedo':\n            sigma, albedo = self.common_forward(x)\n            normal = None\n            color = albedo\n        \n        else: # lambertian shading\n            with torch.enable_grad():\n                with torch.cuda.amp.autocast(enabled=False):\n                    x.requires_grad_(True)\n                    sigma, albedo = self.common_forward(x)\n                    normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]\n            normal = safe_normalize(normal)\n            normal = torch.nan_to_num(normal)\n\n            lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]\n\n            if shading == 'textureless':\n                color = lambertian.unsqueeze(-1).repeat(1, 3)\n            elif shading == 'normal':\n                color = (normal + 1) / 2\n            else: # 'lambertian'\n                color = albedo * lambertian.unsqueeze(-1)\n            \n        return sigma, color, normal\n\n      \n    def density(self, x):\n        # x: [N, 3], in [-bound, bound]\n        \n        sigma, albedo = self.common_forward(x)\n        \n        return {\n            'sigma': sigma,\n            'albedo': albedo,\n        }\n\n\n    def background(self, d):\n\n        h = self.encoder_bg(d) # [N, C]\n        \n        h = self.bg_net(h)\n\n        # sigmoid activation for rgb\n        rgbs = torch.sigmoid(h)\n\n        return rgbs\n\n    # optimizer utils\n    def get_params(self, lr):\n\n        params = [\n            {'params': self.encoder.parameters(), 'lr': lr * 10},\n            {'params': self.sigma_net.parameters(), 'lr': lr},\n        ]        \n\n        if self.opt.bg_radius > 0:\n            params.append({'params': self.bg_net.parameters(), 'lr': lr})\n        \n        if self.opt.dmtet and not self.opt.lock_geo:\n            params.append({'params': self.sdf, 'lr': lr})\n            params.append({'params': self.deform, 'lr': lr})\n\n        return params"
  },
  {
    "path": "nerf/provider.py",
    "content": "import os\nimport cv2\nimport glob\nimport json\nimport tqdm\nimport random\nimport numpy as np\nfrom scipy.spatial.transform import Slerp, Rotation\n\nimport trimesh\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader\n\nfrom .utils import get_rays, safe_normalize\n\nDIR_COLORS = np.array([\n    [255, 0, 0, 255], # front\n    [0, 255, 0, 255], # side\n    [0, 0, 255, 255], # back\n    [255, 255, 0, 255], # side\n    [255, 0, 255, 255], # overhead\n    [0, 255, 255, 255], # bottom\n], dtype=np.uint8)\n\ndef visualize_poses(poses, dirs, size=0.1):\n    # poses: [B, 4, 4], dirs: [B]\n\n    axes = trimesh.creation.axis(axis_length=4)\n    sphere = trimesh.creation.icosphere(radius=1)\n    objects = [axes, sphere]\n\n    for pose, dir in zip(poses, dirs):\n        # a camera is visualized with 8 line segments.\n        pos = pose[:3, 3]\n        a = pos + size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2]\n        b = pos - size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2]\n        c = pos - size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2]\n        d = pos + size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2]\n\n        segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]])\n        segs = trimesh.load_path(segs)\n\n        # different color for different dirs\n        segs.colors = DIR_COLORS[[dir]].repeat(len(segs.entities), 0)\n\n        objects.append(segs)\n\n    trimesh.Scene(objects).show()\n\ndef get_view_direction(thetas, phis, overhead, front):\n    #                   phis: [B,];          thetas: [B,]\n    # front = 0             [-front/2, front/2)\n    # side (cam left) = 1   [front/2, 180-front/2)\n    # back = 2              [180-front/2, 180+front/2)\n    # side (cam right) = 3  [180+front/2, 360-front/2)\n    # top = 4               [0, overhead]\n    # bottom = 5            [180-overhead, 180]\n    res = torch.zeros(thetas.shape[0], dtype=torch.long)\n    # first determine by phis\n    phis = phis % (2 * np.pi)\n    res[(phis < front / 2) | (phis >= 2 * np.pi - front / 2)] = 0\n    res[(phis >= front / 2) & (phis < np.pi - front / 2)] = 1\n    res[(phis >= np.pi - front / 2) & (phis < np.pi + front / 2)] = 2\n    res[(phis >= np.pi + front / 2) & (phis < 2 * np.pi - front / 2)] = 3\n    # override by thetas\n    res[thetas <= overhead] = 4\n    res[thetas >= (np.pi - overhead)] = 5\n    return res\n\n\ndef rand_poses(size, device, opt, radius_range=[1, 1.5], theta_range=[0, 120], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, uniform_sphere_rate=0.5):\n    ''' generate random poses from an orbit camera\n    Args:\n        size: batch size of generated poses.\n        device: where to allocate the output.\n        radius: camera radius\n        theta_range: [min, max], should be in [0, pi]\n        phi_range: [min, max], should be in [0, 2 * pi]\n    Return:\n        poses: [size, 4, 4]\n    '''\n\n    theta_range = np.array(theta_range) / 180 * np.pi\n    phi_range = np.array(phi_range) / 180 * np.pi\n    angle_overhead = angle_overhead / 180 * np.pi\n    angle_front = angle_front / 180 * np.pi\n\n    radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]\n\n    if random.random() < uniform_sphere_rate:\n        unit_centers = F.normalize(\n            torch.stack([\n                torch.randn(size, device=device),\n                torch.abs(torch.randn(size, device=device)),\n                torch.randn(size, device=device),\n            ], dim=-1), p=2, dim=1\n        )\n        thetas = torch.acos(unit_centers[:,1])\n        phis = torch.atan2(unit_centers[:,0], unit_centers[:,2])\n        phis[phis < 0] += 2 * np.pi\n        centers = unit_centers * radius.unsqueeze(-1)\n    else:\n        thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]\n        phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]\n        phis[phis < 0] += 2 * np.pi\n\n        centers = torch.stack([\n            radius * torch.sin(thetas) * torch.sin(phis),\n            radius * torch.cos(thetas),\n            radius * torch.sin(thetas) * torch.cos(phis),\n        ], dim=-1) # [B, 3]\n\n    targets = 0\n\n    # jitters\n    if opt.jitter_pose:\n        jit_center = opt.jitter_center # 0.015  # was 0.2\n        jit_target = opt.jitter_target\n        centers += torch.rand_like(centers) * jit_center - jit_center/2.0\n        targets += torch.randn_like(centers) * jit_target\n\n    # lookat\n    forward_vector = safe_normalize(centers - targets)\n    up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)\n    right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))\n\n    if opt.jitter_pose:\n        up_noise = torch.randn_like(up_vector) * opt.jitter_up\n    else:\n        up_noise = 0\n\n    up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)\n\n    poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)\n    poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)\n    poses[:, :3, 3] = centers\n\n    if return_dirs:\n        dirs = get_view_direction(thetas, phis, angle_overhead, angle_front)\n    else:\n        dirs = None\n\n    # back to degree\n    thetas = thetas / np.pi * 180\n    phis = phis / np.pi * 180\n\n    return poses, dirs, thetas, phis, radius\n\n\ndef circle_poses(device, radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), return_dirs=False, angle_overhead=30, angle_front=60):\n\n    theta = theta / 180 * np.pi\n    phi = phi / 180 * np.pi\n    angle_overhead = angle_overhead / 180 * np.pi\n    angle_front = angle_front / 180 * np.pi\n\n    centers = torch.stack([\n        radius * torch.sin(theta) * torch.sin(phi),\n        radius * torch.cos(theta),\n        radius * torch.sin(theta) * torch.cos(phi),\n    ], dim=-1) # [B, 3]\n\n    # lookat\n    forward_vector = safe_normalize(centers)\n    up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(len(centers), 1)\n    right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))\n    up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))\n\n    poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(len(centers), 1, 1)\n    poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)\n    poses[:, :3, 3] = centers\n\n    if return_dirs:\n        dirs = get_view_direction(theta, phi, angle_overhead, angle_front)\n    else:\n        dirs = None\n\n    return poses, dirs\n\n\nclass NeRFDataset:\n    def __init__(self, opt, device, type='train', H=256, W=256, size=100):\n        super().__init__()\n\n        self.opt = opt\n        self.device = device\n        self.type = type # train, val, test\n\n        self.H = H\n        self.W = W\n        self.size = size\n\n        self.training = self.type in ['train', 'all']\n\n        self.cx = self.H / 2\n        self.cy = self.W / 2\n\n        self.near = self.opt.min_near\n        self.far = 1000 # infinite\n\n        # [debug] visualize poses\n        # poses, dirs, _, _, _ = rand_poses(100, self.device, opt, radius_range=self.opt.radius_range, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose, uniform_sphere_rate=1)\n        # visualize_poses(poses.detach().cpu().numpy(), dirs.detach().cpu().numpy())\n\n    def get_default_view_data(self):\n\n        H = int(self.opt.known_view_scale * self.H)\n        W = int(self.opt.known_view_scale * self.W)\n        cx = H / 2\n        cy = W / 2\n\n        radii = torch.FloatTensor(self.opt.ref_radii).to(self.device)\n        thetas = torch.FloatTensor(self.opt.ref_polars).to(self.device)\n        phis = torch.FloatTensor(self.opt.ref_azimuths).to(self.device)\n        poses, dirs = circle_poses(self.device, radius=radii, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)\n        fov = self.opt.default_fovy\n        focal = H / (2 * np.tan(np.deg2rad(fov) / 2))\n        intrinsics = np.array([focal, focal, cx, cy])\n\n        projection = torch.tensor([\n            [2*focal/W, 0, 0, 0],\n            [0, -2*focal/H, 0, 0],\n            [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],\n            [0, 0, -1, 0]\n        ], dtype=torch.float32, device=self.device).unsqueeze(0).repeat(len(radii), 1, 1)\n\n        mvp = projection @ torch.inverse(poses) # [B, 4, 4]\n\n        # sample a low-resolution but full image\n        rays = get_rays(poses, intrinsics, H, W, -1)\n\n        data = {\n            'H': H,\n            'W': W,\n            'rays_o': rays['rays_o'],\n            'rays_d': rays['rays_d'],\n            'dir': dirs,\n            'mvp': mvp,\n            'polar': self.opt.ref_polars,\n            'azimuth': self.opt.ref_azimuths,\n            'radius': self.opt.ref_radii,\n        }\n\n        return data\n\n    def collate(self, index):\n\n        B = len(index)\n\n        if self.training:\n            # random pose on the fly\n            poses, dirs, thetas, phis, radius = rand_poses(B, self.device, self.opt, radius_range=self.opt.radius_range, theta_range=self.opt.theta_range, phi_range=self.opt.phi_range, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, uniform_sphere_rate=self.opt.uniform_sphere_rate)\n\n            # random focal\n            fov = random.random() * (self.opt.fovy_range[1] - self.opt.fovy_range[0]) + self.opt.fovy_range[0]\n\n        elif self.type == 'six_views':\n            # six views\n            thetas_six = [90, 90,  90,  90, 1e-3, 179.999]\n            phis_six =   [ 0, 90, 180, -90,    0,       0]\n            thetas = torch.FloatTensor([thetas_six[index[0]]]).to(self.device)\n            phis = torch.FloatTensor([phis_six[index[0]]]).to(self.device)\n            radius = torch.FloatTensor([self.opt.default_radius]).to(self.device)\n            poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)\n\n            # fixed focal\n            fov = self.opt.default_fovy\n\n        else:\n            # circle pose\n            thetas = torch.FloatTensor([self.opt.default_polar]).to(self.device)\n            phis = torch.FloatTensor([(index[0] / self.size) * 360]).to(self.device)\n            radius = torch.FloatTensor([self.opt.default_radius]).to(self.device)\n            poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)\n\n            # fixed focal\n            fov = self.opt.default_fovy\n\n        focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2))\n        intrinsics = np.array([focal, focal, self.cx, self.cy])\n\n        projection = torch.tensor([\n            [2*focal/self.W, 0, 0, 0],\n            [0, -2*focal/self.H, 0, 0],\n            [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],\n            [0, 0, -1, 0]\n        ], dtype=torch.float32, device=self.device).unsqueeze(0)\n\n        mvp = projection @ torch.inverse(poses) # [1, 4, 4]\n\n        # sample a low-resolution but full image\n        rays = get_rays(poses, intrinsics, self.H, self.W, -1)\n\n        # delta polar/azimuth/radius to default view\n        delta_polar = thetas - self.opt.default_polar\n        delta_azimuth = phis - self.opt.default_azimuth\n        delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]\n        delta_radius = radius - self.opt.default_radius\n\n        data = {\n            'H': self.H,\n            'W': self.W,\n            'rays_o': rays['rays_o'],\n            'rays_d': rays['rays_d'],\n            'dir': dirs,\n            'mvp': mvp,\n            'polar': delta_polar,\n            'azimuth': delta_azimuth,\n            'radius': delta_radius,\n        }\n\n        return data\n\n    def dataloader(self, batch_size=None):\n        batch_size = batch_size or self.opt.batch_size\n        loader = DataLoader(list(range(self.size)), batch_size=batch_size, collate_fn=self.collate, shuffle=self.training, num_workers=0)\n        loader._data = self\n        return loader"
  },
  {
    "path": "nerf/renderer.py",
    "content": "import os\nimport math\nimport cv2\nimport trimesh\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport nvdiffrast.torch as dr\n\nimport mcubes\nimport raymarching\nfrom meshutils import decimate_mesh, clean_mesh, poisson_mesh_reconstruction\nfrom .utils import custom_meshgrid, safe_normalize\n\n\ndef sample_pdf(bins, weights, n_samples, det=False):\n    # This implementation is from NeRF\n    # bins: [B, T], old_z_vals\n    # weights: [B, T - 1], bin weights.\n    # return: [B, n_samples], new_z_vals\n\n    # Get pdf\n    weights = weights + 1e-5  # prevent nans\n    pdf = weights / torch.sum(weights, -1, keepdim=True)\n    cdf = torch.cumsum(pdf, -1)\n    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)\n    # Take uniform samples\n    if det:\n        u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)\n        u = u.expand(list(cdf.shape[:-1]) + [n_samples])\n    else:\n        u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)\n\n    # Invert CDF\n    u = u.contiguous()\n    inds = torch.searchsorted(cdf, u, right=True)\n    below = torch.max(torch.zeros_like(inds - 1), inds - 1)\n    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)\n    inds_g = torch.stack([below, above], -1)  # (B, n_samples, 2)\n\n    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]\n    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)\n    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)\n\n    denom = (cdf_g[..., 1] - cdf_g[..., 0])\n    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)\n    t = (u - cdf_g[..., 0]) / denom\n    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])\n\n    return samples\n\n@torch.cuda.amp.autocast(enabled=False)\ndef near_far_from_bound(rays_o, rays_d, bound, type='cube', min_near=0.05):\n    # rays: [B, N, 3], [B, N, 3]\n    # bound: int, radius for ball or half-edge-length for cube\n    # return near [B, N, 1], far [B, N, 1]\n\n    radius = rays_o.norm(dim=-1, keepdim=True)\n\n    if type == 'sphere':\n        near = radius - bound # [B, N, 1]\n        far = radius + bound\n\n    elif type == 'cube':\n        tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3]\n        tmax = (bound - rays_o) / (rays_d + 1e-15)\n        near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0]\n        far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0]\n        # if far < near, means no intersection, set both near and far to inf (1e9 here)\n        mask = far < near\n        near[mask] = 1e9\n        far[mask] = 1e9\n        # restrict near to a minimal value\n        near = torch.clamp(near, min=min_near)\n\n    return near, far\n\n\ndef plot_pointcloud(pc, color=None):\n    # pc: [N, 3]\n    # color: [N, 3/4]\n    print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))\n    pc = trimesh.PointCloud(pc, color)\n    # axis\n    axes = trimesh.creation.axis(axis_length=4)\n    # sphere\n    sphere = trimesh.creation.icosphere(radius=1)\n    trimesh.Scene([pc, axes, sphere]).show()\n\n\nclass DMTet():\n    def __init__(self, device):\n        self.device = device\n        self.triangle_table = torch.tensor([\n            [-1, -1, -1, -1, -1, -1],\n            [ 1,  0,  2, -1, -1, -1],\n            [ 4,  0,  3, -1, -1, -1],\n            [ 1,  4,  2,  1,  3,  4],\n            [ 3,  1,  5, -1, -1, -1],\n            [ 2,  3,  0,  2,  5,  3],\n            [ 1,  4,  0,  1,  5,  4],\n            [ 4,  2,  5, -1, -1, -1],\n            [ 4,  5,  2, -1, -1, -1],\n            [ 4,  1,  0,  4,  5,  1],\n            [ 3,  2,  0,  3,  5,  2],\n            [ 1,  3,  5, -1, -1, -1],\n            [ 4,  1,  2,  4,  3,  1],\n            [ 3,  0,  4, -1, -1, -1],\n            [ 2,  0,  1, -1, -1, -1],\n            [-1, -1, -1, -1, -1, -1]\n        ], dtype=torch.long, device=device)\n        self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device=device)\n        self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device=device)\n    \n    def sort_edges(self, edges_ex2):\n        with torch.no_grad():\n            order = (edges_ex2[:,0] > edges_ex2[:,1]).long()\n            order = order.unsqueeze(dim=1)\n\n            a = torch.gather(input=edges_ex2, index=order, dim=1)      \n            b = torch.gather(input=edges_ex2, index=1-order, dim=1)  \n\n        return torch.stack([a, b],-1)\n\n    def __call__(self, pos_nx3, sdf_n, tet_fx4):\n        # pos_nx3: [N, 3]\n        # sdf_n:   [N]\n        # tet_fx4: [F, 4]\n\n        with torch.no_grad():\n            occ_n = sdf_n > 0\n            occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)\n            occ_sum = torch.sum(occ_fx4, -1) # [F,]\n            valid_tets = (occ_sum>0) & (occ_sum<4)\n            occ_sum = occ_sum[valid_tets]\n\n            # find all vertices\n            all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)\n            all_edges = self.sort_edges(all_edges)\n            unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True)  \n            \n            unique_edges = unique_edges.long()\n            mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1\n            mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1\n            mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device=self.device)\n            idx_map = mapping[idx_map] # map edges to verts\n\n            interp_v = unique_edges[mask_edges]\n\n        edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)\n        edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)\n        edges_to_interp_sdf[:,-1] *= -1\n\n        denominator = edges_to_interp_sdf.sum(1,keepdim = True)\n\n        edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator\n        verts = (edges_to_interp * edges_to_interp_sdf).sum(1)\n\n        idx_map = idx_map.reshape(-1,6)\n\n        v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=self.device))\n        tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)\n        num_triangles = self.num_triangles_table[tetindex]\n\n        # Generate triangle indices\n        faces = torch.cat((\n            torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3),\n            torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3),\n        ), dim=0)\n\n        return verts, faces\n\ndef compute_edge_to_face_mapping(attr_idx):\n    with torch.no_grad():\n        # Get unique edges\n        # Create all edges, packed by triangle\n        all_edges = torch.cat((\n            torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1),\n            torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1),\n            torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1),\n        ), dim=-1).view(-1, 2)\n\n        # Swap edge order so min index is always first\n        order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)\n        sorted_edges = torch.cat((\n            torch.gather(all_edges, 1, order),\n            torch.gather(all_edges, 1, 1 - order)\n        ), dim=-1)\n\n        # Elliminate duplicates and return inverse mapping\n        unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True)\n\n        tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda()\n\n        tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda()\n\n        # Compute edge to face table\n        mask0 = order[:,0] == 0\n        mask1 = order[:,0] == 1\n        tris_per_edge[idx_map[mask0], 0] = tris[mask0]\n        tris_per_edge[idx_map[mask1], 1] = tris[mask1]\n\n        return tris_per_edge\n\n@torch.cuda.amp.autocast(enabled=False)\ndef normal_consistency(face_normals, t_pos_idx):\n\n    tris_per_edge = compute_edge_to_face_mapping(t_pos_idx)\n\n    # Fetch normals for both faces sharind an edge\n    n0 = face_normals[tris_per_edge[:, 0], :]\n    n1 = face_normals[tris_per_edge[:, 1], :]\n\n    # Compute error metric based on normal difference\n    term = torch.clamp(torch.sum(n0 * n1, -1, keepdim=True), min=-1.0, max=1.0)\n    term = (1.0 - term)\n\n    return torch.mean(torch.abs(term))\n\n\ndef laplacian_uniform(verts, faces):\n\n    V = verts.shape[0]\n    F = faces.shape[0]\n\n    # Neighbor indices\n    ii = faces[:, [1, 2, 0]].flatten()\n    jj = faces[:, [2, 0, 1]].flatten()\n    adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(dim=1)\n    adj_values = torch.ones(adj.shape[1], device=verts.device, dtype=torch.float)\n\n    # Diagonal indices\n    diag_idx = adj[0]\n\n    # Build the sparse matrix\n    idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)\n    values = torch.cat((-adj_values, adj_values))\n\n    # The coalesce operation sums the duplicate indices, resulting in the\n    # correct diagonal\n    return torch.sparse_coo_tensor(idx, values, (V,V)).coalesce()\n\n\n@torch.cuda.amp.autocast(enabled=False)\ndef laplacian_smooth_loss(verts, faces):\n    with torch.no_grad():\n        L = laplacian_uniform(verts, faces.long())\n    loss = L.mm(verts)\n    loss = loss.norm(dim=1)\n    loss = loss.mean()\n    return loss\n\n\nclass NeRFRenderer(nn.Module):\n    def __init__(self, opt):\n        super().__init__()\n\n        self.opt = opt\n        self.bound = opt.bound\n        self.cascade = 1 + math.ceil(math.log2(opt.bound))\n        self.grid_size = 128\n        self.max_level = None\n        self.dmtet = opt.dmtet\n        self.cuda_ray = opt.cuda_ray\n        self.taichi_ray = opt.taichi_ray\n        self.min_near = opt.min_near\n        self.density_thresh = opt.density_thresh\n\n        # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)\n        # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.\n        aabb_train = torch.FloatTensor([-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound])\n        aabb_infer = aabb_train.clone()\n        self.register_buffer('aabb_train', aabb_train)\n        self.register_buffer('aabb_infer', aabb_infer)\n\n        self.glctx = None\n\n        # extra state for cuda raymarching\n        if self.cuda_ray:\n            # density grid\n            density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]\n            density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]\n            self.register_buffer('density_grid', density_grid)\n            self.register_buffer('density_bitfield', density_bitfield)\n            self.mean_density = 0\n            self.iter_density = 0\n        \n        if self.dmtet:\n            # load dmtet vertices\n            tets = np.load('tets/{}_tets.npz'.format(self.opt.tet_grid_size))\n            self.verts = - torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * 2 # covers [-1, 1]\n            self.indices  = torch.tensor(tets['indices'], dtype=torch.long, device='cuda')\n            self.tet_scale = torch.tensor([1, 1, 1], dtype=torch.float32, device='cuda')\n            self.dmtet_model = DMTet('cuda')\n\n            # vert sdf and deform\n            sdf = torch.nn.Parameter(torch.zeros_like(self.verts[..., 0]), requires_grad=True)\n            self.register_parameter('sdf', sdf)\n            deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True)\n            self.register_parameter('deform', deform)\n\n            edges = torch.tensor([0,1, 0,2, 0,3, 1,2, 1,3, 2,3], dtype=torch.long, device=\"cuda\") # six edges for each tetrahedron.\n            all_edges = self.indices[:,edges].reshape(-1,2) # [M * 6, 2]\n            all_edges_sorted = torch.sort(all_edges, dim=1)[0]\n            self.all_edges = torch.unique(all_edges_sorted, dim=0)\n\n            if self.opt.h <= 2048 and self.opt.w <= 2048:\n                self.glctx = dr.RasterizeCudaContext()\n            else:\n                self.glctx = dr.RasterizeGLContext()\n        \n        if self.taichi_ray:\n            from einops import rearrange\n            from taichi_modules import RayMarcherTaichi\n            from taichi_modules import VolumeRendererTaichi\n            from taichi_modules import RayAABBIntersector as RayAABBIntersectorTaichi\n            from taichi_modules import raymarching_test as raymarching_test_taichi\n            from taichi_modules import composite_test as composite_test_fw\n            from taichi_modules import packbits as packbits_taichi\n            self.rearrange = rearrange\n            self.packbits_taichi = packbits_taichi\n            self.ray_aabb_intersector = RayAABBIntersectorTaichi\n            self.raymarching_test_taichi = raymarching_test_taichi\n            self.composite_test_fw = composite_test_fw\n            self.ray_marching = RayMarcherTaichi(batch_size=4096) # TODO: hard encoded batch size\n            self.volume_render = VolumeRendererTaichi(batch_size=4096) # TODO: hard encoded batch size\n            # density grid\n            density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]\n            density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]\n            self.register_buffer('density_grid', density_grid)\n            self.register_buffer('density_bitfield', density_bitfield)\n            self.mean_density = 0\n            self.iter_density = 0\n    \n    @torch.no_grad()\n    def density_blob(self, x):\n        # x: [B, N, 3]\n        \n        d = (x ** 2).sum(-1)\n        \n        if self.opt.density_activation == 'exp':\n            g = self.opt.blob_density * torch.exp(- d / (2 * self.opt.blob_radius ** 2))\n        else:\n            g = self.opt.blob_density * (1 - torch.sqrt(d) / self.opt.blob_radius)\n\n        return g\n    \n    def forward(self, x, d):\n        raise NotImplementedError()\n\n    def density(self, x):\n        raise NotImplementedError()\n\n    def reset_extra_state(self):\n        if not (self.cuda_ray or self.taichi_ray):\n            return \n        # density grid\n        self.density_grid.zero_()\n        self.mean_density = 0\n        self.iter_density = 0\n\n    @torch.no_grad()\n    def export_mesh(self, path, resolution=None, decimate_target=-1, S=128):\n\n        if self.opt.dmtet:\n\n            sdf = self.sdf\n            deform = torch.tanh(self.deform) / self.opt.tet_grid_size\n\n            vertices, triangles = self.dmtet_model(self.verts + deform, sdf, self.indices)\n\n            vertices = vertices.detach().cpu().numpy()\n            triangles = triangles.detach().cpu().numpy()\n\n        else:\n\n            if resolution is None:\n                resolution = self.grid_size\n\n            if self.cuda_ray:\n                density_thresh = min(self.mean_density, self.density_thresh) \\\n                    if np.greater(self.mean_density, 0) else self.density_thresh\n            else:\n                density_thresh = self.density_thresh\n            \n            # TODO: use a larger thresh to extract a surface mesh from the density field, but this value is very empirical...\n            if self.opt.density_activation == 'softplus':\n                density_thresh = density_thresh * 25\n            \n            sigmas = np.zeros([resolution, resolution, resolution], dtype=np.float32)\n\n            # query\n            X = torch.linspace(-1, 1, resolution).split(S)\n            Y = torch.linspace(-1, 1, resolution).split(S)\n            Z = torch.linspace(-1, 1, resolution).split(S)\n\n            for xi, xs in enumerate(X):\n                for yi, ys in enumerate(Y):\n                    for zi, zs in enumerate(Z):\n                        xx, yy, zz = custom_meshgrid(xs, ys, zs)\n                        pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]\n                        val = self.density(pts.to(self.aabb_train.device))\n                        sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]\n\n            print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')\n\n            vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)\n            vertices = vertices / (resolution - 1.0) * 2 - 1\n\n        # clean\n        vertices = vertices.astype(np.float32)\n        triangles = triangles.astype(np.int32)\n        vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)\n        \n        # decimation\n        if decimate_target > 0 and triangles.shape[0] > decimate_target:\n            vertices, triangles = decimate_mesh(vertices, triangles, decimate_target)\n\n        v = torch.from_numpy(vertices).contiguous().float().to(self.aabb_train.device)\n        f = torch.from_numpy(triangles).contiguous().int().to(self.aabb_train.device)\n\n        # mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...\n        # mesh.export(os.path.join(path, f'mesh.ply'))\n\n        def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''):\n            # v, f: torch Tensor\n            device = v.device\n            v_np = v.cpu().numpy() # [N, 3]\n            f_np = f.cpu().numpy() # [M, 3]\n\n            print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')\n\n            # unwrap uvs\n            import xatlas\n            import nvdiffrast.torch as dr\n            from sklearn.neighbors import NearestNeighbors\n            from scipy.ndimage import binary_dilation, binary_erosion\n\n            atlas = xatlas.Atlas()\n            atlas.add_mesh(v_np, f_np)\n            chart_options = xatlas.ChartOptions()\n            chart_options.max_iterations = 4 # for faster unwrap...\n            atlas.generate(chart_options=chart_options)\n            vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]\n\n            # vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2]\n\n            vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)\n            ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)\n\n            # render uv maps\n            uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]\n            uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]\n\n            if ssaa > 1:\n                h = int(h0 * ssaa)\n                w = int(w0 * ssaa)\n            else:\n                h, w = h0, w0\n            \n            if self.glctx is None:\n                if h <= 2048 and w <= 2048:\n                    self.glctx = dr.RasterizeCudaContext()\n                else:\n                    self.glctx = dr.RasterizeGLContext()\n\n            rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]\n            xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]\n            mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]\n\n            # masked query \n            xyzs = xyzs.view(-1, 3)\n            mask = (mask > 0).view(-1)\n            \n            feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32)\n\n            if mask.any():\n                xyzs = xyzs[mask] # [M, 3]\n\n                # batched inference to avoid OOM\n                all_feats = []\n                head = 0\n                while head < xyzs.shape[0]:\n                    tail = min(head + 640000, xyzs.shape[0])\n                    results_ = self.density(xyzs[head:tail])\n                    all_feats.append(results_['albedo'].float())\n                    head += 640000\n\n                feats[mask] = torch.cat(all_feats, dim=0)\n            \n            feats = feats.view(h, w, -1)\n            mask = mask.view(h, w)\n\n            # quantize [0.0, 1.0] to [0, 255]\n            feats = feats.cpu().numpy()\n            feats = (feats * 255).astype(np.uint8)\n\n            ### NN search as an antialiasing ...\n            mask = mask.cpu().numpy()\n\n            inpaint_region = binary_dilation(mask, iterations=3)\n            inpaint_region[mask] = 0\n\n            search_region = mask.copy()\n            not_search_region = binary_erosion(search_region, iterations=2)\n            search_region[not_search_region] = 0\n\n            search_coords = np.stack(np.nonzero(search_region), axis=-1)\n            inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)\n\n            knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)\n            _, indices = knn.kneighbors(inpaint_coords)\n\n            feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]\n\n            feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR)\n\n            # do ssaa after the NN search, in numpy\n            if ssaa > 1:\n                feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR)\n\n            cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats)\n\n            # save obj (v, vt, f /)\n            obj_file = os.path.join(path, f'{name}mesh.obj')\n            mtl_file = os.path.join(path, f'{name}mesh.mtl')\n\n            print(f'[INFO] writing obj mesh to {obj_file}')\n            with open(obj_file, \"w\") as fp:\n                fp.write(f'mtllib {name}mesh.mtl \\n')\n                \n                print(f'[INFO] writing vertices {v_np.shape}')\n                for v in v_np:\n                    fp.write(f'v {v[0]} {v[1]} {v[2]} \\n')\n            \n                print(f'[INFO] writing vertices texture coords {vt_np.shape}')\n                for v in vt_np:\n                    fp.write(f'vt {v[0]} {1 - v[1]} \\n') \n\n                print(f'[INFO] writing faces {f_np.shape}')\n                fp.write(f'usemtl mat0 \\n')\n                for i in range(len(f_np)):\n                    fp.write(f\"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \\n\")\n\n            with open(mtl_file, \"w\") as fp:\n                fp.write(f'newmtl mat0 \\n')\n                fp.write(f'Ka 1.000000 1.000000 1.000000 \\n')\n                fp.write(f'Kd 1.000000 1.000000 1.000000 \\n')\n                fp.write(f'Ks 0.000000 0.000000 0.000000 \\n')\n                fp.write(f'Tr 1.000000 \\n')\n                fp.write(f'illum 1 \\n')\n                fp.write(f'Ns 0.000000 \\n')\n                fp.write(f'map_Kd {name}albedo.png \\n')\n\n        _export(v, f)\n\n    def run(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs):\n        # rays_o, rays_d: [B, N, 3]\n        # bg_color: [BN, 3] in range [0, 1]\n        # return: image: [B, N, 3], depth: [B, N]\n\n        prefix = rays_o.shape[:-1]\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0] # N = B * N, in fact\n        device = rays_o.device\n\n        results = {}\n\n        # choose aabb\n        aabb = self.aabb_train if self.training else self.aabb_infer\n\n        # sample steps\n        # nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)\n        # nears.unsqueeze_(-1)\n        # fars.unsqueeze_(-1)\n        nears, fars = near_far_from_bound(rays_o, rays_d, self.bound, type='sphere', min_near=self.min_near)\n\n        # random sample light_d if not provided\n        if light_d is None:\n            # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)\n            light_d = safe_normalize(rays_o + torch.randn(3, device=rays_o.device)) # [N, 3]\n\n        #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')\n\n        z_vals = torch.linspace(0.0, 1.0, self.opt.num_steps, device=device).unsqueeze(0) # [1, T]\n        z_vals = z_vals.expand((N, self.opt.num_steps)) # [N, T]\n        z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]\n\n        # perturb z_vals\n        sample_dist = (fars - nears) / self.opt.num_steps\n        if perturb:\n            z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist\n            #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.\n\n        # generate xyzs\n        xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]\n        xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.\n\n        #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())\n\n        # query SDF and RGB\n        density_outputs = self.density(xyzs.reshape(-1, 3))\n\n        #sigmas = density_outputs['sigma'].view(N, self.opt.num_steps) # [N, T]\n        for k, v in density_outputs.items():\n            density_outputs[k] = v.view(N, self.opt.num_steps, -1)\n\n        # upsample z_vals (nerf-like)\n        if self.opt.upsample_steps > 0:\n            with torch.no_grad():\n\n                deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]\n                deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)\n\n                alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T]\n                alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]\n                weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]\n\n                # sample new z_vals\n                z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]\n                new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], self.opt.upsample_steps, det=not self.training).detach() # [N, t]\n\n                new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]\n                new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.\n\n            # only forward new points to save computation\n            new_density_outputs = self.density(new_xyzs.reshape(-1, 3))\n            #new_sigmas = new_density_outputs['sigma'].view(N, self.opt.upsample_steps) # [N, t]\n            for k, v in new_density_outputs.items():\n                new_density_outputs[k] = v.view(N, self.opt.upsample_steps, -1)\n\n            # re-order\n            z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]\n            z_vals, z_index = torch.sort(z_vals, dim=1)\n\n            xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]\n            xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))\n\n            for k in density_outputs:\n                tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)\n                density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))\n\n        deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]\n        deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)\n        alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t]\n        alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]\n        weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]\n\n        dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)\n        light_d = light_d.view(-1, 1, 3).expand_as(xyzs)\n        for k, v in density_outputs.items():\n            density_outputs[k] = v.view(-1, v.shape[-1])\n\n        dirs = safe_normalize(dirs)\n        sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d.reshape(-1, 3), ratio=ambient_ratio, shading=shading)\n        rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]\n        if normals is not None:\n            normals = normals.view(N, -1, 3)\n\n        # calculate weight_sum (mask)\n        weights_sum = weights.sum(dim=-1) # [N]\n        \n        # calculate depth \n        depth = torch.sum(weights * z_vals, dim=-1)\n\n        # calculate color\n        image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]\n\n        # mix background color\n        if bg_color is None:\n            if self.opt.bg_radius > 0:\n                # use the bg model to calculate bg_color\n                bg_color = self.background(rays_d) # [N, 3]\n            else:\n                bg_color = 1\n            \n        image = image + (1 - weights_sum).unsqueeze(-1) * bg_color\n\n        image = image.view(*prefix, 3)\n        depth = depth.view(*prefix)\n        weights_sum = weights_sum.reshape(*prefix)\n\n        if self.training:\n            if self.opt.lambda_orient > 0 and normals is not None:\n                # orientation loss\n                loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2\n                results['loss_orient'] = loss_orient.sum(-1).mean()\n            \n            if self.opt.lambda_3d_normal_smooth > 0 and normals is not None:\n                normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)\n                results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean()\n            \n            if (self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0) and normals is not None:\n                normal_image = torch.sum(weights.unsqueeze(-1) * (normals + 1) / 2, dim=-2) # [N, 3], in [0, 1]\n                results['normal_image'] = normal_image\n        \n        results['image'] = image\n        results['depth'] = depth\n        results['weights'] = weights\n        results['weights_sum'] = weights_sum\n\n        return results\n\n\n    def run_cuda(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, T_thresh=1e-4, binarize=False, **kwargs):\n        # rays_o, rays_d: [B, N, 3]\n        # return: image: [B, N, 3], depth: [B, N]\n\n        prefix = rays_o.shape[:-1]\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0] # B * N, in fact\n        device = rays_o.device\n\n        # pre-calculate near far\n        nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer)\n\n        # random sample light_d if not provided\n        if light_d is None:\n            # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)\n            light_d = safe_normalize(rays_o + torch.randn(3, device=rays_o.device)) # [N, 3]\n\n        results = {}\n\n        if self.training:\n            xyzs, dirs, ts, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb, self.opt.dt_gamma, self.opt.max_steps)\n            dirs = safe_normalize(dirs)\n\n            if light_d.shape[0] > 1:\n                flatten_rays = raymarching.flatten_rays(rays, xyzs.shape[0]).long()\n                light_d = light_d[flatten_rays]\n            \n            sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)\n            weights, weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ts, rays, T_thresh, binarize)\n            \n            # normals related regularizations\n            if self.opt.lambda_orient > 0 and normals is not None:\n                # orientation loss \n                loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2\n                results['loss_orient'] = loss_orient.mean()\n            \n            if self.opt.lambda_3d_normal_smooth > 0 and normals is not None:\n                normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)\n                results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean()\n            \n            if (self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0) and normals is not None:\n                _, _, _, normal_image = raymarching.composite_rays_train(sigmas.detach(), (normals + 1) / 2, ts, rays, T_thresh, binarize)\n                results['normal_image'] = normal_image\n            \n            # weights normalization\n            results['weights'] = weights\n\n        else:\n           \n            # allocate outputs \n            dtype = torch.float32\n            \n            weights_sum = torch.zeros(N, dtype=dtype, device=device)\n            depth = torch.zeros(N, dtype=dtype, device=device)\n            image = torch.zeros(N, 3, dtype=dtype, device=device)\n            \n            n_alive = N\n            rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]\n            rays_t = nears.clone() # [N]\n\n            step = 0\n            \n            while step < self.opt.max_steps: # hard coded max step\n\n                # count alive rays \n                n_alive = rays_alive.shape[0]\n\n                # exit loop\n                if n_alive <= 0:\n                    break\n\n                # decide compact_steps\n                n_step = max(min(N // n_alive, 8), 1)\n\n                xyzs, dirs, ts = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb if step == 0 else False, self.opt.dt_gamma, self.opt.max_steps)\n                dirs = safe_normalize(dirs)\n                sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)\n                raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh, binarize)\n\n                rays_alive = rays_alive[rays_alive >= 0]\n                #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')\n\n                step += n_step\n\n        # mix background color\n        if bg_color is None:\n            if self.opt.bg_radius > 0:\n                # use the bg model to calculate bg_color\n                bg_color = self.background(rays_d) # [N, 3]\n            else:\n                bg_color = 1\n\n        image = image + (1 - weights_sum).unsqueeze(-1) * bg_color\n        image = image.view(*prefix, 3)\n\n        depth = depth.view(*prefix)\n\n        weights_sum = weights_sum.reshape(*prefix)\n\n        results['image'] = image\n        results['depth'] = depth\n        results['weights_sum'] = weights_sum\n        \n        return results\n\n    @torch.no_grad()\n    def init_tet(self, mesh=None):\n\n        if mesh is not None:\n            # normalize mesh\n            scale = 0.8 / np.array(mesh.bounds[1] - mesh.bounds[0]).max()\n            center = np.array(mesh.bounds[1] + mesh.bounds[0]) / 2\n            mesh.vertices = (mesh.vertices - center) * scale\n\n            # init scale\n            # self.tet_scale = torch.from_numpy(np.abs(mesh.vertices).max(axis=0) + 1e-1).to(self.verts.dtype).cuda()\n            self.tet_scale = torch.from_numpy(np.array([np.abs(mesh.vertices).max()]) + 1e-1).to(self.verts.dtype).cuda()\n            self.verts = self.verts * self.tet_scale\n\n            # init sdf\n            import cubvh\n            BVH = cubvh.cuBVH(mesh.vertices, mesh.faces)\n            sdf, _, _ = BVH.signed_distance(self.verts, return_uvw=False, mode='watertight')\n            sdf *= -10 # INNER is POSITIVE, also make it stronger\n            self.sdf.data += sdf.to(self.sdf.data.dtype).clamp(-1, 1)\n\n        else:\n\n            if self.cuda_ray:\n                density_thresh = min(self.mean_density, self.density_thresh)\n            else:\n                density_thresh = self.density_thresh\n        \n            if self.opt.density_activation == 'softplus':\n                density_thresh = density_thresh * 25\n\n            # init scale\n            sigma = self.density(self.verts)['sigma'] # verts covers [-1, 1] now\n            mask = sigma > density_thresh\n            valid_verts = self.verts[mask]\n            self.tet_scale = valid_verts.abs().amax(dim=0) + 1e-1\n            self.verts = self.verts * self.tet_scale\n\n            # init sigma\n            sigma = self.density(self.verts)['sigma'] # new verts\n            self.sdf.data += (sigma - density_thresh).clamp(-1, 1)\n\n        print(f'[INFO] init dmtet: scale = {self.tet_scale}')\n\n\n    def run_dmtet(self, rays_o, rays_d, mvp, h, w, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, **kwargs):\n        # mvp: [B, 4, 4]\n\n        device = mvp.device\n        campos = rays_o[:, 0, :] # only need one ray per batch\n\n        # random sample light_d if not provided\n        if light_d is None:\n            # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)\n            light_d = safe_normalize(campos + torch.randn_like(campos)).view(-1, 1, 1, 3) # [B, 1, 1, 3]\n\n        results = {}\n\n        # get mesh\n        sdf = self.sdf\n        deform = torch.tanh(self.deform) / self.opt.tet_grid_size\n\n        verts, faces = self.dmtet_model(self.verts + deform, sdf, self.indices)\n\n        # get normals\n        i0, i1, i2 = faces[:, 0], faces[:, 1], faces[:, 2]\n        v0, v1, v2 = verts[i0, :], verts[i1, :], verts[i2, :]\n\n        faces = faces.int()\n        \n        face_normals = torch.cross(v1 - v0, v2 - v0)\n        face_normals = safe_normalize(face_normals)\n        \n        vn = torch.zeros_like(verts)\n        vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)\n        vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)\n        vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)\n\n        vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))\n\n        # rasterization\n        verts_clip = torch.bmm(F.pad(verts, pad=(0, 1), mode='constant', value=1.0).unsqueeze(0).repeat(mvp.shape[0], 1, 1), \n                               mvp.permute(0,2,1)).float()  # [B, N, 4]\n        rast, rast_db = dr.rasterize(self.glctx, verts_clip, faces, (h, w))\n        \n        alpha = (rast[..., 3:] > 0).float()\n        xyzs, _ = dr.interpolate(verts.unsqueeze(0), rast, faces) # [B, H, W, 3]\n        normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, faces)\n        normal = safe_normalize(normal)\n\n        xyzs = xyzs.view(-1, 3)\n        mask = (rast[..., 3:] > 0).view(-1).detach()\n\n        # do the lighting here since we have normal from mesh now.\n        albedo = torch.zeros_like(xyzs, dtype=torch.float32)\n        if mask.any():\n            masked_albedo = self.density(xyzs[mask])['albedo']\n            albedo[mask] = masked_albedo.float()\n        albedo = albedo.view(-1, h, w, 3)\n\n        # these two modes lead to no parameters to optimize if using --lock_geo.\n        if self.opt.lock_geo and shading in ['textureless', 'normal']:\n            shading = 'lambertian'\n\n        if shading == 'albedo':\n            color = albedo\n        elif shading == 'textureless':\n            lambertian = ambient_ratio + (1 - ambient_ratio)  * (normal * light_d).sum(-1).float().clamp(min=0)\n            color = lambertian.unsqueeze(-1).repeat(1, 1, 1, 3)\n        elif shading == 'normal':\n            color = (normal + 1) / 2\n        else: # 'lambertian'\n            lambertian = ambient_ratio + (1 - ambient_ratio)  * (normal * light_d).sum(-1).float().clamp(min=0)\n            color = albedo * lambertian.unsqueeze(-1)\n\n        color = dr.antialias(color, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3]\n        alpha = dr.antialias(alpha, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 1]\n\n        # mix background color\n        if bg_color is None:\n            if self.opt.bg_radius > 0:\n                # use the bg model to calculate bg_color\n                bg_color = self.background(rays_d) # [N, 3]\n            else:\n                bg_color = 1\n        \n        if torch.is_tensor(bg_color) and len(bg_color.shape) > 1:\n            bg_color = bg_color.view(-1, h, w, 3)\n        \n        depth = rast[:, :, :, [2]] # [B, H, W]\n        color = color + (1 - alpha) * bg_color\n\n        results['depth'] = depth        \n        results['image'] = color\n        results['weights_sum'] = alpha.squeeze(-1)\n\n        if self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0:\n            normal_image = dr.antialias((normal + 1) / 2, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3]\n            results['normal_image'] = normal_image\n        \n        # regularizations\n        if self.training:\n            if self.opt.lambda_mesh_normal > 0:\n                results['normal_loss'] = normal_consistency(face_normals, faces)\n            if self.opt.lambda_mesh_laplacian > 0:\n                results['lap_loss'] = laplacian_smooth_loss(verts, faces)\n\n        return results\n\n    def run_taichi(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, T_thresh=1e-4, **kwargs):\n        # rays_o, rays_d: [B, N, 3], assumes B == 1\n        # return: image: [B, N, 3], depth: [B, N]\n\n        prefix = rays_o.shape[:-1]\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0] # N = B * N, in fact\n        device = rays_o.device\n\n        # pre-calculate near far\n        exp_step_factor = kwargs.get('exp_step_factor', 0.)\n        MAX_SAMPLES = 1024\n        NEAR_DISTANCE = 0.01\n        center = torch.zeros(1, 3)\n        half_size = torch.ones(1, 3)\n        _, hits_t, _ = self.ray_aabb_intersector.apply(rays_o, rays_d, center, half_size, 1)\n        hits_t[(hits_t[:, 0, 0] >= 0) & (hits_t[:, 0, 0] < NEAR_DISTANCE), 0, 0] = NEAR_DISTANCE\n\n        # TODO: should sample different light_d for each batch... but taichi end doesn't have a flatten_ray implemented currently...\n        # random sample light_d if not provided\n        if light_d is None:\n            # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)\n            light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float))\n            light_d = safe_normalize(light_d)\n\n        results = {}\n\n        if self.training:\n            rays_a, xyzs, dirs, deltas, ts, _ = self.ray_marching(rays_o, rays_d, hits_t[:, 0], self.density_bitfield, self.cascade, self.bound, exp_step_factor, self.grid_size, MAX_SAMPLES)\n            dirs = safe_normalize(dirs)\n            # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())\n            sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)\n            _, weights_sum, depth, image, weights = self.volume_render(sigmas, rgbs, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4))\n            \n            # normals related regularizations\n            if self.opt.lambda_orient > 0 and normals is not None:\n                # orientation loss \n                loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2\n                results['loss_orient'] = loss_orient.mean()\n            \n            if self.opt.lambda_3d_normal_smooth > 0 and normals is not None:\n                normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)\n                results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean()\n            \n            if (self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0) and normals is not None:\n                _, _, _, normal_image, _ = self.volume_render(sigmas.detach(), (normals + 1) / 2, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4))\n                results['normal_image'] = normal_image\n            \n            # weights normalization\n            results['weights'] = weights\n\n        else:\n        \n            # allocate outputs \n            dtype = torch.float32\n            \n            weights_sum = torch.zeros(N, dtype=dtype, device=device)\n            depth = torch.zeros(N, dtype=dtype, device=device)\n            image = torch.zeros(N, 3, dtype=dtype, device=device)\n            \n            n_alive = N\n            rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]\n            rays_t = hits_t[:, 0, 0]\n            step = 0\n            \n            min_samples = 1 if exp_step_factor == 0 else 4\n\n            while step < self.opt.max_steps: # hard coded max step\n\n                # count alive rays \n                n_alive = rays_alive.shape[0]\n\n                # exit loop\n                if n_alive <= 0:\n                    break\n\n                # decide compact_steps\n                # n_step = max(min(N // n_alive, 8), 1)\n                n_step = max(min(N // n_alive, 64), min_samples)\n\n                xyzs, dirs, deltas, ts, N_eff_samples = \\\n                self.raymarching_test_taichi(rays_o, rays_d, hits_t[:, 0], rays_alive,\n                                    self.density_bitfield, self.cascade,\n                                    self.bound, exp_step_factor,\n                                    self.grid_size, MAX_SAMPLES, n_step)\n\n                xyzs = self.rearrange(xyzs, 'n1 n2 c -> (n1 n2) c')\n                dirs = self.rearrange(dirs, 'n1 n2 c -> (n1 n2) c')\n                dirs = safe_normalize(dirs)\n                valid_mask = ~torch.all(dirs == 0, dim=1)\n                if valid_mask.sum() == 0:\n                    break\n\n                sigmas = torch.zeros(len(xyzs), device=device)\n                rgbs = torch.zeros(len(xyzs), 3, device=device)\n                normals = torch.zeros(len(xyzs), 3, device=device)\n\n                sigmas[valid_mask], _rgbs, normals = self(xyzs[valid_mask], dirs[valid_mask], light_d, ratio=ambient_ratio, shading=shading)\n                rgbs[valid_mask] = _rgbs.float()\n                sigmas = self.rearrange(sigmas, '(n1 n2) -> n1 n2', n2=n_step)\n                rgbs = self.rearrange(rgbs, '(n1 n2) c -> n1 n2 c', n2=n_step)\n                if normals is not None:\n                    normals = self.rearrange(normals, '(n1 n2) c -> n1 n2 c', n2=n_step)\n\n                self.composite_test_fw(sigmas, rgbs, deltas, ts, hits_t[:,0], rays_alive,\n                                    kwargs.get('T_threshold', 1e-4), N_eff_samples,\n                                    weights_sum, depth, image)\n\n                rays_alive = rays_alive[rays_alive >= 0]\n\n                step += n_step\n\n        # mix background color\n        if bg_color is None:\n            if self.opt.bg_radius > 0:\n                # use the bg model to calculate bg_color\n                bg_color = self.background(rays_d) # [N, 3]\n            else:\n                bg_color = 1\n\n        image = image + self.rearrange(1 - weights_sum, 'n -> n 1') * bg_color\n        image = image.view(*prefix, 3)\n\n        depth = depth.view(*prefix)\n\n        weights_sum = weights_sum.reshape(*prefix)\n\n        results['image'] = image\n        results['depth'] = depth\n        results['weights_sum'] = weights_sum\n        \n        return results\n\n\n    @torch.no_grad()\n    def update_extra_state(self, decay=0.95, S=128):\n        # call before each epoch to update extra states.\n\n        if not (self.cuda_ray or self.taichi_ray):\n            return \n        \n        ### update density grid\n        tmp_grid = - torch.ones_like(self.density_grid)\n        \n        X = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S)\n        Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S)\n        Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S)\n\n        for xs in X:\n            for ys in Y:\n                for zs in Z:\n                    \n                    # construct points\n                    xx, yy, zz = custom_meshgrid(xs, ys, zs)\n                    coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)\n                    indices = raymarching.morton3D(coords).long() # [N]\n                    xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]\n\n                    # cascading\n                    for cas in range(self.cascade):\n                        bound = min(2 ** cas, self.bound)\n                        half_grid_size = bound / self.grid_size\n                        # scale to current cascade's resolution\n                        cas_xyzs = xyzs * (bound - half_grid_size)\n                        # add noise in [-hgs, hgs]\n                        cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size\n                        # query density\n                        sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()\n                        # assign \n                        tmp_grid[cas, indices] = sigmas\n        # ema update\n        valid_mask = self.density_grid >= 0\n        self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])\n        self.mean_density = torch.mean(self.density_grid[valid_mask]).item()\n        self.iter_density += 1\n\n        # convert to bitfield\n        density_thresh = min(self.mean_density, self.density_thresh)\n        if self.cuda_ray:\n            self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)\n        elif self.taichi_ray:\n            self.packbits_taichi(self.density_grid.reshape(-1).contiguous(), density_thresh, self.density_bitfield)\n\n        # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f}')\n\n\n    def render(self, rays_o, rays_d, mvp, h, w, staged=False, max_ray_batch=4096, **kwargs):\n        # rays_o, rays_d: [B, N, 3]\n        # return: pred_rgb: [B, N, 3]\n        B, N = rays_o.shape[:2]\n        device = rays_o.device\n\n        if self.dmtet:\n            results = self.run_dmtet(rays_o, rays_d, mvp, h, w, **kwargs)\n        elif self.cuda_ray:\n            results = self.run_cuda(rays_o, rays_d, **kwargs)\n        elif self.taichi_ray:\n            results = self.run_taichi(rays_o, rays_d, **kwargs)\n        else:\n            if staged:\n                depth = torch.empty((B, N), device=device)\n                image = torch.empty((B, N, 3), device=device)\n                weights_sum = torch.empty((B, N), device=device)\n\n                for b in range(B):\n                    head = 0\n                    while head < N:\n                        tail = min(head + max_ray_batch, N)\n                        results_ = self.run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs)\n                        depth[b:b+1, head:tail] = results_['depth']\n                        weights_sum[b:b+1, head:tail] = results_['weights_sum']\n                        image[b:b+1, head:tail] = results_['image']\n                        head += max_ray_batch\n                \n                results = {}\n                results['depth'] = depth\n                results['image'] = image\n                results['weights_sum'] = weights_sum\n\n            else:\n                results = self.run(rays_o, rays_d, **kwargs)\n\n        return results\n"
  },
  {
    "path": "nerf/utils.py",
    "content": "import os\nimport gc\nimport glob\nimport tqdm\nimport math\nimport imageio\nimport psutil\nfrom pathlib import Path\nimport random\nimport shutil\nimport warnings\nimport tensorboardX\n\nimport numpy as np\n\nimport time\n\nimport cv2\nimport matplotlib.pyplot as plt\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nimport torch.distributed as dist\nimport torchvision.transforms.functional as TF\nfrom torchmetrics import PearsonCorrCoef\n\nfrom rich.console import Console\nfrom torch_ema import ExponentialMovingAverage\n\nfrom packaging import version as pver\n\ndef adjust_text_embeddings(embeddings, azimuth, opt):\n    text_z_list = []\n    weights_list = []\n    K = 0\n    for b in range(azimuth.shape[0]):\n        text_z_, weights_ = get_pos_neg_text_embeddings(embeddings, azimuth[b], opt)\n        K = max(K, weights_.shape[0])\n        text_z_list.append(text_z_)\n        weights_list.append(weights_)\n\n    # Interleave text_embeddings from different dirs to form a batch\n    text_embeddings = []\n    for i in range(K):\n        for text_z in text_z_list:\n            # if uneven length, pad with the first embedding\n            text_embeddings.append(text_z[i] if i < len(text_z) else text_z[0])\n    text_embeddings = torch.stack(text_embeddings, dim=0) # [B * K, 77, 768]\n\n    # Interleave weights from different dirs to form a batch\n    weights = []\n    for i in range(K):\n        for weights_ in weights_list:\n            weights.append(weights_[i] if i < len(weights_) else torch.zeros_like(weights_[0]))\n    weights = torch.stack(weights, dim=0) # [B * K]\n    return text_embeddings, weights\n\ndef get_pos_neg_text_embeddings(embeddings, azimuth_val, opt):\n    if azimuth_val >= -90 and azimuth_val < 90:\n        if azimuth_val >= 0:\n            r = 1 - azimuth_val / 90\n        else:\n            r = 1 + azimuth_val / 90\n        start_z = embeddings['front']\n        end_z = embeddings['side']\n        # if random.random() < 0.3:\n        #     r = r + random.gauss(0, 0.08)\n        pos_z = r * start_z + (1 - r) * end_z\n        text_z = torch.cat([pos_z, embeddings['front'], embeddings['side']], dim=0)\n        if r > 0.8:\n            front_neg_w = 0.0\n        else:\n            front_neg_w = math.exp(-r * opt.front_decay_factor) * opt.negative_w\n        if r < 0.2:\n            side_neg_w = 0.0\n        else:\n            side_neg_w = math.exp(-(1-r) * opt.side_decay_factor) * opt.negative_w\n\n        weights = torch.tensor([1.0, front_neg_w, side_neg_w])\n    else:\n        if azimuth_val >= 0:\n            r = 1 - (azimuth_val - 90) / 90\n        else:\n            r = 1 + (azimuth_val + 90) / 90\n        start_z = embeddings['side']\n        end_z = embeddings['back']\n        # if random.random() < 0.3:\n        #     r = r + random.gauss(0, 0.08)\n        pos_z = r * start_z + (1 - r) * end_z\n        text_z = torch.cat([pos_z, embeddings['side'], embeddings['front']], dim=0)\n        front_neg_w = opt.negative_w \n        if r > 0.8:\n            side_neg_w = 0.0\n        else:\n            side_neg_w = math.exp(-r * opt.side_decay_factor) * opt.negative_w / 2\n\n        weights = torch.tensor([1.0, side_neg_w, front_neg_w])\n    return text_z, weights.to(text_z.device)\n\ndef custom_meshgrid(*args):\n    # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid\n    if pver.parse(torch.__version__) < pver.parse('1.10'):\n        return torch.meshgrid(*args)\n    else:\n        return torch.meshgrid(*args, indexing='ij')\n\ndef safe_normalize(x, eps=1e-20):\n    return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))\n\n@torch.cuda.amp.autocast(enabled=False)\ndef get_rays(poses, intrinsics, H, W, N=-1, error_map=None):\n    ''' get rays\n    Args:\n        poses: [B, 4, 4], cam2world\n        intrinsics: [4]\n        H, W, N: int\n        error_map: [B, 128 * 128], sample probability based on training error\n    Returns:\n        rays_o, rays_d: [B, N, 3]\n        inds: [B, N]\n    '''\n\n    device = poses.device\n    B = poses.shape[0]\n    fx, fy, cx, cy = intrinsics\n\n    i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device))\n    i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5\n    j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5\n\n    results = {}\n\n    if N > 0:\n        N = min(N, H*W)\n\n        if error_map is None:\n            inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate\n            inds = inds.expand([B, N])\n        else:\n\n            # weighted sample on a low-reso grid\n            inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128)\n\n            # map to the original resolution with random perturb.\n            inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway.\n            sx, sy = H / 128, W / 128\n            inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1)\n            inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1)\n            inds = inds_x * W + inds_y\n\n            results['inds_coarse'] = inds_coarse # need this when updating error_map\n\n        i = torch.gather(i, -1, inds)\n        j = torch.gather(j, -1, inds)\n\n        results['inds'] = inds\n\n    else:\n        inds = torch.arange(H*W, device=device).expand([B, H*W])\n\n    zs = - torch.ones_like(i)\n    xs = - (i - cx) / fx * zs\n    ys = (j - cy) / fy * zs\n    directions = torch.stack((xs, ys, zs), dim=-1)\n    # directions = safe_normalize(directions)\n    rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)\n\n    rays_o = poses[..., :3, 3] # [B, 3]\n    rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]\n\n    results['rays_o'] = rays_o\n    results['rays_d'] = rays_d\n\n    return results\n\n\ndef seed_everything(seed):\n    random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    #torch.backends.cudnn.deterministic = True\n    #torch.backends.cudnn.benchmark = True\n\n\n@torch.jit.script\ndef linear_to_srgb(x):\n    return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)\n\n\n@torch.jit.script\ndef srgb_to_linear(x):\n    return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)\n\n\nclass Trainer(object):\n    def __init__(self,\n\t\t         argv, # command line args\n                 name, # name of this experiment\n                 opt, # extra conf\n                 model, # network\n                 guidance, # guidance network\n                 criterion=None, # loss function, if None, assume inline implementation in train_step\n                 optimizer=None, # optimizer\n                 ema_decay=None, # if use EMA, set the decay\n                 lr_scheduler=None, # scheduler\n                 metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.\n                 local_rank=0, # which GPU am I\n                 world_size=1, # total num of GPUs\n                 device=None, # device to use, usually setting to None is OK. (auto choose device)\n                 mute=False, # whether to mute all print\n                 fp16=False, # amp optimize level\n                 max_keep_ckpt=2, # max num of saved ckpts in disk\n                 workspace='workspace', # workspace to save logs & ckpts\n                 best_mode='min', # the smaller/larger result, the better\n                 use_loss_as_metric=True, # use loss as the first metric\n                 report_metric_at_train=False, # also report metrics at training\n                 use_checkpoint=\"latest\", # which ckpt to use at init time\n                 use_tensorboardX=True, # whether to use tensorboard for logging\n                 scheduler_update_every_step=False, # whether to call scheduler.step() after every train step\n                 ):\n\n        self.argv = argv\n        self.name = name\n        self.opt = opt\n        self.mute = mute\n        self.metrics = metrics\n        self.local_rank = local_rank\n        self.world_size = world_size\n        self.workspace = workspace\n        self.ema_decay = ema_decay\n        self.fp16 = fp16\n        self.best_mode = best_mode\n        self.use_loss_as_metric = use_loss_as_metric\n        self.report_metric_at_train = report_metric_at_train\n        self.max_keep_ckpt = max_keep_ckpt\n        self.use_checkpoint = use_checkpoint\n        self.use_tensorboardX = use_tensorboardX\n        self.time_stamp = time.strftime(\"%Y-%m-%d_%H-%M-%S\")\n        self.scheduler_update_every_step = scheduler_update_every_step\n        self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')\n        self.console = Console()\n\n        model.to(self.device)\n        if self.world_size > 1:\n            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])\n        self.model = model\n\n        # guide model\n        self.guidance = guidance\n        self.embeddings = {}\n\n        # text prompt / images\n        if self.guidance is not None:\n            for key in self.guidance:\n                for p in self.guidance[key].parameters():\n                    p.requires_grad = False\n                self.embeddings[key] = {}\n            self.prepare_embeddings()\n\n        if isinstance(criterion, nn.Module):\n            criterion.to(self.device)\n        self.criterion = criterion\n\n        if self.opt.images is not None:\n            self.pearson = PearsonCorrCoef().to(self.device)\n\n        if optimizer is None:\n            self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam\n        else:\n            self.optimizer = optimizer(self.model)\n\n        if lr_scheduler is None:\n            self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler\n        else:\n            self.lr_scheduler = lr_scheduler(self.optimizer)\n\n        if ema_decay is not None:\n            self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay)\n        else:\n            self.ema = None\n\n        self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)\n\n        # variable init\n        self.total_train_t = 0\n        self.epoch = 0\n        self.global_step = 0\n        self.local_step = 0\n        self.stats = {\n            \"loss\": [],\n            \"valid_loss\": [],\n            \"results\": [], # metrics[0], or valid_loss\n            \"checkpoints\": [], # record path of saved ckpt, to automatically remove old ckpt\n            \"best_result\": None,\n        }\n\n        # auto fix\n        if len(metrics) == 0 or self.use_loss_as_metric:\n            self.best_mode = 'min'\n\n        # workspace prepare\n        self.log_ptr = None\n        if self.workspace is not None:\n            os.makedirs(self.workspace, exist_ok=True)\n            self.log_path = os.path.join(workspace, f\"log_{self.name}.txt\")\n            self.log_ptr = open(self.log_path, \"a+\")\n\n            self.ckpt_path = os.path.join(self.workspace, 'checkpoints')\n            self.best_path = f\"{self.ckpt_path}/{self.name}.pth\"\n            os.makedirs(self.ckpt_path, exist_ok=True)\n\n            # Save a copy of image_config in the experiment workspace\n            if opt.image_config is not None:\n                shutil.copyfile(opt.image_config, os.path.join(self.workspace, os.path.basename(opt.image_config)))\n\n            # Save a copy of images in the experiment workspace\n            if opt.images is not None:\n                for image_file in opt.images:\n                    shutil.copyfile(image_file, os.path.join(self.workspace, os.path.basename(image_file)))\n\n        self.log(f'[INFO] Cmdline: {self.argv}')\n        self.log(f'[INFO] opt: {self.opt}')\n        self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {\"fp16\" if self.fp16 else \"fp32\"} | {self.workspace}')\n        self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')\n\n        if self.workspace is not None:\n            if self.use_checkpoint == \"scratch\":\n                self.log(\"[INFO] Training from scratch ...\")\n            elif self.use_checkpoint == \"latest\":\n                self.log(\"[INFO] Loading latest checkpoint ...\")\n                self.load_checkpoint()\n            elif self.use_checkpoint == \"latest_model\":\n                self.log(\"[INFO] Loading latest checkpoint (model only)...\")\n                self.load_checkpoint(model_only=True)\n            elif self.use_checkpoint == \"best\":\n                if os.path.exists(self.best_path):\n                    self.log(\"[INFO] Loading best checkpoint ...\")\n                    self.load_checkpoint(self.best_path)\n                else:\n                    self.log(f\"[INFO] {self.best_path} not found, loading latest ...\")\n                    self.load_checkpoint()\n            else: # path to ckpt\n                self.log(f\"[INFO] Loading {self.use_checkpoint} ...\")\n                self.load_checkpoint(self.use_checkpoint)\n\n    # calculate the text embs.\n    @torch.no_grad()\n    def prepare_embeddings(self):\n\n        # text embeddings (stable-diffusion)\n        if self.opt.text is not None:\n\n            if 'SD' in self.guidance:\n                self.embeddings['SD']['default'] = self.guidance['SD'].get_text_embeds([self.opt.text])\n                self.embeddings['SD']['uncond'] = self.guidance['SD'].get_text_embeds([self.opt.negative])\n\n                for d in ['front', 'side', 'back']:\n                    self.embeddings['SD'][d] = self.guidance['SD'].get_text_embeds([f\"{self.opt.text}, {d} view\"])\n\n            if 'IF' in self.guidance:\n                self.embeddings['IF']['default'] = self.guidance['IF'].get_text_embeds([self.opt.text])\n                self.embeddings['IF']['uncond'] = self.guidance['IF'].get_text_embeds([self.opt.negative])\n\n                for d in ['front', 'side', 'back']:\n                    self.embeddings['IF'][d] = self.guidance['IF'].get_text_embeds([f\"{self.opt.text}, {d} view\"])\n\n            if 'clip' in self.guidance:\n                self.embeddings['clip']['text'] = self.guidance['clip'].get_text_embeds(self.opt.text)\n\n        if self.opt.images is not None:\n\n            h = int(self.opt.known_view_scale * self.opt.h)\n            w = int(self.opt.known_view_scale * self.opt.w)\n\n            # load processed image\n            for image in self.opt.images:\n                assert image.endswith('_rgba.png') # the rest of this code assumes that the _rgba image has been passed.\n            rgbas = [cv2.cvtColor(cv2.imread(image, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA) for image in self.opt.images]\n            rgba_hw = np.stack([cv2.resize(rgba, (w, h), interpolation=cv2.INTER_AREA).astype(np.float32) / 255 for rgba in rgbas])\n            rgb_hw = rgba_hw[..., :3] * rgba_hw[..., 3:] + (1 - rgba_hw[..., 3:])\n            self.rgb = torch.from_numpy(rgb_hw).permute(0,3,1,2).contiguous().to(self.device)\n            self.mask = torch.from_numpy(rgba_hw[..., 3] > 0.5).to(self.device)\n            print(f'[INFO] dataset: load image prompt {self.opt.images} {self.rgb.shape}')\n\n            # load depth\n            depth_paths = [image.replace('_rgba.png', '_depth.png') for image in self.opt.images]\n            depths = [cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) for depth_path in depth_paths]\n            depth = np.stack([cv2.resize(depth, (w, h), interpolation=cv2.INTER_AREA) for depth in depths])\n            self.depth = torch.from_numpy(depth.astype(np.float32) / 255).to(self.device)  # TODO: this should be mapped to FP16\n            print(f'[INFO] dataset: load depth prompt {depth_paths} {self.depth.shape}')\n\n            # load normal   # TODO: don't load if normal loss is 0\n            normal_paths = [image.replace('_rgba.png', '_normal.png') for image in self.opt.images]\n            normals = [cv2.imread(normal_path, cv2.IMREAD_UNCHANGED) for normal_path in normal_paths]\n            normal = np.stack([cv2.resize(normal, (w, h), interpolation=cv2.INTER_AREA) for normal in normals])\n            self.normal = torch.from_numpy(normal.astype(np.float32) / 255).to(self.device)\n            print(f'[INFO] dataset: load normal prompt {normal_paths} {self.normal.shape}')\n\n            # encode embeddings for zero123\n            if 'zero123' in self.guidance:\n                rgba_256 = np.stack([cv2.resize(rgba, (256, 256), interpolation=cv2.INTER_AREA).astype(np.float32) / 255 for rgba in rgbas])\n                rgbs_256 = rgba_256[..., :3] * rgba_256[..., 3:] + (1 - rgba_256[..., 3:])\n                rgb_256 = torch.from_numpy(rgbs_256).permute(0,3,1,2).contiguous().to(self.device)\n                guidance_embeds = self.guidance['zero123'].get_img_embeds(rgb_256)\n                self.embeddings['zero123']['default'] = {\n                    'zero123_ws' : self.opt.zero123_ws,\n                    'c_crossattn' : guidance_embeds[0],\n                    'c_concat' : guidance_embeds[1],\n                    'ref_polars' : self.opt.ref_polars,\n                    'ref_azimuths' : self.opt.ref_azimuths,\n                    'ref_radii' : self.opt.ref_radii,\n                }\n\n            if 'clip' in self.guidance:\n                self.embeddings['clip']['image'] = self.guidance['clip'].get_img_embeds(self.rgb)\n\n\n    def __del__(self):\n        if self.log_ptr:\n            self.log_ptr.close()\n\n\n    def log(self, *args, **kwargs):\n        if self.local_rank == 0:\n            if not self.mute:\n                #print(*args)\n                self.console.print(*args, **kwargs)\n            if self.log_ptr:\n                print(*args, file=self.log_ptr)\n                self.log_ptr.flush() # write immediately to file\n\n    ### ------------------------------\n\n    def train_step(self, data, save_guidance_path:Path=None):\n        \"\"\"\n            Args:\n                save_guidance_path: an image that combines the NeRF render, the added latent noise,\n                    the denoised result and optionally the fully-denoised image.\n        \"\"\"\n\n        # perform RGBD loss instead of SDS if is image-conditioned\n        do_rgbd_loss = self.opt.images is not None and \\\n            (self.global_step % self.opt.known_view_interval == 0)\n\n        # override random camera with fixed known camera\n        if do_rgbd_loss:\n            data = self.default_view_data\n\n        # experiment iterations ratio\n        # i.e. what proportion of this experiment have we completed (in terms of iterations) so far?\n        exp_iter_ratio = (self.global_step - self.opt.exp_start_iter) / (self.opt.exp_end_iter - self.opt.exp_start_iter)\n\n        # progressively relaxing view range\n        if self.opt.progressive_view:\n            r = min(1.0, self.opt.progressive_view_init_ratio + 2.0*exp_iter_ratio)\n            self.opt.phi_range = [self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[0] * r,\n                                  self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[1] * r]\n            self.opt.theta_range = [self.opt.default_polar * (1 - r) + self.opt.full_theta_range[0] * r,\n                                    self.opt.default_polar * (1 - r) + self.opt.full_theta_range[1] * r]\n            self.opt.radius_range = [self.opt.default_radius * (1 - r) + self.opt.full_radius_range[0] * r,\n                                    self.opt.default_radius * (1 - r) + self.opt.full_radius_range[1] * r]\n            self.opt.fovy_range = [self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[0] * r,\n                                    self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[1] * r]\n\n        # progressively increase max_level\n        if self.opt.progressive_level:\n            self.model.max_level = min(1.0, 0.25 + 2.0*exp_iter_ratio)\n\n        rays_o = data['rays_o'] # [B, N, 3]\n        rays_d = data['rays_d'] # [B, N, 3]\n        mvp = data['mvp'] # [B, 4, 4]\n\n        B, N = rays_o.shape[:2]\n        H, W = data['H'], data['W']\n\n        # When ref_data has B images > opt.batch_size\n        if B > self.opt.batch_size:\n            # choose batch_size images out of those B images\n            choice = torch.randperm(B)[:self.opt.batch_size]\n            B = self.opt.batch_size\n            rays_o = rays_o[choice]\n            rays_d = rays_d[choice]\n            mvp = mvp[choice]\n\n        if do_rgbd_loss:\n            ambient_ratio = 1.0\n            shading = 'lambertian' # use lambertian instead of albedo to get normal\n            as_latent = False\n            binarize = False\n            bg_color = torch.rand((B * N, 3), device=rays_o.device)\n\n            # add camera noise to avoid grid-like artifact\n            if self.opt.known_view_noise_scale > 0:\n                noise_scale = self.opt.known_view_noise_scale #* (1 - self.global_step / self.opt.iters)\n                rays_o = rays_o + torch.randn(3, device=self.device) * noise_scale\n                rays_d = rays_d + torch.randn(3, device=self.device) * noise_scale\n\n        elif exp_iter_ratio <= self.opt.latent_iter_ratio:\n            ambient_ratio = 1.0\n            shading = 'normal'\n            as_latent = True\n            binarize = False\n            bg_color = None\n\n        else:\n            if exp_iter_ratio <= self.opt.albedo_iter_ratio:\n                ambient_ratio = 1.0\n                shading = 'albedo'\n            else:\n                # random shading\n                ambient_ratio = self.opt.min_ambient_ratio + (1.0-self.opt.min_ambient_ratio) * random.random()\n                rand = random.random()\n                if rand >= (1.0 - self.opt.textureless_ratio):\n                    shading = 'textureless'\n                else:\n                    shading = 'lambertian'\n\n            as_latent = False\n\n            # random weights binarization (like mobile-nerf) [NOT WORKING NOW]\n            # binarize_thresh = min(0.5, -0.5 + self.global_step / self.opt.iters)\n            # binarize = random.random() < binarize_thresh\n            binarize = False\n\n            # random background\n            rand = random.random()\n            if self.opt.bg_radius > 0 and rand > 0.5:\n                bg_color = None # use bg_net\n            else:\n                bg_color = torch.rand(3).to(self.device) # single color random bg\n\n        outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, binarize=binarize)\n        pred_depth = outputs['depth'].reshape(B, 1, H, W)\n        pred_mask = outputs['weights_sum'].reshape(B, 1, H, W)\n        if 'normal_image' in outputs:\n            pred_normal = outputs['normal_image'].reshape(B, H, W, 3)\n\n        if as_latent:\n            # abuse normal & mask as latent code for faster geometry initialization (ref: fantasia3D)\n            pred_rgb = torch.cat([outputs['image'], outputs['weights_sum'].unsqueeze(-1)], dim=-1).reshape(B, H, W, 4).permute(0, 3, 1, 2).contiguous() # [B, 4, H, W]\n        else:\n            pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [B, 3, H, W]\n\n        # known view loss\n        if do_rgbd_loss:\n            gt_mask = self.mask # [B, H, W]\n            gt_rgb = self.rgb   # [B, 3, H, W]\n            gt_normal = self.normal # [B, H, W, 3]\n            gt_depth = self.depth   # [B, H, W]\n\n            if len(gt_rgb) > self.opt.batch_size:\n                gt_mask = gt_mask[choice]\n                gt_rgb = gt_rgb[choice]\n                gt_normal = gt_normal[choice]\n                gt_depth = gt_depth[choice]\n\n            # color loss\n            gt_rgb = gt_rgb * gt_mask[:, None].float() + bg_color.reshape(B, H, W, 3).permute(0,3,1,2).contiguous() * (1 - gt_mask[:, None].float())\n            loss = self.opt.lambda_rgb * F.mse_loss(pred_rgb, gt_rgb)\n\n            # mask loss\n            loss = loss + self.opt.lambda_mask * F.mse_loss(pred_mask[:, 0], gt_mask.float())\n\n            # normal loss\n            if self.opt.lambda_normal > 0 and 'normal_image' in outputs:\n                valid_gt_normal = 1 - 2 * gt_normal[gt_mask] # [B, 3]\n                valid_pred_normal = 2 * pred_normal[gt_mask] - 1 # [B, 3]\n\n                lambda_normal = self.opt.lambda_normal * min(1, self.global_step / self.opt.iters)\n                loss = loss + lambda_normal * (1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean())\n\n            # relative depth loss\n            if self.opt.lambda_depth > 0:\n                valid_gt_depth = gt_depth[gt_mask] # [B,]\n                valid_pred_depth = pred_depth[:, 0][gt_mask] # [B,]\n                lambda_depth = self.opt.lambda_depth * min(1, self.global_step / self.opt.iters)\n                loss = loss + lambda_depth * (1 - self.pearson(valid_pred_depth, valid_gt_depth))\n\n                # # scale-invariant\n                # with torch.no_grad():\n                #     A = torch.cat([valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1) # [B, 2]\n                #     X = torch.linalg.lstsq(A, valid_pred_depth).solution # [2, 1]\n                #     valid_gt_depth = A @ X # [B, 1]\n                # lambda_depth = self.opt.lambda_depth #* min(1, self.global_step / self.opt.iters)\n                # loss = loss + lambda_depth * F.mse_loss(valid_pred_depth, valid_gt_depth)\n\n        # novel view loss\n        else:\n\n            loss = 0\n\n            if 'SD' in self.guidance:\n                # interpolate text_z\n                azimuth = data['azimuth'] # [-180, 180]\n\n                # ENHANCE: remove loop to handle batch size > 1\n                text_z = [self.embeddings['SD']['uncond']] * azimuth.shape[0]\n                if self.opt.perpneg:\n\n                    text_z_comp, weights = adjust_text_embeddings(self.embeddings['SD'], azimuth, self.opt)\n                    text_z.append(text_z_comp)\n\n                else:                \n                    for b in range(azimuth.shape[0]):\n                        if azimuth[b] >= -90 and azimuth[b] < 90:\n                            if azimuth[b] >= 0:\n                                r = 1 - azimuth[b] / 90\n                            else:\n                                r = 1 + azimuth[b] / 90\n                            start_z = self.embeddings['SD']['front']\n                            end_z = self.embeddings['SD']['side']\n                        else:\n                            if azimuth[b] >= 0:\n                                r = 1 - (azimuth[b] - 90) / 90\n                            else:\n                                r = 1 + (azimuth[b] + 90) / 90\n                            start_z = self.embeddings['SD']['side']\n                            end_z = self.embeddings['SD']['back']\n                        text_z.append(r * start_z + (1 - r) * end_z)\n\n                text_z = torch.cat(text_z, dim=0)\n                if self.opt.perpneg:\n                    loss = loss + self.guidance['SD'].train_step_perpneg(text_z, weights, pred_rgb, as_latent=as_latent, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance,\n                                                    save_guidance_path=save_guidance_path)\n                else:\n                    loss = loss + self.guidance['SD'].train_step(text_z, pred_rgb, as_latent=as_latent, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance,\n                                                                save_guidance_path=save_guidance_path)\n\n            if 'IF' in self.guidance:\n                # interpolate text_z\n                azimuth = data['azimuth'] # [-180, 180]\n\n                # ENHANCE: remove loop to handle batch size > 1\n                text_z = [self.embeddings['IF']['uncond']] * azimuth.shape[0]\n                if self.opt.perpneg:\n                    text_z_comp, weights = adjust_text_embeddings(self.embeddings['IF'], azimuth, self.opt)\n                    text_z.append(text_z_comp)\n                else:\n                    for b in range(azimuth.shape[0]):\n                        if azimuth[b] >= -90 and azimuth[b] < 90:\n                            if azimuth[b] >= 0:\n                                r = 1 - azimuth[b] / 90\n                            else:\n                                r = 1 + azimuth[b] / 90\n                            start_z = self.embeddings['IF']['front']\n                            end_z = self.embeddings['IF']['side']\n                        else:\n                            if azimuth[b] >= 0:\n                                r = 1 - (azimuth[b] - 90) / 90\n                            else:\n                                r = 1 + (azimuth[b] + 90) / 90\n                            start_z = self.embeddings['IF']['side']\n                            end_z = self.embeddings['IF']['back']\n                        text_z.append(r * start_z + (1 - r) * end_z)\n\n                text_z = torch.cat(text_z, dim=0)\n\n                if self.opt.perpneg:\n                    loss = loss + self.guidance['IF'].train_step_perpneg(text_z, weights, pred_rgb, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance)\n                else:\n                    loss = loss + self.guidance['IF'].train_step(text_z, pred_rgb, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance)\n                    \n            if 'zero123' in self.guidance:\n\n                polar = data['polar']\n                azimuth = data['azimuth']\n                radius = data['radius']\n\n                loss = loss + self.guidance['zero123'].train_step(self.embeddings['zero123']['default'], pred_rgb, polar, azimuth, radius, guidance_scale=self.opt.guidance_scale,\n                                                                  as_latent=as_latent, grad_scale=self.opt.lambda_guidance, save_guidance_path=save_guidance_path)\n\n            if 'clip' in self.guidance:\n\n                # empirical, far view should apply smaller CLIP loss\n                lambda_guidance = 10 * (1 - abs(azimuth) / 180) * self.opt.lambda_guidance\n\n                loss = loss + self.guidance['clip'].train_step(self.embeddings['clip'], pred_rgb, grad_scale=lambda_guidance)\n\n        # regularizations\n        if not self.opt.dmtet:\n\n            if self.opt.lambda_opacity > 0:\n                loss_opacity = (outputs['weights_sum'] ** 2).mean()\n                loss = loss + self.opt.lambda_opacity * loss_opacity\n\n            if self.opt.lambda_entropy > 0:\n                alphas = outputs['weights'].clamp(1e-5, 1 - 1e-5)\n                # alphas = alphas ** 2 # skewed entropy, favors 0 over 1\n                loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()\n                lambda_entropy = self.opt.lambda_entropy * min(1, 2 * self.global_step / self.opt.iters)\n                loss = loss + lambda_entropy * loss_entropy\n\n            if self.opt.lambda_2d_normal_smooth > 0 and 'normal_image' in outputs:\n                # pred_vals = outputs['normal_image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous()\n                # smoothed_vals = TF.gaussian_blur(pred_vals.detach(), kernel_size=9)\n                # loss_smooth = F.mse_loss(pred_vals, smoothed_vals)\n                # total-variation\n                loss_smooth = (pred_normal[:, 1:, :, :] - pred_normal[:, :-1, :, :]).square().mean() + \\\n                              (pred_normal[:, :, 1:, :] - pred_normal[:, :, :-1, :]).square().mean()\n                loss = loss + self.opt.lambda_2d_normal_smooth * loss_smooth\n\n            if self.opt.lambda_orient > 0 and 'loss_orient' in outputs:\n                loss_orient = outputs['loss_orient']\n                loss = loss + self.opt.lambda_orient * loss_orient\n\n            if self.opt.lambda_3d_normal_smooth > 0 and 'loss_normal_perturb' in outputs:\n                loss_normal_perturb = outputs['loss_normal_perturb']\n                loss = loss + self.opt.lambda_3d_normal_smooth * loss_normal_perturb\n\n        else:\n\n            if self.opt.lambda_mesh_normal > 0:\n                loss = loss + self.opt.lambda_mesh_normal * outputs['normal_loss']\n\n            if self.opt.lambda_mesh_laplacian > 0:\n                loss = loss + self.opt.lambda_mesh_laplacian * outputs['lap_loss']\n\n        return pred_rgb, pred_depth, loss\n\n    def post_train_step(self):\n\n        # unscale grad before modifying it!\n        # ref: https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping\n        self.scaler.unscale_(self.optimizer)\n\n        # clip grad\n        if self.opt.grad_clip >= 0:\n            torch.nn.utils.clip_grad_value_(self.model.parameters(), self.opt.grad_clip)\n\n        if not self.opt.dmtet and self.opt.backbone == 'grid':\n\n            if self.opt.lambda_tv > 0:\n                lambda_tv = min(1.0, self.global_step / (0.5 * self.opt.iters)) * self.opt.lambda_tv\n                self.model.encoder.grad_total_variation(lambda_tv, None, self.model.bound)\n            if self.opt.lambda_wd > 0:\n                self.model.encoder.grad_weight_decay(self.opt.lambda_wd)\n\n    def eval_step(self, data):\n\n        rays_o = data['rays_o'] # [B, N, 3]\n        rays_d = data['rays_d'] # [B, N, 3]\n        mvp = data['mvp']\n\n        B, N = rays_o.shape[:2]\n        H, W = data['H'], data['W']\n\n        shading = data['shading'] if 'shading' in data else 'albedo'\n        ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0\n        light_d = data['light_d'] if 'light_d' in data else None\n\n        outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=True, perturb=False, bg_color=None, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading)\n        pred_rgb = outputs['image'].reshape(B, H, W, 3)\n        pred_depth = outputs['depth'].reshape(B, H, W)\n\n        # dummy\n        loss = torch.zeros([1], device=pred_rgb.device, dtype=pred_rgb.dtype)\n\n        return pred_rgb, pred_depth, loss\n\n    def test_step(self, data, bg_color=None, perturb=False):\n        rays_o = data['rays_o'] # [B, N, 3]\n        rays_d = data['rays_d'] # [B, N, 3]\n        mvp = data['mvp']\n\n        B, N = rays_o.shape[:2]\n        H, W = data['H'], data['W']\n\n        if bg_color is not None:\n            bg_color = bg_color.to(rays_o.device)\n\n        shading = data['shading'] if 'shading' in data else 'albedo'\n        ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0\n        light_d = data['light_d'] if 'light_d' in data else None\n\n        outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=True, perturb=perturb, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, bg_color=bg_color)\n\n        pred_rgb = outputs['image'].reshape(B, H, W, 3)\n        pred_depth = outputs['depth'].reshape(B, H, W)\n\n        return pred_rgb, pred_depth, None\n\n    def save_mesh(self, loader=None, save_path=None):\n\n        if save_path is None:\n            save_path = os.path.join(self.workspace, 'mesh')\n\n        self.log(f\"==> Saving mesh to {save_path}\")\n\n        os.makedirs(save_path, exist_ok=True)\n\n        self.model.export_mesh(save_path, resolution=self.opt.mcubes_resolution, decimate_target=self.opt.decimate_target)\n\n        self.log(f\"==> Finished saving mesh.\")\n\n    ### ------------------------------\n\n    def train(self, train_loader, valid_loader, test_loader, max_epochs):\n\n        if self.use_tensorboardX and self.local_rank == 0:\n            self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, \"run\", self.name))\n\n        start_t = time.time()\n\n        for epoch in range(self.epoch + 1, max_epochs + 1):\n            self.epoch = epoch\n\n            self.train_one_epoch(train_loader, max_epochs)\n\n            if self.workspace is not None and self.local_rank == 0:\n                self.save_checkpoint(full=True, best=False)\n\n            if self.epoch % self.opt.eval_interval == 0:\n                self.evaluate_one_epoch(valid_loader)\n                self.save_checkpoint(full=False, best=True)\n\n            if self.epoch % self.opt.test_interval == 0 or self.epoch == max_epochs:\n                self.test(test_loader)\n\n        end_t = time.time()\n\n        self.total_train_t = end_t - start_t + self.total_train_t\n\n        self.log(f\"[INFO] training takes {(self.total_train_t)/ 60:.4f} minutes.\")\n\n        if self.use_tensorboardX and self.local_rank == 0:\n            self.writer.close()\n\n    def evaluate(self, loader, name=None):\n        self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX\n        self.evaluate_one_epoch(loader, name)\n        self.use_tensorboardX = use_tensorboardX\n\n    def test(self, loader, save_path=None, name=None, write_video=True):\n\n        if save_path is None:\n            save_path = os.path.join(self.workspace, 'results')\n\n        if name is None:\n            name = f'{self.name}_ep{self.epoch:04d}'\n\n        os.makedirs(save_path, exist_ok=True)\n\n        self.log(f\"==> Start Test, save results to {save_path}\")\n\n        pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')\n        self.model.eval()\n\n        if write_video:\n            all_preds = []\n            all_preds_depth = []\n\n        with torch.no_grad():\n\n            for i, data in enumerate(loader):\n\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    preds, preds_depth, _ = self.test_step(data)\n\n                pred = preds[0].detach().cpu().numpy()\n                pred = (pred * 255).astype(np.uint8)\n\n                pred_depth = preds_depth[0].detach().cpu().numpy()\n                pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min() + 1e-6)\n                pred_depth = (pred_depth * 255).astype(np.uint8)\n\n                if write_video:\n                    all_preds.append(pred)\n                    all_preds_depth.append(pred_depth)\n                else:\n                    cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_rgb.png'), cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))\n                    cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_depth.png'), pred_depth)\n\n                pbar.update(loader.batch_size)\n\n        if write_video:\n            all_preds = np.stack(all_preds, axis=0)\n            all_preds_depth = np.stack(all_preds_depth, axis=0)\n\n            imageio.mimwrite(os.path.join(save_path, f'{name}_rgb.mp4'), all_preds, fps=25, quality=8, macro_block_size=1)\n            imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, macro_block_size=1)\n\n        self.log(f\"==> Finished Test.\")\n\n    # [GUI] train text step.\n    def train_gui(self, train_loader, step=16):\n\n        self.model.train()\n\n        total_loss = torch.tensor([0], dtype=torch.float32, device=self.device)\n\n        loader = iter(train_loader)\n\n        for _ in range(step):\n\n            # mimic an infinite loop dataloader (in case the total dataset is smaller than step)\n            try:\n                data = next(loader)\n            except StopIteration:\n                loader = iter(train_loader)\n                data = next(loader)\n\n            # update grid every 16 steps\n            if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    self.model.update_extra_state()\n\n            self.global_step += 1\n\n            self.optimizer.zero_grad()\n\n            with torch.cuda.amp.autocast(enabled=self.fp16):\n                pred_rgbs, pred_depths, loss = self.train_step(data)\n\n            self.scaler.scale(loss).backward()\n            self.post_train_step()\n            self.scaler.step(self.optimizer)\n            self.scaler.update()\n\n            if self.scheduler_update_every_step:\n                self.lr_scheduler.step()\n\n            total_loss += loss.detach()\n\n        if self.ema is not None:\n            self.ema.update()\n\n        average_loss = total_loss.item() / step\n\n        if not self.scheduler_update_every_step:\n            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n                self.lr_scheduler.step(average_loss)\n            else:\n                self.lr_scheduler.step()\n\n        outputs = {\n            'loss': average_loss,\n            'lr': self.optimizer.param_groups[0]['lr'],\n        }\n\n        return outputs\n\n\n    # [GUI] test on a single image\n    def test_gui(self, pose, intrinsics, mvp, W, H, bg_color=None, spp=1, downscale=1, light_d=None, ambient_ratio=1.0, shading='albedo'):\n\n        # render resolution (may need downscale to for better frame rate)\n        rH = int(H * downscale)\n        rW = int(W * downscale)\n        intrinsics = intrinsics * downscale\n\n        pose = torch.from_numpy(pose).unsqueeze(0).to(self.device)\n        mvp = torch.from_numpy(mvp).unsqueeze(0).to(self.device)\n\n        rays = get_rays(pose, intrinsics, rH, rW, -1)\n\n        # from degree theta/phi to 3D normalized vec\n        light_d = np.deg2rad(light_d)\n        light_d = np.array([\n            np.sin(light_d[0]) * np.sin(light_d[1]),\n            np.cos(light_d[0]),\n            np.sin(light_d[0]) * np.cos(light_d[1]),\n        ], dtype=np.float32)\n        light_d = torch.from_numpy(light_d).to(self.device)\n\n        data = {\n            'rays_o': rays['rays_o'],\n            'rays_d': rays['rays_d'],\n            'mvp': mvp,\n            'H': rH,\n            'W': rW,\n            'light_d': light_d,\n            'ambient_ratio': ambient_ratio,\n            'shading': shading,\n        }\n\n        self.model.eval()\n\n        if self.ema is not None:\n            self.ema.store()\n            self.ema.copy_to()\n\n        with torch.no_grad():\n            with torch.cuda.amp.autocast(enabled=self.fp16):\n                # here spp is used as perturb random seed!\n                preds, preds_depth, _ = self.test_step(data, bg_color=bg_color, perturb=False if spp == 1 else spp)\n\n        if self.ema is not None:\n            self.ema.restore()\n\n        # interpolation to the original resolution\n        if downscale != 1:\n            # have to permute twice with torch...\n            preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='nearest').permute(0, 2, 3, 1).contiguous()\n            preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)\n\n        outputs = {\n            'image': preds[0].detach().cpu().numpy(),\n            'depth': preds_depth[0].detach().cpu().numpy(),\n        }\n\n        return outputs\n\n    def train_one_epoch(self, loader, max_epochs):\n        self.log(f\"==> [{time.strftime('%Y-%m-%d_%H-%M-%S')}] Start Training {self.workspace} Epoch {self.epoch}/{max_epochs}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...\")\n\n        total_loss = 0\n        if self.local_rank == 0 and self.report_metric_at_train:\n            for metric in self.metrics:\n                metric.clear()\n\n        self.model.train()\n\n        # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs\n        # ref: https://pytorch.org/docs/stable/data.html\n        if self.world_size > 1:\n            loader.sampler.set_epoch(self.epoch)\n\n        if self.local_rank == 0:\n            pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')\n\n        self.local_step = 0\n\n        if self.opt.save_guidance:\n            save_guidance_folder = Path(self.workspace) / 'guidance'\n            save_guidance_folder.mkdir(parents=True, exist_ok=True)\n\n        for data in loader:\n\n            # update grid every 16 steps\n            if (self.model.cuda_ray or self.model.taichi_ray) and self.global_step % self.opt.update_extra_interval == 0:\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    self.model.update_extra_state()\n\n            self.local_step += 1\n            self.global_step += 1\n\n            self.optimizer.zero_grad()\n\n            with torch.cuda.amp.autocast(enabled=self.fp16):\n                if self.opt.save_guidance and (self.global_step % self.opt.save_guidance_interval == 0):\n                    save_guidance_path = save_guidance_folder / f'step_{self.global_step:07d}.png'\n                else:\n                    save_guidance_path = None\n                pred_rgbs, pred_depths, loss = self.train_step(data, save_guidance_path=save_guidance_path)\n\n            # hooked grad clipping for RGB space\n            if self.opt.grad_clip_rgb >= 0:\n                def _hook(grad):\n                    if self.opt.fp16:\n                        # correctly handle the scale\n                        grad_scale = self.scaler._get_scale_async()\n                        return grad.clamp(grad_scale * -self.opt.grad_clip_rgb, grad_scale * self.opt.grad_clip_rgb)\n                    else:\n                        return grad.clamp(-self.opt.grad_clip_rgb, self.opt.grad_clip_rgb)\n                pred_rgbs.register_hook(_hook)\n                # pred_rgbs.retain_grad()\n\n            self.scaler.scale(loss).backward()\n\n            self.post_train_step()\n            self.scaler.step(self.optimizer)\n            self.scaler.update()\n\n            if self.scheduler_update_every_step:\n                self.lr_scheduler.step()\n\n            loss_val = loss.item()\n            total_loss += loss_val\n\n            if self.local_rank == 0:\n                # if self.report_metric_at_train:\n                #     for metric in self.metrics:\n                #         metric.update(preds, truths)\n\n                if self.use_tensorboardX:\n                    self.writer.add_scalar(\"train/loss\", loss_val, self.global_step)\n                    self.writer.add_scalar(\"train/lr\", self.optimizer.param_groups[0]['lr'], self.global_step)\n\n                if self.scheduler_update_every_step:\n                    pbar.set_description(f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}\")\n                else:\n                    pbar.set_description(f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})\")\n                pbar.update(loader.batch_size)\n\n        if self.ema is not None:\n            self.ema.update()\n\n        average_loss = total_loss / self.local_step\n        self.stats[\"loss\"].append(average_loss)\n\n        if self.local_rank == 0:\n            pbar.close()\n            if self.report_metric_at_train:\n                for metric in self.metrics:\n                    self.log(metric.report(), style=\"red\")\n                    if self.use_tensorboardX:\n                        metric.write(self.writer, self.epoch, prefix=\"train\")\n                    metric.clear()\n\n        if not self.scheduler_update_every_step:\n            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n                self.lr_scheduler.step(average_loss)\n            else:\n                self.lr_scheduler.step()\n\n        cpu_mem, gpu_mem = get_CPU_mem(), get_GPU_mem()[0]\n        self.log(f\"==> [{time.strftime('%Y-%m-%d_%H-%M-%S')}] Finished Epoch {self.epoch}/{max_epochs}. CPU={cpu_mem:.1f}GB, GPU={gpu_mem:.1f}GB.\")\n\n\n    def evaluate_one_epoch(self, loader, name=None):\n        self.log(f\"++> Evaluate {self.workspace} at epoch {self.epoch} ...\")\n\n        if name is None:\n            name = f'{self.name}_ep{self.epoch:04d}'\n\n        total_loss = 0\n        if self.local_rank == 0:\n            for metric in self.metrics:\n                metric.clear()\n\n        self.model.eval()\n\n        if self.ema is not None:\n            self.ema.store()\n            self.ema.copy_to()\n\n        if self.local_rank == 0:\n            pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')\n\n        with torch.no_grad():\n            self.local_step = 0\n\n            for data in loader:\n                self.local_step += 1\n\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    preds, preds_depth, loss = self.eval_step(data)\n\n                # all_gather/reduce the statistics (NCCL only support all_*)\n                if self.world_size > 1:\n                    dist.all_reduce(loss, op=dist.ReduceOp.SUM)\n                    loss = loss / self.world_size\n\n                    preds_list = [torch.zeros_like(preds).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...]\n                    dist.all_gather(preds_list, preds)\n                    preds = torch.cat(preds_list, dim=0)\n\n                    preds_depth_list = [torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...]\n                    dist.all_gather(preds_depth_list, preds_depth)\n                    preds_depth = torch.cat(preds_depth_list, dim=0)\n\n                loss_val = loss.item()\n                total_loss += loss_val\n\n                # only rank = 0 will perform evaluation.\n                if self.local_rank == 0:\n\n                    # save image\n                    save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png')\n                    save_path_depth = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_depth.png')\n\n                    #self.log(f\"==> Saving validation image to {save_path}\")\n                    os.makedirs(os.path.dirname(save_path), exist_ok=True)\n\n                    pred = preds[0].detach().cpu().numpy()\n                    pred = (pred * 255).astype(np.uint8)\n\n                    pred_depth = preds_depth[0].detach().cpu().numpy()\n                    pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min() + 1e-6)\n                    pred_depth = (pred_depth * 255).astype(np.uint8)\n\n                    cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))\n                    cv2.imwrite(save_path_depth, pred_depth)\n\n                    pbar.set_description(f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})\")\n                    pbar.update(loader.batch_size)\n\n\n        average_loss = total_loss / self.local_step\n        self.stats[\"valid_loss\"].append(average_loss)\n\n        if self.local_rank == 0:\n            pbar.close()\n            if not self.use_loss_as_metric and len(self.metrics) > 0:\n                result = self.metrics[0].measure()\n                self.stats[\"results\"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result\n            else:\n                self.stats[\"results\"].append(average_loss) # if no metric, choose best by min loss\n\n            for metric in self.metrics:\n                self.log(metric.report(), style=\"blue\")\n                if self.use_tensorboardX:\n                    metric.write(self.writer, self.epoch, prefix=\"evaluate\")\n                metric.clear()\n\n        if self.ema is not None:\n            self.ema.restore()\n\n        self.log(f\"++> Evaluate epoch {self.epoch} Finished.\")\n\n    def save_checkpoint(self, name=None, full=False, best=False):\n\n        if name is None:\n            name = f'{self.name}_ep{self.epoch:04d}'\n\n        state = {\n            'epoch': self.epoch,\n            'global_step': self.global_step,\n            'stats': self.stats,\n        }\n\n        if self.model.cuda_ray:\n            state['mean_density'] = self.model.mean_density\n\n        if self.opt.dmtet:\n            state['tet_scale'] = self.model.tet_scale.cpu().numpy()\n\n        if full:\n            state['optimizer'] = self.optimizer.state_dict()\n            state['lr_scheduler'] = self.lr_scheduler.state_dict()\n            state['scaler'] = self.scaler.state_dict()\n            if self.ema is not None:\n                state['ema'] = self.ema.state_dict()\n\n        if not best:\n\n            state['model'] = self.model.state_dict()\n\n            file_path = f\"{name}.pth\"\n\n            self.stats[\"checkpoints\"].append(file_path)\n\n            if len(self.stats[\"checkpoints\"]) > self.max_keep_ckpt:\n                old_ckpt = os.path.join(self.ckpt_path, self.stats[\"checkpoints\"].pop(0))\n                if os.path.exists(old_ckpt):\n                    os.remove(old_ckpt)\n\n            torch.save(state, os.path.join(self.ckpt_path, file_path))\n\n        else:\n            if len(self.stats[\"results\"]) > 0:\n                # always save best since loss cannot reflect performance.\n                if True:\n                    # self.log(f\"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}\")\n                    # self.stats[\"best_result\"] = self.stats[\"results\"][-1]\n\n                    # save ema results\n                    if self.ema is not None:\n                        self.ema.store()\n                        self.ema.copy_to()\n\n                    state['model'] = self.model.state_dict()\n\n                    if self.ema is not None:\n                        self.ema.restore()\n\n                    torch.save(state, self.best_path)\n            else:\n                self.log(f\"[WARN] no evaluated results found, skip saving best checkpoint.\")\n\n    def load_checkpoint(self, checkpoint=None, model_only=False):\n        if checkpoint is None:\n            checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/*.pth'))\n            if checkpoint_list:\n                checkpoint = checkpoint_list[-1]\n                self.log(f\"[INFO] Latest checkpoint is {checkpoint}\")\n            else:\n                self.log(\"[WARN] No checkpoint found, model randomly initialized.\")\n                return\n\n        checkpoint_dict = torch.load(checkpoint, map_location=self.device)\n\n        if 'model' not in checkpoint_dict:\n            self.model.load_state_dict(checkpoint_dict)\n            self.log(\"[INFO] loaded model.\")\n            return\n\n        missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False)\n        self.log(\"[INFO] loaded model.\")\n        if len(missing_keys) > 0:\n            self.log(f\"[WARN] missing keys: {missing_keys}\")\n        if len(unexpected_keys) > 0:\n            self.log(f\"[WARN] unexpected keys: {unexpected_keys}\")\n\n        if self.ema is not None and 'ema' in checkpoint_dict:\n            try:\n                self.ema.load_state_dict(checkpoint_dict['ema'])\n                self.log(\"[INFO] loaded EMA.\")\n            except:\n                self.log(\"[WARN] failed to loaded EMA.\")\n\n        if self.model.cuda_ray:\n            if 'mean_density' in checkpoint_dict:\n                self.model.mean_density = checkpoint_dict['mean_density']\n\n        if self.opt.dmtet:\n            if 'tet_scale' in checkpoint_dict:\n                new_scale = torch.from_numpy(checkpoint_dict['tet_scale']).to(self.device)\n                self.model.verts *= new_scale / self.model.tet_scale\n                self.model.tet_scale = new_scale\n\n        if model_only:\n            return\n\n        self.stats = checkpoint_dict['stats']\n        self.epoch = checkpoint_dict['epoch']\n        self.global_step = checkpoint_dict['global_step']\n        self.log(f\"[INFO] load at epoch {self.epoch}, global step {self.global_step}\")\n\n        if self.optimizer and 'optimizer' in checkpoint_dict:\n            try:\n                self.optimizer.load_state_dict(checkpoint_dict['optimizer'])\n                self.log(\"[INFO] loaded optimizer.\")\n            except:\n                self.log(\"[WARN] Failed to load optimizer.\")\n\n        if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:\n            try:\n                self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])\n                self.log(\"[INFO] loaded scheduler.\")\n            except:\n                self.log(\"[WARN] Failed to load scheduler.\")\n\n        if self.scaler and 'scaler' in checkpoint_dict:\n            try:\n                self.scaler.load_state_dict(checkpoint_dict['scaler'])\n                self.log(\"[INFO] loaded scaler.\")\n            except:\n                self.log(\"[WARN] Failed to load scaler.\")\n\n\ndef get_CPU_mem():\n    return psutil.Process(os.getpid()).memory_info().rss /1024**3\n\n\ndef get_GPU_mem():\n    num = torch.cuda.device_count()\n    mem, mems = 0, []\n    for i in range(num):\n        mem_free, mem_total = torch.cuda.mem_get_info(i)\n        mems.append(int(((mem_total - mem_free)/1024**3)*1000)/1000)\n        mem += mems[-1]\n    return mem, mems\n"
  },
  {
    "path": "optimizer.py",
    "content": "# Copyright 2022 Garena Online Private Limited\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import List\n\nimport torch\nfrom torch import Tensor\nfrom torch.optim.optimizer import Optimizer\n\n\nclass Adan(Optimizer):\n    \"\"\"\n    Implements a pytorch variant of Adan\n    Adan was proposed in\n    Adan: Adaptive Nesterov Momentum Algorithm for\n        Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022.\n    https://arxiv.org/abs/2208.06677\n    Arguments:\n        params (iterable): iterable of parameters to optimize or\n            dicts defining parameter groups.\n        lr (float, optional): learning rate. (default: 1e-3)\n        betas (Tuple[float, float, flot], optional): coefficients used for\n            first- and second-order moments. (default: (0.98, 0.92, 0.99))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability. (default: 1e-8)\n        weight_decay (float, optional): decoupled weight decay\n            (L2 penalty) (default: 0)\n        max_grad_norm (float, optional): value used to clip\n            global grad norm (default: 0.0 no clip)\n        no_prox (bool): how to perform the decoupled weight decay\n            (default: False)\n        foreach (bool): if True would use torch._foreach implementation.\n            It's faster but uses slightly more memory. (default: True)\n    \"\"\"\n    def __init__(self,\n                 params,\n                 lr=1e-3,\n                 betas=(0.98, 0.92, 0.99),\n                 eps=1e-8,\n                 weight_decay=0.0,\n                 max_grad_norm=0.0,\n                 no_prox=False,\n                 foreach: bool = True):\n        if not 0.0 <= max_grad_norm:\n            raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm))\n        if not 0.0 <= lr:\n            raise ValueError('Invalid learning rate: {}'.format(lr))\n        if not 0.0 <= eps:\n            raise ValueError('Invalid epsilon value: {}'.format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError('Invalid beta parameter at index 0: {}'.format(\n                betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError('Invalid beta parameter at index 1: {}'.format(\n                betas[1]))\n        if not 0.0 <= betas[2] < 1.0:\n            raise ValueError('Invalid beta parameter at index 2: {}'.format(\n                betas[2]))\n        defaults = dict(lr=lr,\n                        betas=betas,\n                        eps=eps,\n                        weight_decay=weight_decay,\n                        max_grad_norm=max_grad_norm,\n                        no_prox=no_prox,\n                        foreach=foreach)\n        super().__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super(Adan, self).__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault('no_prox', False)\n\n    @torch.no_grad()\n    def restart_opt(self):\n        for group in self.param_groups:\n            group['step'] = 0\n            for p in group['params']:\n                if p.requires_grad:\n                    state = self.state[p]\n                    # State initialization\n\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p)\n                    # Exponential moving average of squared gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p)\n                    # Exponential moving average of gradient difference\n                    state['exp_avg_diff'] = torch.zeros_like(p)\n\n    @torch.no_grad()\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\"\"\"\n\n        loss = None\n        if closure is not None:\n            with torch.enable_grad():\n                loss = closure()\n\n        if self.defaults['max_grad_norm'] > 0:\n            device = self.param_groups[0]['params'][0].device\n            global_grad_norm = torch.zeros(1, device=device)\n\n            max_grad_norm = torch.tensor(self.defaults['max_grad_norm'],\n                                         device=device)\n            for group in self.param_groups:\n\n                for p in group['params']:\n                    if p.grad is not None:\n                        grad = p.grad\n                        global_grad_norm.add_(grad.pow(2).sum())\n\n            global_grad_norm = torch.sqrt(global_grad_norm)\n\n            clip_global_grad_norm = torch.clamp(\n                max_grad_norm / (global_grad_norm + group['eps']),\n                max=1.0).item()\n        else:\n            clip_global_grad_norm = 1.0\n\n        for group in self.param_groups:\n            params_with_grad = []\n            grads = []\n            exp_avgs = []\n            exp_avg_sqs = []\n            exp_avg_diffs = []\n            neg_pre_grads = []\n\n            beta1, beta2, beta3 = group['betas']\n            # assume same step across group now to simplify things\n            # per parameter step can be easily support\n            # by making it tensor, or pass list into kernel\n            if 'step' in group:\n                group['step'] += 1\n            else:\n                group['step'] = 1\n\n            bias_correction1 = 1.0 - beta1**group['step']\n            bias_correction2 = 1.0 - beta2**group['step']\n            bias_correction3 = 1.0 - beta3**group['step']\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                params_with_grad.append(p)\n                grads.append(p.grad)\n\n                state = self.state[p]\n                if len(state) == 0:\n                    state['exp_avg'] = torch.zeros_like(p)\n                    state['exp_avg_sq'] = torch.zeros_like(p)\n                    state['exp_avg_diff'] = torch.zeros_like(p)\n\n                if 'neg_pre_grad' not in state or group['step'] == 1:\n                    state['neg_pre_grad'] = p.grad.clone().mul_(\n                        -clip_global_grad_norm)\n\n                exp_avgs.append(state['exp_avg'])\n                exp_avg_sqs.append(state['exp_avg_sq'])\n                exp_avg_diffs.append(state['exp_avg_diff'])\n                neg_pre_grads.append(state['neg_pre_grad'])\n\n            kwargs = dict(\n                params=params_with_grad,\n                grads=grads,\n                exp_avgs=exp_avgs,\n                exp_avg_sqs=exp_avg_sqs,\n                exp_avg_diffs=exp_avg_diffs,\n                neg_pre_grads=neg_pre_grads,\n                beta1=beta1,\n                beta2=beta2,\n                beta3=beta3,\n                bias_correction1=bias_correction1,\n                bias_correction2=bias_correction2,\n                bias_correction3_sqrt=math.sqrt(bias_correction3),\n                lr=group['lr'],\n                weight_decay=group['weight_decay'],\n                eps=group['eps'],\n                no_prox=group['no_prox'],\n                clip_global_grad_norm=clip_global_grad_norm,\n            )\n\n            if group['foreach']:\n                _multi_tensor_adan(**kwargs)\n            else:\n                _single_tensor_adan(**kwargs)\n\n        return loss\n\n\ndef _single_tensor_adan(\n    params: List[Tensor],\n    grads: List[Tensor],\n    exp_avgs: List[Tensor],\n    exp_avg_sqs: List[Tensor],\n    exp_avg_diffs: List[Tensor],\n    neg_pre_grads: List[Tensor],\n    *,\n    beta1: float,\n    beta2: float,\n    beta3: float,\n    bias_correction1: float,\n    bias_correction2: float,\n    bias_correction3_sqrt: float,\n    lr: float,\n    weight_decay: float,\n    eps: float,\n    no_prox: bool,\n    clip_global_grad_norm: Tensor,\n):\n    for i, param in enumerate(params):\n        grad = grads[i]\n        exp_avg = exp_avgs[i]\n        exp_avg_sq = exp_avg_sqs[i]\n        exp_avg_diff = exp_avg_diffs[i]\n        neg_grad_or_diff = neg_pre_grads[i]\n\n        grad.mul_(clip_global_grad_norm)\n\n        # for memory saving, we use `neg_grad_or_diff`\n        # to get some temp variable in a inplace way\n        neg_grad_or_diff.add_(grad)\n\n        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)  # m_t\n        exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff,\n                                      alpha=1 - beta2)  # diff_t\n\n        neg_grad_or_diff.mul_(beta2).add_(grad)\n        exp_avg_sq.mul_(beta3).addcmul_(neg_grad_or_diff,\n                                        neg_grad_or_diff,\n                                        value=1 - beta3)  # n_t\n\n        denom = ((exp_avg_sq).sqrt() / bias_correction3_sqrt).add_(eps)\n        step_size_diff = lr * beta2 / bias_correction2\n        step_size = lr / bias_correction1\n\n        if no_prox:\n            param.mul_(1 - lr * weight_decay)\n            param.addcdiv_(exp_avg, denom, value=-step_size)\n            param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff)\n        else:\n            param.addcdiv_(exp_avg, denom, value=-step_size)\n            param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff)\n            param.div_(1 + lr * weight_decay)\n\n        neg_grad_or_diff.zero_().add_(grad, alpha=-1.0)\n\n\ndef _multi_tensor_adan(\n    params: List[Tensor],\n    grads: List[Tensor],\n    exp_avgs: List[Tensor],\n    exp_avg_sqs: List[Tensor],\n    exp_avg_diffs: List[Tensor],\n    neg_pre_grads: List[Tensor],\n    *,\n    beta1: float,\n    beta2: float,\n    beta3: float,\n    bias_correction1: float,\n    bias_correction2: float,\n    bias_correction3_sqrt: float,\n    lr: float,\n    weight_decay: float,\n    eps: float,\n    no_prox: bool,\n    clip_global_grad_norm: Tensor,\n):\n    if len(params) == 0:\n        return\n\n    torch._foreach_mul_(grads, clip_global_grad_norm)\n\n    # for memory saving, we use `neg_pre_grads`\n    # to get some temp variable in a inplace way\n    torch._foreach_add_(neg_pre_grads, grads)\n\n    torch._foreach_mul_(exp_avgs, beta1)\n    torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)  # m_t\n\n    torch._foreach_mul_(exp_avg_diffs, beta2)\n    torch._foreach_add_(exp_avg_diffs, neg_pre_grads,\n                        alpha=1 - beta2)  # diff_t\n\n    torch._foreach_mul_(neg_pre_grads, beta2)\n    torch._foreach_add_(neg_pre_grads, grads)\n    torch._foreach_mul_(exp_avg_sqs, beta3)\n    torch._foreach_addcmul_(exp_avg_sqs,\n                            neg_pre_grads,\n                            neg_pre_grads,\n                            value=1 - beta3)  # n_t\n\n    denom = torch._foreach_sqrt(exp_avg_sqs)\n    torch._foreach_div_(denom, bias_correction3_sqrt)\n    torch._foreach_add_(denom, eps)\n\n    step_size_diff = lr * beta2 / bias_correction2\n    step_size = lr / bias_correction1\n\n    if no_prox:\n        torch._foreach_mul_(params, 1 - lr * weight_decay)\n        torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)\n        torch._foreach_addcdiv_(params,\n                                exp_avg_diffs,\n                                denom,\n                                value=-step_size_diff)\n    else:\n        torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)\n        torch._foreach_addcdiv_(params,\n                                exp_avg_diffs,\n                                denom,\n                                value=-step_size_diff)\n        torch._foreach_div_(params, 1 + lr * weight_decay)\n    torch._foreach_zero_(neg_pre_grads)\n    torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0)"
  },
  {
    "path": "preprocess_image.py",
    "content": "import os\nimport sys\nimport cv2\nimport argparse\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision import transforms\nfrom PIL import Image\n\nclass BackgroundRemoval():\n    def __init__(self, device='cuda'):\n\n        from carvekit.api.high import HiInterface\n        self.interface = HiInterface(\n            object_type=\"object\",  # Can be \"object\" or \"hairs-like\".\n            batch_size_seg=5,\n            batch_size_matting=1,\n            device=device,\n            seg_mask_size=640,  # Use 640 for Tracer B7 and 320 for U2Net\n            matting_mask_size=2048,\n            trimap_prob_threshold=231,\n            trimap_dilation=30,\n            trimap_erosion_iters=5,\n            fp16=True,\n        )\n\n    @torch.no_grad()\n    def __call__(self, image):\n        # image: [H, W, 3] array in [0, 255].\n        image = Image.fromarray(image)\n\n        image = self.interface([image])[0]\n        image = np.array(image)\n\n        return image\n\nclass BLIP2():\n    def __init__(self, device='cuda'):\n        self.device = device\n        from transformers import AutoProcessor, Blip2ForConditionalGeneration\n        self.processor = AutoProcessor.from_pretrained(\"Salesforce/blip2-opt-2.7b\")\n        self.model = Blip2ForConditionalGeneration.from_pretrained(\"Salesforce/blip2-opt-2.7b\", torch_dtype=torch.float16).to(device)\n\n    @torch.no_grad()\n    def __call__(self, image):\n        image = Image.fromarray(image)\n        inputs = self.processor(image, return_tensors=\"pt\").to(self.device, torch.float16)\n\n        generated_ids = self.model.generate(**inputs, max_new_tokens=20)\n        generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()\n\n        return generated_text\n\n\nclass DPT():\n    def __init__(self, task='depth', device='cuda'):\n\n        self.task = task\n        self.device = device\n\n        from dpt import DPTDepthModel\n\n        if task == 'depth':\n            path = 'pretrained/omnidata/omnidata_dpt_depth_v2.ckpt'\n            self.model = DPTDepthModel(backbone='vitb_rn50_384')\n            self.aug = transforms.Compose([\n                transforms.Resize((384, 384)),\n                transforms.ToTensor(),\n                transforms.Normalize(mean=0.5, std=0.5)\n            ])\n\n        else: # normal\n            path = 'pretrained/omnidata/omnidata_dpt_normal_v2.ckpt'\n            self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3)\n            self.aug = transforms.Compose([\n                transforms.Resize((384, 384)),\n                transforms.ToTensor()\n            ])\n\n        # load model\n        checkpoint = torch.load(path, map_location='cpu')\n        if 'state_dict' in checkpoint:\n            state_dict = {}\n            for k, v in checkpoint['state_dict'].items():\n                state_dict[k[6:]] = v\n        else:\n            state_dict = checkpoint\n        self.model.load_state_dict(state_dict)\n        self.model.eval().to(device)\n\n\n    @torch.no_grad()\n    def __call__(self, image):\n        # image: np.ndarray, uint8, [H, W, 3]\n        H, W = image.shape[:2]\n        image = Image.fromarray(image)\n\n        image = self.aug(image).unsqueeze(0).to(self.device)\n\n        if self.task == 'depth':\n            depth = self.model(image).clamp(0, 1)\n            depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bicubic', align_corners=False)\n            depth = depth.squeeze(1).cpu().numpy()\n            return depth\n        else:\n            normal = self.model(image).clamp(0, 1)\n            normal = F.interpolate(normal, size=(H, W), mode='bicubic', align_corners=False)\n            normal = normal.cpu().numpy()\n            return normal\n\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('path', type=str, help=\"path to image (png, jpeg, etc.)\")\n    parser.add_argument('--size', default=256, type=int, help=\"output resolution\")\n    parser.add_argument('--border_ratio', default=0.2, type=float, help=\"output border ratio\")\n    parser.add_argument('--recenter', type=bool, default=True, help=\"recenter, potentially not helpful for multiview zero123\")\n    parser.add_argument('--dont_recenter', dest='recenter', action='store_false')\n    opt = parser.parse_args()\n\n    out_dir = os.path.dirname(opt.path)\n    out_rgba = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_rgba.png')\n    out_depth = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_depth.png')\n    out_normal = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_normal.png')\n    out_caption = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_caption.txt')\n\n    # load image\n    print(f'[INFO] loading image...')\n    image = cv2.imread(opt.path, cv2.IMREAD_UNCHANGED)\n    if image.shape[-1] == 4:\n        image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)\n    else:\n        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n\n    # carve background\n    print(f'[INFO] background removal...')\n    carved_image = BackgroundRemoval()(image) # [H, W, 4]\n    mask = carved_image[..., -1] > 0\n\n    # predict depth\n    print(f'[INFO] depth estimation...')\n    dpt_depth_model = DPT(task='depth')\n    depth = dpt_depth_model(image)[0]\n    depth[mask] = (depth[mask] - depth[mask].min()) / (depth[mask].max() - depth[mask].min() + 1e-9)\n    depth[~mask] = 0\n    depth = (depth * 255).astype(np.uint8)\n    del dpt_depth_model\n\n    # predict normal\n    print(f'[INFO] normal estimation...')\n    dpt_normal_model = DPT(task='normal')\n    normal = dpt_normal_model(image)[0]\n    normal = (normal * 255).astype(np.uint8).transpose(1, 2, 0)\n    normal[~mask] = 0\n    del dpt_normal_model\n\n    # recenter\n    if opt.recenter:\n        print(f'[INFO] recenter...')\n        final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8)\n        final_depth = np.zeros((opt.size, opt.size), dtype=np.uint8)\n        final_normal = np.zeros((opt.size, opt.size, 3), dtype=np.uint8)\n\n        coords = np.nonzero(mask)\n        x_min, x_max = coords[0].min(), coords[0].max()\n        y_min, y_max = coords[1].min(), coords[1].max()\n        h = x_max - x_min\n        w = y_max - y_min\n        desired_size = int(opt.size * (1 - opt.border_ratio))\n        scale = desired_size / max(h, w)\n        h2 = int(h * scale)\n        w2 = int(w * scale)\n        x2_min = (opt.size - h2) // 2\n        x2_max = x2_min + h2\n        y2_min = (opt.size - w2) // 2\n        y2_max = y2_min + w2\n        final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)\n        final_depth[x2_min:x2_max, y2_min:y2_max] = cv2.resize(depth[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)\n        final_normal[x2_min:x2_max, y2_min:y2_max] = cv2.resize(normal[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)\n\n    else:\n        final_rgba = carved_image\n        final_depth = depth\n        final_normal = normal\n\n    # write output\n    cv2.imwrite(out_rgba, cv2.cvtColor(final_rgba, cv2.COLOR_RGBA2BGRA))\n    cv2.imwrite(out_depth, final_depth)\n    cv2.imwrite(out_normal, final_normal)\n\n    # predict caption (it's too slow... use your brain instead)\n    # print(f'[INFO] captioning...')\n    # blip2 = BLIP2()\n    # caption = blip2(image)\n    # with open(out_caption, 'w') as f:\n    #     f.write(caption)\n\n"
  },
  {
    "path": "pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-04\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: \"image_target\"\n    cond_stage_key: \"image_cond\"\n    image_size: 32\n    channels: 4\n    cond_stage_trainable: false   # Note: different from the one we trained before\n    conditioning_key: hybrid\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n\n    scheduler_config: # 10000 warmup steps\n      target: ldm.lr_scheduler.LambdaLinearScheduler\n      params:\n        warm_up_steps: [ 100 ]\n        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases\n        f_start: [ 1.e-6 ]\n        f_max: [ 1. ]\n        f_min: [ 1. ]\n\n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        image_size: 32 # unused\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: True\n        legacy: False\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder\n\n\n# data:\n#   target: ldm.data.simple.ObjaverseDataModuleFromConfig\n#   params:\n#     root_dir: 'views_whole_sphere'\n#     batch_size: 192\n#     num_workers: 16\n#     total_view: 4\n#     train:\n#       validation: False\n#       image_transforms:\n#         size: 256\n\n#     validation:\n#       validation: True\n#       image_transforms:\n#         size: 256\n\n\n# lightning:\n#   find_unused_parameters: false\n#   metrics_over_trainsteps_checkpoint: True\n#   modelcheckpoint:\n#     params:\n#       every_n_train_steps: 5000\n#   callbacks:\n#     image_logger:\n#       target: main.ImageLogger\n#       params:\n#         batch_frequency: 500\n#         max_images: 32\n#         increase_log_steps: False\n#         log_first_step: True\n#         log_images_kwargs:\n#           use_ema_scope: False\n#           inpaint: False\n#           plot_progressive_rows: False\n#           plot_diffusion_rows: False\n#           N: 32\n#           unconditional_scale: 3.0\n#           unconditional_label: [\"\"]\n\n#   trainer:\n#     benchmark: True\n#     val_check_interval: 5000000 # really sorry\n#     num_sanity_val_steps: 0\n#     accumulate_grad_batches: 1\n"
  },
  {
    "path": "raymarching/__init__.py",
    "content": "from .raymarching import *"
  },
  {
    "path": "raymarching/backend.py",
    "content": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    '-O3', '-std=c++14',\n    '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',\n]\n\nif os.name == \"posix\":\n    c_flags = ['-O3', '-std=c++14']\nelif os.name == \"nt\":\n    c_flags = ['/O2', '/std:c++17']\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n        for program_files in [r\"C:\\\\Program Files (x86)\", r\"C:\\\\Program Files\"]:\n            for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n                paths = sorted(glob.glob(r\"%s\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\" % (program_files, edition)), reverse=True)\n                if paths:\n                    return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\"Could not locate a supported Microsoft Visual C++ installation\")\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n_backend = load(name='_raymarching',\n                extra_cflags=c_flags,\n                extra_cuda_cflags=nvcc_flags,\n                sources=[os.path.join(_src_path, 'src', f) for f in [\n                    'raymarching.cu',\n                    'bindings.cpp',\n                ]],\n                )\n\n__all__ = ['_backend']"
  },
  {
    "path": "raymarching/raymarching.py",
    "content": "import numpy as np\nimport time\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n# lazy building: \n# `import raymarching` will not immediately build the extension, only if you actually call any functions.\n\nBACKEND = None\n\ndef get_backend():\n    global BACKEND\n\n    if BACKEND is None:\n        try:\n            import _raymarching as _backend\n        except ImportError:\n            from .backend import _backend\n\n        BACKEND = _backend\n    \n    return BACKEND\n\n# ----------------------------------------\n# utils\n# ----------------------------------------\n\nclass _near_far_from_aabb(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):\n        ''' near_far_from_aabb, CUDA implementation\n        Calculate rays' intersection time (near and far) with aabb\n        Args:\n            rays_o: float, [N, 3]\n            rays_d: float, [N, 3]\n            aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)\n            min_near: float, scalar\n        Returns:\n            nears: float, [N]\n            fars: float, [N]\n        '''\n        if not rays_o.is_cuda: rays_o = rays_o.cuda()\n        if not rays_d.is_cuda: rays_d = rays_d.cuda()\n\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0] # num rays\n\n        nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)\n        fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)\n\n        get_backend().near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)\n\n        return nears, fars\n\nnear_far_from_aabb = _near_far_from_aabb.apply\n\n\nclass _sph_from_ray(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, rays_o, rays_d, radius):\n        ''' sph_from_ray, CUDA implementation\n        get spherical coordinate on the background sphere from rays.\n        Assume rays_o are inside the Sphere(radius).\n        Args:\n            rays_o: [N, 3]\n            rays_d: [N, 3]\n            radius: scalar, float\n        Return:\n            coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)\n        '''\n        if not rays_o.is_cuda: rays_o = rays_o.cuda()\n        if not rays_d.is_cuda: rays_d = rays_d.cuda()\n\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0] # num rays\n\n        coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)\n\n        get_backend().sph_from_ray(rays_o, rays_d, radius, N, coords)\n\n        return coords\n\nsph_from_ray = _sph_from_ray.apply\n\n\nclass _morton3D(Function):\n    @staticmethod\n    def forward(ctx, coords):\n        ''' morton3D, CUDA implementation\n        Args:\n            coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) \n            TODO: check if the coord range is valid! (current 128 is safe)\n        Returns:\n            indices: [N], int32, in [0, 128^3)\n            \n        '''\n        if not coords.is_cuda: coords = coords.cuda()\n        \n        N = coords.shape[0]\n\n        indices = torch.empty(N, dtype=torch.int32, device=coords.device)\n        \n        get_backend().morton3D(coords.int(), N, indices)\n\n        return indices\n\nmorton3D = _morton3D.apply\n\nclass _morton3D_invert(Function):\n    @staticmethod\n    def forward(ctx, indices):\n        ''' morton3D_invert, CUDA implementation\n        Args:\n            indices: [N], int32, in [0, 128^3)\n        Returns:\n            coords: [N, 3], int32, in [0, 128)\n            \n        '''\n        if not indices.is_cuda: indices = indices.cuda()\n        \n        N = indices.shape[0]\n\n        coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)\n        \n        get_backend().morton3D_invert(indices.int(), N, coords)\n\n        return coords\n\nmorton3D_invert = _morton3D_invert.apply\n\n\nclass _packbits(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, grid, thresh, bitfield=None):\n        ''' packbits, CUDA implementation\n        Pack up the density grid into a bit field to accelerate ray marching.\n        Args:\n            grid: float, [C, H * H * H], assume H % 2 == 0\n            thresh: float, threshold\n        Returns:\n            bitfield: uint8, [C, H * H * H / 8]\n        '''\n        if not grid.is_cuda: grid = grid.cuda()\n        grid = grid.contiguous()\n\n        C = grid.shape[0]\n        H3 = grid.shape[1]\n        N = C * H3 // 8\n\n        if bitfield is None:\n            bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)\n\n        get_backend().packbits(grid, N, thresh, bitfield)\n\n        return bitfield\n\npackbits = _packbits.apply\n\n\nclass _flatten_rays(Function):\n    @staticmethod\n    def forward(ctx, rays, M):\n        ''' flatten rays\n        Args:\n            rays: [N, 2], all rays' (point_offset, point_count),\n            M: scalar, int, count of points (we cannot get this info from rays unfortunately...)\n        Returns:\n            res: [M], flattened ray index.\n        '''\n        if not rays.is_cuda: rays = rays.cuda()\n        rays = rays.contiguous()\n\n        N = rays.shape[0]\n\n        res = torch.zeros(M, dtype=torch.int, device=rays.device)\n\n        get_backend().flatten_rays(rays, N, M, res)\n\n        return res\n\nflatten_rays = _flatten_rays.apply\n\n# ----------------------------------------\n# train functions\n# ----------------------------------------\n\nclass _march_rays_train(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, perturb=False, dt_gamma=0, max_steps=1024, contract=False):\n        ''' march rays to generate points (forward only)\n        Args:\n            rays_o/d: float, [N, 3]\n            bound: float, scalar\n            density_bitfield: uint8: [CHHH // 8]\n            C: int\n            H: int\n            nears/fars: float, [N]\n            step_counter: int32, (2), used to count the actual number of generated points.\n            mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)\n            perturb: bool\n            align: int, pad output so its size is dividable by align, set to -1 to disable.\n            force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.\n            dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)\n            max_steps: int, max number of sampled points along each ray, also affect min_stepsize.\n        Returns:\n            xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)\n            dirs: float, [M, 3], all generated points' view dirs.\n            ts: float, [M, 2], all generated points' ts.\n            rays: int32, [N, 2], all rays' (point_offset, point_count), e.g., xyzs[rays[i, 0]:(rays[i, 0] + rays[i, 1])] --> points belonging to rays[i, 0]\n        '''\n\n        if not rays_o.is_cuda: rays_o = rays_o.cuda()\n        if not rays_d.is_cuda: rays_d = rays_d.cuda()\n        if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()\n        \n        rays_o = rays_o.float().contiguous().view(-1, 3)\n        rays_d = rays_d.float().contiguous().view(-1, 3)\n        density_bitfield = density_bitfield.contiguous()\n\n        N = rays_o.shape[0] # num rays\n        \n        step_counter = torch.zeros(1, dtype=torch.int32, device=rays_o.device) # point counter, ray counter\n        \n        if perturb:\n            noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)\n        else:\n            noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)\n        \n        # first pass: write rays, get total number of points M to render\n        rays = torch.empty(N, 2, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps\n        get_backend().march_rays_train(rays_o, rays_d, density_bitfield, bound, contract, dt_gamma, max_steps, N, C, H, nears, fars, None, None, None, rays, step_counter, noises)\n\n        # allocate based on M\n        M = step_counter.item()\n        # print(M, N)\n        # print(rays[:, 0].max())\n\n        xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)\n        dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)\n        ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)\n\n        # second pass: write outputs\n        get_backend().march_rays_train(rays_o, rays_d, density_bitfield, bound, contract, dt_gamma, max_steps, N, C, H, nears, fars, xyzs, dirs, ts, rays, step_counter, noises)\n\n        return xyzs, dirs, ts, rays\n\nmarch_rays_train = _march_rays_train.apply\n\n\nclass _composite_rays_train(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, sigmas, rgbs, ts, rays, T_thresh=1e-4, binarize=False):\n        ''' composite rays' rgbs, according to the ray marching formula.\n        Args:\n            rgbs: float, [M, 3]\n            sigmas: float, [M,]\n            ts: float, [M, 2]\n            rays: int32, [N, 3]\n        Returns:\n            weights: float, [M]\n            weights_sum: float, [N,], the alpha channel\n            depth: float, [N, ], the Depth\n            image: float, [N, 3], the RGB channel (after multiplying alpha!)\n        '''\n        \n        sigmas = sigmas.float().contiguous()\n        rgbs = rgbs.float().contiguous()\n\n        M = sigmas.shape[0]\n        N = rays.shape[0]\n\n        weights = torch.zeros(M, dtype=sigmas.dtype, device=sigmas.device) # may leave unmodified, so init with 0\n        weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)\n\n        depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)\n        image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)\n\n        get_backend().composite_rays_train_forward(sigmas, rgbs, ts, rays, M, N, T_thresh, binarize, weights, weights_sum, depth, image)\n\n        ctx.save_for_backward(sigmas, rgbs, ts, rays, weights_sum, depth, image)\n        ctx.dims = [M, N, T_thresh, binarize]\n\n        return weights, weights_sum, depth, image\n    \n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_weights, grad_weights_sum, grad_depth, grad_image):\n        \n        grad_weights = grad_weights.contiguous()\n        grad_weights_sum = grad_weights_sum.contiguous()\n        grad_depth = grad_depth.contiguous()\n        grad_image = grad_image.contiguous()\n\n        sigmas, rgbs, ts, rays, weights_sum, depth, image = ctx.saved_tensors\n        M, N, T_thresh, binarize = ctx.dims\n   \n        grad_sigmas = torch.zeros_like(sigmas)\n        grad_rgbs = torch.zeros_like(rgbs)\n\n        get_backend().composite_rays_train_backward(grad_weights, grad_weights_sum, grad_depth, grad_image, sigmas, rgbs, ts, rays, weights_sum, depth, image, M, N, T_thresh, binarize, grad_sigmas, grad_rgbs)\n\n        return grad_sigmas, grad_rgbs, None, None, None, None\n\n\ncomposite_rays_train = _composite_rays_train.apply\n\n# ----------------------------------------\n# infer functions\n# ----------------------------------------\n\nclass _march_rays(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, perturb=False, dt_gamma=0, max_steps=1024, contract=False):\n        ''' march rays to generate points (forward only, for inference)\n        Args:\n            n_alive: int, number of alive rays\n            n_step: int, how many steps we march\n            rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)\n            rays_t: float, [N], the alive rays' time, we only use the first n_alive.\n            rays_o/d: float, [N, 3]\n            bound: float, scalar\n            density_bitfield: uint8: [CHHH // 8]\n            C: int\n            H: int\n            nears/fars: float, [N]\n            align: int, pad output so its size is dividable by align, set to -1 to disable.\n            perturb: bool/int, int > 0 is used as the random seed.\n            dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)\n            max_steps: int, max number of sampled points along each ray, also affect min_stepsize.\n        Returns:\n            xyzs: float, [n_alive * n_step, 3], all generated points' coords\n            dirs: float, [n_alive * n_step, 3], all generated points' view dirs.\n            ts: float, [n_alive * n_step, 2], all generated points' ts\n        '''\n        \n        if not rays_o.is_cuda: rays_o = rays_o.cuda()\n        if not rays_d.is_cuda: rays_d = rays_d.cuda()\n        \n        rays_o = rays_o.float().contiguous().view(-1, 3)\n        rays_d = rays_d.float().contiguous().view(-1, 3)\n\n        M = n_alive * n_step\n        \n        xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)\n        dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)\n        ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth\n\n        if perturb:\n            # torch.manual_seed(perturb) # test_gui uses spp index as seed\n            noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)\n        else:\n            noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)\n\n        get_backend().march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, contract, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, ts, noises)\n\n        return xyzs, dirs, ts\n\nmarch_rays = _march_rays.apply\n\n\nclass _composite_rays(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float\n    def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh=1e-2, binarize=False):\n        ''' composite rays' rgbs, according to the ray marching formula. (for inference)\n        Args:\n            n_alive: int, number of alive rays\n            n_step: int, how many steps we march\n            rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)\n            rays_t: float, [N], the alive rays' time\n            sigmas: float, [n_alive * n_step,]\n            rgbs: float, [n_alive * n_step, 3]\n            ts: float, [n_alive * n_step, 2]\n        In-place Outputs:\n            weights_sum: float, [N,], the alpha channel\n            depth: float, [N,], the depth value\n            image: float, [N, 3], the RGB channel (after multiplying alpha!)\n        '''\n        sigmas = sigmas.float().contiguous()\n        rgbs = rgbs.float().contiguous()\n        get_backend().composite_rays(n_alive, n_step, T_thresh, binarize, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image)\n        return tuple()\n\n\ncomposite_rays = _composite_rays.apply"
  },
  {
    "path": "raymarching/setup.py",
    "content": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    '-O3', '-std=c++14',\n    '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',\n]\n\nif os.name == \"posix\":\n    c_flags = ['-O3', '-std=c++14']\nelif os.name == \"nt\":\n    c_flags = ['/O2', '/std:c++17']\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n        for program_files in [r\"C:\\\\Program Files (x86)\", r\"C:\\\\Program Files\"]:\n            for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n                paths = sorted(glob.glob(r\"%s\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\" % (program_files, edition)), reverse=True)\n                if paths:\n                    return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\"Could not locate a supported Microsoft Visual C++ installation\")\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n'''\nUsage:\n\npython setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)\n\npython setup.py install # build extensions and install (copy) to PATH.\npip install . # ditto but better (e.g., dependency & metadata handling)\n\npython setup.py develop # build extensions and install (symbolic) to PATH.\npip install -e . # ditto but better (e.g., dependency & metadata handling)\n\n'''\nsetup(\n    name='raymarching', # package name, import this to use python API\n    ext_modules=[\n        CUDAExtension(\n            name='_raymarching', # extension name, import this to use CUDA API\n            sources=[os.path.join(_src_path, 'src', f) for f in [\n                'raymarching.cu',\n                'bindings.cpp',\n            ]],\n            extra_compile_args={\n                'cxx': c_flags,\n                'nvcc': nvcc_flags,\n            }\n        ),\n    ],\n    cmdclass={\n        'build_ext': BuildExtension,\n    }\n)"
  },
  {
    "path": "raymarching/src/bindings.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"raymarching.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    // utils\n    m.def(\"flatten_rays\", &flatten_rays, \"flatten_rays (CUDA)\");\n    m.def(\"packbits\", &packbits, \"packbits (CUDA)\");\n    m.def(\"near_far_from_aabb\", &near_far_from_aabb, \"near_far_from_aabb (CUDA)\");\n    m.def(\"sph_from_ray\", &sph_from_ray, \"sph_from_ray (CUDA)\");\n    m.def(\"morton3D\", &morton3D, \"morton3D (CUDA)\");\n    m.def(\"morton3D_invert\", &morton3D_invert, \"morton3D_invert (CUDA)\");\n    // train\n    m.def(\"march_rays_train\", &march_rays_train, \"march_rays_train (CUDA)\");\n    m.def(\"composite_rays_train_forward\", &composite_rays_train_forward, \"composite_rays_train_forward (CUDA)\");\n    m.def(\"composite_rays_train_backward\", &composite_rays_train_backward, \"composite_rays_train_backward (CUDA)\");\n    // infer\n    m.def(\"march_rays\", &march_rays, \"march rays (CUDA)\");\n    m.def(\"composite_rays\", &composite_rays, \"composite rays (CUDA)\");\n}"
  },
  {
    "path": "raymarching/src/raymarching.cu",
    "content": "#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/torch.h>\n\n#include <cstdio>\n#include <stdint.h>\n#include <stdexcept>\n#include <limits>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\")\n#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x \" must be an int tensor\")\n#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x \" must be a floating tensor\")\n\n\ninline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }\ninline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }\ninline constexpr __device__ float PI() { return 3.141592653589793f; }\ninline constexpr __device__ float RPI() { return 0.3183098861837907f; }\n\n\ntemplate <typename T>\ninline __host__ __device__ T div_round_up(T val, T divisor) {\n    return (val + divisor - 1) / divisor;\n}\n\ninline __host__ __device__ float signf(const float x) {\n    return copysignf(1.0, x);\n}\n\ninline __host__ __device__ float clamp(const float x, const float min, const float max) {\n    return fminf(max, fmaxf(min, x));\n}\n\ninline __host__ __device__ void swapf(float& a, float& b) {\n    float c = a; a = b; b = c;\n}\n\ninline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {\n    const float mx = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));\n    int exponent;\n    frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...\n    return fminf(max_cascade - 1, fmaxf(0, exponent));\n}\n\ninline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {\n    const float mx = dt * H * 0.5;\n    int exponent;\n    frexpf(mx, &exponent);\n    return fminf(max_cascade - 1, fmaxf(0, exponent));\n}\n\ninline __host__ __device__ uint32_t __expand_bits(uint32_t v)\n{\n\tv = (v * 0x00010001u) & 0xFF0000FFu;\n\tv = (v * 0x00000101u) & 0x0F00F00Fu;\n\tv = (v * 0x00000011u) & 0xC30C30C3u;\n\tv = (v * 0x00000005u) & 0x49249249u;\n\treturn v;\n}\n\ninline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)\n{\n\tuint32_t xx = __expand_bits(x);\n\tuint32_t yy = __expand_bits(y);\n\tuint32_t zz = __expand_bits(z);\n\treturn xx | (yy << 1) | (zz << 2);\n}\n\ninline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)\n{\n\tx = x & 0x49249249;\n\tx = (x | (x >> 2)) & 0xc30c30c3;\n\tx = (x | (x >> 4)) & 0x0f00f00f;\n\tx = (x | (x >> 8)) & 0xff0000ff;\n\tx = (x | (x >> 16)) & 0x0000ffff;\n\treturn x;\n}\n\n\n////////////////////////////////////////////////////\n/////////////           utils          /////////////\n////////////////////////////////////////////////////\n\n// rays_o/d: [N, 3]\n// nears/fars: [N]\n// scalar_t should always be float in use.\ntemplate <typename scalar_t>\n__global__ void kernel_near_far_from_aabb(\n    const scalar_t * __restrict__ rays_o,\n    const scalar_t * __restrict__ rays_d,\n    const scalar_t * __restrict__ aabb,\n    const uint32_t N,\n    const float min_near,\n    scalar_t * nears, scalar_t * fars\n) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    rays_o += n * 3;\n    rays_d += n * 3;\n\n    const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];\n    const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];\n    const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;\n\n    // get near far (assume cube scene)\n    float near = (aabb[0] - ox) * rdx;\n    float far = (aabb[3] - ox) * rdx;\n    if (near > far) swapf(near, far);\n\n    float near_y = (aabb[1] - oy) * rdy;\n    float far_y = (aabb[4] - oy) * rdy;\n    if (near_y > far_y) swapf(near_y, far_y);\n\n    if (near > far_y || near_y > far) {\n        nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();\n        return;\n    }\n\n    if (near_y > near) near = near_y;\n    if (far_y < far) far = far_y;\n\n    float near_z = (aabb[2] - oz) * rdz;\n    float far_z = (aabb[5] - oz) * rdz;\n    if (near_z > far_z) swapf(near_z, far_z);\n\n    if (near > far_z || near_z > far) {\n        nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();\n        return;\n    }\n\n    if (near_z > near) near = near_z;\n    if (far_z < far) far = far_z;\n\n    if (near < min_near) near = min_near;\n\n    nears[n] = near;\n    fars[n] = far;\n}\n\n\nvoid near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {\n\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    rays_o.scalar_type(), \"near_far_from_aabb\", ([&] {\n        kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>());\n    }));\n}\n\n\n// rays_o/d: [N, 3]\n// radius: float\n// coords: [N, 2]\ntemplate <typename scalar_t>\n__global__ void kernel_sph_from_ray(\n    const scalar_t * __restrict__ rays_o,\n    const scalar_t * __restrict__ rays_d,\n    const float radius,\n    const uint32_t N,\n    scalar_t * coords\n) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    rays_o += n * 3;\n    rays_d += n * 3;\n    coords += n * 2;\n\n    const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];\n    const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];\n    // const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;\n\n    // solve t from || o + td || = radius\n    const float A = dx * dx + dy * dy + dz * dz;\n    const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2\n    const float C = ox * ox + oy * oy + oz * oz - radius * radius;\n\n    const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)\n\n    // solve theta, phi (assume y is the up axis)\n    const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;\n    const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)\n    const float phi = atan2(z, x); // [-PI, PI)\n\n    // normalize to [-1, 1]\n    coords[0] = 2 * theta * RPI() - 1;\n    coords[1] = phi * RPI();\n}\n\n\nvoid sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {\n\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    rays_o.scalar_type(), \"sph_from_ray\", ([&] {\n        kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>());\n    }));\n}\n\n\n// coords: int32, [N, 3]\n// indices: int32, [N]\n__global__ void kernel_morton3D(\n    const int * __restrict__ coords,\n    const uint32_t N,\n    int * indices\n) {\n    // parallel\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    coords += n * 3;\n    indices[n] = __morton3D(coords[0], coords[1], coords[2]);\n}\n\n\nvoid morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {\n    static constexpr uint32_t N_THREAD = 128;\n    kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>());\n}\n\n\n// indices: int32, [N]\n// coords: int32, [N, 3]\n__global__ void kernel_morton3D_invert(\n    const int * __restrict__ indices,\n    const uint32_t N,\n    int * coords\n) {\n    // parallel\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    coords += n * 3;\n\n    const int ind = indices[n];\n\n    coords[0] = __morton3D_invert(ind >> 0);\n    coords[1] = __morton3D_invert(ind >> 1);\n    coords[2] = __morton3D_invert(ind >> 2);\n}\n\n\nvoid morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {\n    static constexpr uint32_t N_THREAD = 128;\n    kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>());\n}\n\n\n// grid: float, [C, H, H, H]\n// N: int, C * H * H * H / 8\n// density_thresh: float\n// bitfield: uint8, [N]\ntemplate <typename scalar_t>\n__global__ void kernel_packbits(\n    const scalar_t * __restrict__ grid,\n    const uint32_t N,\n    const float density_thresh,\n    uint8_t * bitfield\n) {\n    // parallel per byte\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    grid += n * 8;\n\n    uint8_t bits = 0;\n\n    #pragma unroll\n    for (uint8_t i = 0; i < 8; i++) {\n        bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;\n    }\n\n    bitfield[n] = bits;\n}\n\n\nvoid packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {\n\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    grid.scalar_type(), \"packbits\", ([&] {\n        kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>());\n    }));\n}\n\n\n__global__ void kernel_flatten_rays(\n    const int * __restrict__ rays,\n    const uint32_t N, const uint32_t M,\n    int * res\n) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    uint32_t offset = rays[n * 2];\n    uint32_t num_steps = rays[n * 2 + 1];\n\n    // write to res\n    res += offset;\n    for (int i = 0; i < num_steps; i++) res[i] = n;\n}\n\nvoid flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res) {\n\n    static constexpr uint32_t N_THREAD = 128;\n\n    kernel_flatten_rays<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays.data_ptr<int>(), N, M, res.data_ptr<int>());\n}\n\n////////////////////////////////////////////////////\n/////////////         training         /////////////\n////////////////////////////////////////////////////\n\n// rays_o/d: [N, 3]\n// grid: [CHHH / 8]\n// xyzs, dirs, ts: [M, 3], [M, 3], [M, 2]\n// dirs: [M, 3]\n// rays: [N, 3], idx, offset, num_steps\ntemplate <typename scalar_t>\n__global__ void kernel_march_rays_train(\n    const scalar_t * __restrict__ rays_o,\n    const scalar_t * __restrict__ rays_d,  \n    const uint8_t * __restrict__ grid,\n    const float bound, const bool contract,\n    const float dt_gamma, const uint32_t max_steps,\n    const uint32_t N, const uint32_t C, const uint32_t H,\n    const scalar_t* __restrict__ nears, \n    const scalar_t* __restrict__ fars,\n    scalar_t * xyzs, scalar_t * dirs, scalar_t * ts,\n    int * rays,\n    int * counter,\n    const scalar_t* __restrict__ noises\n) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // is first pass running.\n    const bool first_pass = (xyzs == nullptr);\n\n    // locate\n    rays_o += n * 3;\n    rays_d += n * 3;\n    rays += n * 2;\n\n    uint32_t num_steps = max_steps;\n\n    if (!first_pass) {\n        uint32_t point_index = rays[0];\n        num_steps = rays[1];\n        xyzs += point_index * 3;\n        dirs += point_index * 3;\n        ts += point_index * 2;\n    }\n\n    // ray marching\n    const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];\n    const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];\n    const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;\n    const float rH = 1 / (float)H;\n    const float H3 = H * H * H;\n\n    const float near = nears[n];\n    const float far = fars[n];\n    const float noise = noises[n];\n\n    const float dt_min = 2 * SQRT3() / max_steps;\n    const float dt_max = 2 * SQRT3() * bound / H;\n    // const float dt_max = 1e10f;\n    \n    float t0 = near;\n    t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;\n    float t = t0;\n    uint32_t step = 0;\n\n    //if (t < far) printf(\"valid ray %d t=%f near=%f far=%f \\n\", n, t, near, far);\n    \n    while (t < far && step < num_steps) {\n        // current point\n        const float x = clamp(ox + t * dx, -bound, bound);\n        const float y = clamp(oy + t * dy, -bound, bound);\n        const float z = clamp(oz + t * dz, -bound, bound);\n\n        float dt = clamp(t * dt_gamma, dt_min, dt_max);\n\n        // get mip level\n        const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]\n\n        const float mip_bound = fminf(scalbnf(1.0f, level), bound);\n        const float mip_rbound = 1 / mip_bound;\n\n        // contraction\n        float cx = x, cy = y, cz = z;\n        const float mag = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));\n        if (contract && mag > 1) {\n            // L-INF norm\n            const float Linf_scale = (2 - 1 / mag) / mag;\n            cx *= Linf_scale;\n            cy *= Linf_scale;\n            cz *= Linf_scale;\n        }\n        \n        // convert to nearest grid position\n        const int nx = clamp(0.5 * (cx * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int ny = clamp(0.5 * (cy * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int nz = clamp(0.5 * (cz * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n\n        const uint32_t index = level * H3 + __morton3D(nx, ny, nz);\n        const bool occ = grid[index / 8] & (1 << (index % 8));\n\n        // if occpuied, advance a small step, and write to output\n        //if (n == 0) printf(\"t=%f density=%f vs thresh=%f step=%d\\n\", t, density, density_thresh, step);\n\n        if (occ) {\n            step++;\n            t += dt;\n            if (!first_pass) {\n                xyzs[0] = cx; // write contracted coordinates!\n                xyzs[1] = cy;\n                xyzs[2] = cz;\n                dirs[0] = dx;\n                dirs[1] = dy;\n                dirs[2] = dz;\n                ts[0] = t;\n                ts[1] = dt;\n                xyzs += 3;\n                dirs += 3;\n                ts += 2;\n            }\n        // contraction case: cannot apply voxel skipping.\n        } else if (contract && mag > 1) {\n            t += dt;\n        // else, skip a large step (basically skip a voxel grid)\n        } else {\n            // calc distance to next voxel\n            const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - cx) * rdx;\n            const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - cy) * rdy;\n            const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - cz) * rdz;\n\n            const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));\n            // step until next voxel\n            do { \n                dt = clamp(t * dt_gamma, dt_min, dt_max);\n                t += dt;\n            } while (t < tt);\n        }\n    }\n\n    //printf(\"[n=%d] step=%d, near=%f, far=%f, dt=%f, num_steps=%f\\n\", n, step, near, far, dt_min, (far - near) / dt_min);\n\n    // write rays\n    if (first_pass) {\n        uint32_t point_index = atomicAdd(counter, step);\n        rays[0] = point_index;\n        rays[1] = step;\n    }\n}\n\nvoid march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional<at::Tensor> xyzs, at::optional<at::Tensor> dirs, at::optional<at::Tensor> ts, at::Tensor rays, at::Tensor counter, at::Tensor noises) {\n\n    static constexpr uint32_t N_THREAD = 128;\n    \n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    rays_o.scalar_type(), \"march_rays_train\", ([&] {\n        kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, contract, dt_gamma, max_steps, N, C, H, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(),\n            xyzs.has_value() ? xyzs.value().data_ptr<scalar_t>() : nullptr,\n            dirs.has_value() ? dirs.value().data_ptr<scalar_t>() : nullptr,\n            ts.has_value() ? ts.value().data_ptr<scalar_t>() : nullptr,\n            rays.data_ptr<int>(), counter.data_ptr<int>(), noises.data_ptr<scalar_t>());\n    }));\n}\n\n\n// sigmas: [M]\n// rgbs: [M, 3]\n// ts: [M, 2]\n// rays: [N, 2], offset, num_steps\n// weights: [M]\n// weights_sum: [N], final pixel alpha\n// depth: [N,]\n// image: [N, 3]\ntemplate <typename scalar_t>\n__global__ void kernel_composite_rays_train_forward(\n    const scalar_t * __restrict__ sigmas,\n    const scalar_t * __restrict__ rgbs,  \n    const scalar_t * __restrict__ ts,\n    const int * __restrict__ rays,\n    const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize,\n    scalar_t * weights,\n    scalar_t * weights_sum,\n    scalar_t * depth,\n    scalar_t * image\n) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate \n    uint32_t offset = rays[n * 2];\n    uint32_t num_steps = rays[n * 2 + 1];\n\n    // empty ray, or ray that exceed max step count.\n    if (num_steps == 0 || offset + num_steps > M) {\n        weights_sum[n] = 0;\n        depth[n] = 0;\n        image[n * 3] = 0;\n        image[n * 3 + 1] = 0;\n        image[n * 3 + 2] = 0;\n        return;\n    }\n\n    ts += offset * 2;\n    weights += offset;\n    sigmas += offset;\n    rgbs += offset * 3;\n\n    // accumulate \n    uint32_t step = 0;\n\n    float T = 1.0f;\n    float r = 0, g = 0, b = 0, ws = 0, d = 0;\n\n    while (step < num_steps) {\n\n        const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]);\n        const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha;\n        const float weight = alpha * T;\n\n        weights[0] = weight;\n\n        r += weight * rgbs[0];\n        g += weight * rgbs[1];\n        b += weight * rgbs[2];\n        ws += weight;\n        d += weight * ts[0];\n        \n        T *= 1.0f - alpha;\n\n        // minimal remained transmittence\n        if (T < T_thresh) break;\n\n        //printf(\"[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\\n\", n, step, alpha, weight, T, sum_delta, d);\n\n        // locate\n        weights++;\n        sigmas++;\n        rgbs += 3;\n        ts += 2;\n\n        step++;\n    }\n\n    //printf(\"[n=%d] rgb=(%f, %f, %f), d=%f\\n\", n, r, g, b, d);\n\n    // write\n    weights_sum[n] = ws; // weights_sum\n    depth[n] = d;\n    image[n * 3] = r;\n    image[n * 3 + 1] = g;\n    image[n * 3 + 2] = b;\n}\n\n\nvoid composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {\n\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    sigmas.scalar_type(), \"composite_rays_train_forward\", ([&] {\n        kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), ts.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, binarize, weights.data_ptr<scalar_t>(), weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());\n    }));\n}\n\n\n// grad_weights: [M,]\n// grad_weights_sum: [N,]\n// grad_image: [N, 3]\n// grad_depth: [N,]\n// sigmas: [M]\n// rgbs: [M, 3]\n// ts: [M, 2]\n// rays: [N, 2], offset, num_steps\n// weights_sum: [N,], weights_sum here \n// image: [N, 3]\n// grad_sigmas: [M]\n// grad_rgbs: [M, 3]\ntemplate <typename scalar_t>\n__global__ void kernel_composite_rays_train_backward(\n    const scalar_t * __restrict__ grad_weights,\n    const scalar_t * __restrict__ grad_weights_sum,\n    const scalar_t * __restrict__ grad_depth,\n    const scalar_t * __restrict__ grad_image,\n    const scalar_t * __restrict__ sigmas,\n    const scalar_t * __restrict__ rgbs, \n    const scalar_t * __restrict__ ts,\n    const int * __restrict__ rays,\n    const scalar_t * __restrict__ weights_sum,\n    const scalar_t * __restrict__ depth,\n    const scalar_t * __restrict__ image,\n    const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize,\n    scalar_t * grad_sigmas,\n    scalar_t * grad_rgbs\n) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate \n    uint32_t offset = rays[n * 2];\n    uint32_t num_steps = rays[n * 2 + 1];\n\n    if (num_steps == 0 || offset + num_steps > M) return;\n\n    grad_weights += offset;\n    grad_weights_sum += n;\n    grad_depth += n;\n    grad_image += n * 3;\n    weights_sum += n;\n    depth += n;\n    image += n * 3;\n    sigmas += offset;\n    rgbs += offset * 3;\n    ts += offset * 2;\n    grad_sigmas += offset;\n    grad_rgbs += offset * 3;\n\n    // accumulate \n    uint32_t step = 0;\n    \n    float T = 1.0f;\n    const float r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0], d_final = depth[0];\n    float r = 0, g = 0, b = 0, ws = 0, d = 0;\n\n    while (step < num_steps) {\n        \n        const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]);\n        const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha;\n        const float weight = alpha * T;\n\n        r += weight * rgbs[0];\n        g += weight * rgbs[1];\n        b += weight * rgbs[2];\n        ws += weight;\n        d += weight * ts[0];\n\n        T *= 1.0f - alpha;\n        \n        // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.\n        // write grad_rgbs\n        grad_rgbs[0] = grad_image[0] * weight;\n        grad_rgbs[1] = grad_image[1] * weight;\n        grad_rgbs[2] = grad_image[2] * weight;\n\n        // write grad_sigmas\n        grad_sigmas[0] = ts[1] * (\n            grad_image[0] * (T * rgbs[0] - (r_final - r)) + \n            grad_image[1] * (T * rgbs[1] - (g_final - g)) + \n            grad_image[2] * (T * rgbs[2] - (b_final - b)) +\n            (grad_weights_sum[0] + grad_weights[0]) * (T - (ws_final - ws)) + \n            grad_depth[0] * (T * ts[0] - (d_final - d))\n        );\n\n        //printf(\"[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\\n\", n, step, T, grad_sigmas[0], r_final, r);\n        // minimal remained transmittence\n        if (T < T_thresh) break;\n        \n        // locate\n        sigmas++;\n        rgbs += 3;\n        ts += 2;\n        grad_weights++;\n        grad_sigmas++;\n        grad_rgbs += 3;\n\n        step++;\n    }\n}\n\n\nvoid composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {\n\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    grad_image.scalar_type(), \"composite_rays_train_backward\", ([&] {\n        kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights.data_ptr<scalar_t>(), grad_weights_sum.data_ptr<scalar_t>(), grad_depth.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), ts.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, binarize, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>());\n    }));\n}\n\n\n////////////////////////////////////////////////////\n/////////////          infernce        /////////////\n////////////////////////////////////////////////////\n\ntemplate <typename scalar_t>\n__global__ void kernel_march_rays(\n    const uint32_t n_alive, \n    const uint32_t n_step, \n    const int* __restrict__ rays_alive, \n    const scalar_t* __restrict__ rays_t, \n    const scalar_t* __restrict__ rays_o, \n    const scalar_t* __restrict__ rays_d, \n    const float bound, const bool contract,\n    const float dt_gamma, const uint32_t max_steps,\n    const uint32_t C, const uint32_t H,\n    const uint8_t * __restrict__ grid,\n    const scalar_t* __restrict__ nears,\n    const scalar_t* __restrict__ fars,\n    scalar_t* xyzs, scalar_t* dirs, scalar_t* ts,\n    const scalar_t* __restrict__ noises\n) {\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= n_alive) return;\n\n    const int index = rays_alive[n]; // ray id\n    const float noise = noises[n];\n    \n    // locate\n    rays_o += index * 3;\n    rays_d += index * 3;\n    xyzs += n * n_step * 3;\n    dirs += n * n_step * 3;\n    ts += n * n_step * 2;\n    \n    const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];\n    const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];\n    const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;\n    const float rH = 1 / (float)H;\n    const float H3 = H * H * H;\n    \n    const float near = nears[index], far = fars[index];\n\n    const float dt_min = 2 * SQRT3() / max_steps;\n    const float dt_max = 2 * SQRT3() * bound / H;\n    // const float dt_max = 1e10f;\n\n    // march for n_step steps, record points\n    float t = rays_t[index];\n    t += clamp(t * dt_gamma, dt_min, dt_max) * noise;\n    uint32_t step = 0;\n\n    while (t < far && step < n_step) {\n        // current point\n        const float x = clamp(ox + t * dx, -bound, bound);\n        const float y = clamp(oy + t * dy, -bound, bound);\n        const float z = clamp(oz + t * dz, -bound, bound);\n\n        float dt = clamp(t * dt_gamma, dt_min, dt_max);\n\n        // get mip level\n        const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]\n\n        const float mip_bound = fminf(scalbnf(1, level), bound);\n        const float mip_rbound = 1 / mip_bound;\n        \n        // contraction\n        float cx = x, cy = y, cz = z;\n        const float mag = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));\n        if (contract && mag > 1) {\n            // L-INF norm\n            const float Linf_scale = (2 - 1 / mag) / mag;\n            cx *= Linf_scale;\n            cy *= Linf_scale;\n            cz *= Linf_scale;\n        }\n        \n        // convert to nearest grid position\n        const int nx = clamp(0.5 * (cx * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int ny = clamp(0.5 * (cy * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int nz = clamp(0.5 * (cz * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n\n        const uint32_t index = level * H3 + __morton3D(nx, ny, nz);\n        const bool occ = grid[index / 8] & (1 << (index % 8));\n\n        // if occpuied, advance a small step, and write to output\n        if (occ) {\n            // write step\n            xyzs[0] = cx;\n            xyzs[1] = cy;\n            xyzs[2] = cz;\n            dirs[0] = dx;\n            dirs[1] = dy;\n            dirs[2] = dz;\n            // calc dt\n            t += dt;\n            ts[0] = t;\n            ts[1] = dt;\n            // step\n            xyzs += 3;\n            dirs += 3;\n            ts += 2;\n            step++;\n\n        // contraction case\n        } else if (contract && mag > 1) {\n            t += dt;\n        // else, skip a large step (basically skip a voxel grid)\n        } else {\n            // calc distance to next voxel\n            const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - cx) * rdx;\n            const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - cy) * rdy;\n            const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - cz) * rdz;\n            const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));\n            // step until next voxel\n            do { \n                dt = clamp(t * dt_gamma, dt_min, dt_max);\n                t += dt;\n            } while (t < tt);\n        }\n    }\n}\n\n\nvoid march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises) {\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    rays_o.scalar_type(), \"march_rays\", ([&] {\n        kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, contract, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), ts.data_ptr<scalar_t>(), noises.data_ptr<scalar_t>());\n    }));\n}\n\n\ntemplate <typename scalar_t>\n__global__ void kernel_composite_rays(\n    const uint32_t n_alive, \n    const uint32_t n_step, \n    const float T_thresh, const bool binarize,\n    int* rays_alive, \n    scalar_t* rays_t, \n    const scalar_t* __restrict__ sigmas, \n    const scalar_t* __restrict__ rgbs, \n    const scalar_t* __restrict__ ts, \n    scalar_t* weights_sum, scalar_t* depth, scalar_t* image\n) {\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= n_alive) return;\n\n    const int index = rays_alive[n]; // ray id\n    \n    // locate \n    sigmas += n * n_step;\n    rgbs += n * n_step * 3;\n    ts += n * n_step * 2;\n    \n    rays_t += index;\n    weights_sum += index;\n    depth += index;\n    image += index * 3;\n\n    float t;\n    float d = depth[0], r = image[0], g = image[1], b = image[2], weight_sum = weights_sum[0];\n\n    // accumulate \n    uint32_t step = 0;\n    while (step < n_step) {\n        \n        // ray is terminated if t == 0\n        if (ts[0] == 0) break;\n        \n        const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]);\n        const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha;\n\n        /* \n        T_0 = 1; T_i = \\prod_{j=0}^{i-1} (1 - alpha_j)\n        w_i = alpha_i * T_i\n        --> \n        T_i = 1 - \\sum_{j=0}^{i-1} w_j\n        */\n        const float T = 1 - weight_sum;\n        const float weight = alpha * T;\n        weight_sum += weight;\n\n        t = ts[0];\n        d += weight * t; // real depth\n        r += weight * rgbs[0];\n        g += weight * rgbs[1];\n        b += weight * rgbs[2];\n\n        //printf(\"[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\\n\", n, step, alpha, weight, T, sum_delta, d);\n\n        // ray is terminated if T is too small\n        // use a larger bound to further accelerate inference\n        if (T < T_thresh) break;\n\n        // locate\n        sigmas++;\n        rgbs += 3;\n        ts += 2;\n        step++;\n    }\n\n    //printf(\"[n=%d] rgb=(%f, %f, %f), d=%f\\n\", n, r, g, b, d);\n\n    // rays_alive = -1 means ray is terminated early.\n    if (step < n_step) {\n        rays_alive[n] = -1;\n    } else {\n        rays_t[0] = t;\n    }\n\n    weights_sum[0] = weight_sum; // this is the thing I needed!\n    depth[0] = d;\n    image[0] = r;\n    image[1] = g;\n    image[2] = b;\n}\n\n\nvoid composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, const bool binarize, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor ts, at::Tensor weights, at::Tensor depth, at::Tensor image) {\n    static constexpr uint32_t N_THREAD = 128;\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    image.scalar_type(), \"composite_rays\", ([&] {\n        kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, binarize, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), ts.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());\n    }));\n}"
  },
  {
    "path": "raymarching/src/raymarching.h",
    "content": "#pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n\nvoid near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);\nvoid sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);\nvoid morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);\nvoid morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);\nvoid packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);\nvoid flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res);\n\nvoid march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional<at::Tensor> xyzs, at::optional<at::Tensor> dirs, at::optional<at::Tensor> ts, at::Tensor rays, at::Tensor counter, at::Tensor noises);\nvoid composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);\nvoid composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor grad_sigmas, at::Tensor grad_rgbs);\n\nvoid march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises);\nvoid composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, const bool binarize, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor ts, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);"
  },
  {
    "path": "readme.md",
    "content": "# Stable-Dreamfusion\n\nA pytorch implementation of the text-to-3D model **Dreamfusion**, powered by the [Stable Diffusion](https://github.com/CompVis/stable-diffusion) text-to-2D model.\n\n**ADVERTISEMENT: Please check out [threestudio](https://github.com/threestudio-project/threestudio) for recent improvements and better implementation in 3D content generation!**\n\n**NEWS (2023.6.12)**:\n\n* Support of [Perp-Neg](https://perp-neg.github.io/) to alleviate multi-head problem in Text-to-3D.\n* Support of Perp-Neg for both [Stable Diffusion](https://github.com/CompVis/stable-diffusion) and [DeepFloyd-IF](https://github.com/deep-floyd/IF).\n\nhttps://user-images.githubusercontent.com/25863658/236712982-9f93bd32-83bf-423a-bb7c-f73df7ece2e3.mp4\n\nhttps://user-images.githubusercontent.com/25863658/232403162-51b69000-a242-4b8c-9cd9-4242b09863fa.mp4\n\n### [Update Logs](assets/update_logs.md)\n\n### Colab notebooks:\n* Instant-NGP backbone (`-O`): [![Instant-NGP Backbone](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MXT3yfOFvO0ooKEfiUUvTKwUkrrlCHpF?usp=sharing)\n\n* Vanilla NeRF backbone (`-O2`): [![Vanilla Backbone](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1mvfxG-S_n_gZafWoattku7rLJ2kPoImL?usp=sharing)\n\n# Important Notice\nThis project is a **work-in-progress**, and contains lots of differences from the paper. **The current generation quality cannot match the results from the original paper, and many prompts still fail badly!**\n\n## Notable differences from the paper\n* Since the Imagen model is not publicly available, we use [Stable Diffusion](https://github.com/CompVis/stable-diffusion) to replace it (implementation from [diffusers](https://github.com/huggingface/diffusers)). Different from Imagen, Stable-Diffusion is a latent diffusion model, which diffuses in a latent space instead of the original image space. Therefore, we need the loss to propagate back from the VAE's encoder part too, which introduces extra time cost in training.\n* We use the [multi-resolution grid encoder](https://github.com/NVlabs/instant-ngp/) to implement the NeRF backbone (implementation from [torch-ngp](https://github.com/ashawkey/torch-ngp)), which enables much faster rendering (~10FPS at 800x800).\n* We use the [Adan](https://github.com/sail-sg/Adan) optimizer as default.\n\n# Install\n\n```bash\ngit clone https://github.com/ashawkey/stable-dreamfusion.git\ncd stable-dreamfusion\n```\n\n### Optional: create a python virtual environment\n\nTo avoid python package conflicts, we recommend using a virtual environment, e.g.: using conda or venv:\n\n```bash\npython -m venv venv_stable-dreamfusion\nsource venv_stable-dreamfusion/bin/activate # you need to repeat this step for every new terminal\n```\n\n### Install with pip\n\n```bash\npip install -r requirements.txt\n```\n\n### Download pre-trained models\n\nTo use image-conditioned 3D generation, you need to download some pretrained checkpoints manually:\n* [Zero-1-to-3](https://github.com/cvlab-columbia/zero123) for diffusion backend.\n    We use `zero123-xl.ckpt` by default, and it is hard-coded in `guidance/zero123_utils.py`.\n    ```bash\n    cd pretrained/zero123\n    wget https://zero123.cs.columbia.edu/assets/zero123-xl.ckpt\n    ```\n* [Omnidata](https://github.com/EPFL-VILAB/omnidata/tree/main/omnidata_tools/torch) for depth and normal prediction.\n    These ckpts are hardcoded in `preprocess_image.py`.\n    ```bash\n    mkdir pretrained/omnidata\n    cd pretrained/omnidata\n    # assume gdown is installed\n    gdown '1Jrh-bRnJEjyMCS7f-WsaFlccfPjJPPHI&confirm=t' # omnidata_dpt_depth_v2.ckpt\n    gdown '1wNxVO4vVbDEMEpnAi_jwQObf2MFodcBR&confirm=t' # omnidata_dpt_normal_v2.ckpt\n    ```\n\nTo use [DeepFloyd-IF](https://github.com/deep-floyd/IF), you need to accept the usage conditions from [hugging face](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0), and login with `huggingface-cli login` in command line.\n\nFor DMTet, we port the pre-generated `32/64/128` resolution tetrahedron grids under `tets`.\nThe 256 resolution one can be found [here](https://drive.google.com/file/d/1lgvEKNdsbW5RS4gVxJbgBS4Ac92moGSa/view?usp=sharing).\n\n### Build extension (optional)\nBy default, we use [`load`](https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load) to build the extension at runtime.\nWe also provide the `setup.py` to build each extension:\n```bash\ncd stable-dreamfusion\n\n# install all extension modules\nbash scripts/install_ext.sh\n\n# if you want to install manually, here is an example:\npip install ./raymarching # install to python path (you still need the raymarching/ folder, since this only installs the built extension.)\n```\n\n### Taichi backend (optional)\nUse [Taichi](https://github.com/taichi-dev/taichi) backend for Instant-NGP. It achieves comparable performance to CUDA implementation while **No CUDA** build is required. Install Taichi with pip:\n```bash\npip install -i https://pypi.taichi.graphics/simple/ taichi-nightly\n```\n\n### Trouble Shooting:\n* we assume working with the latest version of all dependencies, if you meet any problems from a specific dependency, please try to upgrade it first (e.g., `pip install -U diffusers`). If the problem still holds, [reporting a bug issue](https://github.com/ashawkey/stable-dreamfusion/issues/new?assignees=&labels=bug&template=bug_report.yaml&title=%3Ctitle%3E) will be appreciated!\n* `[F glutil.cpp:338] eglInitialize() failed Aborted (core dumped)`: this usually indicates problems in OpenGL installation. Try to re-install Nvidia driver, or use nvidia-docker as suggested in https://github.com/ashawkey/stable-dreamfusion/issues/131 if you are using a headless server.\n* `TypeError: xxx_forward(): incompatible function arguments`： this happens when we update the CUDA source and you used `setup.py` to install the extensions earlier. Try to re-install the corresponding extension (e.g., `pip install ./gridencoder`).\n\n### Tested environments\n* Ubuntu 22 with torch 1.12 & CUDA 11.6 on a V100.\n\n# Usage\n\nFirst time running will take some time to compile the CUDA extensions.\n\n```bash\n#### stable-dreamfusion setting\n\n### Instant-NGP NeRF Backbone\n# + faster rendering speed\n# + less GPU memory (~16G)\n# - need to build CUDA extensions (a CUDA-free Taichi backend is available)\n\n## train with text prompt (with the default settings)\n# `-O` equals `--cuda_ray --fp16`\n# `--cuda_ray` enables instant-ngp-like occupancy grid based acceleration.\npython main.py --text \"a hamburger\" --workspace trial -O\n\n# reduce stable-diffusion memory usage with `--vram_O`\n# enable various vram savings (https://huggingface.co/docs/diffusers/optimization/fp16).\npython main.py --text \"a hamburger\" --workspace trial -O --vram_O\n\n# You can collect arguments in a file. You can override arguments by specifying them after `--file`. Note that quoted strings can't be loaded from .args files...\npython main.py --file scripts/res64.args --workspace trial_awesome_hamburger --text \"a photo of an awesome hamburger\"\n\n# use CUDA-free Taichi backend with `--backbone grid_taichi`\npython3 main.py --text \"a hamburger\" --workspace trial -O --backbone grid_taichi\n\n# choose stable-diffusion version (support 1.5, 2.0 and 2.1, default is 2.1 now)\npython main.py --text \"a hamburger\" --workspace trial -O --sd_version 1.5\n\n# use a custom stable-diffusion checkpoint from hugging face:\npython main.py --text \"a hamburger\" --workspace trial -O --hf_key andite/anything-v4.0\n\n# use DeepFloyd-IF for guidance (experimental):\npython main.py --text \"a hamburger\" --workspace trial -O --IF\npython main.py --text \"a hamburger\" --workspace trial -O --IF --vram_O # requires ~24G GPU memory\n\n# we also support negative text prompt now:\npython main.py --text \"a rose\" --negative \"red\" --workspace trial -O\n\n## after the training is finished:\n# test (exporting 360 degree video)\npython main.py --workspace trial -O --test\n# also save a mesh (with obj, mtl, and png texture)\npython main.py --workspace trial -O --test --save_mesh\n# test with a GUI (free view control!)\npython main.py --workspace trial -O --test --gui\n\n### Vanilla NeRF backbone\n# + pure pytorch, no need to build extensions!\n# - slow rendering speed\n# - more GPU memory\n\n## train\n# `-O2` equals `--backbone vanilla`\npython main.py --text \"a hotdog\" --workspace trial2 -O2\n\n# if CUDA OOM, try to reduce NeRF sampling steps (--num_steps and --upsample_steps)\npython main.py --text \"a hotdog\" --workspace trial2 -O2 --num_steps 64 --upsample_steps 0\n\n## test\npython main.py --workspace trial2 -O2 --test\npython main.py --workspace trial2 -O2 --test --save_mesh\npython main.py --workspace trial2 -O2 --test --gui # not recommended, FPS will be low.\n\n### DMTet finetuning\n\n## use --dmtet and --init_with <nerf checkpoint> to finetune the mesh at higher reslution\npython main.py -O --text \"a hamburger\" --workspace trial_dmtet --dmtet --iters 5000 --init_with trial/checkpoints/df.pth\n\n## init dmtet with a mesh to generate texture\n# require install of cubvh: pip install git+https://github.com/ashawkey/cubvh\n# remove --lock_geo to also finetune geometry, but performance may be bad.\npython main.py -O --text \"a white bunny with red eyes\" --workspace trial_dmtet_mesh --dmtet --iters 5000 --init_with ./data/bunny.obj --lock_geo\n\n## test & export the mesh\npython main.py -O --text \"a hamburger\" --workspace trial_dmtet --dmtet --iters 5000 --test --save_mesh\n\n## gui to visualize dmtet\npython main.py -O --text \"a hamburger\" --workspace trial_dmtet --dmtet --iters 5000 --test --gui\n\n### Image-conditioned 3D Generation\n\n## preprocess input image\n# note: the results of image-to-3D is dependent on zero-1-to-3's capability. For best performance, the input image should contain a single front-facing object, it should have square aspect ratio, with <1024 pixel resolution. Check the examples under ./data.\n# this will exports `<image>_rgba.png`, `<image>_depth.png`, and `<image>_normal.png` to the directory containing the input image.\npython preprocess_image.py <image>.png\npython preprocess_image.py <image>.png --border_ratio 0.4 # increase border_ratio if the center object appears too large and results are unsatisfying.\n\n## zero123 train\n# pass in the processed <image>_rgba.png by --image and do NOT pass in --text to enable zero-1-to-3 backend.\npython main.py -O --image <image>_rgba.png --workspace trial_image --iters 5000\n\n# if the image is not exactly front-view (elevation = 0), adjust default_polar (we use polar from 0 to 180 to represent elevation from 90 to -90)\npython main.py -O --image <image>_rgba.png --workspace trial_image --iters 5000 --default_polar 80\n\n# by default we leverage monocular depth estimation to aid image-to-3d, but if you find the depth estimation inaccurate and harms results, turn it off by:\npython main.py -O --image <image>_rgba.png --workspace trial_image --iters 5000 --lambda_depth 0\n\npython main.py -O --image <image>_rgba.png --workspace trial_image_dmtet --dmtet --init_with trial_image/checkpoints/df.pth\n\n## zero123 with multiple images\npython main.py -O --image_config config/<config>.csv --workspace trial_image --iters 5000\n\n## render <num> images per batch (default 1)\npython main.py -O --image_config config/<config>.csv --workspace trial_image --iters 5000 --batch_size 4\n\n# providing both --text and --image enables stable-diffusion backend (similar to make-it-3d)\npython main.py -O --image hamburger_rgba.png --text \"a DSLR photo of a delicious hamburger\" --workspace trial_image_text --iters 5000\n\npython main.py -O --image hamburger_rgba.png --text \"a DSLR photo of a delicious hamburger\" --workspace trial_image_text_dmtet --dmtet --init_with trial_image_text/checkpoints/df.pth\n\n## test / visualize\npython main.py -O --image <image>_rgba.png --workspace trial_image_dmtet --dmtet --test --save_mesh\npython main.py -O --image <image>_rgba.png --workspace trial_image_dmtet --dmtet --test --gui\n\n### Debugging\n\n# Can save guidance images for debugging purposes. These get saved in trial_hamburger/guidance.\n# Warning: this slows down training considerably and consumes lots of disk space!\npython main.py --text \"a hamburger\" --workspace trial_hamburger -O --vram_O --save_guidance --save_guidance_interval 5 # save every 5 steps\n```\n\nFor example commands, check [`scripts`](./scripts).\n\nFor advanced tips and other developing stuff, check [Advanced Tips](./assets/advanced.md).\n\n# Evalutation\n\nReproduce the paper CLIP R-precision evaluation\n\nAfter the testing part in the usage, the validation set containing projection from different angle is generated. Test the R-precision between prompt and the image.(R=1)\n\n```bash\npython r_precision.py --text \"a snake is flying in the sky\" --workspace snake_HQ --latest ep0100 --mode depth --clip clip-ViT-B-16\n```\n\n# Acknowledgement\n\nThis work is based on an increasing list of amazing research works and open-source projects, thanks a lot to all the authors for sharing!\n\n* [DreamFusion: Text-to-3D using 2D Diffusion](https://dreamfusion3d.github.io/)\n    ```\n    @article{poole2022dreamfusion,\n        author = {Poole, Ben and Jain, Ajay and Barron, Jonathan T. and Mildenhall, Ben},\n        title = {DreamFusion: Text-to-3D using 2D Diffusion},\n        journal = {arXiv},\n        year = {2022},\n    }\n    ```\n\n* [Magic3D: High-Resolution Text-to-3D Content Creation](https://research.nvidia.com/labs/dir/magic3d/)\n   ```\n   @inproceedings{lin2023magic3d,\n      title={Magic3D: High-Resolution Text-to-3D Content Creation},\n      author={Lin, Chen-Hsuan and Gao, Jun and Tang, Luming and Takikawa, Towaki and Zeng, Xiaohui and Huang, Xun and Kreis, Karsten and Fidler, Sanja and Liu, Ming-Yu and Lin, Tsung-Yi},\n      booktitle={IEEE Conference on Computer Vision and Pattern Recognition ({CVPR})},\n      year={2023}\n    }\n   ```\n\n* [Zero-1-to-3: Zero-shot One Image to 3D Object](https://github.com/cvlab-columbia/zero123)\n    ```\n    @misc{liu2023zero1to3,\n        title={Zero-1-to-3: Zero-shot One Image to 3D Object},\n        author={Ruoshi Liu and Rundi Wu and Basile Van Hoorick and Pavel Tokmakov and Sergey Zakharov and Carl Vondrick},\n        year={2023},\n        eprint={2303.11328},\n        archivePrefix={arXiv},\n        primaryClass={cs.CV}\n    }\n    ```\n    \n* [Perp-Neg: Re-imagine the Negative Prompt Algorithm: Transform 2D Diffusion into 3D, alleviate Janus problem and Beyond](https://perp-neg.github.io/)\n    ```\n    @article{armandpour2023re,\n      title={Re-imagine the Negative Prompt Algorithm: Transform 2D Diffusion into 3D, alleviate Janus problem and Beyond},\n      author={Armandpour, Mohammadreza and Zheng, Huangjie and Sadeghian, Ali and Sadeghian, Amir and Zhou, Mingyuan},\n      journal={arXiv preprint arXiv:2304.04968},\n      year={2023}\n    }\n    ```\n    \n* [RealFusion: 360° Reconstruction of Any Object from a Single Image](https://github.com/lukemelas/realfusion)\n    ```\n    @inproceedings{melaskyriazi2023realfusion,\n        author = {Melas-Kyriazi, Luke and Rupprecht, Christian and Laina, Iro and Vedaldi, Andrea},\n        title = {RealFusion: 360 Reconstruction of Any Object from a Single Image},\n        booktitle={CVPR}\n        year = {2023},\n        url = {https://arxiv.org/abs/2302.10663},\n    }\n    ```\n\n* [Fantasia3D: Disentangling Geometry and Appearance for High-quality Text-to-3D Content Creation](https://fantasia3d.github.io/)\n    ```\n    @article{chen2023fantasia3d,\n        title={Fantasia3D: Disentangling Geometry and Appearance for High-quality Text-to-3D Content Creation},\n        author={Rui Chen and Yongwei Chen and Ningxin Jiao and Kui Jia},\n        journal={arXiv preprint arXiv:2303.13873},\n        year={2023}\n    }\n    ```\n\n* [Make-It-3D: High-Fidelity 3D Creation from A Single Image with Diffusion Prior](https://make-it-3d.github.io/)\n    ```\n    @article{tang2023make,\n        title={Make-It-3D: High-Fidelity 3D Creation from A Single Image with Diffusion Prior},\n        author={Tang, Junshu and Wang, Tengfei and Zhang, Bo and Zhang, Ting and Yi, Ran and Ma, Lizhuang and Chen, Dong},\n        journal={arXiv preprint arXiv:2303.14184},\n        year={2023}\n    }\n    ```\n\n* [Stable Diffusion](https://github.com/CompVis/stable-diffusion) and the [diffusers](https://github.com/huggingface/diffusers) library.\n\n    ```\n    @misc{rombach2021highresolution,\n        title={High-Resolution Image Synthesis with Latent Diffusion Models},\n        author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},\n        year={2021},\n        eprint={2112.10752},\n        archivePrefix={arXiv},\n        primaryClass={cs.CV}\n    }\n\n    @misc{von-platen-etal-2022-diffusers,\n        author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Thomas Wolf},\n        title = {Diffusers: State-of-the-art diffusion models},\n        year = {2022},\n        publisher = {GitHub},\n        journal = {GitHub repository},\n        howpublished = {\\url{https://github.com/huggingface/diffusers}}\n    }\n    ```\n\n* The GUI is developed with [DearPyGui](https://github.com/hoffstadt/DearPyGui).\n\n* Puppy image from : https://www.pexels.com/photo/high-angle-photo-of-a-corgi-looking-upwards-2664417/\n\n* Anya images from : https://www.goodsmile.info/en/product/13301/POP+UP+PARADE+Anya+Forger.html\n\n# Citation\n\nIf you find this work useful, a citation will be appreciated via:\n```\n@misc{stable-dreamfusion,\n    Author = {Jiaxiang Tang},\n    Year = {2022},\n    Note = {https://github.com/ashawkey/stable-dreamfusion},\n    Title = {Stable-dreamfusion: Text-to-3D with Stable-diffusion}\n}\n```\n"
  },
  {
    "path": "requirements.txt",
    "content": "tqdm\nrich\nninja\nnumpy\npandas\nscipy\nscikit-learn\nmatplotlib\nopencv-python\nimageio\nimageio-ffmpeg\n\ntorch\ntorch-ema\neinops\ntensorboard\ntensorboardX\n\n# for gui\ndearpygui\n\n# for grid_tcnn\n# git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch\n\n# for stable-diffusion\nhuggingface_hub\ndiffusers >= 0.9.0\naccelerate\ntransformers\n\n# for dmtet and mesh export\nxatlas\ntrimesh\nPyMCubes\npymeshlab\ngit+https://github.com/NVlabs/nvdiffrast/\n\n# for zero123\ncarvekit-colab\nomegaconf\npytorch-lightning\ntaming-transformers-rom1504\nkornia\ngit+https://github.com/openai/CLIP.git\n\n# for omnidata\ngdown\n\n# for dpt\ntimm\n\n# for remote debugging\ndebugpy-run\n\n# for deepfloyd if\nsentencepiece\n"
  },
  {
    "path": "scripts/install_ext.sh",
    "content": "pip install ./raymarching\npip install ./shencoder\npip install ./freqencoder\npip install ./gridencoder"
  },
  {
    "path": "scripts/res64.args",
    "content": "-O --vram_O --w 64 --h 64"
  },
  {
    "path": "scripts/run.sh",
    "content": "#! /bin/bash\nCUDA_VISIBLE_DEVICES=1 python main.py -O --text \"a DSLR photo of a delicious hamburger\" --workspace trial_hamburger --iters 5000\nCUDA_VISIBLE_DEVICES=1 python main.py -O --text \"a DSLR photo of a delicious hamburger\" --workspace trial2_hamburger --dmtet --iters 5000 --init_with trial_hamburger/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=1 python main.py -O --text \"a highly detailed stone bust of Theodoros Kolokotronis\" --workspace trial_stonehead --iters 5000\nCUDA_VISIBLE_DEVICES=1 python main.py -O --text \"a highly detailed stone bust of Theodoros Kolokotronis\" --workspace trial2_stonehead --dmtet --iters 5000 --init_with trial_stonehead/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=1 python main.py -O --text \"an astronaut, full body\" --workspace trial_astronaut --iters 5000\nCUDA_VISIBLE_DEVICES=1 python main.py -O --text \"an astronaut, full body\" --workspace trial2_astronaut --dmtet --iters 5000 --init_with trial_astronaut/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=1 python main.py -O --text \"a DSLR photo of a squirrel-octopus hybrid\" --workspace trial_squrrel_octopus --iters 5000\nCUDA_VISIBLE_DEVICES=1 python main.py -O --text \"a DSLR photo of a squirrel-octopus hybrid\" --workspace trial2_squrrel_octopus --dmtet --iters 5000 --init_with trial_squrrel_octopus/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=1 python main.py -O --text \"a baby bunny sitting on top of a stack of pancakes\" --workspace trial_rabbit_pancake --iters 5000\nCUDA_VISIBLE_DEVICES=1 python main.py -O --text \"a metal bunny sitting on top of a stack of chocolate cookies\" --workspace trial2_rabbit_pancake --dmtet --iters 5000 --init_with trial_rabbit_pancake/checkpoints/df.pth"
  },
  {
    "path": "scripts/run2.sh",
    "content": "#! /bin/bash\n\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"a DSLR photo of a shiba inu playing golf wearing tartan golf clothes and hat\" --workspace trial_shiba --iters 10000\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"a DSLR photo of a shiba inu playing golf wearing tartan golf clothes and hat\" --workspace trial2_shiba --dmtet --iters 5000 --init_with trial_shiba/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"a banana peeling itself\" --workspace trial_banana --iters 10000\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"a banana peeling itself\" --workspace trial2_banana --dmtet --iters 5000 --init_with trial_banana/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"a capybara wearing a top hat, low poly\" --workspace trial_capybara --iters 10000\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"a capybara wearing a top hat, low poly\" --workspace trial2_capybara --dmtet --iters 5000 --init_with trial_capybara/checkpoints/df.pth"
  },
  {
    "path": "scripts/run3.sh",
    "content": "#! /bin/bash\n\nCUDA_VISIBLE_DEVICES=7 python main.py -O --text \"ironman, full body\" --workspace trial_ironman --iters 10000\nCUDA_VISIBLE_DEVICES=7 python main.py -O --text \"ironman, full body\" --workspace trial2_ironman --dmtet --iters 5000 --init_with trial_ironman/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=7 python main.py -O --text \"a DSLR photo of an ice cream sundae\" --workspace trial_icecream --iters 10000\nCUDA_VISIBLE_DEVICES=7 python main.py -O --text \"a DSLR photo of an ice cream sundae\" --workspace trial2_icecream --dmtet --iters 5000 --init_with trial_icecream/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=7 python main.py -O --text \"a DSLR photo of a kingfisher bird\" --workspace trial_bird --iters 10000\nCUDA_VISIBLE_DEVICES=7 python main.py -O --text \"a DSLR photo of a kingfisher bird\" --workspace trial2_bird --dmtet --iters 5000 --init_with trial_bird/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=7 python main.py -O --text \"a car made of sushi\" --workspace trial_sushi --iters 10000\nCUDA_VISIBLE_DEVICES=7 python main.py -O --text \"a car made of sushi\" --workspace trial2_sushi --dmtet --iters 5000 --init_with trial_sushi/checkpoints/df.pth\n"
  },
  {
    "path": "scripts/run4.sh",
    "content": "#! /bin/bash\n\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"a rabbit, animated movie character, high detail 3d model\" --workspace trial_rabbit2 --iters 10000\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"a rabbit, animated movie character, high detail 3d model\" --workspace trial2_rabbit2 --dmtet --iters 5000 --init_with trial_rabbit2/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"a corgi dog, highly detailed 3d model\" --workspace trial_corgi --iters 10000\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"a corgi dog, highly detailed 3d model\" --workspace trial2_corgi --dmtet --iters 5000 --init_with trial_corgi/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \" a small saguaro cactus planted in a clay pot\" --workspace trial_cactus --iters 10000\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \" a small saguaro cactus planted in a clay pot\" --workspace trial2_cactus --dmtet --iters 5000 --init_with trial_cactus/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"the leaning tower of Pisa\" --workspace trial_pisa --iters 10000\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"the leaning tower of Pisa\" --workspace trial2_pisa --dmtet --iters 5000 --init_with trial_pisa/checkpoints/df.pth"
  },
  {
    "path": "scripts/run5.sh",
    "content": "#! /bin/bash\n\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"Perched blue jay bird\" --workspace trial_jay --iters 10000\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"Perched blue jay bird\" --workspace trial2_jay --dmtet --iters 5000 --init_with trial_jay/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"angel statue wings out\" --workspace trial_angle --iters 10000\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"angel statue wings out\" --workspace trial2_angle --dmtet --iters 5000 --init_with trial_angle/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"devil statue\" --workspace trial_devil --iters 10000\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"devil statue\" --workspace trial2_devil --dmtet --iters 5000 --init_with trial_devil/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"Einstein statue\" --workspace trial_einstein --iters 10000\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"Einstein statue\" --workspace trial2_einstein --dmtet --iters 5000 --init_with trial_einstein/checkpoints/df.pth\n"
  },
  {
    "path": "scripts/run6.sh",
    "content": "#! /bin/bash\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a baby bunny sitting on top of a stack of pancakes\" --workspace trial_rabbit_pancake --iters 5000\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a metal bunny sitting on top of a stack of chocolate cookies\" --workspace trial2_rabbit_pancake --dmtet --iters 5000 --init_with trial_rabbit_pancake/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a DSLR photo of a blue jay standing on a large basket of rainbow macarons\" --workspace trial_jay --iters 5000\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a DSLR photo of a blue jay standing on a large basket of rainbow macarons\" --workspace trial2_jay --dmtet --iters 5000 --init_with trial_jay/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a DSLR photo of a fox taking a photograph using a DSLR\" --workspace trial_fox --iters 5000\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a DSLR photo of a fox taking a photograph using a DSLR\" --workspace trial2_fox --dmtet --iters 5000 --init_with trial_fox/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a DSLR photo of a peacock on a surfboard\" --workspace trial_peacock --iters 5000\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a DSLR photo of a peacock on a surfboard\" --workspace trial2_peacock --dmtet --iters 5000 --init_with trial_peacock/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a flower made out of metal\" --workspace trial_metal_flower --iters 5000\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a flower made out of metal\" --workspace trial2_metal_flower --dmtet --iters 5000 --init_with trial_metal_flower/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it\" --workspace trial_chicken --iters 5000\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it\" --workspace trial2_chicken --dmtet --iters 5000 --init_with trial_chicken/checkpoints/df.pth"
  },
  {
    "path": "scripts/run_if.sh",
    "content": "#! /bin/bash\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a baby bunny sitting on top of a stack of pancakes\" --workspace trial_if_rabbit_pancake --iters 5000 --IF\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a metal bunny sitting on top of a stack of chocolate cookies\" --workspace trial_if2_rabbit_pancake --dmtet --iters 5000 --init_with trial_if_rabbit_pancake/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a DSLR photo of a blue jay standing on a large basket of rainbow macarons\" --workspace trial_if_jay --iters 5000 --IF\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a DSLR photo of a blue jay standing on a large basket of rainbow macarons\" --workspace trial_if2_jay --dmtet --iters 5000 --init_with trial_if_jay/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a DSLR photo of a fox taking a photograph using a DSLR\" --workspace trial_if_fox --iters 5000 --IF\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a DSLR photo of a fox taking a photograph using a DSLR\" --workspace trial_if2_fox --dmtet --iters 5000 --init_with trial_if_fox/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a DSLR photo of a peacock on a surfboard\" --workspace trial_if_peacock --iters 5000 --IF\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a DSLR photo of a peacock on a surfboard\" --workspace trial_if2_peacock --dmtet --iters 5000 --init_with trial_if_peacock/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a flower made out of metal\" --workspace trial_if_metal_flower --iters 5000 --IF\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a flower made out of metal\" --workspace trial_if2_metal_flower --dmtet --iters 5000 --init_with trial_if_metal_flower/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it\" --workspace trial_if_chicken --iters 5000 --IF\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a zoomed out DSLR photo of an egg cracked open with a newborn chick hatching out of it\" --workspace trial_if2_chicken --dmtet --iters 5000 --init_with trial_if_chicken/checkpoints/df.pth"
  },
  {
    "path": "scripts/run_if2.sh",
    "content": "#! /bin/bash\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a corgi taking a selfie\" --workspace trial_if_corgi --iters 5000 --IF\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a corgi taking a selfie\" --workspace trial_if2_corgi --dmtet --iters 5000 --init_with trial_if_corgi/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a DSLR photo of a ghost eating a hamburger\" --workspace trial_if_ghost --iters 5000 --IF\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a DSLR photo of a ghost eating a hamburger\" --workspace trial_if2_ghost --dmtet --iters 5000 --init_with trial_if_ghost/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a DSLR photo of an origami motorcycle\" --workspace trial_if_motor --iters 5000 --IF\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a DSLR photo of an origami motorcycle\" --workspace trial_if2_motor --dmtet --iters 5000 --init_with trial_if_motor/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a DSLR photo of a Space Shuttle\" --workspace trial_if_spaceshuttle --iters 5000 --IF\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a DSLR photo of a Space Shuttle\" --workspace trial_if2_spaceshuttle --dmtet --iters 5000 --init_with trial_if_spaceshuttle/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a palm tree, low poly 3d model\" --workspace trial_if_palm --iters 5000 --IF\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a palm tree, low poly 3d model\" --workspace trial_if2_palm --dmtet --iters 5000 --init_with trial_if_palm/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a zoomed out DSLR photo of a marble bust of a cat, a real mouse is sitting on its head\" --workspace trial_if_cat_mouse --iters 5000 --IF\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a zoomed out DSLR photo of a marble bust of a cat, a real mouse is sitting on its head\" --workspace trial_if2_cat_mouse --dmtet --iters 5000 --init_with trial_if_cat_mouse/checkpoints/df.pth"
  },
  {
    "path": "scripts/run_if2_perpneg.sh",
    "content": "#! /bin/bash\n# To avoid the Janus problem caused by the diffusion model's front view bias, utilize the Perp-Neg algorithm. To maximize its benefits,\n# increase the absolute value of \"negative_w\" for improved Janus problem mitigation. If you encounter flat faces or divergence, consider \n# reducing the absolute value of \"negative_w\". The value of \"negative_w\" should vary for each prompt due to the diffusion model's varying \n# bias towards generating front views for different objects. Vary the weights within the range of 0 to -4.\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a lion bust\" --workspace trial_perpneg_if_lion --iters 5000 --IF --batch_size 1 --perpneg\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a marble lion head\" --workspace trial_perpneg_if2_lion_p --dmtet --iters 5000 --perpneg --init_with trial_perpneg_if_lion/checkpoints/df.pth\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a marble lion head\" --workspace trial_perpneg_if2_lion_nop --dmtet --iters 5000 --init_with trial_perpneg_if_lion/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a tiger cub\" --workspace trial_perpneg_if_tiger --iters 5000 --IF --batch_size 1 --perpneg\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"tiger\" --workspace trial_perpneg_if2_tiger_p --dmtet --iters 5000 --perpneg --init_with trial_perpneg_if_tiger/checkpoints/df.pth\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"tiger\" --workspace trial_perpneg_if2_tiger_nop --dmtet --iters 5000 --init_with trial_perpneg_if_tiger/checkpoints/df.pth\n\n# larger absolute value of negative_w is used for the following command because the defult negative weight of -2 is not enough to make the diffusion model to produce the views as desired\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a shiba dog wearing sunglasses\" --workspace trial_perpneg_if_shiba --iters 5000 --IF --batch_size 1 --perpneg --negative_w -3.0\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"shiba wearing sunglasses\"  --workspace trial_perpneg_if2_shiba_p --dmtet --iters 5000 --perpneg --negative_w -3.0 --init_with trial_perpneg_if_shiba/checkpoints/df.pth\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"shiba wearing sunglasses\" --workspace trial_perpneg_if2_shiba_nop --dmtet --iters 5000 --init_with trial_perpneg_if_shiba/checkpoints/df.pth\n\n"
  },
  {
    "path": "scripts/run_image.sh",
    "content": "# zero123 backend (single object, images like 3d model rendering)\n\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/teddy_rgba.png --workspace trial_image_teddy --iters 5000\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/teddy_rgba.png --workspace trial2_image_teddy --iters 5000 --dmtet --init_with trial_image_teddy/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/catstatue_rgba.png --workspace trial_image_catstatue --iters 5000\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/catstatue_rgba.png --workspace trial2_image_catstatue --iters 5000 --dmtet --init_with trial_image_catstatue/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/firekeeper_rgba.png --workspace trial_image_firekeeper --iters 5000\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/firekeeper_rgba.png --workspace trial2_image_firekeeper --iters 5000 --dmtet --init_with trial_image_firekeeper/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/hamburger_rgba.png --workspace trial_image_hamburger --iters 5000\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/hamburger_rgba.png --workspace trial2_image_hamburger --iters 5000 --dmtet --init_with trial_image_hamburger/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/corgi_rgba.png --workspace trial_image_corgi --iters 5000\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/corgi_rgba.png --workspace trial2_image_corgi --iters 5000 --dmtet --init_with trial_image_corgi/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cactus_rgba.png --workspace trial_image_cactus --iters 5000\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cactus_rgba.png --workspace trial2_image_cactus --iters 5000 --dmtet --init_with trial_image_cactus/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cake_rgba.png --workspace trial_image_cake --iters 5000\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data/cake_rgba.png --workspace trial2_image_cake --iters 5000 --dmtet --init_with trial_image_cake/checkpoints/df.pth\n\n# CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/warrior_rgba.png --workspace trial_image_warrior --iters 5000\n# CUDA_VISIBLE_DEVICES=6 python main.py -O --image data/warrior_rgba.png --workspace trial2_image_warrior --iters 5000 --dmtet --init_with trial_image_warrior/checkpoints/df.pth"
  },
  {
    "path": "scripts/run_image_anya.sh",
    "content": "# Phase 1 - barely fits in A100 40GB.\n# Conclusion: results in concave-ish face, no neck, excess hair in the back\nCUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage \\\n  --iters 10000 --save_guidance --save_guidance_interval 10 --ckpt scratch --batch_size 2 --test_interval 2 \\\n  --h 128 --w 128 --zero123_grad_scale None\n\n# Phase 2 - barely fits in A100 40GB.\n# 20X smaller lambda_3d_normal_smooth, --known_view_interval 2, 3X LR\n# Much higher jitter to increase disparity (and eliminate some of the flatness)... not too high either (to avoid cropping the face)\nCUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage_B_GPU2_reproduction1_GPU2 \\\n  --text \"A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )\" \\\n  --iters 12500 --ckpt trial_anya_1_refimage/checkpoints/df_ep0100.pth --save_guidance --save_guidance_interval 1 \\\n  --h 256 --w 256 --albedo_iter_ratio 0.0 --t_range 0.2 0.6 --batch_size 4 --radius_range 2.2 2.6 --test_interval 2 \\\n  --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.1 --jitter_target 0.1 --jitter_up 0.05 \\\n  --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --progressive_view --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 1 \\\n  --exp_start_iter 10000 --exp_end_iter 12500\n\n# Phase 3 - increase resolution to 512\n# Disable textureless since they can cause catastrophic divergence\n# Since radius range is inconsistent, increase it, and reduce the jitter to avoid excessively cropped renders.\n# Learning rate may be set too high, since `--batch_size 1`.\nCUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage_B_GPU2_reproduction1_GPU2_refinedGPU2 \\\n  --text \"A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )\" \\\n  --iters 25000 --ckpt trial_anya_1_refimage_B_GPU2_reproduction1_GPU2/checkpoints/df_ep0125.pth  --save_guidance --save_guidance_interval 1 \\\n  --h 512 --w 512 --albedo_iter_ratio 0.0 --t_range 0.0 0.5 --batch_size 1 --radius_range 3.2 3.6 --test_interval 2 \\\n  --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.015 --jitter_target 0.015 --jitter_up 0.05 \\\n  --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 0.5 --textureless_ratio 0.0 --min_ambient_ratio 0.3 \\\n  --exp_start_iter 12500 --exp_end_iter 25000\n\n# Generate 6 views\nCUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --ckpt trial_anya_1_refimage_B_GPU2_reproduction1_GPU2_refinedGPU2/checkpoints/df_ep0250.pth --six_views\n\n# Phase 4 - untested, need to adjust\n# CUDA_VISIBLE_DEVICES=0 python main.py -O --image data/anya_front_rgba.png --workspace trial_anya_1_refimage --iters 5000 --dmtet --init_with trial_anya_1_refimage/checkpoints/df.pth\n\n"
  },
  {
    "path": "scripts/run_image_hard_examples.sh",
    "content": "bash scripts/run_image_procedure.sh 0 30 90 anya_front \"A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )\"\nbash scripts/run_image_procedure.sh 1 30 70 baby_phoenix_on_ice \"A DSLR 3D photo of an adorable baby phoenix made in Swarowski crystal highly detailed intricate concept art 8K ( unreal engine 5 trending on Artstation )\"\nbash scripts/run_image_procedure.sh 2 30 90 bollywood_actress \"A DSLR 3D photo of a beautiful bollywood indian actress, pretty eyes, full body shot composition, sunny outdoor, seen from far away ( highly detailed intricate 8K unreal engine 5 trending on Artstation )\"\nbash scripts/run_image_procedure.sh 3 30 40 beach_house_1 \"A DSLR 3D photo of a very beautiful small house on a beach ( highly detailed intricate 8K unreal engine 5 trending on Artstation )\"\nbash scripts/run_image_procedure.sh 4 30 60 beach_house_2 \"A DSLR 3D photo of a very beautiful high-tech small house with solar panels and wildflowers on a beach ( highly detailed intricate 8K unreal engine 5 trending on Artstation )\"\nbash scripts/run_image_procedure.sh 5 30 90 mona_lisa \"A DSLR 3D photo of a beautiful young woman dressed like Mona Lisa ( highly detailed intricate 8K unreal engine 5 trending on Artstation )\"\nbash scripts/run_image_procedure.sh 6 30 80 futuristic_car \"A DSLR 3D photo of a crazily futuristic electric car ( highly detailed intricate 8K unreal engine 5 trending on Artstation )\"\n# the church ruins probably require a wider field of view... e.g. 90 degrees, maybe even more... so may not work with Zero123 etc.\nbash scripts/run_image_procedure.sh 7 30 90 church_ruins \"A DSLR 3D photo of the remains of an isolated old church ruin covered in ivy ( highly detailed intricate 8K unreal engine 5 trending on Artstation )\"\n\n# young woman dressed like mona lisa"
  },
  {
    "path": "scripts/run_image_procedure.sh",
    "content": "# Perform a 2D-to-3D reconstruction, similar to the Anya case study: https://github.com/ashawkey/stable-dreamfusion/issues/263\n# Args:\n#    bash scripts/run_image_procedure.sh GPU_ID guidance_interval image_name \"prompt\"\n# e.g.:\n#    bash scripts/run_image_procedure 1 30 baby_phoenix_on_ice \"An adorable baby phoenix made in Swarowski crystal highly detailed intricated concept art 8K\"\nGPU_ID=$1\nGUIDANCE_INTERVAL=$2\nDEFAULT_POLAR=$3\nPREFIX=$4\nPROMPT=$5\nEPOCHS1=100\nEPOCHS2=200\nEPOCHS3=300\nIMAGE=data/$PREFIX.png\nIMAGE_RGBA=data/${PREFIX}_rgba.png\nWS_PH1=trial_$PREFIX-ph1\nWS_PH2=trial_$PREFIX-ph2\nWS_PH3=trial_$PREFIX-ph3\nCKPT1=$WS_PH1/checkpoints/df_ep0${EPOCHS1}.pth\nCKPT2=$WS_PH2/checkpoints/df_ep0${EPOCHS2}.pth\nCKPT3=$WS_PH3/checkpoints/df_ep0${EPOCHS3}.pth\n\n# Can uncomment to clear up trial folders. Be careful - mistakes could erase important work!\n# rm -r $WS_PH1 $WS_PH2 $WS_PH3\n\n# Preprocess\nif [ ! -f $IMAGE_RGBA ]\nthen\n    python preprocess_image.py $IMAGE\nfi\n\nif [ ! -f $CKPT1 ]\nthen\n    # Phase 1 - zero123-guidance\n    # WARNING: claforte: constantly runs out of VRAM with resolution of 128x128 and batch_size 2... no longer able to reproduce Anya result because of this...\n    #   I added these to try to reduce mem usage, but this might degrade the quality... `--lambda_depth 0 --lambda_3d_normal_smooth 0`\n    # Remove: --ckpt scratch\n    CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --workspace $WS_PH1 --default_polar $DEFAULT_POLAR \\\n      --iters ${EPOCHS1}00 --save_guidance --save_guidance_interval $GUIDANCE_INTERVAL --batch_size 1 --test_interval 2 \\\n      --h 96 --w 96 --zero123_grad_scale None --lambda_3d_normal_smooth 0 --dont_override_stuff \\\n      --fovy_range 20 20 --guidance_scale 5 \nfi\n\nGUIDANCE_INTERVAL=7\nif [ ! -f $CKPT2 ]\nthen\n  # Phase 2 - SD-guidance at 256x256\n  CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --workspace $WS_PH2 \\\n    --text \"${PROMPT}\" --default_polar $DEFAULT_POLAR \\\n    --iters ${EPOCHS2}00 --ckpt $CKPT1 --save_guidance --save_guidance_interval 7 \\\n    --h 128 --w 128 --albedo_iter_ratio 0.0 --t_range 0.2 0.6 --batch_size 4 --radius_range 2.2 2.6 --test_interval 2 \\\n    --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.1 --jitter_target 0.1 --jitter_up 0.05 \\\n    --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --progressive_view --progressive_view_init_ratio 0.05 --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 1 --textureless_ratio 0.0 --min_ambient_ratio 0.3 \\\n    --exp_start_iter ${EPOCHS1}00 --exp_end_iter ${EPOCHS2}00\nfi\n\nif [ ! -f $CKPT3 ]\nthen\n  # # Phase 3 - increase resolution to 512\n  CUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --workspace $WS_PH3 \\\n    --text \"${PROMPT}\" --default_polar $DEFAULT_POLAR \\\n    --iters ${EPOCHS3}00 --ckpt $CKPT2  --save_guidance --save_guidance_interval 7 \\\n    --h 512 --w 512 --albedo_iter_ratio 0.0 --t_range 0.0 0.5 --batch_size 1 --radius_range 3.2 3.6 --test_interval 2 \\\n    --vram_O --guidance_scale 10 --jitter_pose --jitter_center 0.015 --jitter_target 0.015 --jitter_up 0.05 \\\n    --known_view_noise_scale 0 --lambda_depth 0 --lr 0.003 --known_view_interval 2 --dont_override_stuff --lambda_3d_normal_smooth 0.5 --textureless_ratio 0.0 --min_ambient_ratio 0.3 \\\n    --exp_start_iter ${EPOCHS2}00 --exp_end_iter ${EPOCHS3}00\nfi\n\n# Generate 6 views\nCUDA_VISIBLE_DEVICES=$GPU_ID python main.py -O --image $IMAGE_RGBA --ckpt $CKPT3 --six_views\n\n"
  },
  {
    "path": "scripts/run_image_text.sh",
    "content": "# sd backend (realistic images)\n\nCUDA_VISIBLE_DEVICES=4 python main.py -O --image data/teddy_rgba.png --text \"a brown teddy bear sitting on a ground\" --workspace trial_imagetext_teddy --iters 5000\nCUDA_VISIBLE_DEVICES=4 python main.py -O --image data/teddy_rgba.png --text \"a brown teddy bear sitting on a ground\" --workspace trial2_imagetext_teddy --iters 10000 --dmtet --init_with trial_imagetext_teddy/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=4 python main.py -O --image data/corgi_rgba.png --text \"a corgi running\" --workspace trial_imagetext_corgi --iters 5000\nCUDA_VISIBLE_DEVICES=4 python main.py -O --image data/corgi_rgba.png --text \"a corgi running\" --workspace trial2_imagetext_corgi --iters 10000 --dmtet --init_with trial_imagetext_corgi/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=4 python main.py -O --image data/hamburger_rgba.png --text \"a DSLR photo of a delicious hamburger\" --workspace trial_imagetext_hamburger --iters 5000\nCUDA_VISIBLE_DEVICES=4 python main.py -O --image data/hamburger_rgba.png --text \"a DSLR photo of a delicious hamburger\" --workspace trial2_imagetext_hamburger --iters 10000 --dmtet --init_with trial_imagetext_hamburger/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=4 python main.py -O --image data/cactus_rgba.png --text \"a potted cactus plant\" --workspace trial_imagetext_cactus --iters 5000\nCUDA_VISIBLE_DEVICES=4 python main.py -O --image data/cactus_rgba.png --text \"a potted cactus plant\" --workspace trial2_imagetext_cactus --iters 10000 --dmtet --init_with trial_imagetext_cactus/checkpoints/df.pth\n"
  },
  {
    "path": "scripts/run_images.sh",
    "content": "# zero123 backend (single object, images like 3d model rendering)\n\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/corgi.csv --workspace trial_images_corgi --iters 5000\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/corgi.csv --workspace trial2_images_corgi --iters 10000 --dmtet --init_with trial_images_corgi/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/car.csv --workspace trial_images_car --iters 5000\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/car.csv --workspace trial2_images_car --iters 10000 --dmtet --init_with trial_images_car/checkpoints/df.pth\n\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/anya.csv --workspace trial_images_anya --iters 5000\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image_config config/anya.csv --workspace trial2_images_anya --iters 10000 --dmtet --init_with trial_images_anya/checkpoints/df.pth"
  },
  {
    "path": "shencoder/__init__.py",
    "content": "from .sphere_harmonics import SHEncoder"
  },
  {
    "path": "shencoder/backend.py",
    "content": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    '-O3', '-std=c++14',\n    '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',\n]\n\nif os.name == \"posix\":\n    c_flags = ['-O3', '-std=c++14']\nelif os.name == \"nt\":\n    c_flags = ['/O2', '/std:c++17']\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n        for program_files in [r\"C:\\\\Program Files (x86)\", r\"C:\\\\Program Files\"]:\n            for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n                paths = sorted(glob.glob(r\"%s\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\" % (program_files, edition)), reverse=True)\n                if paths:\n                    return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\"Could not locate a supported Microsoft Visual C++ installation\")\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n_backend = load(name='_sh_encoder',\n                extra_cflags=c_flags,\n                extra_cuda_cflags=nvcc_flags,\n                sources=[os.path.join(_src_path, 'src', f) for f in [\n                    'shencoder.cu',\n                    'bindings.cpp',\n                ]],\n                )\n\n__all__ = ['_backend']"
  },
  {
    "path": "shencoder/setup.py",
    "content": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    '-O3', '-std=c++14',\n    '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',\n]\n\nif os.name == \"posix\":\n    c_flags = ['-O3', '-std=c++14']\nelif os.name == \"nt\":\n    c_flags = ['/O2', '/std:c++17']\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n        for program_files in [r\"C:\\\\Program Files (x86)\", r\"C:\\\\Program Files\"]:\n            for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n                paths = sorted(glob.glob(r\"%s\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\" % (program_files, edition)), reverse=True)\n                if paths:\n                    return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\"Could not locate a supported Microsoft Visual C++ installation\")\n        os.environ[\"PATH\"] += \";\" + cl_path\n\nsetup(\n    name='shencoder', # package name, import this to use python API\n    ext_modules=[\n        CUDAExtension(\n            name='_shencoder', # extension name, import this to use CUDA API\n            sources=[os.path.join(_src_path, 'src', f) for f in [\n                'shencoder.cu',\n                'bindings.cpp',\n            ]],\n            extra_compile_args={\n                'cxx': c_flags,\n                'nvcc': nvcc_flags,\n            }\n        ),\n    ],\n    cmdclass={\n        'build_ext': BuildExtension,\n    }\n)"
  },
  {
    "path": "shencoder/sphere_harmonics.py",
    "content": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function import once_differentiable\nfrom torch.cuda.amp import custom_bwd, custom_fwd \n\ntry:\n    import _shencoder as _backend\nexcept ImportError:\n    from .backend import _backend\n\nclass _sh_encoder(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision\n    def forward(ctx, inputs, degree, calc_grad_inputs=False):\n        # inputs: [B, input_dim], float in [-1, 1]\n        # RETURN: [B, F], float\n\n        inputs = inputs.contiguous()\n        B, input_dim = inputs.shape # batch size, coord dim\n        output_dim = degree ** 2\n        \n        outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)\n\n        if calc_grad_inputs:\n            dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device)\n        else:\n            dy_dx = None\n\n        _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx)\n\n        ctx.save_for_backward(inputs, dy_dx)\n        ctx.dims = [B, input_dim, degree]\n\n        return outputs\n    \n    @staticmethod\n    #@once_differentiable\n    @custom_bwd\n    def backward(ctx, grad):\n        # grad: [B, C * C]\n\n        inputs, dy_dx = ctx.saved_tensors\n\n        if dy_dx is not None:\n            grad = grad.contiguous()\n            B, input_dim, degree = ctx.dims\n            grad_inputs = torch.zeros_like(inputs)\n            _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs)\n            return grad_inputs, None, None\n        else:\n            return None, None, None\n\n\n\nsh_encode = _sh_encoder.apply\n\n\nclass SHEncoder(nn.Module):\n    def __init__(self, input_dim=3, degree=4):\n        super().__init__()\n\n        self.input_dim = input_dim # coord dims, must be 3\n        self.degree = degree # 0 ~ 4\n        self.output_dim = degree ** 2\n\n        assert self.input_dim == 3, \"SH encoder only support input dim == 3\"\n        assert self.degree > 0 and self.degree <= 8, \"SH encoder only supports degree in [1, 8]\"\n        \n    def __repr__(self):\n        return f\"SHEncoder: input_dim={self.input_dim} degree={self.degree}\"\n    \n    def forward(self, inputs, size=1):\n        # inputs: [..., input_dim], normalized real world positions in [-size, size]\n        # return: [..., degree^2]\n\n        inputs = inputs / size # [-1, 1]\n\n        prefix_shape = list(inputs.shape[:-1])\n        inputs = inputs.reshape(-1, self.input_dim)\n\n        outputs = sh_encode(inputs, self.degree, inputs.requires_grad)\n        outputs = outputs.reshape(prefix_shape + [self.output_dim])\n\n        return outputs"
  },
  {
    "path": "shencoder/src/bindings.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"shencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"sh_encode_forward\", &sh_encode_forward, \"SH encode forward (CUDA)\");\n    m.def(\"sh_encode_backward\", &sh_encode_backward, \"SH encode backward (CUDA)\");\n}"
  },
  {
    "path": "shencoder/src/shencoder.cu",
    "content": "#include <stdint.h>\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/torch.h>\n\n#include <algorithm>\n#include <stdexcept>\n\n#include <cstdio>\n\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\")\n#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x \" must be an int tensor\")\n#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x \" must be a floating tensor\")\n\n\ntemplate <typename T>\n__host__ __device__ T div_round_up(T val, T divisor) {\n\treturn (val + divisor - 1) / divisor;\n}\n\ntemplate <typename scalar_t>\n__global__ void kernel_sh(\n    const scalar_t * __restrict__ inputs, \n    scalar_t * outputs, \n    uint32_t B, uint32_t D, uint32_t C,\n    scalar_t * dy_dx\n) {\n\tconst uint32_t b = threadIdx.x + blockIdx.x * blockDim.x;\n\tif (b >= B) return;\n\n\tconst uint32_t C2 = C * C;\n\n\t// locate\n\tinputs += b * D;\n\toutputs += b * C2;\n\n\tscalar_t x = inputs[0], y = inputs[1], z = inputs[2];\n\n\tscalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z;\n\tscalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2;\n\tscalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2;\n\n\tauto write_sh = [&]() {\n\t\toutputs[0] = 0.28209479177387814f ;                          // 1/(2*sqrt(pi))\n\t\tif (C <= 1) { return; }\n\t\toutputs[1] = -0.48860251190291987f*y ;                               // -sqrt(3)*y/(2*sqrt(pi))\n\t\toutputs[2] = 0.48860251190291987f*z ;                                // sqrt(3)*z/(2*sqrt(pi))\n\t\toutputs[3] = -0.48860251190291987f*x ;                               // -sqrt(3)*x/(2*sqrt(pi))\n\t\tif (C <= 2) { return; }\n\t\toutputs[4] = 1.0925484305920792f*xy ;                                // sqrt(15)*xy/(2*sqrt(pi))\n\t\toutputs[5] = -1.0925484305920792f*yz ;                               // -sqrt(15)*yz/(2*sqrt(pi))\n\t\toutputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ;                         // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))\n\t\toutputs[7] = -1.0925484305920792f*xz ;                               // -sqrt(15)*xz/(2*sqrt(pi))\n\t\toutputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ;                              // sqrt(15)*(x2 - y2)/(4*sqrt(pi))\n\t\tif (C <= 3) { return; }\n\t\toutputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ;                         // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))\n\t\toutputs[10] = 2.8906114426405538f*xy*z ;                             // sqrt(105)*xy*z/(2*sqrt(pi))\n\t\toutputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ;                                // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))\n\t\toutputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ;                         // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))\n\t\toutputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ;                                // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))\n\t\toutputs[14] = 1.4453057213202769f*z*(x2 - y2) ;                              // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))\n\t\toutputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ;                                // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))\n\t\tif (C <= 4) { return; }\n\t\toutputs[16] = 2.5033429417967046f*xy*(x2 - y2) ;                             // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi))\n\t\toutputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ;                                // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi))\n\t\toutputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ;                               // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi))\n\t\toutputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ;                               // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi))\n\t\toutputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ;                                // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi))\n\t\toutputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ;                               // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi))\n\t\toutputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ;                                // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi))\n\t\toutputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ;                                // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi))\n\t\toutputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ;                         // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n\t\tif (C <= 5) { return; }\n\t\toutputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                            // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n\t\toutputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ;                           // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi))\n\t\toutputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ;                         // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))\n\t\toutputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ;                              // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi))\n\t\toutputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ;                             // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))\n\t\toutputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ;                            // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi))\n\t\toutputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ;                             // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))\n\t\toutputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ;                               // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi))\n\t\toutputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ;                         // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi))\n\t\toutputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ;                         // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n\t\toutputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ;                            // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))\n\t\tif (C <= 6) { return; }\n\t\toutputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ;                               // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))\n\t\toutputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                            // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n\t\toutputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ;                             // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))\n\t\toutputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ;                               // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))\n\t\toutputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ;                           // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n\t\toutputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ;                            // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n\t\toutputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ;                         // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi))\n\t\toutputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ;                            // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n\t\toutputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ;                               // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi))\n\t\toutputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ;                               // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi))\n\t\toutputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ;                          // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))\n\t\toutputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ;                            // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))\n\t\toutputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ;                         // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))\n\t\tif (C <= 7) { return; }\n\t\toutputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ;                              // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi))\n\t\toutputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ;                             // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))\n\t\toutputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ;                          // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi))\n\t\toutputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ;                           // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))\n\t\toutputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ;                              // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))\n\t\toutputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ;                              // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))\n\t\toutputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ;                              // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))\n\t\toutputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ;                              // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi))\n\t\toutputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ;                              // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))\n\t\toutputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ;                         // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi))\n\t\toutputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ;                              // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))\n\t\toutputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ;                         // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))\n\t\toutputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                          // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi))\n\t\toutputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ;                               // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))\n\t\toutputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ;                              // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi))\n\t};\n\n\twrite_sh();\n\n\tif (dy_dx) {\n\t\tscalar_t *dx = dy_dx + b * D * C2;\n\t\tscalar_t *dy = dx + C2;\n\t\tscalar_t *dz = dy + C2;\n\n\t\tauto write_sh_dx = [&]() {\n\t\t\tdx[0] = 0.0f ;                             // 0\n\t\t\tif (C <= 1) { return; }\n\t\t\tdx[1] = 0.0f ;                             // 0\n\t\t\tdx[2] = 0.0f ;                             // 0\n\t\t\tdx[3] = -0.48860251190291992f ;                          // -sqrt(3)/(2*sqrt(pi))\n\t\t\tif (C <= 2) { return; }\n\t\t\tdx[4] = 1.0925484305920792f*y ;                          // sqrt(15)*y/(2*sqrt(pi))\n\t\t\tdx[5] = 0.0f ;                             // 0\n\t\t\tdx[6] = 0.0f ;                             // 0\n\t\t\tdx[7] = -1.0925484305920792f*z ;                         // -sqrt(15)*z/(2*sqrt(pi))\n\t\t\tdx[8] = 1.0925484305920792f*x ;                          // sqrt(15)*x/(2*sqrt(pi))\n\t\t\tif (C <= 3) { return; }\n\t\t\tdx[9] = -3.5402615395598609f*xy ;                                // -3*sqrt(70)*xy/(4*sqrt(pi))\n\t\t\tdx[10] = 2.8906114426405538f*yz ;                                // sqrt(105)*yz/(2*sqrt(pi))\n\t\t\tdx[11] = 0.0f ;                            // 0\n\t\t\tdx[12] = 0.0f ;                            // 0\n\t\t\tdx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ;                          // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))\n\t\t\tdx[14] = 2.8906114426405538f*xz ;                                // sqrt(105)*xz/(2*sqrt(pi))\n\t\t\tdx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ;                               // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))\n\t\t\tif (C <= 4) { return; }\n\t\t\tdx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ;                           // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi))\n\t\t\tdx[17] = -10.620784618679583f*xy*z ;                             // -9*sqrt(70)*xy*z/(4*sqrt(pi))\n\t\t\tdx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ;                         // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi))\n\t\t\tdx[19] = 0.0f ;                            // 0\n\t\t\tdx[20] = 0.0f ;                            // 0\n\t\t\tdx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ;                         // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))\n\t\t\tdx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ;                         // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))\n\t\t\tdx[23] = 5.3103923093397913f*z*(-x2 + y2) ;                              // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))\n\t\t\tdx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ;                           // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))\n\t\t\tif (C <= 5) { return; }\n\t\t\tdx[25] = 13.127641136803401f*xy*(-x2 + y2) ;                             // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi))\n\t\t\tdx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ;                          // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi))\n\t\t\tdx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ;                         // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi))\n\t\t\tdx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ;                         // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi))\n\t\t\tdx[29] = 0.0f ;                            // 0\n\t\t\tdx[30] = 0.0f ;                            // 0\n\t\t\tdx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ;                          // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))\n\t\t\tdx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ;                         // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))\n\t\t\tdx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ;                         // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi))\n\t\t\tdx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ;                          // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))\n\t\t\tdx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ;                               // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))\n\t\t\tif (C <= 6) { return; }\n\t\t\tdx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ;                             // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))\n\t\t\tdx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ;                           // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi))\n\t\t\tdx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ;                           // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))\n\t\t\tdx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ;                              // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi))\n\t\t\tdx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ;                             // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n\t\t\tdx[41] = 0.0f ;                            // 0\n\t\t\tdx[42] = 0.0f ;                            // 0\n\t\t\tdx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ;                              // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n\t\t\tdx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ;                             // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n\t\t\tdx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ;                              // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))\n\t\t\tdx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ;                           // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))\n\t\t\tdx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ;                           // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))\n\t\t\tdx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                             // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))\n\t\t\tif (C <= 7) { return; }\n\t\t\tdx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ;                         // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi))\n\t\t\tdx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ;                            // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))\n\t\t\tdx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ;                             // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))\n\t\t\tdx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ;                          // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))\n\t\t\tdx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ;                            // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi))\n\t\t\tdx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ;                         // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))\n\t\t\tdx[55] = 0.0f ;                            // 0\n\t\t\tdx[56] = 0.0f ;                            // 0\n\t\t\tdx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ;                         // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))\n\t\t\tdx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ;                         // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))\n\t\t\tdx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ;                              // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi))\n\t\t\tdx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ;                          // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))\n\t\t\tdx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ;                              // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi))\n\t\t\tdx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                            // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))\n\t\t\tdx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ;                         // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))\n\t\t};\n\n\t\tauto write_sh_dy = [&]() {\n\t\t\tdy[0] = 0.0f ;                             // 0\n\t\t\tif (C <= 1) { return; }\n\t\t\tdy[1] = -0.48860251190291992f ;                          // -sqrt(3)/(2*sqrt(pi))\n\t\t\tdy[2] = 0.0f ;                             // 0\n\t\t\tdy[3] = 0.0f ;                             // 0\n\t\t\tif (C <= 2) { return; }\n\t\t\tdy[4] = 1.0925484305920792f*x ;                          // sqrt(15)*x/(2*sqrt(pi))\n\t\t\tdy[5] = -1.0925484305920792f*z ;                         // -sqrt(15)*z/(2*sqrt(pi))\n\t\t\tdy[6] = 0.0f ;                             // 0\n\t\t\tdy[7] = 0.0f ;                             // 0\n\t\t\tdy[8] = -1.0925484305920792f*y ;                         // -sqrt(15)*y/(2*sqrt(pi))\n\t\t\tif (C <= 3) { return; }\n\t\t\tdy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ;                                // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))\n\t\t\tdy[10] = 2.8906114426405538f*xz ;                                // sqrt(105)*xz/(2*sqrt(pi))\n\t\t\tdy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ;                          // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))\n\t\t\tdy[12] = 0.0f ;                            // 0\n\t\t\tdy[13] = 0.0f ;                            // 0\n\t\t\tdy[14] = -2.8906114426405538f*yz ;                               // -sqrt(105)*yz/(2*sqrt(pi))\n\t\t\tdy[15] = 3.5402615395598609f*xy ;                                // 3*sqrt(70)*xy/(4*sqrt(pi))\n\t\t\tif (C <= 4) { return; }\n\t\t\tdy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ;                           // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))\n\t\t\tdy[17] = 5.3103923093397913f*z*(-x2 + y2) ;                              // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))\n\t\t\tdy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ;                         // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))\n\t\t\tdy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ;                         // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))\n\t\t\tdy[20] = 0.0f ;                            // 0\n\t\t\tdy[21] = 0.0f ;                            // 0\n\t\t\tdy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ;                         // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi))\n\t\t\tdy[23] = 10.620784618679583f*xy*z ;                              // 9*sqrt(70)*xy*z/(4*sqrt(pi))\n\t\t\tdy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ;                          // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi))\n\t\t\tif (C <= 5) { return; }\n\t\t\tdy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ;                               // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))\n\t\t\tdy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ;                          // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))\n\t\t\tdy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ;                         // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))\n\t\t\tdy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ;                         // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))\n\t\t\tdy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ;                          // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))\n\t\t\tdy[30] = 0.0f ;                            // 0\n\t\t\tdy[31] = 0.0f ;                            // 0\n\t\t\tdy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ;                         // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi))\n\t\t\tdy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ;                         // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi))\n\t\t\tdy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ;                         // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi))\n\t\t\tdy[35] = 13.127641136803401f*xy*(x2 - y2) ;                              // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi))\n\t\t\tif (C <= 6) { return; }\n\t\t\tdy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                             // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))\n\t\t\tdy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ;                           // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))\n\t\t\tdy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ;                           // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))\n\t\t\tdy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ;                              // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))\n\t\t\tdy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ;                             // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n\t\t\tdy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ;                              // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n\t\t\tdy[42] = 0.0f ;                            // 0\n\t\t\tdy[43] = 0.0f ;                            // 0\n\t\t\tdy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ;                              // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi))\n\t\t\tdy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ;                              // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi))\n\t\t\tdy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ;                          // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))\n\t\t\tdy[47] = 47.332383244635047f*xy*z*(x2 - y2) ;                            // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi))\n\t\t\tdy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                              // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n\t\t\tif (C <= 7) { return; }\n\t\t\tdy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ;                         // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))\n\t\t\tdy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                            // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))\n\t\t\tdy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ;                                // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi))\n\t\t\tdy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ;                          // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))\n\t\t\tdy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ;                             // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))\n\t\t\tdy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ;                         // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))\n\t\t\tdy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ;                         // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))\n\t\t\tdy[56] = 0.0f ;                            // 0\n\t\t\tdy[57] = 0.0f ;                            // 0\n\t\t\tdy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ;                          // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))\n\t\t\tdy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ;                           // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))\n\t\t\tdy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ;                         // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))\n\t\t\tdy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ;                              // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))\n\t\t\tdy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                             // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n\t\t\tdy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ;                                // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))\n\t\t};\n\n\t\tauto write_sh_dz = [&]() {\n\t\t\tdz[0] = 0.0f ;                             // 0\n\t\t\tif (C <= 1) { return; }\n\t\t\tdz[1] = 0.0f ;                             // 0\n\t\t\tdz[2] = 0.48860251190291992f ;                           // sqrt(3)/(2*sqrt(pi))\n\t\t\tdz[3] = 0.0f ;                             // 0\n\t\t\tif (C <= 2) { return; }\n\t\t\tdz[4] = 0.0f ;                             // 0\n\t\t\tdz[5] = -1.0925484305920792f*y ;                         // -sqrt(15)*y/(2*sqrt(pi))\n\t\t\tdz[6] = 1.8923493915151202f*z ;                          // 3*sqrt(5)*z/(2*sqrt(pi))\n\t\t\tdz[7] = -1.0925484305920792f*x ;                         // -sqrt(15)*x/(2*sqrt(pi))\n\t\t\tdz[8] = 0.0f ;                             // 0\n\t\t\tif (C <= 3) { return; }\n\t\t\tdz[9] = 0.0f ;                             // 0\n\t\t\tdz[10] = 2.8906114426405538f*xy ;                                // sqrt(105)*xy/(2*sqrt(pi))\n\t\t\tdz[11] = -4.5704579946446566f*yz ;                               // -5*sqrt(42)*yz/(4*sqrt(pi))\n\t\t\tdz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ;                            // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi))\n\t\t\tdz[13] = -4.5704579946446566f*xz ;                               // -5*sqrt(42)*xz/(4*sqrt(pi))\n\t\t\tdz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ;                                // sqrt(105)*(x2 - y2)/(4*sqrt(pi))\n\t\t\tdz[15] = 0.0f ;                            // 0\n\t\t\tif (C <= 4) { return; }\n\t\t\tdz[16] = 0.0f ;                            // 0\n\t\t\tdz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ;                          // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))\n\t\t\tdz[18] = 13.246445740605839f*xy*z ;                              // 21*sqrt(5)*xy*z/(2*sqrt(pi))\n\t\t\tdz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ;                          // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi))\n\t\t\tdz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ;                          // (105*z**3 - 45*z)/(4*sqrt(pi))\n\t\t\tdz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ;                          // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi))\n\t\t\tdz[22] = 6.6232228703029197f*z*(x2 - y2) ;                               // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi))\n\t\t\tdz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ;                          // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))\n\t\t\tdz[24] = 0.0f ;                            // 0\n\t\t\tif (C <= 5) { return; }\n\t\t\tdz[25] = 0.0f ;                            // 0\n\t\t\tdz[26] = 8.3026492595241645f*xy*(x2 - y2) ;                              // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi))\n\t\t\tdz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ;                         // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi))\n\t\t\tdz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ;                         // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi))\n\t\t\tdz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ;                         // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi))\n\t\t\tdz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ;                           // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi))\n\t\t\tdz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ;                         // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi))\n\t\t\tdz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ;                          // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi))\n\t\t\tdz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ;                         // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi))\n\t\t\tdz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ;                            // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n\t\t\tdz[35] = 0.0f ;                            // 0\n\t\t\tif (C <= 6) { return; }\n\t\t\tdz[36] = 0.0f ;                            // 0\n\t\t\tdz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                              // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n\t\t\tdz[38] = 44.401711264127719f*xy*z*(x2 - y2) ;                            // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi))\n\t\t\tdz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ;                          // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi))\n\t\t\tdz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ;                              // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi))\n\t\t\tdz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ;                               // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))\n\t\t\tdz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ;                              // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi))\n\t\t\tdz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ;                               // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))\n\t\t\tdz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ;                               // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi))\n\t\t\tdz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ;                          // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi))\n\t\t\tdz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ;                           // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n\t\t\tdz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ;                              // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))\n\t\t\tdz[48] = 0.0f ;                            // 0\n\t\t\tif (C <= 7) { return; }\n\t\t\tdz[49] = 0.0f ;                            // 0\n\t\t\tdz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ;                                // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))\n\t\t\tdz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                             // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n\t\t\tdz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ;                              // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi))\n\t\t\tdz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ;                         // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi))\n\t\t\tdz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ;                            // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))\n\t\t\tdz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ;                           // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))\n\t\t\tdz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ;                           // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi))\n\t\t\tdz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ;                           // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))\n\t\t\tdz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ;                         // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi))\n\t\t\tdz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ;                         // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi))\n\t\t\tdz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ;                            // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))\n\t\t\tdz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ;                             // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))\n\t\t\tdz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ;                            // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))\n\t\t\tdz[63] = 0.0f ;                            // 0\n\t\t};\n\t\twrite_sh_dx();\n\t\twrite_sh_dy();\n\t\twrite_sh_dz();\n\t}\n}\n\n\ntemplate <typename scalar_t>\n__global__ void kernel_sh_backward(\n    const scalar_t * __restrict__ grad,\n\tconst scalar_t * __restrict__ inputs,\n    uint32_t B, uint32_t D, uint32_t C,\n    const scalar_t * __restrict__ dy_dx,\n    scalar_t * grad_inputs\n) {\n\tconst uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;\n\tconst uint32_t b = t / D;\n\tif (b >= B) return;\n\n\tconst uint32_t d = t - b * D;\n\tconst uint32_t C2 = C * C;\n\n\t// locate\n\tgrad += b * C2;\n\tdy_dx += b * D * C2 + d * C2;\n\n\tfor (int ch = 0; ch < C2; ch++) {\n\t\tgrad_inputs[t] += grad[ch] * dy_dx[ch];\n\t\t//printf(\"t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\\n\", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]);\n\t}\n\n}\n\n// inputs: [B, D], float, in [0, 1]\n// outputs: [B, L * C], float\ntemplate <typename scalar_t>\nvoid sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) {\n\tstatic constexpr uint32_t N_THREADS = 256;\n\tkernel_sh<scalar_t><<<div_round_up(B, N_THREADS), N_THREADS>>>(inputs, outputs, B, D, C, dy_dx);\n}\n\n\ntemplate <typename scalar_t>\nvoid sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) {\n\tstatic constexpr uint32_t N_THREADS = 256;\n\tkernel_sh_backward<scalar_t><<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad, inputs, B, D, C, dy_dx, grad_inputs);\n}\n\n\nvoid sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx) {\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(outputs);\n    // CHECK_CUDA(dy_dx);\n    \n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(outputs);\n    // CHECK_CONTIGUOUS(dy_dx);\n\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(outputs);\n    // CHECK_IS_FLOATING(dy_dx);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    inputs.scalar_type(), \"sh_encode_forward_cuda\", ([&] {\n\t\tsh_encode_forward_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), outputs.data_ptr<scalar_t>(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr);\n    }));\t\n}\n\nvoid sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) {    \n    CHECK_CUDA(grad);\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(dy_dx);\n    CHECK_CUDA(grad_inputs);\n    \n    CHECK_CONTIGUOUS(grad);\n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(dy_dx);\n    CHECK_CONTIGUOUS(grad_inputs);\n\n    CHECK_IS_FLOATING(grad);\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(dy_dx);\n    CHECK_IS_FLOATING(grad_inputs);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    grad.scalar_type(), \"sh_encode_backward_cuda\", ([&] {\n    \tsh_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<scalar_t>(), B, D, C, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>());\n    }));\t\n}"
  },
  {
    "path": "shencoder/src/shencoder.h",
    "content": "# pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// inputs: [B, D], float, in [-1, 1]\n// outputs: [B, F], float\n\nvoid sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx);\nvoid sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs);"
  },
  {
    "path": "taichi_modules/__init__.py",
    "content": "from .ray_march import RayMarcherTaichi, raymarching_test\nfrom .volume_train import VolumeRendererTaichi\nfrom .intersection import RayAABBIntersector\nfrom .volume_render_test import composite_test\nfrom .utils import packbits"
  },
  {
    "path": "taichi_modules/hash_encoder.py",
    "content": "import numpy as np\nimport taichi as ti\nimport torch\nfrom taichi.math import uvec3\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\nfrom .utils import (data_type, ti2torch, ti2torch_grad, ti2torch_grad_vec,\n                    ti2torch_vec, torch2ti, torch2ti_grad, torch2ti_grad_vec,\n                    torch2ti_vec, torch_type)\n\nhalf2 = ti.types.vector(n=2, dtype=ti.f16)\n\n\n@ti.kernel\ndef random_initialize(data: ti.types.ndarray()):\n    for I in ti.grouped(data):\n        data[I] = (ti.random() * 2.0 - 1.0) * 1e-4\n\n\n@ti.kernel\ndef ti_copy(data1: ti.template(), data2: ti.template()):\n    for I in ti.grouped(data1):\n        data1[I] = data2[I]\n\n\n@ti.kernel\ndef ti_copy_array(data1: ti.types.ndarray(), data2: ti.types.ndarray()):\n    for I in ti.grouped(data1):\n        data1[I] = data2[I]\n\n\n@ti.kernel\ndef ti_copy_field_array(data1: ti.template(), data2: ti.types.ndarray()):\n    for I in ti.grouped(data1):\n        data1[I] = data2[I]\n\n\n@ti.func\ndef fast_hash(pos_grid_local):\n    result = ti.uint32(0)\n    # primes = uvec3(ti.uint32(1), ti.uint32(1958374283), ti.uint32(2654435761))\n    primes = uvec3(ti.uint32(1), ti.uint32(2654435761), ti.uint32(805459861))\n    for i in ti.static(range(3)):\n        result ^= ti.uint32(pos_grid_local[i]) * primes[i]\n    return result\n\n\n@ti.func\ndef under_hash(pos_grid_local, resolution):\n    result = ti.uint32(0)\n    stride = ti.uint32(1)\n    for i in ti.static(range(3)):\n        result += ti.uint32(pos_grid_local[i] * stride)\n        stride *= resolution\n    return result\n\n\n@ti.func\ndef grid_pos2hash_index(indicator, pos_grid_local, resolution, map_size):\n    hash_result = ti.uint32(0)\n    if indicator == 1:\n        hash_result = under_hash(pos_grid_local, resolution)\n    else:\n        hash_result = fast_hash(pos_grid_local)\n\n    return hash_result % map_size\n\n\n@ti.kernel\ndef hash_encode_kernel(\n        xyzs: ti.template(), table: ti.template(),\n        xyzs_embedding: ti.template(), hash_map_indicator: ti.template(),\n        hash_map_sizes_field: ti.template(), offsets: ti.template(), B: ti.i32,\n        per_level_scale: ti.f32):\n\n    # get hash table embedding\n    ti.loop_config(block_dim=16)\n    for i, level in ti.ndrange(B, 16):\n        xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]])\n\n        scale = 16 * ti.exp(level * ti.log(per_level_scale)) - 1.0\n        resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1\n\n        offset = offsets[level] * 2\n\n        pos = xyz * scale + 0.5\n        pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32)\n        pos -= pos_grid_uint\n\n        indicator = hash_map_indicator[level]\n        map_size = hash_map_sizes_field[level]\n\n        local_feature_0 = 0.0\n        local_feature_1 = 0.0\n\n        for idx in ti.static(range(8)):\n            w = 1.\n            pos_grid_local = uvec3(0)\n\n            for d in ti.static(range(3)):\n                if (idx & (1 << d)) == 0:\n                    pos_grid_local[d] = pos_grid_uint[d]\n                    w *= 1 - pos[d]\n                else:\n                    pos_grid_local[d] = pos_grid_uint[d] + 1\n                    w *= pos[d]\n\n            index = grid_pos2hash_index(indicator, pos_grid_local, resolution,\n                                        map_size)\n            index_table = offset + index * 2\n            index_table_int = ti.cast(index_table, ti.int32)\n            local_feature_0 += w * table[index_table_int]\n            local_feature_1 += w * table[index_table_int + 1]\n\n        xyzs_embedding[i, level * 2] = local_feature_0\n        xyzs_embedding[i, level * 2 + 1] = local_feature_1\n\n\n@ti.kernel\ndef hash_encode_kernel_half2(\n        xyzs: ti.template(), table: ti.template(),\n        xyzs_embedding: ti.template(), hash_map_indicator: ti.template(),\n        hash_map_sizes_field: ti.template(), offsets: ti.template(), B: ti.i32,\n        per_level_scale: ti.f16):\n\n    # get hash table embedding\n    ti.loop_config(block_dim=32)\n    for i, level in ti.ndrange(B, 16):\n        xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]])\n\n        scale = 16 * ti.exp(level * ti.log(per_level_scale)) - 1.0\n        resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1\n\n        offset = offsets[level]\n\n        pos = xyz * scale + 0.5\n        pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32)\n        pos -= pos_grid_uint\n\n        indicator = hash_map_indicator[level]\n        map_size = hash_map_sizes_field[level]\n\n        local_feature = half2(0.0)\n        for idx in ti.static(range(8)):\n            w = ti.f32(1.0)\n            pos_grid_local = uvec3(0)\n\n            for d in ti.static(range(3)):\n                if (idx & (1 << d)) == 0:\n                    pos_grid_local[d] = pos_grid_uint[d]\n                    w *= 1 - pos[d]\n                else:\n                    pos_grid_local[d] = pos_grid_uint[d] + 1\n                    w *= pos[d]\n\n            index = grid_pos2hash_index(indicator, pos_grid_local, resolution,\n                                        map_size)\n\n            index_table = offset + index\n            index_table_int = ti.cast(index_table, ti.int32)\n\n            local_feature += w * table[index_table_int]\n        xyzs_embedding[i, level] = local_feature\n\n\nclass HashEncoderTaichi(torch.nn.Module):\n\n    def __init__(self,\n                 b=1.3195079565048218,\n                 batch_size=8192,\n                 data_type=data_type,\n                 half2_opt=False):\n        super(HashEncoderTaichi, self).__init__()\n\n        self.per_level_scale = b\n        if batch_size < 2048:\n            batch_size = 2048\n\n        # per_level_scale = 1.3195079565048218\n        print(\"per_level_scale: \", b)\n        self.offsets = ti.field(ti.i32, shape=(16, ))\n        self.hash_map_sizes_field = ti.field(ti.uint32, shape=(16, ))\n        self.hash_map_indicator = ti.field(ti.i32, shape=(16, ))\n        base_res = 16\n        max_params = 2**19\n        offset_ = 0\n        hash_map_sizes = []\n        for i in range(16):\n            resolution = int(\n                np.ceil(base_res * np.exp(i * np.log(self.per_level_scale)) -\n                        1.0)) + 1\n            params_in_level = resolution**3\n            params_in_level = int(resolution**\n                                  3) if params_in_level % 8 == 0 else int(\n                                      (params_in_level + 8 - 1) / 8) * 8\n            params_in_level = min(max_params, params_in_level)\n            self.offsets[i] = offset_\n            hash_map_sizes.append(params_in_level)\n            self.hash_map_indicator[\n                i] = 1 if resolution**3 <= params_in_level else 0\n            offset_ += params_in_level\n        print(\"offset_: \", offset_)\n        size = np.uint32(np.array(hash_map_sizes))\n        self.hash_map_sizes_field.from_numpy(size)\n\n        self.total_hash_size = offset_ * 2\n        print(\"total_hash_size: \", self.total_hash_size)\n\n        self.hash_table = torch.nn.Parameter(torch.zeros(self.total_hash_size,\n                                                         dtype=torch_type),\n                                             requires_grad=True)\n        random_initialize(self.hash_table)\n\n        if half2_opt:\n            assert self.total_hash_size % 2 == 0\n            self.parameter_fields = half2.field(shape=(self.total_hash_size //\n                                                       2, ),\n                                                needs_grad=True)\n            self.output_fields = half2.field(shape=(batch_size * 1024, 16),\n                                             needs_grad=True)\n\n            self.torch2ti = torch2ti_vec\n            self.ti2torch = ti2torch_vec\n            self.ti2torch_grad = ti2torch_grad_vec\n            self.torch2ti_grad = torch2ti_grad_vec\n\n            self._hash_encode_kernel = hash_encode_kernel_half2\n        else:\n            self.parameter_fields = ti.field(data_type,\n                                             shape=(self.total_hash_size, ),\n                                             needs_grad=True)\n            self.output_fields = ti.field(dtype=data_type,\n                                          shape=(batch_size * 1024, 32),\n                                          needs_grad=True)\n            self.torch2ti = torch2ti\n            self.ti2torch = ti2torch\n            self.ti2torch_grad = ti2torch_grad\n            self.torch2ti_grad = torch2ti_grad\n\n            self._hash_encode_kernel = hash_encode_kernel\n\n        self.input_fields = ti.field(dtype=data_type,\n                                     shape=(batch_size * 1024, 3),\n                                     needs_grad=True)\n        self.output_dim = 32 # the output dim: num levels (16) x level num (2)\n        self.register_buffer(\n            'hash_grad', torch.zeros(self.total_hash_size, dtype=torch_type))\n        self.register_buffer(\n            'output_embedding',\n            torch.zeros(batch_size * 1024, 32, dtype=torch_type))\n\n        class _module_function(torch.autograd.Function):\n\n            @staticmethod\n            @custom_fwd(cast_inputs=torch_type)\n            def forward(ctx, input_pos, params):\n                output_embedding = self.output_embedding[:input_pos.\n                                                         shape[0]].contiguous(\n                                                         )\n                torch2ti(self.input_fields, input_pos.contiguous())\n                self.torch2ti(self.parameter_fields, params.contiguous())\n\n                self._hash_encode_kernel(\n                    self.input_fields,\n                    self.parameter_fields,\n                    self.output_fields,\n                    self.hash_map_indicator,\n                    self.hash_map_sizes_field,\n                    self.offsets,\n                    input_pos.shape[0],\n                    self.per_level_scale,\n                )\n                self.ti2torch(self.output_fields, output_embedding)\n\n                return output_embedding\n\n            @staticmethod\n            @custom_bwd\n            def backward(ctx, doutput):\n\n                self.zero_grad()\n\n                self.torch2ti_grad(self.output_fields, doutput.contiguous())\n                self._hash_encode_kernel.grad(\n                    self.input_fields,\n                    self.parameter_fields,\n                    self.output_fields,\n                    self.hash_map_indicator,\n                    self.hash_map_sizes_field,\n                    self.offsets,\n                    doutput.shape[0],\n                    self.per_level_scale,\n                )\n                self.ti2torch_grad(self.parameter_fields,\n                                   self.hash_grad.contiguous())\n                return None, self.hash_grad\n\n        self._module_function = _module_function\n\n    def zero_grad(self):\n        self.parameter_fields.grad.fill(0.)\n\n    def forward(self, positions, bound=1):\n        positions = (positions + bound) / (2 * bound) \n        return self._module_function.apply(positions, self.hash_table)\n"
  },
  {
    "path": "taichi_modules/intersection.py",
    "content": "import taichi as ti\nimport torch\nfrom taichi.math import vec3\nfrom torch.cuda.amp import custom_fwd\n\nfrom .utils import NEAR_DISTANCE\n\n\n@ti.kernel\ndef simple_ray_aabb_intersec_taichi_forward(\n        hits_t: ti.types.ndarray(ndim=2),\n        rays_o: ti.types.ndarray(ndim=2),\n        rays_d: ti.types.ndarray(ndim=2),\n        centers: ti.types.ndarray(ndim=2),\n        half_sizes: ti.types.ndarray(ndim=2)):\n\n    for r in ti.ndrange(hits_t.shape[0]):\n        ray_o = vec3([rays_o[r, 0], rays_o[r, 1], rays_o[r, 2]])\n        ray_d = vec3([rays_d[r, 0], rays_d[r, 1], rays_d[r, 2]])\n        inv_d = 1.0 / ray_d\n\n        center = vec3([centers[0, 0], centers[0, 1], centers[0, 2]])\n        half_size = vec3(\n            [half_sizes[0, 0], half_sizes[0, 1], half_sizes[0, 1]])\n\n        t_min = (center - half_size - ray_o) * inv_d\n        t_max = (center + half_size - ray_o) * inv_d\n\n        _t1 = ti.min(t_min, t_max)\n        _t2 = ti.max(t_min, t_max)\n        t1 = _t1.max()\n        t2 = _t2.min()\n\n        if t2 > 0.0:\n            hits_t[r, 0, 0] = ti.max(t1, NEAR_DISTANCE)\n            hits_t[r, 0, 1] = t2\n\n\nclass RayAABBIntersector(torch.autograd.Function):\n    \"\"\"\n    Computes the intersections of rays and axis-aligned voxels.\n\n    Inputs:\n        rays_o: (N_rays, 3) ray origins\n        rays_d: (N_rays, 3) ray directions\n        centers: (N_voxels, 3) voxel centers\n        half_sizes: (N_voxels, 3) voxel half sizes\n        max_hits: maximum number of intersected voxels to keep for one ray\n                  (for a cubic scene, this is at most 3*N_voxels^(1/3)-2)\n\n    Outputs:\n        hits_cnt: (N_rays) number of hits for each ray\n        (followings are from near to far)\n        hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit)\n        hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit)\n    \"\"\"\n\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, rays_o, rays_d, center, half_size, max_hits):\n        hits_t = (torch.zeros(\n            rays_o.size(0), 1, 2, device=rays_o.device, dtype=torch.float32) -\n                  1).contiguous()\n\n        simple_ray_aabb_intersec_taichi_forward(hits_t, rays_o, rays_d, center,\n                                                half_size)\n\n        return None, hits_t, None\n"
  },
  {
    "path": "taichi_modules/ray_march.py",
    "content": "import taichi as ti\nimport torch\nfrom taichi.math import vec3\nfrom torch.cuda.amp import custom_fwd\n\nfrom .utils import __morton3D, calc_dt, mip_from_dt, mip_from_pos\n\n\n@ti.kernel\ndef raymarching_train(rays_o: ti.types.ndarray(ndim=2),\n                      rays_d: ti.types.ndarray(ndim=2),\n                      hits_t: ti.types.ndarray(ndim=2),\n                      density_bitfield: ti.types.ndarray(ndim=1),\n                      noise: ti.types.ndarray(ndim=1),\n                      counter: ti.types.ndarray(ndim=1),\n                      rays_a: ti.types.ndarray(ndim=2),\n                      xyzs: ti.types.ndarray(ndim=2),\n                      dirs: ti.types.ndarray(ndim=2),\n                      deltas: ti.types.ndarray(ndim=1),\n                      ts: ti.types.ndarray(ndim=1), cascades: int,\n                      grid_size: int, scale: float, exp_step_factor: float,\n                      max_samples: float):\n\n    # ti.loop_config(block_dim=256)\n    for r in noise:\n        ray_o = vec3(rays_o[r, 0], rays_o[r, 1], rays_o[r, 2])\n        ray_d = vec3(rays_d[r, 0], rays_d[r, 1], rays_d[r, 2])\n        d_inv = 1.0 / ray_d\n\n        t1, t2 = hits_t[r, 0], hits_t[r, 1]\n\n        grid_size3 = grid_size**3\n        grid_size_inv = 1.0 / grid_size\n\n        if t1 >= 0:\n            dt = calc_dt(t1, exp_step_factor, grid_size, scale)\n            t1 += dt * noise[r]\n\n        t = t1\n        N_samples = 0\n\n        while (0 <= t) & (t < t2) & (N_samples < max_samples):\n            xyz = ray_o + t * ray_d\n            dt = calc_dt(t, exp_step_factor, grid_size, scale)\n            mip = ti.max(mip_from_pos(xyz, cascades),\n                         mip_from_dt(dt, grid_size, cascades))\n\n            # mip_bound = 0.5\n            # mip_bound = ti.min(ti.pow(2., mip - 1), scale)\n            mip_bound = scale\n            mip_bound_inv = 1 / mip_bound\n\n            nxyz = ti.math.clamp(0.5 * (xyz * mip_bound_inv + 1) * grid_size,\n                                 0.0, grid_size - 1.0)\n            # nxyz = ti.ceil(nxyz)\n\n            idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32))\n            occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8))\n            # idx = __morton3D(ti.cast(nxyz, ti.uint32))\n            # occ = density_bitfield[mip, idx//8] & (1 << ti.cast(idx%8, ti.uint32))\n\n            if occ:\n                t += dt\n                N_samples += 1\n            else:\n                # t += dt\n                txyz = (((nxyz + 0.5 + 0.5 * ti.math.sign(ray_d)) *\n                         grid_size_inv * 2 - 1) * mip_bound - xyz) * d_inv\n\n                t_target = t + ti.max(0, txyz.min())\n                t += calc_dt(t, exp_step_factor, grid_size, scale)\n                while t < t_target:\n                    t += calc_dt(t, exp_step_factor, grid_size, scale)\n\n        start_idx = ti.atomic_add(counter[0], N_samples)\n        ray_count = ti.atomic_add(counter[1], 1)\n\n        rays_a[ray_count, 0] = r\n        rays_a[ray_count, 1] = start_idx\n        rays_a[ray_count, 2] = N_samples\n\n        t = t1\n        samples = 0\n\n        while (t < t2) & (samples < N_samples):\n            xyz = ray_o + t * ray_d\n            dt = calc_dt(t, exp_step_factor, grid_size, scale)\n            mip = ti.max(mip_from_pos(xyz, cascades),\n                         mip_from_dt(dt, grid_size, cascades))\n\n            # mip_bound = 0.5\n            # mip_bound = ti.min(ti.pow(2., mip - 1), scale)\n            mip_bound = scale\n            mip_bound_inv = 1 / mip_bound\n\n            nxyz = ti.math.clamp(0.5 * (xyz * mip_bound_inv + 1) * grid_size,\n                                 0.0, grid_size - 1.0)\n            # nxyz = ti.ceil(nxyz)\n\n            idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32))\n            occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8))\n            # idx = __morton3D(ti.cast(nxyz, ti.uint32))\n            # occ = density_bitfield[mip, idx//8] & (1 << ti.cast(idx%8, ti.uint32))\n\n            if occ:\n                s = start_idx + samples\n                xyzs[s, 0] = xyz[0]\n                xyzs[s, 1] = xyz[1]\n                xyzs[s, 2] = xyz[2]\n                dirs[s, 0] = ray_d[0]\n                dirs[s, 1] = ray_d[1]\n                dirs[s, 2] = ray_d[2]\n                ts[s] = t\n                deltas[s] = dt\n                t += dt\n                samples += 1\n            else:\n                # t += dt\n                txyz = (((nxyz + 0.5 + 0.5 * ti.math.sign(ray_d)) *\n                         grid_size_inv * 2 - 1) * mip_bound - xyz) * d_inv\n\n                t_target = t + ti.max(0, txyz.min())\n                t += calc_dt(t, exp_step_factor, grid_size, scale)\n                while t < t_target:\n                    t += calc_dt(t, exp_step_factor, grid_size, scale)\n\n\n@ti.kernel\ndef raymarching_train_backword(segments: ti.types.ndarray(ndim=2),\n                               ts: ti.types.ndarray(ndim=1),\n                               dL_drays_o: ti.types.ndarray(ndim=2),\n                               dL_drays_d: ti.types.ndarray(ndim=2),\n                               dL_dxyzs: ti.types.ndarray(ndim=2),\n                               dL_ddirs: ti.types.ndarray(ndim=2)):\n\n    for s in segments:\n        index = segments[s]\n        dxyz = dL_dxyzs[index]\n        ddir = dL_ddirs[index]\n\n        dL_drays_o[s] = dxyz\n        dL_drays_d[s] = dxyz * ts[index] + ddir\n\n\nclass RayMarcherTaichi(torch.nn.Module):\n\n    def __init__(self, batch_size=8192):\n        super(RayMarcherTaichi, self).__init__()\n\n        self.register_buffer('rays_a',\n                             torch.zeros(batch_size, 3, dtype=torch.int32))\n        self.register_buffer(\n            'xyzs', torch.zeros(batch_size * 1024, 3, dtype=torch.float32))\n        self.register_buffer(\n            'dirs', torch.zeros(batch_size * 1024, 3, dtype=torch.float32))\n        self.register_buffer(\n            'deltas', torch.zeros(batch_size * 1024, dtype=torch.float32))\n        self.register_buffer(\n            'ts', torch.zeros(batch_size * 1024, dtype=torch.float32))\n\n        # self.register_buffer('dL_drays_o', torch.zeros(batch_size, dtype=torch.float32))\n        # self.register_buffer('dL_drays_d', torch.zeros(batch_size, dtype=torch.float32))\n\n        class _module_function(torch.autograd.Function):\n\n            @staticmethod\n            @custom_fwd(cast_inputs=torch.float32)\n            def forward(ctx, rays_o, rays_d, hits_t, density_bitfield,\n                        cascades, scale, exp_step_factor, grid_size,\n                        max_samples):\n                # noise to perturb the first sample of each ray\n                noise = torch.rand_like(rays_o[:, 0])\n                counter = torch.zeros(2,\n                                      device=rays_o.device,\n                                      dtype=torch.int32)\n\n                raymarching_train(\\\n                    rays_o, rays_d,\n                    hits_t.contiguous(),\n                    density_bitfield, noise, counter,\n                    self.rays_a.contiguous(),\n                    self.xyzs.contiguous(),\n                    self.dirs.contiguous(),\n                    self.deltas.contiguous(),\n                    self.ts.contiguous(),\n                    cascades, grid_size, scale,\n                    exp_step_factor, max_samples)\n\n                # ti.sync()\n\n                total_samples = counter[0]  # total samples for all rays\n                # remove redundant output\n                xyzs = self.xyzs[:total_samples]\n                dirs = self.dirs[:total_samples]\n                deltas = self.deltas[:total_samples]\n                ts = self.ts[:total_samples]\n\n                return self.rays_a, xyzs, dirs, deltas, ts, total_samples\n\n                # @staticmethod\n                # @custom_bwd\n                # def backward(ctx, dL_drays_a, dL_dxyzs, dL_ddirs, dL_ddeltas, dL_dts,\n                #              dL_dtotal_samples):\n                #     rays_a, ts = ctx.saved_tensors\n                #     # rays_a = rays_a.contiguous()\n                #     ts = ts.contiguous()\n                #     segments = torch.cat([rays_a[:, 1], rays_a[-1:, 1] + rays_a[-1:, 2]])\n                #     dL_drays_o = torch.zeros_like(rays_a[:, 0])\n                #     dL_drays_d = torch.zeros_like(rays_a[:, 0])\n                #     raymarching_train_backword(segments.contiguous(), ts, dL_drays_o,\n                #                                dL_drays_d, dL_dxyzs, dL_ddirs)\n                #     # ti.sync()\n                #     # dL_drays_o = segment_csr(dL_dxyzs, segments)\n                #     # dL_drays_d = \\\n                #     #     segment_csr(dL_dxyzs*rearrange(ts, 'n -> n 1')+dL_ddirs, segments)\n\n                #     return dL_drays_o, dL_drays_d, None, None, None, None, None, None, None\n\n        self._module_function = _module_function\n\n    def forward(self, rays_o, rays_d, hits_t, density_bitfield, cascades,\n                scale, exp_step_factor, grid_size, max_samples):\n        return self._module_function.apply(rays_o, rays_d, hits_t,\n                                           density_bitfield, cascades, scale,\n                                           exp_step_factor, grid_size,\n                                           max_samples)\n\n\n@ti.kernel\ndef raymarching_test_kernel(\n        rays_o: ti.types.ndarray(ndim=2),\n        rays_d: ti.types.ndarray(ndim=2),\n        hits_t: ti.types.ndarray(ndim=2),\n        alive_indices: ti.types.ndarray(ndim=1),\n        density_bitfield: ti.types.ndarray(ndim=1),\n        cascades: int,\n        grid_size: int,\n        scale: float,\n        exp_step_factor: float,\n        N_samples: int,\n        max_samples: int,\n        xyzs: ti.types.ndarray(ndim=2),\n        dirs: ti.types.ndarray(ndim=2),\n        deltas: ti.types.ndarray(ndim=1),\n        ts: ti.types.ndarray(ndim=1),\n        N_eff_samples: ti.types.ndarray(ndim=1),\n):\n\n    for n in alive_indices:\n        r = alive_indices[n]\n        grid_size3 = grid_size**3\n        grid_size_inv = 1.0 / grid_size\n\n        ray_o = vec3(rays_o[r, 0], rays_o[r, 1], rays_o[r, 2])\n        ray_d = vec3(rays_d[r, 0], rays_d[r, 1], rays_d[r, 2])\n        d_inv = 1.0 / ray_d\n\n        t = hits_t[r, 0]\n        t2 = hits_t[r, 1]\n\n        s = 0\n\n        while (0 <= t) & (t < t2) & (s < N_samples):\n            xyz = ray_o + t * ray_d\n            dt = calc_dt(t, exp_step_factor, grid_size, scale)\n            mip = ti.max(mip_from_pos(xyz, cascades),\n                         mip_from_dt(dt, grid_size, cascades))\n\n            # mip_bound = 0.5\n            # mip_bound = ti.min(ti.pow(2., mip - 1), scale)\n            mip_bound = scale\n            mip_bound_inv = 1 / mip_bound\n\n            nxyz = ti.math.clamp(0.5 * (xyz * mip_bound_inv + 1) * grid_size,\n                                 0.0, grid_size - 1.0)\n            # nxyz = ti.ceil(nxyz)\n\n            idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32))\n            occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8))\n\n            if occ:\n                xyzs[n, s, 0] = xyz[0]\n                xyzs[n, s, 1] = xyz[1]\n                xyzs[n, s, 2] = xyz[2]\n                dirs[n, s, 0] = ray_d[0]\n                dirs[n, s, 1] = ray_d[1]\n                dirs[n, s, 2] = ray_d[2]\n                ts[n, s] = t\n                deltas[n, s] = dt\n                t += dt\n                hits_t[r, 0] = t\n                s += 1\n\n            else:\n                txyz = (((nxyz + 0.5 + 0.5 * ti.math.sign(ray_d)) *\n                         grid_size_inv * 2 - 1) * mip_bound - xyz) * d_inv\n\n                t_target = t + ti.max(0, txyz.min())\n                t += calc_dt(t, exp_step_factor, grid_size, scale)\n                while t < t_target:\n                    t += calc_dt(t, exp_step_factor, grid_size, scale)\n\n        N_eff_samples[n] = s\n\n\ndef raymarching_test(rays_o, rays_d, hits_t, alive_indices, density_bitfield,\n                     cascades, scale, exp_step_factor, grid_size, max_samples,\n                     N_samples):\n\n    N_rays = alive_indices.size(0)\n    xyzs = torch.zeros(N_rays,\n                       N_samples,\n                       3,\n                       device=rays_o.device,\n                       dtype=rays_o.dtype)\n    dirs = torch.zeros(N_rays,\n                       N_samples,\n                       3,\n                       device=rays_o.device,\n                       dtype=rays_o.dtype)\n    deltas = torch.zeros(N_rays,\n                         N_samples,\n                         device=rays_o.device,\n                         dtype=rays_o.dtype)\n    ts = torch.zeros(N_rays,\n                     N_samples,\n                     device=rays_o.device,\n                     dtype=rays_o.dtype)\n    N_eff_samples = torch.zeros(N_rays,\n                                device=rays_o.device,\n                                dtype=torch.int32)\n\n    raymarching_test_kernel(rays_o, rays_d, hits_t, alive_indices,\n                            density_bitfield, cascades, grid_size, scale,\n                            exp_step_factor, N_samples, max_samples, xyzs,\n                            dirs, deltas, ts, N_eff_samples)\n\n    # ti.sync()\n\n    return xyzs, dirs, deltas, ts, N_eff_samples\n"
  },
  {
    "path": "taichi_modules/utils.py",
    "content": "import taichi as ti\nimport torch\nfrom taichi.math import uvec3\n\ntaichi_block_size = 128\n\ndata_type = ti.f32\ntorch_type = torch.float32\n\nMAX_SAMPLES = 1024\nNEAR_DISTANCE = 0.01\nSQRT3 = 1.7320508075688772\nSQRT3_MAX_SAMPLES = SQRT3 / 1024\nSQRT3_2 = 1.7320508075688772 * 2\n\n\n@ti.func\ndef scalbn(x, exponent):\n    return x * ti.math.pow(2, exponent)\n\n\n@ti.func\ndef calc_dt(t, exp_step_factor, grid_size, scale):\n    return ti.math.clamp(t * exp_step_factor, SQRT3_MAX_SAMPLES,\n                         SQRT3_2 * scale / grid_size)\n\n\n@ti.func\ndef frexp_bit(x):\n    exponent = 0\n    if x != 0.0:\n        # frac = ti.abs(x)\n        bits = ti.bit_cast(x, ti.u32)\n        exponent = ti.i32((bits & ti.u32(0x7f800000)) >> 23) - 127\n        # exponent = (ti.i32(bits & ti.u32(0x7f800000)) >> 23) - 127\n        bits &= ti.u32(0x7fffff)\n        bits |= ti.u32(0x3f800000)\n        frac = ti.bit_cast(bits, ti.f32)\n        if frac < 0.5:\n            exponent -= 1\n        elif frac > 1.0:\n            exponent += 1\n    return exponent\n\n\n@ti.func\ndef mip_from_pos(xyz, cascades):\n    mx = ti.abs(xyz).max()\n    # _, exponent = _frexp(mx)\n    exponent = frexp_bit(ti.f32(mx)) + 1\n    # frac, exponent = ti.frexp(ti.f32(mx))\n    return ti.min(cascades - 1, ti.max(0, exponent))\n\n\n@ti.func\ndef mip_from_dt(dt, grid_size, cascades):\n    # _, exponent = _frexp(dt*grid_size)\n    exponent = frexp_bit(ti.f32(dt * grid_size))\n    # frac, exponent = ti.frexp(ti.f32(dt*grid_size))\n    return ti.min(cascades - 1, ti.max(0, exponent))\n\n\n@ti.func\ndef __expand_bits(v):\n    v = (v * ti.uint32(0x00010001)) & ti.uint32(0xFF0000FF)\n    v = (v * ti.uint32(0x00000101)) & ti.uint32(0x0F00F00F)\n    v = (v * ti.uint32(0x00000011)) & ti.uint32(0xC30C30C3)\n    v = (v * ti.uint32(0x00000005)) & ti.uint32(0x49249249)\n    return v\n\n\n@ti.func\ndef __morton3D(xyz):\n    xyz = __expand_bits(xyz)\n    return xyz[0] | (xyz[1] << 1) | (xyz[2] << 2)\n\n\n@ti.func\ndef __morton3D_invert(x):\n    x = x & (0x49249249)\n    x = (x | (x >> 2)) & ti.uint32(0xc30c30c3)\n    x = (x | (x >> 4)) & ti.uint32(0x0f00f00f)\n    x = (x | (x >> 8)) & ti.uint32(0xff0000ff)\n    x = (x | (x >> 16)) & ti.uint32(0x0000ffff)\n    return ti.int32(x)\n\n\n@ti.kernel\ndef morton3D_invert_kernel(indices: ti.types.ndarray(ndim=1),\n                           coords: ti.types.ndarray(ndim=2)):\n    for i in indices:\n        ind = ti.uint32(indices[i])\n        coords[i, 0] = __morton3D_invert(ind >> 0)\n        coords[i, 1] = __morton3D_invert(ind >> 1)\n        coords[i, 2] = __morton3D_invert(ind >> 2)\n\n\ndef morton3D_invert(indices):\n    coords = torch.zeros(indices.size(0),\n                         3,\n                         device=indices.device,\n                         dtype=torch.int32)\n    morton3D_invert_kernel(indices.contiguous(), coords)\n    ti.sync()\n    return coords\n\n\n@ti.kernel\ndef morton3D_kernel(xyzs: ti.types.ndarray(ndim=2),\n                    indices: ti.types.ndarray(ndim=1)):\n    for s in indices:\n        xyz = uvec3([xyzs[s, 0], xyzs[s, 1], xyzs[s, 2]])\n        indices[s] = ti.cast(__morton3D(xyz), ti.int32)\n\n\ndef morton3D(coords1):\n    indices = torch.zeros(coords1.size(0),\n                          device=coords1.device,\n                          dtype=torch.int32)\n    morton3D_kernel(coords1.contiguous(), indices)\n    ti.sync()\n    return indices\n\n\n@ti.kernel\ndef packbits(density_grid: ti.types.ndarray(ndim=1),\n             density_threshold: float,\n             density_bitfield: ti.types.ndarray(ndim=1)):\n\n    for n in density_bitfield:\n        bits = ti.uint8(0)\n\n        for i in ti.static(range(8)):\n            bits |= (ti.uint8(1) << i) if (\n                density_grid[8 * n + i] > density_threshold) else ti.uint8(0)\n\n        density_bitfield[n] = bits\n\n\n@ti.kernel\ndef torch2ti(field: ti.template(), data: ti.types.ndarray()):\n    for I in ti.grouped(data):\n        field[I] = data[I]\n\n\n@ti.kernel\ndef ti2torch(field: ti.template(), data: ti.types.ndarray()):\n    for I in ti.grouped(data):\n        data[I] = field[I]\n\n\n@ti.kernel\ndef ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()):\n    for I in ti.grouped(grad):\n        grad[I] = field.grad[I]\n\n\n@ti.kernel\ndef torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()):\n    for I in ti.grouped(grad):\n        field.grad[I] = grad[I]\n\n\n@ti.kernel\ndef torch2ti_vec(field: ti.template(), data: ti.types.ndarray()):\n    for I in range(data.shape[0] // 2):\n        field[I] = ti.Vector([data[I * 2], data[I * 2 + 1]])\n\n\n@ti.kernel\ndef ti2torch_vec(field: ti.template(), data: ti.types.ndarray()):\n    for i, j in ti.ndrange(data.shape[0], data.shape[1] // 2):\n        data[i, j * 2] = field[i, j][0]\n        data[i, j * 2 + 1] = field[i, j][1]\n\n\n@ti.kernel\ndef ti2torch_grad_vec(field: ti.template(), grad: ti.types.ndarray()):\n    for I in range(grad.shape[0] // 2):\n        grad[I * 2] = field.grad[I][0]\n        grad[I * 2 + 1] = field.grad[I][1]\n\n\n@ti.kernel\ndef torch2ti_grad_vec(field: ti.template(), grad: ti.types.ndarray()):\n    for i, j in ti.ndrange(grad.shape[0], grad.shape[1] // 2):\n        field.grad[i, j][0] = grad[i, j * 2]\n        field.grad[i, j][1] = grad[i, j * 2 + 1]\n\n\ndef extract_model_state_dict(ckpt_path,\n                             model_name='model',\n                             prefixes_to_ignore=[]):\n    checkpoint = torch.load(ckpt_path, map_location='cpu')\n    checkpoint_ = {}\n    if 'state_dict' in checkpoint:  # if it's a pytorch-lightning checkpoint\n        checkpoint = checkpoint['state_dict']\n    for k, v in checkpoint.items():\n        if not k.startswith(model_name):\n            continue\n        k = k[len(model_name) + 1:]\n        for prefix in prefixes_to_ignore:\n            if k.startswith(prefix):\n                break\n        else:\n            checkpoint_[k] = v\n    return checkpoint_\n\n\ndef load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]):\n    if not ckpt_path:\n        return\n    model_dict = model.state_dict()\n    checkpoint_ = extract_model_state_dict(ckpt_path, model_name,\n                                           prefixes_to_ignore)\n    model_dict.update(checkpoint_)\n    model.load_state_dict(model_dict)\n\ndef depth2img(depth):\n    depth = (depth - depth.min()) / (depth.max() - depth.min())\n    depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8),\n                                  cv2.COLORMAP_TURBO)\n\n    return depth_img"
  },
  {
    "path": "taichi_modules/volume_render_test.py",
    "content": "import taichi as ti\n\n\n@ti.kernel\ndef composite_test(\n    sigmas: ti.types.ndarray(ndim=2), rgbs: ti.types.ndarray(ndim=3),\n    deltas: ti.types.ndarray(ndim=2), ts: ti.types.ndarray(ndim=2),\n    hits_t: ti.types.ndarray(ndim=2),\n    alive_indices: ti.types.ndarray(ndim=1), T_threshold: float,\n    N_eff_samples: ti.types.ndarray(ndim=1),\n    opacity: ti.types.ndarray(ndim=1),\n    depth: ti.types.ndarray(ndim=1), rgb: ti.types.ndarray(ndim=2)):\n\n    for n in alive_indices:\n        samples = N_eff_samples[n]\n        if samples == 0:\n            alive_indices[n] = -1\n        else:\n            r = alive_indices[n]\n\n            T = 1 - opacity[r]\n\n            rgb_temp_0 = 0.0\n            rgb_temp_1 = 0.0\n            rgb_temp_2 = 0.0\n            depth_temp = 0.0\n            opacity_temp = 0.0\n\n            for s in range(samples):\n                a = 1.0 - ti.exp(-sigmas[n, s] * deltas[n, s])\n                w = a * T\n\n                rgb_temp_0 += w * rgbs[n, s, 0]\n                rgb_temp_1 += w * rgbs[n, s, 1]\n                rgb_temp_2 += w * rgbs[n, s, 2]\n                depth[r] += w * ts[n, s]\n                opacity[r] += w\n                T *= 1.0 - a\n\n                if T <= T_threshold:\n                    alive_indices[n] = -1\n                    break\n\n            rgb[r, 0] += rgb_temp_0\n            rgb[r, 1] += rgb_temp_1\n            rgb[r, 2] += rgb_temp_2\n            depth[r] += depth_temp\n            opacity[r] += opacity_temp\n"
  },
  {
    "path": "taichi_modules/volume_train.py",
    "content": "import taichi as ti\nimport torch\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\nfrom .utils import (data_type, ti2torch, ti2torch_grad, torch2ti,\n                    torch2ti_grad, torch_type)\n\n\n@ti.kernel\ndef composite_train_fw_array(\n        sigmas: ti.types.ndarray(),\n        rgbs: ti.types.ndarray(),\n        deltas: ti.types.ndarray(),\n        ts: ti.types.ndarray(),\n        rays_a: ti.types.ndarray(),\n        T_threshold: float,\n        total_samples: ti.types.ndarray(),\n        opacity: ti.types.ndarray(),\n        depth: ti.types.ndarray(),\n        rgb: ti.types.ndarray(),\n        ws: ti.types.ndarray(),\n):\n\n    for n in opacity:\n        ray_idx = rays_a[n, 0]\n        start_idx = rays_a[n, 1]\n        N_samples = rays_a[n, 2]\n\n        T = 1.0\n        samples = 0\n        while samples < N_samples:\n            s = start_idx + samples\n            a = 1.0 - ti.exp(-sigmas[s] * deltas[s])\n            w = a * T\n\n            rgb[ray_idx, 0] += w * rgbs[s, 0]\n            rgb[ray_idx, 1] += w * rgbs[s, 1]\n            rgb[ray_idx, 2] += w * rgbs[s, 2]\n            depth[ray_idx] += w * ts[s]\n            opacity[ray_idx] += w\n            ws[s] = w\n            T *= 1.0 - a\n\n            # if T<T_threshold:\n            #     break\n            samples += 1\n\n        total_samples[ray_idx] = samples\n\n\n@ti.kernel\ndef composite_train_fw(sigmas: ti.template(), rgbs: ti.template(),\n                       deltas: ti.template(), ts: ti.template(),\n                       rays_a: ti.template(), T_threshold: float,\n                       T: ti.template(), total_samples: ti.template(),\n                       opacity: ti.template(), depth: ti.template(),\n                       rgb: ti.template(), ws: ti.template()):\n\n    ti.loop_config(block_dim=256)\n    for n in opacity:\n        ray_idx = ti.i32(rays_a[n, 0])\n        start_idx = ti.i32(rays_a[n, 1])\n        N_samples = ti.i32(rays_a[n, 2])\n\n        rgb[ray_idx, 0] = 0.0\n        rgb[ray_idx, 1] = 0.0\n        rgb[ray_idx, 2] = 0.0\n        depth[ray_idx] = 0.0\n        opacity[ray_idx] = 0.0\n        total_samples[ray_idx] = 0\n\n        T[start_idx] = 1.0\n        # T_ = 1.0\n        # samples = 0\n        # while samples<N_samples:\n        for sample_ in range(N_samples):\n            # T_ = T[ray_idx, samples]\n            s = start_idx + sample_\n            T_ = T[s]\n            if T_ > T_threshold:\n                # s = start_idx + sample_\n                a = 1.0 - ti.exp(-sigmas[s] * deltas[s])\n                w = a * T_\n                rgb[ray_idx, 0] += w * rgbs[s, 0]\n                rgb[ray_idx, 1] += w * rgbs[s, 1]\n                rgb[ray_idx, 2] += w * rgbs[s, 2]\n                depth[ray_idx] += w * ts[s]\n                opacity[ray_idx] += w\n                ws[s] = w\n                # T_ *= (1.0-a)\n                T[s + 1] = T_ * (1.0 - a)\n                # if T[s+1]>=T_threshold:\n                # samples += 1\n                total_samples[ray_idx] += 1\n            else:\n                T[s + 1] = 0.0\n\n        # total_samples[ray_idx] = N_samples\n\n\n@ti.kernel\ndef check_value(\n        fields: ti.template(),\n        array: ti.types.ndarray(),\n        checker: ti.types.ndarray(),\n):\n    for I in ti.grouped(array):\n        if fields[I] == array[I]:\n            checker[I] = 1\n\n\nclass VolumeRendererTaichi(torch.nn.Module):\n\n    def __init__(self, batch_size=8192, data_type=data_type):\n        super(VolumeRendererTaichi, self).__init__()\n        # samples level\n        self.sigmas_fields = ti.field(dtype=data_type,\n                                      shape=(batch_size * 1024, ),\n                                      needs_grad=True)\n        self.rgbs_fields = ti.field(dtype=data_type,\n                                    shape=(batch_size * 1024, 3),\n                                    needs_grad=True)\n        self.deltas_fields = ti.field(dtype=data_type,\n                                      shape=(batch_size * 1024, ),\n                                      needs_grad=True)\n        self.ts_fields = ti.field(dtype=data_type,\n                                  shape=(batch_size * 1024, ),\n                                  needs_grad=True)\n        self.ws_fields = ti.field(dtype=data_type,\n                                  shape=(batch_size * 1024, ),\n                                  needs_grad=True)\n        self.T = ti.field(dtype=data_type,\n                          shape=(batch_size * 1024),\n                          needs_grad=True)\n\n        # rays level\n        self.rays_a_fields = ti.field(dtype=ti.i64, shape=(batch_size, 3))\n        self.total_samples_fields = ti.field(dtype=ti.i64,\n                                             shape=(batch_size, ))\n        self.opacity_fields = ti.field(dtype=data_type,\n                                       shape=(batch_size, ),\n                                       needs_grad=True)\n        self.depth_fields = ti.field(dtype=data_type,\n                                     shape=(batch_size, ),\n                                     needs_grad=True)\n        self.rgb_fields = ti.field(dtype=data_type,\n                                   shape=(batch_size, 3),\n                                   needs_grad=True)\n\n        # preallocate tensor\n        self.register_buffer('total_samples',\n                             torch.zeros(batch_size, dtype=torch.int64))\n        self.register_buffer('rgb', torch.zeros(batch_size,\n                                                3,\n                                                dtype=torch_type))\n        self.register_buffer('opacity',\n                             torch.zeros(batch_size, dtype=torch_type))\n        self.register_buffer('depth', torch.zeros(batch_size,\n                                                  dtype=torch_type))\n        self.register_buffer('ws',\n                             torch.zeros(batch_size * 1024, dtype=torch_type))\n\n        self.register_buffer('sigma_grad',\n                             torch.zeros(batch_size * 1024, dtype=torch_type))\n        self.register_buffer(\n            'rgb_grad', torch.zeros(batch_size * 1024, 3, dtype=torch_type))\n\n        class _module_function(torch.autograd.Function):\n\n            @staticmethod\n            @custom_fwd(cast_inputs=torch_type)\n            def forward(ctx, sigmas, rgbs, deltas, ts, rays_a, T_threshold):\n                # If no output gradient is provided, no need to\n                # automatically materialize it as torch.zeros.\n\n                ctx.T_threshold = T_threshold\n                ctx.samples_size = sigmas.shape[0]\n\n                ws = self.ws[:sigmas.shape[0]]\n\n                torch2ti(self.sigmas_fields, sigmas.contiguous())\n                torch2ti(self.rgbs_fields, rgbs.contiguous())\n                torch2ti(self.deltas_fields, deltas.contiguous())\n                torch2ti(self.ts_fields, ts.contiguous())\n                torch2ti(self.rays_a_fields, rays_a.contiguous())\n                composite_train_fw(self.sigmas_fields, self.rgbs_fields,\n                                   self.deltas_fields, self.ts_fields,\n                                   self.rays_a_fields, T_threshold, self.T,\n                                   self.total_samples_fields,\n                                   self.opacity_fields, self.depth_fields,\n                                   self.rgb_fields, self.ws_fields)\n                ti2torch(self.total_samples_fields, self.total_samples)\n                ti2torch(self.opacity_fields, self.opacity)\n                ti2torch(self.depth_fields, self.depth)\n                ti2torch(self.rgb_fields, self.rgb)\n\n\n                return self.total_samples.sum(\n                ), self.opacity, self.depth, self.rgb, ws\n\n            @staticmethod\n            @custom_bwd\n            def backward(ctx, dL_dtotal_samples, dL_dopacity, dL_ddepth,\n                         dL_drgb, dL_dws):\n\n                T_threshold = ctx.T_threshold\n                samples_size = ctx.samples_size\n\n                sigma_grad = self.sigma_grad[:samples_size].contiguous()\n                rgb_grad = self.rgb_grad[:samples_size].contiguous()\n\n                self.zero_grad()\n\n                torch2ti_grad(self.opacity_fields, dL_dopacity.contiguous())\n                torch2ti_grad(self.depth_fields, dL_ddepth.contiguous())\n                torch2ti_grad(self.rgb_fields, dL_drgb.contiguous())\n                torch2ti_grad(self.ws_fields, dL_dws.contiguous())\n                composite_train_fw.grad(self.sigmas_fields, self.rgbs_fields,\n                                        self.deltas_fields, self.ts_fields,\n                                        self.rays_a_fields, T_threshold,\n                                        self.T, self.total_samples_fields,\n                                        self.opacity_fields, self.depth_fields,\n                                        self.rgb_fields, self.ws_fields)\n                ti2torch_grad(self.sigmas_fields, sigma_grad)\n                ti2torch_grad(self.rgbs_fields, rgb_grad)\n\n                return sigma_grad, rgb_grad, None, None, None, None\n\n        self._module_function = _module_function\n\n    def zero_grad(self):\n        self.sigmas_fields.grad.fill(0.)\n        self.rgbs_fields.grad.fill(0.)\n        self.T.grad.fill(0.)\n\n\n    def forward(self, sigmas, rgbs, deltas, ts, rays_a, T_threshold):\n        return self._module_function.apply(sigmas, rgbs, deltas, ts, rays_a,\n                                           T_threshold)\n"
  },
  {
    "path": "tets/README.md",
    "content": "Place the tet grid files in this folder. \nWe provide a few example grids. See the main README.md for a download link.\n\nYou can also generate your own grids using https://github.com/crawforddoran/quartet \nPlease see the `generate_tets.py` script for an example. \n\n"
  },
  {
    "path": "tets/generate_tets.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport numpy as np\n\n\n'''\nThis code segment shows how to use Quartet: https://github.com/crawforddoran/quartet, \nto generate a tet grid \n1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet`\n2) Run the function below to generate a file `cube_32_tet.tet`\n'''\n\ndef generate_tetrahedron_grid_file(res=32, root='..'):\n    frac = 1.0 / res\n    command = 'cd %s/quartet; ' % (root) + \\\n                './quartet_release meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res)\n    os.system(command)\n\n\n'''\nThis code segment shows how to convert from a quartet .tet file to compressed npz file\n'''\ndef convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile = '32_tets.npz'):\n\n    file1 = open(quartetfile, 'r')\n    header = file1.readline()\n    numvertices = int(header.split(\" \")[1])\n    numtets     = int(header.split(\" \")[2])\n    print(numvertices, numtets)\n\n    # load vertices\n    vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices)\n    vertices = vertices - 0.5\n    print(vertices.shape, vertices.min(), vertices.max())\n\n    # load indices\n    indices = np.loadtxt(quartetfile, dtype=int, skiprows=1+numvertices, max_rows=numtets)\n    print(indices.shape)\n\n    np.savez_compressed(npzfile, vertices=vertices, indices=indices)\n\nif __name__ == '__main__':\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--res', type=int, default=32)\n    parser.add_argument('--root', type=str, default='..')\n    args = parser.parse_args()\n\n    generate_tetrahedron_grid_file(res=args.res, root=args.root)\n    convert_from_quartet_to_npz(quartetfile=os.path.join(args.root, 'quartet', 'meshes', f'cube_{args.res}.000000_tet.tet'), npzfile=os.path.join('./tets', f'{args.res}_tets.npz'))"
  }
]