[
  {
    "path": ".github/workflows/pypi-publish.yml",
    "content": "name: Upload Python Package\n\non:\n  release:\n    types: [created]\n  workflow_dispatch:\n\njobs:\n  deploy:\n\n    runs-on: ubuntu-latest\n\n    environment:\n      name: pypi\n      url: https://pypi.org/project/threefiner/\n    permissions:\n      id-token: write  # IMPORTANT: this permission is mandatory for trusted publishing\n\n    steps:\n    - uses: actions/checkout@v3\n    - name: Set up Python\n      uses: actions/setup-python@v3\n      with:\n        python-version: '3.10'\n    # prepare distributions in dist/\n    - name: Install dependencies and Build\n      run: |\n        python -m pip install --upgrade pip\n        pip install setuptools wheel\n        python setup.py sdist bdist_wheel\n    # publish by trusted publishers (need to first setup in pypi.org projects-manage-publishing!)\n    # ref: https://github.com/marketplace/actions/pypi-publish\n    - name: Publish package distributions to PyPI\n      uses: pypa/gh-action-pypi-publish@release/v1"
  },
  {
    "path": ".gitignore",
    "content": "__pycache__\ntmp*\ndata_*\nlogs\nlogs*\nvideos*\n\n\n*.egg-info\nbuild/"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "recursive-include threefiner/lights *"
  },
  {
    "path": "gradio_app.py",
    "content": "import os\nimport tyro\nimport tqdm\nimport torch\nimport gradio as gr\n\nimport kiui\n\nfrom threefiner.opt import config_defaults, config_doc, check_options\nfrom threefiner.gui import GUI\n\nGRADIO_SAVE_PATH_MESH = 'gradio_output.glb'\nGRADIO_SAVE_PATH_VIDEO = 'gradio_output.mp4'\n\nopt = tyro.cli(tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc))\n\n# hacks for not loading mesh at initialization\nopt.save = GRADIO_SAVE_PATH_MESH\nopt.prompt = ''\nopt.text_dir = True\nopt.front_dir = '+z'\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\ngui = GUI(opt)\n\n# process function\ndef process(input_model, input_text, input_dir, iters):\n\n    # set front facing direction (map from gradio model3D's mysterious coordinate system to OpenGL...)\n    opt.text_dir = True\n    if input_dir == 'front':\n        opt.front_dir = '-z'\n    elif input_dir == 'back':\n        opt.front_dir = '+z'\n    elif input_dir == 'left':\n        opt.front_dir = '+x'\n    elif input_dir == 'right':\n        opt.front_dir = '-x'\n    elif input_dir == 'up':\n        opt.front_dir = '+y'\n    elif input_dir == 'down':\n        opt.front_dir = '-y'\n    else:\n        # turn off text_dir\n        opt.text_dir = False\n        opt.front_dir = '+z'\n    \n    # set mesh path\n    opt.mesh = input_model\n\n    # load mesh!\n    gui.renderer = gui.renderer_class(opt, device).to(device)\n\n    # set prompt\n    gui.prompt = opt.positive_prompt + ', ' + input_text\n\n    # train\n    gui.prepare_train() # update optimizer and prompt embeddings\n    for i in tqdm.trange(iters):\n        gui.train_step()\n\n    # save mesh & video\n    gui.save_model(GRADIO_SAVE_PATH_MESH)\n    gui.save_model(GRADIO_SAVE_PATH_VIDEO)\n    \n    # return 3d model & video\n    return GRADIO_SAVE_PATH_MESH, GRADIO_SAVE_PATH_VIDEO\n\n# gradio UI\nblock = gr.Blocks().queue()\nwith block:\n    gr.Markdown(\"\"\"\n    ## Threefiner: Text-guided mesh refinement.\n    \"\"\")\n\n    with gr.Row(variant='panel'):\n        with gr.Column(scale=1):\n            input_model = gr.Model3D(label=\"input mesh\")\n            input_text = gr.Text(label=\"prompt\")\n            input_dir = gr.Radio(['front', 'back', 'left', 'right', 'up', 'down'], label=\"front-facing direction\")\n            iters = gr.Slider(minimum=100, maximum=1000, step=100, value=400, label=\"training iterations\")\n            button_gen = gr.Button(\"Refine!\")\n        \n        with gr.Column(scale=1):\n            output_model = gr.Model3D(label=\"output mesh\")\n            output_video = gr.Video(label=\"output video\")\n\n        button_gen.click(process, inputs=[input_model, input_text, input_dir, iters], outputs=[output_model, output_video])\n    \nblock.launch(server_name=\"0.0.0.0\", share=True)"
  },
  {
    "path": "readme.md",
    "content": "<p align=\"center\">\n    <picture>\n    <img alt=\"logo\" src=\"assets/threefiner_icon.png\" width=\"20%\">\n    </picture>\n    </br>\n    <b>Threefiner</b>\n</p>\n\nAn interface for text-guided mesh refinement.\n\nhttps://github.com/3DTopia/threefiner/assets/25863658/a4abe725-b542-4a4a-a6d4-e4c4821f7d96\n\n### Features\n* **Mesh in, mesh out**: we support `ply` with vertex colors, `obj`, and single object `glb/gltf` with textures!\n* **Easy to use**: both a CLI and a GUI is available.\n* **Performant**: Refine your texture in 1 minute with Deepfloyd-IF-II.\n\n### Install\n\nWe rely on `torch` and several CUDA extensions, please make sure you install them correctly first!\n```bash\n# tiny-cuda-nn\npip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch\n\n# nvdiffrast\npip install git+https://github.com/NVlabs/nvdiffrast\n\n# [optional, will use pysdf if unavailable] cubvh:\npip install git+https://github.com/ashawkey/cubvh\n```\n\nTo use [Deepfloyd-IF](https://github.com/deep-floyd/IF) models, please log in to your huggingface and accept the [license](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0).\n\nTo install this package:\n```bash\n# install from pypi\npip install threefiner\n\n# install from github\npip install git+https://github.com/3DTopia/threefiner\n\n# local install\ngit clone https://github.com/3DTopia/threefiner\ncd threefiner\npip install .\n```\n\n### Usage\n\n```bash\n### command line interface\nthreefiner --help\n# this is short for\npython -m threefiner.cli --help\n\n### refine a coarse mesh ('input.obj') using Stable-diffusion and save to 'logs/hamburger.glb'\nthreefiner sd --mesh input.obj --prompt 'a hamburger' --outdir logs --save hamburger.glb\n\n### if the initial texture is good, we recommend using IF2 for refinement.\n# by default, it will save to './name_fine.glb'\nthreefiner if2 --mesh name.glb --prompt 'description'\n\n### if the initial texture is not good, we recommend using SD or IF first.\nthreefiner sd --mesh name.glb --prompt 'description'\nthreefiner if --mesh name.glb --prompt 'description'\n\n### if the initial geometry is good, you can fix the geometry.\nthreefiner sd_fixgeo --mesh name.glb --prompt 'description'\nthreefiner if_fixgeo --mesh name.glb --prompt 'description'\nthreefiner if2_fixgeo --mesh name.glb --prompt 'description'\n\n### advanced\n# directional text prompt (append front/side/back view in text prompt)\n# you need to know the mesh's front facing direction and specify it by '--front_dir'\n# we use the OpenGL coordinate system, i.e., +x is right, +y is up, +z is front (more details: https://kit.kiui.moe/camera/)\n# clock-wise rotation can be specified per 90 degree, e.g., +z1, -y2\nthreefiner if2 --mesh input.glb --prompt 'description' --text_dir --front_dir='+z'\n\n# adjust training iterations\nthreefiner if2 --mesh input.glb --prompt 'description' --iters 1000\n\n# explicitly fix the geometry and only refine texture\nthreefiner if2 --fix-geo --geom_mode mesh --mesh input.glb --prompt 'description' # equals if2_fixgeo\n\n# open a GUI to visualize the training progress (needs a desktop)\nthreefiner if2 --mesh input.glb --prompt 'description' --gui\n```\n\nGradio demo:\n```bash\n# requires gradio 4\npython gradio_app.py if2\n```\n\nFor more examples, please see [scripts](./scripts/).\n\n### Q&A\n\n* **How to make sure `--front_dir` for your model?**\n    \n    You may first visualize it in a 3D viewer that follows OpenGL coordinate system:\n    <p align=\"center\">\n        <picture>\n        <img alt=\"example_front_dir\" src=\"assets/coord.jpg\" width=\"50%\">\n        </picture>\n    </p>\n    The chair is facing down the Y axis (Green), so we can use `--front_dir=\"-y\"` to rectify it to face +Z axis (Blue).\n\n* **fatal error: EGL/egl.h: No such file or directory**\n\n    By default, we use the OpenGL rasterizer. This error means there is no OpenGL installation, which is often the case for headless servers.\n    It's recommended to install OpenGL (along with NVIDIA driver) as it brings better performance.\n    Otherwise, you can append `--force_cuda_rast` to use the CUDA rasterizer instead.\n\n## Acknowledgement\n\nThis work is built on many amazing research works and open-source projects, thanks a lot to all the authors for sharing!\n\n- SDS `guidance` classes are based on [diffusers](https://github.com/huggingface/diffusers).\n- `diffmc` geometry is based on [diso](https://github.com/SarahWeiii/diso).\n- `mesh` geometry is based on [nerf2mesh](https://github.com/ashawkey/nerf2mesh).\n- Texture encoding is based on [tinycudann](https://github.com/NVlabs/tiny-cuda-nn).\n- Mesh renderer is based on [nvdiffrast](https://github.com/NVlabs/nvdiffrast).\n- GUI is based on [dearpygui](https://github.com/hoffstadt/DearPyGui).\n- The coarse models used in demo are generated by [Genie](https://lumalabs.ai/genie?view=create) and [3DTopia](https://github.com/3DTopia/3DTopia).\n"
  },
  {
    "path": "scripts/run.sh",
    "content": "export CUDA_VISIBLE_DEVICES=0\n\n# the mesh is already with good initial texture, just refine it using IF2\nthreefiner if2 --mesh data/car.glb --prompt 'a red car' --outdir logs --save car_fine.glb --text_dir --front_dir='+x'\n\n# the mesh is coarse, using SD for diverse texture generation and IF2 for refinement\nthreefiner sd --mesh data/chair.ply --prompt 'a swivel chair' --outdir logs --save chair_coarse.glb --text_dir --front_dir='-y'\nthreefiner if2 --mesh logs/chair_coarse.glb --prompt 'a swivel chair' --outdir logs --save chair_fine.glb --text_dir --front_dir='+z'\n"
  },
  {
    "path": "scripts/test_all.sh",
    "content": "export CUDA_VISIBLE_DEVICES=1\n\n# geom_mode\nthreefiner if2 --geom_mode diffmc --save car_diffmc.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'\nthreefiner if2 --geom_mode mesh --save car_mesh.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'\nthreefiner if2 --geom_mode pbr_diffmc --save car_pbr_diffmc.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'\nthreefiner if2 --geom_mode pbr_mesh --save car_pbr_mesh.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'\n\n# tex_mode\nthreefiner if2 --tex_mode mlp --save car_mlp.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'\nthreefiner if2 --tex_mode triplane --save car_triplane.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'\n\n# guidance mode\nthreefiner sd --save car_SD.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'\nthreefiner if --save car_IF.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\n\nsetup(\n  name = 'threefiner',\n  packages = find_packages(exclude=[]),\n  include_package_data = True,\n  entry_points={\n    # CLI tools\n    'console_scripts': [\n      'threefiner = threefiner.cli:main'\n    ],\n  },\n  version = '0.1.2',\n  license='MIT',\n  description = 'Threefiner: a text-guided mesh refiner',\n  author = 'kiui',\n  author_email = 'ashawkey1999@gmail.com',\n  long_description=open(\"readme.md\", encoding=\"utf-8\").read(),\n  long_description_content_type = 'text/markdown',\n  url = 'https://github.com/3DTopia/threefiner',\n  keywords = [\n    'generative mesh refinement',\n  ],\n  install_requires=[\n    'tyro',\n    'tqdm',\n    'rich',\n    'ninja',\n    'numpy',\n    'pandas',\n    'matplotlib',\n    'opencv-python',\n    'imageio',\n    'imageio-ffmpeg',\n    'scipy',\n    'scikit-learn',\n    'torch',\n    'einops',\n    'huggingface_hub',\n    'diffusers',\n    'accelerate',\n    'transformers',\n    \"sentencepiece\", # required by deepfloyd-if T5 encoder\n    'plyfile',\n    'pygltflib',\n    'xatlas',\n    'trimesh',\n    'PyMCubes',\n    'pymeshlab',\n    \"pysdf\",\n    \"diso\",\n    \"envlight\",\n    'dearpygui',\n    'kiui >= 0.2.1',\n  ],\n  classifiers=[\n    'Topic :: Scientific/Engineering :: Artificial Intelligence',\n    'License :: OSI Approved :: MIT License',\n    'Programming Language :: Python :: 3',\n  ],\n)\n"
  },
  {
    "path": "threefiner/__init__.py",
    "content": ""
  },
  {
    "path": "threefiner/cli.py",
    "content": "import os\nimport tyro\nfrom threefiner.opt import config_defaults, config_doc, check_options\nfrom threefiner.gui import GUI\n\n\ndef main():    \n    opt = tyro.cli(tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc))\n    opt = check_options(opt)\n    gui = GUI(opt)\n    if gui.gui:\n        gui.render()\n    else:\n        gui.train(opt.iters)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "threefiner/gui.py",
    "content": "import os\nimport tqdm\nimport random\nimport imageio\nimport numpy as np\n\nimport torch\nimport torch.nn.functional as F\n\nGUI_AVAILABLE = True\ntry:\n    import dearpygui.dearpygui as dpg\nexcept Exception as e:\n    GUI_AVAILABLE = False\n\nimport kiui\nfrom kiui.cam import orbit_camera, OrbitCamera\nfrom kiui.mesh_utils import laplacian_smooth_loss, normal_consistency\n\nfrom threefiner.opt import Options\n\nclass GUI:\n    def __init__(self, opt: Options):\n        self.opt = opt  # shared with the trainer's opt to support in-place modification of rendering parameters.\n        if not GUI_AVAILABLE and opt.gui:\n            print(f'[WARN] cannot import dearpygui, assume without --gui')\n        self.gui = opt.gui and GUI_AVAILABLE # enable gui\n        self.W = opt.W\n        self.H = opt.H\n        self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)\n\n        self.mode = \"image\"\n        self.seed = \"random\"\n\n        self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)\n        self.need_update = True  # update buffer_image\n\n        self.save_path = os.path.join(self.opt.outdir, self.opt.save)\n        os.makedirs(self.opt.outdir, exist_ok=True)\n\n        # models\n        self.device = torch.device(\"cuda\")\n\n        self.guidance = None\n\n        # renderer\n        if self.opt.geom_mode == 'mesh':\n            from threefiner.renderer.mesh_renderer import Renderer\n        elif self.opt.geom_mode == 'diffmc':\n            from threefiner.renderer.diffmc_renderer import Renderer\n        elif self.opt.geom_mode == 'pbr_mesh':\n            from threefiner.renderer.pbr_mesh_renderer import Renderer\n        elif self.opt.geom_mode == 'pbr_diffmc':\n            from threefiner.renderer.pbr_diffmc_renderer import Renderer\n        else:\n            raise NotImplementedError(f\"unknown geometry mode: {self.opt.geom_mode}\")\n\n        self.renderer_class = Renderer\n        \n        if self.opt.mesh is None:\n            self.renderer = None\n        else:\n            self.renderer = Renderer(opt, self.device).to(self.device)\n\n        # input prompt\n        self.prompt = self.opt.prompt\n        self.negative_prompt = \"\"\n\n        if self.opt.positive_prompt is not None:\n            self.prompt = self.opt.positive_prompt + ', ' + self.prompt\n        if self.opt.negative_prompt is not None:\n            self.negative_prompt = self.opt.negative_prompt\n        \n        # training stuff\n        self.training = False\n        self.optimizer = None\n        self.step = 0\n        self.train_steps = 1  # steps per rendering loop\n\n        if self.gui:\n            dpg.create_context()\n            self.register_dpg()\n            self.test_step()\n\n    def __del__(self):\n        if self.gui:\n            dpg.destroy_context()\n\n    def seed_everything(self):\n        try:\n            seed = int(self.seed)\n        except:\n            seed = np.random.randint(0, 1000000)\n\n        os.environ[\"PYTHONHASHSEED\"] = str(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed(seed)\n        \n        self.last_seed = seed\n\n    def prepare_train(self):\n\n        assert self.renderer is not None, 'no mesh loaded!'\n\n        self.step = 0\n\n        # setup training\n        self.optimizer = torch.optim.Adam(self.renderer.get_params())\n\n        # lazy load guidance model\n        if self.guidance is None:\n            print(f\"[INFO] loading guidance...\")\n            if self.opt.mode == 'SD':\n                from threefiner.guidance.sd_utils import StableDiffusion\n                self.guidance = StableDiffusion(self.device, vram_O=self.opt.vram_O)\n            elif self.opt.mode == 'SD_NFSD':\n                from threefiner.guidance.sd_nfsd_utils import StableDiffusion\n                self.guidance = StableDiffusion(self.device, vram_O=self.opt.vram_O)\n            elif self.opt.mode == 'SDCN':\n                from threefiner.guidance.sdcn_utils import StableDiffusionControlNet\n                self.guidance = StableDiffusionControlNet(self.device, vram_O=self.opt.vram_O)\n            elif self.opt.mode == 'IF':\n                from threefiner.guidance.if_utils import IF\n                self.guidance = IF(self.device, vram_O=self.opt.vram_O)\n            elif self.opt.mode == 'IF2':\n                from threefiner.guidance.if2_utils import IF2\n                self.guidance = IF2(self.device, vram_O=self.opt.vram_O)\n            elif self.opt.mode == 'IF2_NFSD':\n                from threefiner.guidance.if2_nfsd_utils import IF2\n                self.guidance = IF2(self.device, vram_O=self.opt.vram_O)\n            elif self.opt.mode == 'SD_ISM':\n                from threefiner.guidance.sd_ism_utils import StableDiffusion\n                self.guidance = StableDiffusion(self.device, vram_O=self.opt.vram_O)\n            elif self.opt.mode == 'IF2_ISM':\n                from threefiner.guidance.if2_ism_utils import IF2\n                self.guidance = IF2(self.device, vram_O=self.opt.vram_O)\n            else:\n                raise NotImplementedError(f\"unknown guidance mode {self.opt.mode}!\")\n            print(f\"[INFO] loaded guidance!\")\n\n        # prepare embeddings\n        with torch.no_grad():\n            self.guidance.get_text_embeds([self.prompt], [self.negative_prompt])\n\n           \n    def train_step(self):\n        starter = torch.cuda.Event(enable_timing=True)\n        ender = torch.cuda.Event(enable_timing=True)\n        starter.record()\n\n        self.renderer.train()\n\n        for _ in range(self.train_steps):\n\n            self.step += 1\n            step_ratio = min(1, self.step / self.opt.iters)\n\n            loss = 0\n\n            ### novel view (manual batch)\n            images = []\n            poses = []\n            normals = []\n            ori_images = []\n            vers, hors, radii = [], [], []\n            \n            for _ in range(self.opt.batch_size):\n\n                # render random view\n                ver = np.random.randint(-60, 30)\n                hor = np.random.randint(-180, 180)\n                radius = np.random.uniform() - 0.5 # [-0.5, 0.5]\n                pose = orbit_camera(ver, hor, self.opt.radius + radius)\n\n                vers.append(ver)\n                hors.append(hor)\n                radii.append(radius)\n                poses.append(pose)\n\n                # random render resolution\n                ssaa = min(2.0, max(0.125, 2 * np.random.random()))\n                out = self.renderer.render(pose, self.cam.perspective, self.opt.render_resolution, self.opt.render_resolution, ssaa=ssaa)\n\n                image = out[\"image\"] # [H, W, 3] in [0, 1]\n                image = image.permute(2,0,1).contiguous().unsqueeze(0) # [1, 3, H, W] in [0, 1]\n                images.append(image)\n\n                # mix_normal\n                if not self.opt.fix_geo and self.opt.mix_normal:\n                    normal = out['normal']\n                    normal = normal.permute(2,0,1).contiguous().unsqueeze(0)\n                    normals.append(normal)\n\n                # IF SR model requires the original rendering\n                if self.opt.mode in ['IF2', 'IF2_NFSD', 'IF2_ISM', 'SDCN']:\n                    out_mesh = self.renderer.render_mesh(pose, self.cam.perspective, self.opt.render_resolution, self.opt.render_resolution, ssaa=1)\n                    ori_image = out_mesh[\"image\"] # [H, W, 3] in [0, 1]\n                    ori_image = ori_image.permute(2,0,1).contiguous().unsqueeze(0)\n                    ori_images.append(ori_image)\n                    # ori_images.append(image.clone())\n\n            # guidance loss\n            guidance_input = {'pred_rgb': torch.cat(images, dim=0)}\n\n            if not self.opt.fix_geo and self.opt.mix_normal:\n                if random.random() > 0.5:\n                    ratio = random.random()\n                    guidance_input['pred_rgb'] = guidance_input['pred_rgb'] * ratio + torch.cat(normals, dim=0) * (1 - ratio)\n            \n            # guidance_input['step_ratio'] = step_ratio\n            if self.opt.mode in ['IF2', 'IF2_NFSD', 'IF2_ISM']:\n                guidance_input['ori_rgb'] = torch.cat(ori_images, dim=0)\n            if self.opt.mode == 'SDCN':\n                guidance_input['control_images'] = {'tile': torch.cat(ori_images, dim=0)}\n            if self.opt.text_dir:\n                guidance_input['vers'] = vers\n                guidance_input['hors'] = hors\n            \n            loss = loss + self.opt.lambda_sd * self.guidance.train_step(**guidance_input)\n            \n            # geom regularizations\n            if self.opt.geom_mode in ['diffmc', 'pbr_diffmc', 'mesh', 'pbr_mesh'] and not self.opt.fix_geo:\n                if self.opt.lambda_lap > 0:\n                    lap_loss = laplacian_smooth_loss(self.renderer.v, self.renderer.f)\n                    loss = loss + self.opt.lambda_lap * lap_loss\n                if self.opt.lambda_normal > 0:\n                    normal_loss = normal_consistency(self.renderer.v, self.renderer.f)\n                    loss = loss + self.opt.lambda_normal * normal_loss\n                if self.opt.geom_mode in ['mesh', 'pbr_mesh'] and self.opt.lambda_offsets > 0:\n                    offset_loss = (self.renderer.v_offsets ** 2).sum(-1).mean()\n                    loss = loss + self.opt.lambda_offsets * offset_loss\n            \n            # optimize step\n            loss.backward()\n            self.optimizer.step()\n            self.optimizer.zero_grad()\n\n            # for mesh geom_mode: peoriodically remesh\n            if self.opt.geom_mode in ['mesh', 'pbr_mesh'] and not self.opt.fix_geo:\n                if self.step > 0 and self.step % self.opt.remesh_interval == 0:\n                    self.renderer.remesh()\n                    # reset optimizer\n                    self.optimizer = torch.optim.Adam(self.renderer.get_params())\n\n        ender.record()\n        torch.cuda.synchronize()\n        t = starter.elapsed_time(ender)\n\n        self.need_update = True\n\n        if self.gui:\n            dpg.set_value(\"_log_train_time\", f\"{t:.4f}ms\")\n            dpg.set_value(\n                \"_log_train_log\",\n                f\"step = {self.step: 5d} (+{self.train_steps: 2d}) loss = {loss.item():.4f}\",\n            )\n\n    @torch.no_grad()\n    def test_step(self):\n        # ignore if no need to update\n        if not self.need_update:\n            return\n\n        starter = torch.cuda.Event(enable_timing=True)\n        ender = torch.cuda.Event(enable_timing=True)\n        starter.record()\n\n        # should update image\n        if self.need_update:\n            # render image\n            self.renderer.eval()\n\n            out = self.renderer.render(self.cam.pose, self.cam.perspective, self.H, self.W)\n\n            buffer_image = out[self.mode]  # [H, W, 3]\n\n            if self.mode in ['depth', 'alpha']:\n                buffer_image = buffer_image.repeat(1, 1, 3)\n                if self.mode == 'depth':\n                    buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20)\n\n            self.buffer_image = buffer_image.contiguous().clamp(0, 1).detach().cpu().numpy()\n\n            self.need_update = False\n\n        ender.record()\n        torch.cuda.synchronize()\n        t = starter.elapsed_time(ender)\n\n        if self.gui:\n            dpg.set_value(\"_log_infer_time\", f\"{t:.4f}ms ({int(1000/t)} FPS)\")\n            dpg.set_value(\n                \"_texture\", self.buffer_image\n            )  # buffer must be contiguous, else seg fault!\n\n    def save_model(self, save_path=None):\n\n        if save_path is None:\n            save_path = self.save_path\n\n        # export video\n        if save_path.endswith(\".mp4\"):\n            images = []\n            elevation = 0\n            azimuth = np.arange(0, 360, 3, dtype=np.int32) # front-->back-->front\n            for azi in tqdm.tqdm(azimuth):\n                pose = orbit_camera(elevation, azi, self.opt.radius)\n                out = self.renderer.render(pose, self.cam.perspective, self.opt.render_resolution, self.opt.render_resolution, ssaa=1)    \n                image = (out[\"image\"].detach().cpu().numpy() * 255).astype(np.uint8)\n                images.append(image)\n            images = np.stack(images, axis=0)\n            # ~4 seconds, 120 frames at 30 fps\n            imageio.mimwrite(save_path, images, fps=30, quality=8, macro_block_size=1)\n        # export mesh\n        else:\n            self.renderer.export_mesh(save_path, texture_resolution=self.opt.texture_resolution)\n\n        print(f\"[INFO] save model to {save_path}.\")\n\n    def register_dpg(self):\n        ### register texture\n\n        with dpg.texture_registry(show=False):\n            dpg.add_raw_texture(\n                self.W,\n                self.H,\n                self.buffer_image,\n                format=dpg.mvFormat_Float_rgb,\n                tag=\"_texture\",\n            )\n\n        ### register window\n\n        # the rendered image, as the primary window\n        with dpg.window(\n            tag=\"_primary_window\",\n            width=self.W,\n            height=self.H,\n            pos=[0, 0],\n            no_move=True,\n            no_title_bar=True,\n            no_scrollbar=True,\n        ):\n            # add the texture\n            dpg.add_image(\"_texture\")\n\n        # dpg.set_primary_window(\"_primary_window\", True)\n\n        # control window\n        with dpg.window(\n            label=\"Control\",\n            tag=\"_control_window\",\n            width=600,\n            height=self.H,\n            pos=[self.W, 0],\n            no_move=True,\n            no_title_bar=True,\n        ):\n            # button theme\n            with dpg.theme() as theme_button:\n                with dpg.theme_component(dpg.mvButton):\n                    dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))\n                    dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))\n                    dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))\n                    dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)\n                    dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)\n\n            # timer stuff\n            with dpg.group(horizontal=True):\n                dpg.add_text(\"Infer time: \")\n                dpg.add_text(\"no data\", tag=\"_log_infer_time\")\n\n            def callback_setattr(sender, app_data, user_data):\n                setattr(self, user_data, app_data)\n\n            # init stuff\n            with dpg.collapsing_header(label=\"Initialize\", default_open=True):\n\n                # seed stuff\n                def callback_set_seed(sender, app_data):\n                    self.seed = app_data\n                    self.seed_everything()\n\n                dpg.add_input_text(\n                    label=\"seed\",\n                    default_value=self.seed,\n                    on_enter=True,\n                    callback=callback_set_seed,\n                )\n\n                # input stuff\n                def callback_select_input(sender, app_data):\n                    # only one item\n                    for k, v in app_data[\"selections\"].items():\n                        dpg.set_value(\"_log_input\", k)\n                        self.load_input(v)\n\n                    self.need_update = True\n\n                with dpg.file_dialog(\n                    directory_selector=False,\n                    show=False,\n                    callback=callback_select_input,\n                    file_count=1,\n                    tag=\"file_dialog_tag\",\n                    width=700,\n                    height=400,\n                ):\n                    dpg.add_file_extension(\"Images{.jpg,.jpeg,.png}\")\n\n                with dpg.group(horizontal=True):\n                    dpg.add_button(\n                        label=\"input\",\n                        callback=lambda: dpg.show_item(\"file_dialog_tag\"),\n                    )\n                    dpg.add_text(\"\", tag=\"_log_input\")\n                \n                # prompt stuff\n            \n                dpg.add_input_text(\n                    label=\"prompt\",\n                    default_value=self.prompt,\n                    callback=callback_setattr,\n                    user_data=\"prompt\",\n                )\n\n                dpg.add_input_text(\n                    label=\"negative\",\n                    default_value=self.negative_prompt,\n                    callback=callback_setattr,\n                    user_data=\"negative_prompt\",\n                )\n\n                # save current model\n                with dpg.group(horizontal=True):\n                    dpg.add_text(\"Save: \")\n\n                    dpg.add_button(\n                        label=\"model\",\n                        tag=\"_button_save_model\",\n                        callback=self.save_model,\n                    )\n                    dpg.bind_item_theme(\"_button_save_model\", theme_button)\n\n                    dpg.add_input_text(\n                        label=\"\",\n                        default_value=self.save_path,\n                        callback=callback_setattr,\n                        user_data=\"save_path\",\n                    )\n\n            # training stuff\n            with dpg.collapsing_header(label=\"Train\", default_open=True):\n                # lr and train button\n                with dpg.group(horizontal=True):\n                    dpg.add_text(\"Train: \")\n\n                    def callback_train(sender, app_data):\n                        if self.training:\n                            self.training = False\n                            dpg.configure_item(\"_button_train\", label=\"start\")\n                        else:\n                            self.prepare_train()\n                            self.training = True\n                            dpg.configure_item(\"_button_train\", label=\"stop\")\n\n                    # dpg.add_button(\n                    #     label=\"init\", tag=\"_button_init\", callback=self.prepare_train\n                    # )\n                    # dpg.bind_item_theme(\"_button_init\", theme_button)\n\n                    dpg.add_button(\n                        label=\"start\", tag=\"_button_train\", callback=callback_train\n                    )\n                    dpg.bind_item_theme(\"_button_train\", theme_button)\n\n                with dpg.group(horizontal=True):\n                    dpg.add_text(\"\", tag=\"_log_train_time\")\n                    dpg.add_text(\"\", tag=\"_log_train_log\")\n\n            # rendering options\n            with dpg.collapsing_header(label=\"Rendering\", default_open=True):\n                # mode combo\n                def callback_change_mode(sender, app_data):\n                    self.mode = app_data\n                    self.need_update = True\n\n                dpg.add_combo(\n                    (\"image\", \"depth\", \"alpha\", \"normal\"),\n                    label=\"mode\",\n                    default_value=self.mode,\n                    callback=callback_change_mode,\n                )\n\n                # fov slider\n                def callback_set_fovy(sender, app_data):\n                    self.cam.fovy = np.deg2rad(app_data)\n                    self.need_update = True\n\n                dpg.add_slider_int(\n                    label=\"FoV (vertical)\",\n                    min_value=1,\n                    max_value=120,\n                    format=\"%d deg\",\n                    default_value=np.rad2deg(self.cam.fovy),\n                    callback=callback_set_fovy,\n                )\n\n        ### register camera handler\n\n        def callback_camera_drag_rotate_or_draw_mask(sender, app_data):\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            dx = app_data[1]\n            dy = app_data[2]\n\n            self.cam.orbit(dx, dy)\n            self.need_update = True\n\n        def callback_camera_wheel_scale(sender, app_data):\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            delta = app_data\n\n            self.cam.scale(delta)\n            self.need_update = True\n\n        def callback_camera_drag_pan(sender, app_data):\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            dx = app_data[1]\n            dy = app_data[2]\n\n            self.cam.pan(dx, dy)\n            self.need_update = True\n\n        with dpg.handler_registry():\n            # for camera moving\n            dpg.add_mouse_drag_handler(\n                button=dpg.mvMouseButton_Left,\n                callback=callback_camera_drag_rotate_or_draw_mask,\n            )\n            dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)\n            dpg.add_mouse_drag_handler(\n                button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan\n            )\n\n        dpg.create_viewport(\n            title=\"Threefiner\",\n            width=self.W + 600,\n            height=self.H + (45 if os.name == \"nt\" else 0),\n            resizable=False,\n        )\n\n        ### global theme\n        with dpg.theme() as theme_no_padding:\n            with dpg.theme_component(dpg.mvAll):\n                # set all padding to 0 to avoid scroll bar\n                dpg.add_theme_style(\n                    dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core\n                )\n                dpg.add_theme_style(\n                    dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core\n                )\n                dpg.add_theme_style(\n                    dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core\n                )\n\n        dpg.bind_item_theme(\"_primary_window\", theme_no_padding)\n\n        dpg.setup_dearpygui()\n\n        ### register a larger font\n        # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf\n        if os.path.exists(\"LXGWWenKai-Regular.ttf\"):\n            with dpg.font_registry():\n                with dpg.font(\"LXGWWenKai-Regular.ttf\", 18) as default_font:\n                    dpg.bind_font(default_font)\n\n        # dpg.show_metrics()\n\n        dpg.show_viewport()\n\n    def render(self):\n        assert self.gui\n        while dpg.is_dearpygui_running():\n            # update texture every frame\n            if self.training:\n                self.train_step()\n            self.test_step()\n            dpg.render_dearpygui_frame()\n    \n    # no gui mode\n    def train(self, iters=500):\n        if iters > 0:\n            self.prepare_train()\n            for i in tqdm.trange(iters):\n                self.train_step()\n        # save\n        self.save_model()"
  },
  {
    "path": "threefiner/guidance/__init__.py",
    "content": ""
  },
  {
    "path": "threefiner/guidance/if2_ism_utils.py",
    "content": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    IFPipeline,\n    IFSuperResolutionPipeline,\n)\n\n\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef invert_noise(scheduler, noisy_samples, noise, timesteps):\n    alphas_cumprod = scheduler.alphas_cumprod.to(device=noisy_samples.device, dtype=noisy_samples.dtype)\n    timesteps = timesteps.to(noisy_samples.device)\n\n    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5\n    sqrt_alpha_prod = sqrt_alpha_prod.flatten()\n    while len(sqrt_alpha_prod.shape) < len(noisy_samples.shape):\n        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)\n\n    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5\n    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()\n    while len(sqrt_one_minus_alpha_prod.shape) < len(noisy_samples.shape):\n        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)\n    \n    original_samples = 1 / sqrt_alpha_prod * (noisy_samples - sqrt_one_minus_alpha_prod * noise)\n    return original_samples\n\n\nclass IF2(nn.Module):\n    def __init__(\n        self,\n        device,\n        fp16=True,\n        vram_O=False,\n        model_key = \"DeepFloyd/IF-II-M-v1.0\",\n        t_range=[0.02, 0.50],\n    ):\n        super().__init__()\n\n        self.device = device\n        self.model_key = model_key\n        self.dtype = torch.float16 if fp16 else torch.float32\n\n        # Create model\n        pipe = IFSuperResolutionPipeline.from_pretrained(\n            model_key, variant=\"fp16\", torch_dtype=torch.float16, \n            watermarker=None, safety_checker=None, requires_safety_checker=False,\n        )\n\n        if vram_O:\n            pipe.unet.to(memory_format=torch.channels_last)\n            pipe.enable_attention_slicing(1)\n            # pipe.enable_model_cpu_offload()\n        else:\n            pipe.to(device)\n\n        self.unet = pipe.unet\n        self.tokenizer = pipe.tokenizer\n        self.text_encoder = pipe.text_encoder\n\n        self.scheduler = pipe.scheduler\n        self.image_noising_scheduler = pipe.image_noising_scheduler\n\n        self.pipe = pipe\n\n        self.num_train_timesteps = self.scheduler.config.num_train_timesteps\n        self.min_step = int(self.num_train_timesteps * t_range[0])\n        self.max_step = int(self.num_train_timesteps * t_range[1])\n        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience\n\n        self.embeddings = {}\n        \n\n    @torch.no_grad()\n    def get_text_embeds(self, prompts, negative_prompts):\n        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]\n        neg_embeds = self.encode_text(negative_prompts)\n        null_embeds = self.encode_text([\"\"])\n        self.embeddings['pos'] = pos_embeds\n        self.embeddings['neg'] = neg_embeds\n        self.embeddings['null'] = null_embeds\n\n        # directional embeddings\n        for d in ['front', 'side', 'back']:\n            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])\n            self.embeddings[d] = embeds\n        \n    \n    def encode_text(self, prompt):\n        # prompt: [str]\n        prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)\n        inputs = self.tokenizer(prompt, padding=\"max_length\", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]\n        return embeddings\n\n    def train_step(\n        self,\n        pred_rgb,\n        ori_rgb,\n        step_ratio=None,\n        guidance_scale=5,\n        vers=None, hors=None,\n        delta_t=50, delta_s=200,\n    ):\n        \n        batch_size = pred_rgb.shape[0]\n        pred_rgb = pred_rgb.to(self.dtype)\n        ori_rgb = ori_rgb.to(self.dtype)\n\n        images = F.interpolate(pred_rgb, (256, 256), mode=\"bilinear\", align_corners=False) * 2 - 1\n\n        with torch.no_grad():\n            max_t = torch.full((batch_size,), self.max_step, dtype=torch.long, device=self.device)\n\n            # images_upscaled = images.clone()\n            images_upscaled = F.interpolate(ori_rgb, (256, 256), mode=\"bilinear\", align_corners=False).clamp(0, 1) * 2 - 1\n            noise = torch.randn_like(images_upscaled)\n            images_upscaled = self.image_noising_scheduler.add_noise(images_upscaled, noise, max_t)\n            \n            if step_ratio is not None:\n                # dreamtime-like\n                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)\n                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)\n                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)\n            else:\n                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)\n\n            # w(t), sigma_t^2\n            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1).to(self.dtype)\n\n            ######### debug\n            # imagesx = self.produce_imgs(\n            #     images_upscaled=images_upscaled,\n            #     max_t=max_t,\n            #     images=torch.randn_like(images),\n            #     num_inference_steps=50,\n            #     guidance_scale=4.0,\n            # )  # [1, 3, 64, 64]\n            # import kiui\n            # kiui.vis.plot_image(images_upscaled * 0.5 + 0.5)\n            # kiui.vis.plot_image(imagesx * 0.5 + 0.5)\n            #########\n\n            null_embeddings = self.embeddings['null'].expand(batch_size, -1, -1)\n\n            if hors is None:\n                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n            else:\n                def _get_dir_ind(h):\n                    if abs(h) < 60: return 'front'\n                    elif abs(h) < 120: return 'side'\n                    else: return 'back'\n\n                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n            ########### ISM\n            # steps\n            t = t.clamp(min=delta_t)\n            s = t - delta_t\n            n = s // delta_s\n            r = s % delta_s\n\n            # construct trajectory\n            images_noisy = images.clone()\n            cur_t = torch.full((batch_size,), 0, dtype=torch.long, device=self.device)\n\n            noise = self.unet(torch.cat([images_noisy, images_upscaled], dim=1), cur_t, encoder_hidden_states=null_embeddings, class_labels=max_t).sample.split(images_noisy.shape[1], dim=1)[0]\n            images_original = invert_noise(self.scheduler, images_noisy, noise, cur_t)\n            cur_t += r\n            images_noisy = self.scheduler.add_noise(images_original, noise, cur_t)\n\n            for i in range(n):\n                noise = self.unet(torch.cat([images_noisy, images_upscaled], dim=1), cur_t, encoder_hidden_states=null_embeddings, class_labels=max_t).sample.split(images_noisy.shape[1], dim=1)[0]\n                images_original = invert_noise(self.scheduler, images_noisy, noise, cur_t)\n                cur_t += delta_s\n                images_noisy = self.scheduler.add_noise(images_original, noise, cur_t) # x_s\n\n            # construct last step\n            noise = self.unet(torch.cat([images_noisy, images_upscaled], dim=1), cur_t, encoder_hidden_states=null_embeddings, class_labels=max_t).sample.split(images_noisy.shape[1], dim=1)[0]\n            images_original = invert_noise(self.scheduler, images_noisy, noise, cur_t) # \\hat x_0^s\n\n            # perform guidance\n            images_noisy = self.scheduler.add_noise(images_original, noise, t)\n            model_input = torch.cat([images_noisy, images_upscaled], dim=1)\n            model_input = torch.cat([model_input] * 2)\n            model_input = self.scheduler.scale_model_input(model_input, t)\n            tt = torch.cat([t] * 2)\n            max_tt = torch.cat([max_t] * 2)\n\n            noise_pred = self.unet(\n                model_input, tt, encoder_hidden_states=embeddings, class_labels=max_tt,\n            ).sample\n\n            # perform guidance (high scale from paper!)\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)\n            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n\n            grad = w * (noise_pred - noise)\n            grad = torch.nan_to_num(grad)\n\n            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8\n            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm\n        \n            target = (images - grad).detach()\n\n        loss = 0.5 * F.mse_loss(images, target, reduction='sum') / images.shape[0]\n\n        return loss\n\n    @torch.no_grad()\n    def produce_imgs(\n        self,\n        images_upscaled,\n        max_t,\n        height=256,\n        width=256,\n        num_inference_steps=50,\n        guidance_scale=4.0,\n        images=None,\n    ):\n        if images is None:\n            images = torch.randn(\n                (\n                    1,\n                    self.unet.in_channels,\n                    height,\n                    width,\n                ),\n                device=self.device,\n            )\n        \n        batch_size = images.shape[0]\n\n        self.scheduler.set_timesteps(num_inference_steps)\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps):\n            # expand the images if we are doing classifier-free guidance to avoid doing two forward passes.\n            model_input = torch.cat([images, images_upscaled], dim=1)\n            model_input = torch.cat([model_input] * 2)\n            model_input = self.scheduler.scale_model_input(model_input, t)\n            max_tt = torch.cat([max_t] * 2)\n            \n            # predict the noise residual\n            noise_pred = self.unet(\n                model_input, t, encoder_hidden_states=embeddings, class_labels=max_tt,\n            ).sample\n\n            # perform guidance\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)\n            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)\n            noise_pred = noise_pred_uncond + guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n            noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            images = self.scheduler.step(noise_pred, t, images).prev_sample\n\n        return images\n\n    \n    def prompt_to_img(\n        self,\n        images_upscaled,\n        max_t,\n        prompts,\n        negative_prompts=\"\",\n        height=256,\n        width=256,\n        num_inference_steps=50,\n        guidance_scale=4.0,\n        images=None,\n    ):\n        if isinstance(prompts, str):\n            prompts = [prompts]\n\n        if isinstance(negative_prompts, str):\n            negative_prompts = [negative_prompts]\n\n        # Prompts -> text embeds\n        self.get_text_embeds(prompts, negative_prompts)\n        \n        # Text embeds -> img images\n        images = self.produce_imgs(\n            images_upscaled=images_upscaled,\n            max_t=max_t,\n            height=height,\n            width=width,\n            images=images,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n        )  # [1, 4, 64, 64]\n\n        \n        # Img to Numpy\n        images = images.detach().cpu().permute(0, 2, 3, 1).numpy()\n        images = (images * 255).round().astype(\"uint8\")\n\n        return images\n\n\nif __name__ == \"__main__\":\n    import kiui\n    import argparse\n    import matplotlib.pyplot as plt\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"prompt\", type=str)\n    parser.add_argument(\"--negative\", default=\"\", type=str)\n    parser.add_argument(\"--fp16\", action=\"store_true\", help=\"use float16 for training\")\n    parser.add_argument(\"--vram_O\", action=\"store_true\", help=\"optimization for low VRAM usage\")\n    parser.add_argument(\"-H\", type=int, default=512)\n    parser.add_argument(\"-W\", type=int, default=512)\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--steps\", type=int, default=50)\n    opt = parser.parse_args()\n\n    kiui.seed_everything(opt.seed)\n\n    device = torch.device(\"cuda\")\n\n    sd = IF2(device, opt.fp16, opt.vram_O)\n\n    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)\n\n    # visualize image\n    plt.imshow(imgs[0])\n    plt.show()"
  },
  {
    "path": "threefiner/guidance/if2_nfsd_utils.py",
    "content": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    IFPipeline,\n    IFSuperResolutionPipeline,\n)\n\n\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass IF2(nn.Module):\n    def __init__(\n        self,\n        device,\n        fp16=True,\n        vram_O=False,\n        model_key = \"DeepFloyd/IF-II-L-v1.0\",\n        t_range=[0.02, 0.50],\n    ):\n        super().__init__()\n\n        self.device = device\n        self.model_key = model_key\n        self.dtype = torch.float16 if fp16 else torch.float32\n\n        # Create model\n        pipe = IFSuperResolutionPipeline.from_pretrained(\n            model_key, variant=\"fp16\", torch_dtype=torch.float16, \n            watermarker=None, safety_checker=None, requires_safety_checker=False,\n        )\n\n        if vram_O:\n            pipe.unet.to(memory_format=torch.channels_last)\n            pipe.enable_attention_slicing(1)\n            # pipe.enable_model_cpu_offload()\n        else:\n            pipe.to(device)\n\n        self.unet = pipe.unet\n        self.tokenizer = pipe.tokenizer\n        self.text_encoder = pipe.text_encoder\n\n        self.scheduler = pipe.scheduler\n        self.image_noising_scheduler = pipe.image_noising_scheduler\n\n        self.pipe = pipe\n\n        self.num_train_timesteps = self.scheduler.config.num_train_timesteps\n        self.min_step = int(self.num_train_timesteps * t_range[0])\n        self.max_step = int(self.num_train_timesteps * t_range[1])\n        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience\n\n        self.embeddings = {}\n        \n\n    @torch.no_grad()\n    def get_text_embeds(self, prompts, negative_prompts):\n        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]\n        neg_embeds = self.encode_text(negative_prompts)\n        null_embeds = self.encode_text([''])\n        self.embeddings['pos'] = pos_embeds\n        self.embeddings['neg'] = neg_embeds\n        self.embeddings['null'] = null_embeds\n\n        # directional embeddings\n        for d in ['front', 'side', 'back']:\n            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])\n            self.embeddings[d] = embeds\n        \n    \n    def encode_text(self, prompt):\n        # prompt: [str]\n        prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)\n        inputs = self.tokenizer(prompt, padding=\"max_length\", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]\n        return embeddings\n\n    def train_step(\n        self,\n        pred_rgb,\n        ori_rgb,\n        step_ratio=None,\n        guidance_scale=5,\n        vers=None, hors=None,\n    ):\n        \n        batch_size = pred_rgb.shape[0]\n        pred_rgb = pred_rgb.to(self.dtype)\n        ori_rgb = ori_rgb.to(self.dtype)\n\n        images = F.interpolate(pred_rgb, (256, 256), mode=\"bilinear\", align_corners=False) * 2 - 1\n\n        with torch.no_grad():\n            max_t = torch.full((batch_size,), self.max_step, dtype=torch.long, device=self.device)\n\n            # images_upscaled = images.clone()\n            images_upscaled = F.interpolate(ori_rgb, (256, 256), mode=\"bilinear\", align_corners=False).clamp(0, 1) * 2 - 1\n            noise = torch.randn_like(images_upscaled)\n            images_upscaled = self.image_noising_scheduler.add_noise(images_upscaled, noise, max_t)\n            \n            if step_ratio is not None:\n                # dreamtime-like\n                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)\n                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)\n                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)\n            else:\n                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)\n\n            # w(t), sigma_t^2\n            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1).to(self.dtype)\n\n            ######### debug\n            # imagesx = self.produce_imgs(\n            #     images_upscaled=images_upscaled,\n            #     max_t=max_t,\n            #     images=torch.randn_like(images),\n            #     num_inference_steps=50,\n            #     guidance_scale=4.0,\n            # )  # [1, 3, 64, 64]\n            # import kiui\n            # kiui.vis.plot_image(images_upscaled * 0.5 + 0.5)\n            # kiui.vis.plot_image(imagesx * 0.5 + 0.5)\n            #########\n            \n            # add noise\n            noise = torch.randn_like(images)\n            images_noisy = self.scheduler.add_noise(images, noise, t)\n            # pred noise\n            model_input = torch.cat([images_noisy, images_upscaled], dim=1)\n            model_input = torch.cat([model_input] * 3)\n            model_input = self.scheduler.scale_model_input(model_input, t)\n            tt = torch.cat([t] * 3)\n            max_tt = torch.cat([max_t] * 3)\n\n            if hors is None:\n                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)])\n            else:\n                def _get_dir_ind(h):\n                    if abs(h) < 60: return 'front'\n                    elif abs(h) < 120: return 'side'\n                    else: return 'back'\n\n                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)])\n\n            noise_pred = self.unet(\n                model_input, tt, encoder_hidden_states=embeddings, class_labels=max_tt,\n            ).sample\n\n            # perform guidance (high scale from paper!)\n            noise_pred_cond, noise_pred_uncond, noise_pred_null = noise_pred.chunk(3)\n            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)\n            noise_pred_cond, _ = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)\n            noise_pred_null, _ = noise_pred_null.split(model_input.shape[1] // 2, dim=1)\n\n            delta_c = guidance_scale * (noise_pred_cond - noise_pred_null)\n            mask = (t < 200).int().view(batch_size, 1, 1, 1)\n            delta_d = mask * noise_pred_null + (1 - mask) * (noise_pred_null - noise_pred_uncond)\n            grad = w * (delta_c + delta_d)\n            \n            grad = torch.nan_to_num(grad)\n\n            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8\n            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm\n        \n            target = (images - grad).detach()\n\n        loss = 0.5 * F.mse_loss(images, target, reduction='sum') / images.shape[0]\n\n        return loss\n\n    @torch.no_grad()\n    def produce_imgs(\n        self,\n        images_upscaled,\n        max_t,\n        height=256,\n        width=256,\n        num_inference_steps=50,\n        guidance_scale=4.0,\n        images=None,\n    ):\n        if images is None:\n            images = torch.randn(\n                (\n                    1,\n                    self.unet.in_channels,\n                    height,\n                    width,\n                ),\n                device=self.device,\n            )\n        \n        batch_size = images.shape[0]\n\n        self.scheduler.set_timesteps(num_inference_steps)\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps):\n            # expand the images if we are doing classifier-free guidance to avoid doing two forward passes.\n            model_input = torch.cat([images, images_upscaled], dim=1)\n            model_input = torch.cat([model_input] * 2)\n            model_input = self.scheduler.scale_model_input(model_input, t)\n            max_tt = torch.cat([max_t] * 2)\n            \n            # predict the noise residual\n            noise_pred = self.unet(\n                model_input, t, encoder_hidden_states=embeddings, class_labels=max_tt,\n            ).sample\n\n            # perform guidance\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)\n            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)\n            noise_pred = noise_pred_uncond + guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n            noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            images = self.scheduler.step(noise_pred, t, images).prev_sample\n\n        return images\n\n    \n    def prompt_to_img(\n        self,\n        images_upscaled,\n        max_t,\n        prompts,\n        negative_prompts=\"\",\n        height=256,\n        width=256,\n        num_inference_steps=50,\n        guidance_scale=4.0,\n        images=None,\n    ):\n        if isinstance(prompts, str):\n            prompts = [prompts]\n\n        if isinstance(negative_prompts, str):\n            negative_prompts = [negative_prompts]\n\n        # Prompts -> text embeds\n        self.get_text_embeds(prompts, negative_prompts)\n        \n        # Text embeds -> img images\n        images = self.produce_imgs(\n            images_upscaled=images_upscaled,\n            max_t=max_t,\n            height=height,\n            width=width,\n            images=images,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n        )  # [1, 4, 64, 64]\n\n        \n        # Img to Numpy\n        images = images.detach().cpu().permute(0, 2, 3, 1).numpy()\n        images = (images * 255).round().astype(\"uint8\")\n\n        return images\n\n\nif __name__ == \"__main__\":\n    import kiui\n    import argparse\n    import matplotlib.pyplot as plt\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"prompt\", type=str)\n    parser.add_argument(\"--negative\", default=\"\", type=str)\n    parser.add_argument(\"--fp16\", action=\"store_true\", help=\"use float16 for training\")\n    parser.add_argument(\"--vram_O\", action=\"store_true\", help=\"optimization for low VRAM usage\")\n    parser.add_argument(\"-H\", type=int, default=512)\n    parser.add_argument(\"-W\", type=int, default=512)\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--steps\", type=int, default=50)\n    opt = parser.parse_args()\n\n    kiui.seed_everything(opt.seed)\n\n    device = torch.device(\"cuda\")\n\n    sd = IF2(device, opt.fp16, opt.vram_O)\n\n    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)\n\n    # visualize image\n    plt.imshow(imgs[0])\n    plt.show()"
  },
  {
    "path": "threefiner/guidance/if2_utils.py",
    "content": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    IFPipeline,\n    IFSuperResolutionPipeline,\n)\n\n\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass IF2(nn.Module):\n    def __init__(\n        self,\n        device,\n        fp16=True,\n        vram_O=False,\n        # model_key = \"DeepFloyd/IF-II-L-v1.0\",\n        model_key = \"DeepFloyd/IF-II-M-v1.0\",\n        t_range=[0.02, 0.50],\n    ):\n        super().__init__()\n\n        self.device = device\n        self.model_key = model_key\n        self.dtype = torch.float16 if fp16 else torch.float32\n\n        # Create model\n        pipe = IFSuperResolutionPipeline.from_pretrained(\n            model_key, variant=\"fp16\", torch_dtype=torch.float16, \n            watermarker=None, safety_checker=None, requires_safety_checker=False,\n        )\n\n        if vram_O:\n            pipe.unet.to(memory_format=torch.channels_last)\n            pipe.enable_attention_slicing(1)\n            # pipe.enable_model_cpu_offload()\n            pipe.enable_sequential_cpu_offload()\n        else:\n            pipe.to(device)\n\n        self.unet = pipe.unet\n        self.tokenizer = pipe.tokenizer\n        self.text_encoder = pipe.text_encoder\n\n        self.scheduler = pipe.scheduler\n        self.image_noising_scheduler = pipe.image_noising_scheduler\n\n        self.pipe = pipe\n\n        self.num_train_timesteps = self.scheduler.config.num_train_timesteps\n        self.min_step = int(self.num_train_timesteps * t_range[0])\n        self.max_step = int(self.num_train_timesteps * t_range[1])\n        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience\n\n        self.embeddings = {}\n        \n\n    @torch.no_grad()\n    def get_text_embeds(self, prompts, negative_prompts):\n        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]\n        neg_embeds = self.encode_text(negative_prompts)\n        self.embeddings['pos'] = pos_embeds\n        self.embeddings['neg'] = neg_embeds\n\n        # directional embeddings\n        for d in ['front', 'side', 'back']:\n            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])\n            self.embeddings[d] = embeds\n        \n    \n    def encode_text(self, prompt):\n        # prompt: [str]\n        prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)\n        inputs = self.tokenizer(prompt, padding=\"max_length\", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]\n        return embeddings\n\n    def train_step(\n        self,\n        pred_rgb,\n        ori_rgb,\n        step_ratio=None,\n        guidance_scale=50,\n        vers=None, hors=None,\n    ):\n        \n        batch_size = pred_rgb.shape[0]\n        pred_rgb = pred_rgb.to(self.dtype)\n        ori_rgb = ori_rgb.to(self.dtype)\n\n        images = F.interpolate(pred_rgb, (256, 256), mode=\"bilinear\", align_corners=False) * 2 - 1\n\n        with torch.no_grad():\n            max_t = torch.full((batch_size,), self.max_step, dtype=torch.long, device=self.device)\n\n            # images_upscaled = images.clone()\n            images_upscaled = F.interpolate(ori_rgb, (256, 256), mode=\"bilinear\", align_corners=False).clamp(0, 1) * 2 - 1\n            noise = torch.randn_like(images_upscaled)\n            images_upscaled = self.image_noising_scheduler.add_noise(images_upscaled, noise, max_t)\n            \n            if step_ratio is not None:\n                # dreamtime-like\n                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)\n                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)\n                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)\n            else:\n                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)\n\n            # w(t), sigma_t^2\n            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1).to(self.dtype)\n\n            ######### debug\n            # imagesx = self.produce_imgs(\n            #     images_upscaled=images_upscaled,\n            #     max_t=max_t,\n            #     images=torch.randn_like(images),\n            #     num_inference_steps=50,\n            #     guidance_scale=4.0,\n            # )  # [1, 3, 64, 64]\n            # import kiui\n            # kiui.vis.plot_image(images_upscaled * 0.5 + 0.5)\n            # kiui.vis.plot_image(imagesx * 0.5 + 0.5)\n            #########\n            \n            # add noise\n            noise = torch.randn_like(images)\n            images_noisy = self.scheduler.add_noise(images, noise, t)\n            # pred noise\n            model_input = torch.cat([images_noisy, images_upscaled], dim=1)\n            model_input = torch.cat([model_input] * 2)\n            model_input = self.scheduler.scale_model_input(model_input, t)\n            tt = torch.cat([t] * 2)\n            max_tt = torch.cat([max_t] * 2)\n\n            if hors is None:\n                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n            else:\n                def _get_dir_ind(h):\n                    if abs(h) < 60: return 'front'\n                    elif abs(h) < 120: return 'side'\n                    else: return 'back'\n\n                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n            noise_pred = self.unet(\n                model_input, tt, encoder_hidden_states=embeddings, class_labels=max_tt,\n            ).sample\n\n            # perform guidance (high scale from paper!)\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)\n            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)\n            noise_pred = noise_pred_uncond + guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n\n            grad = w * (noise_pred - noise)\n            grad = torch.nan_to_num(grad)\n\n            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8\n            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm\n        \n            target = (images - grad).detach()\n\n        loss = 0.5 * F.mse_loss(images, target, reduction='sum') / images.shape[0]\n\n        return loss\n\n    @torch.no_grad()\n    def produce_imgs(\n        self,\n        images_upscaled,\n        max_t,\n        height=256,\n        width=256,\n        num_inference_steps=50,\n        guidance_scale=4.0,\n        images=None,\n    ):\n        if images is None:\n            images = torch.randn(\n                (\n                    1,\n                    self.unet.in_channels,\n                    height,\n                    width,\n                ),\n                device=self.device,\n            )\n        \n        batch_size = images.shape[0]\n\n        self.scheduler.set_timesteps(num_inference_steps)\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps):\n            # expand the images if we are doing classifier-free guidance to avoid doing two forward passes.\n            model_input = torch.cat([images, images_upscaled], dim=1)\n            model_input = torch.cat([model_input] * 2)\n            model_input = self.scheduler.scale_model_input(model_input, t)\n            max_tt = torch.cat([max_t] * 2)\n            \n            # predict the noise residual\n            noise_pred = self.unet(\n                model_input, t, encoder_hidden_states=embeddings, class_labels=max_tt,\n            ).sample\n\n            # perform guidance\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)\n            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)\n            noise_pred = noise_pred_uncond + guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n            noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            images = self.scheduler.step(noise_pred, t, images).prev_sample\n\n        return images\n\n    \n    def prompt_to_img(\n        self,\n        images_upscaled,\n        max_t,\n        prompts,\n        negative_prompts=\"\",\n        height=256,\n        width=256,\n        num_inference_steps=50,\n        guidance_scale=4.0,\n        images=None,\n    ):\n        if isinstance(prompts, str):\n            prompts = [prompts]\n\n        if isinstance(negative_prompts, str):\n            negative_prompts = [negative_prompts]\n\n        # Prompts -> text embeds\n        self.get_text_embeds(prompts, negative_prompts)\n        \n        # Text embeds -> img images\n        images = self.produce_imgs(\n            images_upscaled=images_upscaled,\n            max_t=max_t,\n            height=height,\n            width=width,\n            images=images,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n        )  # [1, 4, 64, 64]\n\n        \n        # Img to Numpy\n        images = images.detach().cpu().permute(0, 2, 3, 1).numpy()\n        images = (images * 255).round().astype(\"uint8\")\n\n        return images\n\n\nif __name__ == \"__main__\":\n    import kiui\n    import argparse\n    import matplotlib.pyplot as plt\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"prompt\", type=str)\n    parser.add_argument(\"--negative\", default=\"\", type=str)\n    parser.add_argument(\"--fp16\", action=\"store_true\", help=\"use float16 for training\")\n    parser.add_argument(\"--vram_O\", action=\"store_true\", help=\"optimization for low VRAM usage\")\n    parser.add_argument(\"-H\", type=int, default=512)\n    parser.add_argument(\"-W\", type=int, default=512)\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--steps\", type=int, default=50)\n    opt = parser.parse_args()\n\n    kiui.seed_everything(opt.seed)\n\n    device = torch.device(\"cuda\")\n\n    sd = IF2(device, opt.fp16, opt.vram_O)\n\n    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)\n\n    # visualize image\n    plt.imshow(imgs[0])\n    plt.show()"
  },
  {
    "path": "threefiner/guidance/if_utils.py",
    "content": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    IFPipeline,\n)\n\n\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass IF(nn.Module):\n    def __init__(\n        self,\n        device,\n        fp16=True,\n        vram_O=False,\n        model_key = \"DeepFloyd/IF-I-XL-v1.0\",\n        # model_key = \"DeepFloyd/IF-I-M-v1.0\",\n        t_range=[0.02, 0.98],\n    ):\n        super().__init__()\n\n        self.device = device\n        self.model_key = model_key\n        self.dtype = torch.float16 if fp16 else torch.float32\n\n        # Create model\n        pipe = IFPipeline.from_pretrained(\n            model_key, variant=\"fp16\", torch_dtype=torch.float16, \n            watermarker=None, safety_checker=None, requires_safety_checker=False,\n        )\n\n        if vram_O:\n            pipe.unet.to(memory_format=torch.channels_last)\n            pipe.enable_attention_slicing(1)\n            # pipe.enable_model_cpu_offload()\n            pipe.enable_sequential_cpu_offload()\n        else:\n            pipe.to(device)\n\n        self.unet = pipe.unet\n        self.tokenizer = pipe.tokenizer\n        self.text_encoder = pipe.text_encoder\n\n        self.scheduler = pipe.scheduler\n\n        self.pipe = pipe\n\n        self.num_train_timesteps = self.scheduler.config.num_train_timesteps\n        self.min_step = int(self.num_train_timesteps * t_range[0])\n        self.max_step = int(self.num_train_timesteps * t_range[1])\n        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience\n\n        self.embeddings = {}\n        \n\n    @torch.no_grad()\n    def get_text_embeds(self, prompts, negative_prompts):\n        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]\n        neg_embeds = self.encode_text(negative_prompts)\n        self.embeddings['pos'] = pos_embeds\n        self.embeddings['neg'] = neg_embeds\n\n        # directional embeddings\n        for d in ['front', 'side', 'back']:\n            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])\n            self.embeddings[d] = embeds\n        \n    \n    def encode_text(self, prompt):\n        # prompt: [str]\n        prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)\n        inputs = self.tokenizer(prompt, padding=\"max_length\", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]\n        return embeddings\n\n    @torch.no_grad()\n    def refine(self, pred_rgb,\n               guidance_scale=100, steps=50, strength=0.8,\n        ):\n\n        batch_size = pred_rgb.shape[0]\n        images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False)\n        \n        self.scheduler.set_timesteps(steps)\n        init_step = int(steps * strength)\n        images = self.scheduler.add_noise(images, torch.randn_like(images), self.scheduler.timesteps[init_step])\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps[init_step:]):\n    \n            model_input = torch.cat([images] * 2)\n\n            noise_pred = self.unet(\n                model_input, t, encoder_hidden_states=embeddings,\n            ).sample\n\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)\n            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1], dim=1)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n            noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)\n            \n            images = self.scheduler.step(noise_pred, t, images).prev_sample\n\n        return images\n\n    def train_step(\n        self,\n        pred_rgb,\n        step_ratio=None,\n        guidance_scale=50,\n        vers=None, hors=None,\n    ):\n        \n        batch_size = pred_rgb.shape[0]\n        pred_rgb = pred_rgb.to(self.dtype)\n\n        images = F.interpolate(pred_rgb, (64, 64), mode=\"bilinear\", align_corners=False) * 2 - 1\n\n        with torch.no_grad():\n            \n            if step_ratio is not None:\n                # dreamtime-like\n                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)\n                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)\n                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)\n            else:\n                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)\n\n            # w(t), sigma_t^2\n            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)\n\n            ######### debug\n            # # Text embeds -> img latents\n            # imagesx = self.produce_imgs(\n            #     height=64,\n            #     width=64,\n            #     images=torch.randn_like(images),\n            #     num_inference_steps=50,\n            #     guidance_scale=7.5,\n            # )  # [1, 3, 64, 64]\n            # import kiui\n            # kiui.vis.plot_image(imagesx)\n            #########\n\n            # add noise\n            noise = torch.randn_like(images)\n            images_noisy = self.scheduler.add_noise(images, noise, t)\n            # pred noise\n            model_input = torch.cat([images_noisy] * 2)\n            model_input = self.scheduler.scale_model_input(model_input, t)\n            tt = torch.cat([t] * 2)\n\n            if hors is None:\n                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n            else:\n                def _get_dir_ind(h):\n                    if abs(h) < 60: return 'front'\n                    elif abs(h) < 120: return 'side'\n                    else: return 'back'\n\n                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n            noise_pred = self.unet(\n                model_input, tt, encoder_hidden_states=embeddings\n            ).sample\n\n            # perform guidance (high scale from paper!)\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)\n            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1], dim=1)\n            noise_pred = noise_pred_uncond + guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n\n            grad = w * (noise_pred - noise)\n            grad = torch.nan_to_num(grad)\n\n            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8\n            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm\n\n            target = (images - grad).detach()\n            \n        loss = 0.5 * F.mse_loss(images.float(), target, reduction='sum') / images.shape[0]\n\n        return loss\n\n    @torch.no_grad()\n    def produce_imgs(\n        self,\n        height=512,\n        width=512,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        images=None,\n    ):\n        if images is None:\n            images = torch.randn(\n                (\n                    1,\n                    self.unet.in_channels,\n                    height,\n                    width,\n                ),\n                device=self.device,\n            )\n        \n        batch_size = images.shape[0]\n\n        self.scheduler.set_timesteps(num_inference_steps)\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps):\n            # expand the images if we are doing classifier-free guidance to avoid doing two forward passes.\n            model_input = torch.cat([images] * 2)\n            model_input = self.scheduler.scale_model_input(model_input, t)\n            # predict the noise residual\n            noise_pred = self.unet(\n                model_input, t, encoder_hidden_states=embeddings\n            ).sample\n\n            # perform guidance\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)\n            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1], dim=1)\n            noise_pred = noise_pred_uncond + guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n            noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            images = self.scheduler.step(noise_pred, t, images).prev_sample\n\n        return images\n\n    \n    def prompt_to_img(\n        self,\n        prompts,\n        negative_prompts=\"\",\n        height=64,\n        width=64,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        images=None,\n    ):\n        if isinstance(prompts, str):\n            prompts = [prompts]\n\n        if isinstance(negative_prompts, str):\n            negative_prompts = [negative_prompts]\n\n        # Prompts -> text embeds\n        self.get_text_embeds(prompts, negative_prompts)\n        \n        # Text embeds -> img images\n        images = self.produce_imgs(\n            height=height,\n            width=width,\n            images=images,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n        )  # [1, 4, 64, 64]\n\n        \n        # Img to Numpy\n        images = images.detach().cpu().permute(0, 2, 3, 1).numpy()\n        images = (images * 255).round().astype(\"uint8\")\n\n        return images\n\n\nif __name__ == \"__main__\":\n    import kiui\n    import argparse\n    import matplotlib.pyplot as plt\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"prompt\", type=str)\n    parser.add_argument(\"--negative\", default=\"\", type=str)\n    parser.add_argument(\"--fp16\", action=\"store_true\", help=\"use float16 for training\")\n    parser.add_argument(\"--vram_O\", action=\"store_true\", help=\"optimization for low VRAM usage\")\n    parser.add_argument(\"-H\", type=int, default=512)\n    parser.add_argument(\"-W\", type=int, default=512)\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--steps\", type=int, default=50)\n    opt = parser.parse_args()\n\n    kiui.seed_everything(opt.seed)\n\n    device = torch.device(\"cuda\")\n\n    sd = IF(device, opt.fp16, opt.vram_O)\n\n    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)\n\n    # visualize image\n    plt.imshow(imgs[0])\n    plt.show()"
  },
  {
    "path": "threefiner/guidance/sd_ism_utils.py",
    "content": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    StableDiffusionPipeline,\n)\n\n\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef invert_noise(scheduler, noisy_samples, noise, timesteps):\n    alphas_cumprod = scheduler.alphas_cumprod.to(device=noisy_samples.device, dtype=noisy_samples.dtype)\n    timesteps = timesteps.to(noisy_samples.device)\n\n    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5\n    sqrt_alpha_prod = sqrt_alpha_prod.flatten()\n    while len(sqrt_alpha_prod.shape) < len(noisy_samples.shape):\n        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)\n\n    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5\n    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()\n    while len(sqrt_one_minus_alpha_prod.shape) < len(noisy_samples.shape):\n        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)\n    \n    original_samples = 1 / sqrt_alpha_prod * (noisy_samples - sqrt_one_minus_alpha_prod * noise)\n    return original_samples\n\n\nclass StableDiffusion(nn.Module):\n    def __init__(\n        self,\n        device,\n        fp16=True,\n        vram_O=False,\n        model_key=\"stabilityai/stable-diffusion-2-1-base\",\n        # model_key=\"philz1337/revanimated\",\n        t_range=[0.02, 0.98],\n    ):\n        super().__init__()\n\n        self.device = device\n        self.model_key = model_key\n        self.dtype = torch.float16 if fp16 else torch.float32\n\n        # Create model\n        pipe = StableDiffusionPipeline.from_pretrained(\n            model_key, torch_dtype=self.dtype\n        )\n\n        if vram_O:\n            pipe.enable_sequential_cpu_offload()\n            pipe.enable_vae_slicing()\n            pipe.unet.to(memory_format=torch.channels_last)\n            pipe.enable_attention_slicing(1)\n            # pipe.enable_model_cpu_offload()\n        else:\n            pipe.to(device)\n\n        self.vae = pipe.vae\n        self.tokenizer = pipe.tokenizer\n        self.text_encoder = pipe.text_encoder\n        self.unet = pipe.unet\n\n        self.scheduler = DDIMScheduler.from_pretrained(\n            model_key, subfolder=\"scheduler\", torch_dtype=self.dtype\n        )\n\n        del pipe\n\n        self.num_train_timesteps = self.scheduler.config.num_train_timesteps\n        self.min_step = int(self.num_train_timesteps * t_range[0])\n        self.max_step = int(self.num_train_timesteps * t_range[1])\n        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience\n\n        self.embeddings = {}\n        \n\n    @torch.no_grad()\n    def get_text_embeds(self, prompts, negative_prompts):\n        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]\n        neg_embeds = self.encode_text(negative_prompts)\n        null_embeds = self.encode_text([\"\"])\n        self.embeddings['pos'] = pos_embeds\n        self.embeddings['neg'] = neg_embeds\n        self.embeddings['null'] = null_embeds\n\n        # directional embeddings\n        for d in ['front', 'side', 'back']:\n            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])\n            self.embeddings[d] = embeds\n        \n    \n    def encode_text(self, prompt):\n        # prompt: [str]\n        inputs = self.tokenizer(prompt, padding=\"max_length\", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]\n        return embeddings\n\n    @torch.no_grad()\n    def refine(self, pred_rgb,\n               guidance_scale=100, steps=50, strength=0.8,\n        ):\n\n        batch_size = pred_rgb.shape[0]\n        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)\n        latents = self.encode_imgs(pred_rgb_512.to(self.dtype))\n        # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)\n\n        self.scheduler.set_timesteps(steps)\n        init_step = int(steps * strength)\n        latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps[init_step:]):\n    \n            latent_model_input = torch.cat([latents] * 2)\n\n            noise_pred = self.unet(\n                latent_model_input, t, encoder_hidden_states=embeddings,\n            ).sample\n\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n            \n            latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n        imgs = self.decode_latents(latents) # [1, 3, 512, 512]\n        return imgs\n\n    def train_step(\n        self,\n        pred_rgb,\n        step_ratio=None,\n        guidance_scale=7.5,\n        as_latent=False,\n        vers=None, hors=None,\n        delta_t=50, delta_s=200,\n    ):\n        \n        batch_size = pred_rgb.shape[0]\n        pred_rgb = pred_rgb.to(self.dtype)\n\n        if as_latent:\n            latents = F.interpolate(pred_rgb, (64, 64), mode=\"bilinear\", align_corners=False) * 2 - 1\n        else:\n            # interp to 512x512 to be fed into vae.\n            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode=\"bilinear\", align_corners=False)\n            # encode image into latents with vae, requires grad!\n            latents = self.encode_imgs(pred_rgb_512)\n\n        with torch.no_grad():\n\n            if step_ratio is not None:\n                # dreamtime-like\n                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)\n                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)\n                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)\n            else:\n                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)\n\n            # w(t), sigma_t^2\n            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)\n\n            ######### debug\n            # # Text embeds -> img latents\n            # latentsx = self.produce_latents(\n            #     height=512,\n            #     width=512,\n            #     latents=torch.randn_like(latents),\n            #     num_inference_steps=50,\n            #     guidance_scale=7.5,\n            # )  # [1, 4, 64, 64]\n\n            # # Img latents -> imgs\n            # imgs = self.decode_latents(latentsx)  # [1, 3, 512, 512]\n            # import kiui\n            # kiui.vis.plot_image(imgs)\n            #########\n\n            null_embeddings = self.embeddings['null'].expand(batch_size, -1, -1)\n\n            if hors is None:\n                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n            else:\n                def _get_dir_ind(h):\n                    if abs(h) < 60: return 'front'\n                    elif abs(h) < 120: return 'side'\n                    else: return 'back'\n\n                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n            ########### ISM\n            # steps\n            t = t.clamp(min=delta_t)\n            s = t - delta_t\n            n = s // delta_s\n            r = s % delta_s\n\n            # construct trajectory\n            latents_noisy = latents.clone()\n            cur_t = torch.full((batch_size,), 0, dtype=torch.long, device=self.device)\n\n            noise = self.unet(latents_noisy, cur_t, encoder_hidden_states=null_embeddings).sample\n            latents_original = invert_noise(self.scheduler, latents_noisy, noise, cur_t)\n            cur_t += r\n            latents_noisy = self.scheduler.add_noise(latents_original, noise, cur_t)\n\n            for i in range(n):\n                noise = self.unet(latents_noisy, cur_t, encoder_hidden_states=null_embeddings).sample\n                latents_original = invert_noise(self.scheduler, latents_noisy, noise, cur_t)\n                cur_t += delta_s\n                latents_noisy = self.scheduler.add_noise(latents_original, noise, cur_t) # x_s\n\n            # construct last step\n            noise = self.unet(latents_noisy, cur_t, encoder_hidden_states=null_embeddings).sample # \\epsilon_s\n            latents_original = invert_noise(self.scheduler, latents_noisy, noise, cur_t) # \\hat x_0^s\n\n            # perform guidance\n            latents_noisy = self.scheduler.add_noise(latents_original, noise, t) # x_t\n            latent_model_input = torch.cat([latents_noisy] * 2)\n            tt = torch.cat([t] * 2)\n            noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=embeddings).sample\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n\n            grad = w * (noise_pred - noise)\n            grad = torch.nan_to_num(grad)\n             \n            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8\n            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm\n\n            target = (latents - grad).detach()\n\n        loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]\n\n        return loss\n\n    @torch.no_grad()\n    def produce_latents(\n        self,\n        height=512,\n        width=512,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        latents=None,\n    ):\n        if latents is None:\n            latents = torch.randn(\n                (\n                    1,\n                    self.unet.in_channels,\n                    height // 8,\n                    width // 8,\n                ),\n                device=self.device,\n            )\n        \n        batch_size = latents.shape[0]\n\n        self.scheduler.set_timesteps(num_inference_steps)\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps):\n            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.\n            latent_model_input = torch.cat([latents] * 2)\n            # predict the noise residual\n            noise_pred = self.unet(\n                latent_model_input, t, encoder_hidden_states=embeddings\n            ).sample\n\n            # perform guidance\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n        return latents\n\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n\n        imgs = self.vae.decode(latents).sample\n        imgs = (imgs / 2 + 0.5).clamp(0, 1)\n\n        return imgs\n\n    def encode_imgs(self, imgs):\n        # imgs: [B, 3, H, W]\n\n        imgs = 2 * imgs - 1\n\n        posterior = self.vae.encode(imgs).latent_dist\n        latents = posterior.sample() * self.vae.config.scaling_factor\n\n        return latents\n\n    def prompt_to_img(\n        self,\n        prompts,\n        negative_prompts=\"\",\n        height=512,\n        width=512,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        latents=None,\n    ):\n        if isinstance(prompts, str):\n            prompts = [prompts]\n\n        if isinstance(negative_prompts, str):\n            negative_prompts = [negative_prompts]\n\n        # Prompts -> text embeds\n        self.get_text_embeds(prompts, negative_prompts)\n        \n        # Text embeds -> img latents\n        latents = self.produce_latents(\n            height=height,\n            width=width,\n            latents=latents,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n        )  # [1, 4, 64, 64]\n\n        # Img latents -> imgs\n        imgs = self.decode_latents(latents)  # [1, 3, 512, 512]\n\n        # Img to Numpy\n        imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()\n        imgs = (imgs * 255).round().astype(\"uint8\")\n\n        return imgs\n\n\nif __name__ == \"__main__\":\n    import kiui\n    import argparse\n    import matplotlib.pyplot as plt\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"prompt\", type=str)\n    parser.add_argument(\"--negative\", default=\"\", type=str)\n    parser.add_argument(\"--fp16\", action=\"store_true\", help=\"use float16 for training\")\n    parser.add_argument(\"--vram_O\", action=\"store_true\", help=\"optimization for low VRAM usage\")\n    parser.add_argument(\"-H\", type=int, default=512)\n    parser.add_argument(\"-W\", type=int, default=512)\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--steps\", type=int, default=50)\n    opt = parser.parse_args()\n\n    kiui.seed_everything(opt.seed)\n\n    device = torch.device(\"cuda\")\n\n    sd = StableDiffusion(device, opt.fp16, opt.vram_O)\n\n    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)\n\n    # visualize image\n    plt.imshow(imgs[0])\n    plt.show()\n"
  },
  {
    "path": "threefiner/guidance/sd_nfsd_utils.py",
    "content": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    StableDiffusionPipeline,\n)\n\n\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass StableDiffusion(nn.Module):\n    def __init__(\n        self,\n        device,\n        fp16=True,\n        vram_O=False,\n        model_key=\"stabilityai/stable-diffusion-2-1-base\",\n        # model_key=\"philz1337/revanimated\",\n        t_range=[0.02, 0.50],\n    ):\n        super().__init__()\n\n        self.device = device\n        self.model_key = model_key\n        self.dtype = torch.float16 if fp16 else torch.float32\n\n        # Create model\n        pipe = StableDiffusionPipeline.from_pretrained(\n            model_key, torch_dtype=self.dtype\n        )\n\n        if vram_O:\n            pipe.enable_sequential_cpu_offload()\n            pipe.enable_vae_slicing()\n            pipe.unet.to(memory_format=torch.channels_last)\n            pipe.enable_attention_slicing(1)\n            # pipe.enable_model_cpu_offload()\n        else:\n            pipe.to(device)\n\n        self.vae = pipe.vae\n        self.tokenizer = pipe.tokenizer\n        self.text_encoder = pipe.text_encoder\n        self.unet = pipe.unet\n\n        self.scheduler = DDIMScheduler.from_pretrained(\n            model_key, subfolder=\"scheduler\", torch_dtype=self.dtype\n        )\n\n        del pipe\n\n        self.num_train_timesteps = self.scheduler.config.num_train_timesteps\n        self.min_step = int(self.num_train_timesteps * t_range[0])\n        self.max_step = int(self.num_train_timesteps * t_range[1])\n        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience\n\n        self.embeddings = {}\n        \n\n    @torch.no_grad()\n    def get_text_embeds(self, prompts, negative_prompts):\n        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]\n        neg_embeds = self.encode_text(negative_prompts)\n        null_embeds = self.encode_text([''])\n        self.embeddings['pos'] = pos_embeds\n        self.embeddings['neg'] = neg_embeds\n        self.embeddings['null'] = null_embeds\n\n        # directional embeddings\n        for d in ['front', 'side', 'back']:\n            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])\n            self.embeddings[d] = embeds\n        \n    \n    def encode_text(self, prompt):\n        # prompt: [str]\n        inputs = self.tokenizer(prompt, padding=\"max_length\", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]\n        return embeddings\n\n    @torch.no_grad()\n    def refine(self, pred_rgb,\n               guidance_scale=100, steps=50, strength=0.8,\n        ):\n\n        batch_size = pred_rgb.shape[0]\n        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)\n        latents = self.encode_imgs(pred_rgb_512.to(self.dtype))\n        # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)\n\n        self.scheduler.set_timesteps(steps)\n        init_step = int(steps * strength)\n        latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps[init_step:]):\n    \n            latent_model_input = torch.cat([latents] * 2)\n\n            noise_pred = self.unet(\n                latent_model_input, t, encoder_hidden_states=embeddings,\n            ).sample\n\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n            \n            latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n        imgs = self.decode_latents(latents) # [1, 3, 512, 512]\n        return imgs\n\n    def train_step(\n        self,\n        pred_rgb,\n        step_ratio=None,\n        guidance_scale=7.5,\n        as_latent=False,\n        vers=None, hors=None,\n    ):\n        \n        batch_size = pred_rgb.shape[0]\n        pred_rgb = pred_rgb.to(self.dtype)\n\n        if as_latent:\n            latents = F.interpolate(pred_rgb, (64, 64), mode=\"bilinear\", align_corners=False) * 2 - 1\n        else:\n            # interp to 512x512 to be fed into vae.\n            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode=\"bilinear\", align_corners=False)\n            # encode image into latents with vae, requires grad!\n            latents = self.encode_imgs(pred_rgb_512)\n\n        with torch.no_grad():\n\n            if step_ratio is not None:\n                # dreamtime-like\n                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)\n                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)\n                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)\n            else:\n                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)\n\n            # w(t), sigma_t^2\n            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)\n\n            ######### debug\n            # # Text embeds -> img latents\n            # latentsx = self.produce_latents(\n            #     height=512,\n            #     width=512,\n            #     latents=torch.randn_like(latents),\n            #     num_inference_steps=50,\n            #     guidance_scale=7.5,\n            # )  # [1, 4, 64, 64]\n\n            # # Img latents -> imgs\n            # imgs = self.decode_latents(latentsx)  # [1, 3, 512, 512]\n            # import kiui\n            # kiui.vis.plot_image(imgs)\n            #########\n\n            # add noise\n            noise = torch.randn_like(latents)\n            latents_noisy = self.scheduler.add_noise(latents, noise, t)\n            # pred noise\n            latent_model_input = torch.cat([latents_noisy] * 3)\n            tt = torch.cat([t] * 3)\n\n            if hors is None:\n                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)])\n            else:\n                def _get_dir_ind(h):\n                    if abs(h) < 60: return 'front'\n                    elif abs(h) < 120: return 'side'\n                    else: return 'back'\n\n                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)])\n\n            noise_pred = self.unet(\n                latent_model_input, tt, encoder_hidden_states=embeddings\n            ).sample\n\n            # perform guidance (high scale from paper!)\n            noise_pred_cond, noise_pred_uncond, noise_pred_null = noise_pred.chunk(3)\n            delta_c = guidance_scale * (noise_pred_cond - noise_pred_null)\n            mask = (t < 200).int().view(batch_size, 1, 1, 1)\n            delta_d = mask * noise_pred_null + (1 - mask) * (noise_pred_null - noise_pred_uncond) * 3\n            grad = w * (delta_c + delta_d)\n            grad = torch.nan_to_num(grad)\n\n            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8\n            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm\n\n            target = (latents - grad).detach()\n\n        loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]\n\n        return loss\n\n    @torch.no_grad()\n    def produce_latents(\n        self,\n        height=512,\n        width=512,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        latents=None,\n    ):\n        if latents is None:\n            latents = torch.randn(\n                (\n                    1,\n                    self.unet.in_channels,\n                    height // 8,\n                    width // 8,\n                ),\n                device=self.device,\n            )\n        \n        batch_size = latents.shape[0]\n\n        self.scheduler.set_timesteps(num_inference_steps)\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps):\n            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.\n            latent_model_input = torch.cat([latents] * 2)\n            # predict the noise residual\n            noise_pred = self.unet(\n                latent_model_input, t, encoder_hidden_states=embeddings\n            ).sample\n\n            # perform guidance\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n        return latents\n\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n\n        imgs = self.vae.decode(latents).sample\n        imgs = (imgs / 2 + 0.5).clamp(0, 1)\n\n        return imgs\n\n    def encode_imgs(self, imgs):\n        # imgs: [B, 3, H, W]\n\n        imgs = 2 * imgs - 1\n\n        posterior = self.vae.encode(imgs).latent_dist\n        latents = posterior.sample() * self.vae.config.scaling_factor\n\n        return latents\n\n    def prompt_to_img(\n        self,\n        prompts,\n        negative_prompts=\"\",\n        height=512,\n        width=512,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        latents=None,\n    ):\n        if isinstance(prompts, str):\n            prompts = [prompts]\n\n        if isinstance(negative_prompts, str):\n            negative_prompts = [negative_prompts]\n\n        # Prompts -> text embeds\n        self.get_text_embeds(prompts, negative_prompts)\n        \n        # Text embeds -> img latents\n        latents = self.produce_latents(\n            height=height,\n            width=width,\n            latents=latents,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n        )  # [1, 4, 64, 64]\n\n        # Img latents -> imgs\n        imgs = self.decode_latents(latents)  # [1, 3, 512, 512]\n\n        # Img to Numpy\n        imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()\n        imgs = (imgs * 255).round().astype(\"uint8\")\n\n        return imgs\n\n\nif __name__ == \"__main__\":\n    import kiui\n    import argparse\n    import matplotlib.pyplot as plt\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"prompt\", type=str)\n    parser.add_argument(\"--negative\", default=\"\", type=str)\n    parser.add_argument(\"--fp16\", action=\"store_true\", help=\"use float16 for training\")\n    parser.add_argument(\"--vram_O\", action=\"store_true\", help=\"optimization for low VRAM usage\")\n    parser.add_argument(\"-H\", type=int, default=512)\n    parser.add_argument(\"-W\", type=int, default=512)\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--steps\", type=int, default=50)\n    opt = parser.parse_args()\n\n    kiui.seed_everything(opt.seed)\n\n    device = torch.device(\"cuda\")\n\n    sd = StableDiffusion(device, opt.fp16, opt.vram_O)\n\n    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)\n\n    # visualize image\n    plt.imshow(imgs[0])\n    plt.show()\n"
  },
  {
    "path": "threefiner/guidance/sd_utils.py",
    "content": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    StableDiffusionPipeline,\n)\n\n\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass StableDiffusion(nn.Module):\n    def __init__(\n        self,\n        device,\n        fp16=True,\n        vram_O=False,\n        model_key=\"stabilityai/stable-diffusion-2-1-base\",\n        # model_key=\"philz1337/revanimated\",\n        t_range=[0.02, 0.50],\n    ):\n        super().__init__()\n\n        self.device = device\n        self.model_key = model_key\n        self.dtype = torch.float16 if fp16 else torch.float32\n\n        # Create model\n        pipe = StableDiffusionPipeline.from_pretrained(\n            model_key, torch_dtype=self.dtype\n        )\n\n        if vram_O:\n            pipe.enable_sequential_cpu_offload()\n            pipe.enable_vae_slicing()\n            pipe.unet.to(memory_format=torch.channels_last)\n            pipe.enable_attention_slicing(1)\n            # pipe.enable_model_cpu_offload()\n        else:\n            pipe.to(device)\n\n        self.vae = pipe.vae\n        self.tokenizer = pipe.tokenizer\n        self.text_encoder = pipe.text_encoder\n        self.unet = pipe.unet\n\n        self.scheduler = DDIMScheduler.from_pretrained(\n            model_key, subfolder=\"scheduler\", torch_dtype=self.dtype\n        )\n\n        del pipe\n\n        self.num_train_timesteps = self.scheduler.config.num_train_timesteps\n        self.min_step = int(self.num_train_timesteps * t_range[0])\n        self.max_step = int(self.num_train_timesteps * t_range[1])\n        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience\n\n        self.embeddings = {}\n        \n\n    @torch.no_grad()\n    def get_text_embeds(self, prompts, negative_prompts):\n        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]\n        neg_embeds = self.encode_text(negative_prompts)\n        self.embeddings['pos'] = pos_embeds\n        self.embeddings['neg'] = neg_embeds\n\n        # directional embeddings\n        for d in ['front', 'side', 'back']:\n            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])\n            self.embeddings[d] = embeds\n        \n    \n    def encode_text(self, prompt):\n        # prompt: [str]\n        inputs = self.tokenizer(prompt, padding=\"max_length\", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]\n        return embeddings\n\n    @torch.no_grad()\n    def refine(self, pred_rgb,\n               guidance_scale=100, steps=50, strength=0.8,\n        ):\n\n        batch_size = pred_rgb.shape[0]\n        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)\n        latents = self.encode_imgs(pred_rgb_512.to(self.dtype))\n        # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)\n\n        self.scheduler.set_timesteps(steps)\n        init_step = int(steps * strength)\n        latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps[init_step:]):\n    \n            latent_model_input = torch.cat([latents] * 2)\n\n            noise_pred = self.unet(\n                latent_model_input, t, encoder_hidden_states=embeddings,\n            ).sample\n\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n            \n            latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n        imgs = self.decode_latents(latents) # [1, 3, 512, 512]\n        return imgs\n\n    def train_step(\n        self,\n        pred_rgb,\n        step_ratio=None,\n        guidance_scale=100,\n        as_latent=False,\n        vers=None, hors=None,\n    ):\n        \n        batch_size = pred_rgb.shape[0]\n        pred_rgb = pred_rgb.to(self.dtype)\n\n        if as_latent:\n            latents = F.interpolate(pred_rgb, (64, 64), mode=\"bilinear\", align_corners=False) * 2 - 1\n        else:\n            # interp to 512x512 to be fed into vae.\n            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode=\"bilinear\", align_corners=False)\n            # encode image into latents with vae, requires grad!\n            latents = self.encode_imgs(pred_rgb_512)\n\n        with torch.no_grad():\n\n            if step_ratio is not None:\n                # dreamtime-like\n                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)\n                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)\n                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)\n            else:\n                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)\n\n            # w(t), sigma_t^2\n            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)\n\n            ######### debug\n            # # Text embeds -> img latents\n            # latentsx = self.produce_latents(\n            #     height=512,\n            #     width=512,\n            #     latents=torch.randn_like(latents),\n            #     num_inference_steps=50,\n            #     guidance_scale=7.5,\n            # )  # [1, 4, 64, 64]\n\n            # # Img latents -> imgs\n            # imgs = self.decode_latents(latentsx)  # [1, 3, 512, 512]\n            # import kiui\n            # kiui.vis.plot_image(imgs)\n            #########\n\n            # add noise\n            noise = torch.randn_like(latents)\n            latents_noisy = self.scheduler.add_noise(latents, noise, t)\n            # pred noise\n            latent_model_input = torch.cat([latents_noisy] * 2)\n            tt = torch.cat([t] * 2)\n\n            if hors is None:\n                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n            else:\n                def _get_dir_ind(h):\n                    if abs(h) < 60: return 'front'\n                    elif abs(h) < 120: return 'side'\n                    else: return 'back'\n\n                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n            noise_pred = self.unet(\n                latent_model_input, tt, encoder_hidden_states=embeddings\n            ).sample\n\n            # perform guidance (high scale from paper!)\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n\n            grad = w * (noise_pred - noise)\n            grad = torch.nan_to_num(grad)\n\n            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8\n            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm\n\n            target = (latents - grad).detach()\n\n        loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]\n\n        return loss\n\n    @torch.no_grad()\n    def produce_latents(\n        self,\n        height=512,\n        width=512,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        latents=None,\n    ):\n        if latents is None:\n            latents = torch.randn(\n                (\n                    1,\n                    self.unet.in_channels,\n                    height // 8,\n                    width // 8,\n                ),\n                device=self.device,\n            )\n        \n        batch_size = latents.shape[0]\n\n        self.scheduler.set_timesteps(num_inference_steps)\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps):\n            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.\n            latent_model_input = torch.cat([latents] * 2)\n            # predict the noise residual\n            noise_pred = self.unet(\n                latent_model_input, t, encoder_hidden_states=embeddings\n            ).sample\n\n            # perform guidance\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n        return latents\n\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n\n        imgs = self.vae.decode(latents).sample\n        imgs = (imgs / 2 + 0.5).clamp(0, 1)\n\n        return imgs\n\n    def encode_imgs(self, imgs):\n        # imgs: [B, 3, H, W]\n\n        imgs = 2 * imgs - 1\n\n        posterior = self.vae.encode(imgs).latent_dist\n        latents = posterior.sample() * self.vae.config.scaling_factor\n\n        return latents\n\n    def prompt_to_img(\n        self,\n        prompts,\n        negative_prompts=\"\",\n        height=512,\n        width=512,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        latents=None,\n    ):\n        if isinstance(prompts, str):\n            prompts = [prompts]\n\n        if isinstance(negative_prompts, str):\n            negative_prompts = [negative_prompts]\n\n        # Prompts -> text embeds\n        self.get_text_embeds(prompts, negative_prompts)\n        \n        # Text embeds -> img latents\n        latents = self.produce_latents(\n            height=height,\n            width=width,\n            latents=latents,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n        )  # [1, 4, 64, 64]\n\n        # Img latents -> imgs\n        imgs = self.decode_latents(latents)  # [1, 3, 512, 512]\n\n        # Img to Numpy\n        imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()\n        imgs = (imgs * 255).round().astype(\"uint8\")\n\n        return imgs\n\n\nif __name__ == \"__main__\":\n    import kiui\n    import argparse\n    import matplotlib.pyplot as plt\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"prompt\", type=str)\n    parser.add_argument(\"--negative\", default=\"\", type=str)\n    parser.add_argument(\"--fp16\", action=\"store_true\", help=\"use float16 for training\")\n    parser.add_argument(\"--vram_O\", action=\"store_true\", help=\"optimization for low VRAM usage\")\n    parser.add_argument(\"-H\", type=int, default=512)\n    parser.add_argument(\"-W\", type=int, default=512)\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--steps\", type=int, default=50)\n    opt = parser.parse_args()\n\n    kiui.seed_everything(opt.seed)\n\n    device = torch.device(\"cuda\")\n\n    sd = StableDiffusion(device, opt.fp16, opt.vram_O)\n\n    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)\n\n    # visualize image\n    plt.imshow(imgs[0])\n    plt.show()\n"
  },
  {
    "path": "threefiner/guidance/sdcn_utils.py",
    "content": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    StableDiffusionPipeline,\n    ControlNetModel,\n)\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass StableDiffusionControlNet(nn.Module):\n    def __init__(\n        self,\n        device,\n        fp16=True,\n        vram_O=False,\n        control_mode=[\"tile\"],\n        model_key=\"runwayml/stable-diffusion-v1-5\",\n        # model_key=\"philz1337/revanimated\",\n        t_range=[0.02, 0.50],\n    ):\n        super().__init__()\n\n        self.device = device\n        self.control_mode = control_mode\n        self.dtype = torch.float16 if fp16 else torch.float32\n\n        # Create model\n        pipe = StableDiffusionPipeline.from_pretrained(\n            model_key, torch_dtype=self.dtype\n        )\n\n        if vram_O:\n            pipe.enable_sequential_cpu_offload()\n            pipe.enable_vae_slicing()\n            pipe.unet.to(memory_format=torch.channels_last)\n            pipe.enable_attention_slicing(1)\n            # pipe.enable_model_cpu_offload()\n        else:\n            pipe.to(device)\n\n        self.vae = pipe.vae\n        self.tokenizer = pipe.tokenizer\n        self.text_encoder = pipe.text_encoder\n        self.unet = pipe.unet\n\n        # controlnet\n        if self.control_mode is not None:\n            self.controlnet = {}\n            self.controlnet_conditioning_scale = {}\n            \n            if \"normal\" in self.control_mode:\n                self.controlnet['normal'] = ControlNetModel.from_pretrained(\"lllyasviel/control_v11p_sd15_normalbae\",torch_dtype=self.dtype).to(self.device)\n                self.controlnet_conditioning_scale['normal'] = 1.0\n            if \"depth\" in self.control_mode:\n                self.controlnet['depth'] = ControlNetModel.from_pretrained(\"lllyasviel/control_v11f1p_sd15_depth\",torch_dtype=self.dtype).to(self.device)\n                self.controlnet_conditioning_scale['depth'] = 1.0\n            if \"ip2p\" in self.control_mode:\n                self.controlnet['ip2p'] = ControlNetModel.from_pretrained(\"lllyasviel/control_v11e_sd15_ip2p\",torch_dtype=self.dtype).to(self.device)\n                self.controlnet_conditioning_scale['ip2p'] = 1.0\n            if \"inpaint\" in self.control_mode:\n                self.controlnet['inpaint'] = ControlNetModel.from_pretrained(\"lllyasviel/control_v11p_sd15_inpaint\",torch_dtype=self.dtype).to(self.device)\n                self.controlnet_conditioning_scale['inpaint'] = 1.0\n            if \"depth_inpaint\" in self.control_mode:\n                self.controlnet['depth_inpaint'] = ControlNetModel.from_pretrained(\"lllyasviel/control_v11e_sd15_depth_aware_inpaint\",torch_dtype=self.dtype).to(self.device)\n                self.controlnet_conditioning_scale['depth_inpaint'] = 1.0\n            if \"pose\" in self.control_mode:\n                self.controlnet['pose'] = ControlNetModel.from_pretrained(\"lllyasviel/control_v11p_sd15_openpose\",torch_dtype=self.dtype).to(self.device)\n                self.controlnet_conditioning_scale['pose'] = 1.0\n            if \"tile\" in self.control_mode:\n                self.controlnet['tile'] = ControlNetModel.from_pretrained(\"lllyasviel/control_v11f1e_sd15_tile\",torch_dtype=self.dtype).to(self.device)\n                self.controlnet_conditioning_scale['tile'] = 1.0\n\n        self.scheduler = DDIMScheduler.from_pretrained(\n            model_key, subfolder=\"scheduler\", torch_dtype=self.dtype\n        )\n\n        del pipe\n\n        self.num_train_timesteps = self.scheduler.config.num_train_timesteps\n        self.min_step = int(self.num_train_timesteps * t_range[0])\n        self.max_step = int(self.num_train_timesteps * t_range[1])\n        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience\n\n        self.embeddings = {}\n        \n\n    @torch.no_grad()\n    def get_text_embeds(self, prompts, negative_prompts):\n        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]\n        neg_embeds = self.encode_text(negative_prompts)\n        self.embeddings['pos'] = pos_embeds\n        self.embeddings['neg'] = neg_embeds\n\n        # directional embeddings\n        for d in ['front', 'side', 'back']:\n            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])\n            self.embeddings[d] = embeds\n        \n    \n    def encode_text(self, prompt):\n        # prompt: [str]\n        inputs = self.tokenizer(prompt, padding=\"max_length\", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]\n        return embeddings\n\n    @torch.no_grad()\n    def refine(self, pred_rgb,\n               guidance_scale=100, steps=50, strength=0.8,\n               control_images=None\n        ):\n\n        batch_size = pred_rgb.shape[0]\n        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)\n        latents = self.encode_imgs(pred_rgb_512.to(self.dtype))\n        # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)\n\n        self.scheduler.set_timesteps(steps)\n        init_step = int(steps * strength)\n        latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps[init_step:]):\n    \n            latent_model_input = torch.cat([latents] * 2)\n\n            if self.control_mode is not None and control_images is not None:\n\n                noise_pred = 0\n\n                for mode, controlnet in self.controlnet.items():\n                    # may omit control mode if input is not provided\n                    if mode not in control_images: continue\n                    \n                    control_image = control_images[mode].to(self.dtype)\n                    weight = 1 / len(self.controlnet)\n\n                    control_image_input = torch.cat([control_image] * 2)\n                    down_samples, mid_sample = controlnet(\n                        latent_model_input, t, encoder_hidden_states=embeddings, \n                        controlnet_cond=control_image_input, \n                        conditioning_scale=self.controlnet_conditioning_scale[mode],\n                        return_dict=False\n                    )\n\n                    # predict the noise residual\n                    noise_pred_cur = self.unet(\n                        latent_model_input, t, encoder_hidden_states=embeddings, \n                        down_block_additional_residuals=down_samples, \n                        mid_block_additional_residual=mid_sample\n                    ).sample\n\n                    # merge after unet\n                    noise_pred = noise_pred + weight * noise_pred_cur\n                \n            else:\n                noise_pred = self.unet(\n                    latent_model_input, t, encoder_hidden_states=self.embeddings,\n                ).sample\n\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n            \n            latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n        imgs = self.decode_latents(latents) # [1, 3, 512, 512]\n        return imgs\n\n    def train_step(\n        self,\n        pred_rgb,\n        step_ratio=None,\n        guidance_scale=100,\n        as_latent=False,\n        control_images=None,\n        vers=None, hors=None,\n    ):\n        \n        batch_size = pred_rgb.shape[0]\n        pred_rgb = pred_rgb.to(self.dtype)\n\n        if as_latent:\n            latents = F.interpolate(pred_rgb, (64, 64), mode=\"bilinear\", align_corners=False) * 2 - 1\n        else:\n            # interp to 512x512 to be fed into vae.\n            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode=\"bilinear\", align_corners=False)\n            # encode image into latents with vae, requires grad!\n            latents = self.encode_imgs(pred_rgb_512)\n\n        with torch.no_grad():\n            if step_ratio is not None:\n                # dreamtime-like\n                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)\n                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)\n                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)\n            else:\n                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)\n\n            # w(t), sigma_t^2\n            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)\n\n            ######### debug\n            # # Text embeds -> img latents\n            # latentsx = self.produce_latents(\n            #     height=512,\n            #     width=512,\n            #     latents=torch.randn_like(latents),\n            #     num_inference_steps=50,\n            #     guidance_scale=7.5,\n            #     control_images=control_images,\n            # )  # [1, 4, 64, 64]\n            # # Img latents -> imgs\n            # imgs = self.decode_latents(latentsx)  # [1, 3, 512, 512]\n            # import kiui\n            # kiui.vis.plot_image(control_images['tile'], imgs)\n            #########\n\n            # add noise\n            noise = torch.randn_like(latents)\n            latents_noisy = self.scheduler.add_noise(latents, noise, t)\n            # pred noise\n            latent_model_input = torch.cat([latents_noisy] * 2)\n            tt = torch.cat([t] * 2)\n\n            if hors is None:\n                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n            else:\n                def _get_dir_ind(h):\n                    if abs(h) < 60: return 'front'\n                    elif abs(h) < 120: return 'side'\n                    else: return 'back'\n\n                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])\n                \n\n            if self.control_mode is not None and control_images is not None:\n\n                noise_pred = 0\n\n                for mode, controlnet in self.controlnet.items():\n                    # may omit control mode if input is not provided\n                    if mode not in control_images: continue\n                    \n                    control_image = control_images[mode].to(self.dtype)\n                    weight = 1 / len(self.controlnet)\n\n                    control_image_input = torch.cat([control_image] * 2)\n                    down_samples, mid_sample = controlnet(\n                        latent_model_input, tt, encoder_hidden_states=embeddings, \n                        controlnet_cond=control_image_input, \n                        conditioning_scale=self.controlnet_conditioning_scale[mode],\n                        return_dict=False\n                    )\n\n                    # predict the noise residual\n                    noise_pred_cur = self.unet(\n                        latent_model_input, tt, encoder_hidden_states=embeddings, \n                        down_block_additional_residuals=down_samples, \n                        mid_block_additional_residual=mid_sample\n                    ).sample\n\n                    # merge after unet\n                    noise_pred = noise_pred + weight * noise_pred_cur\n                \n            else:\n                noise_pred = self.unet(\n                    latent_model_input, tt, encoder_hidden_states=embeddings,\n                ).sample\n\n            # perform guidance (high scale from paper!)\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n\n            grad = w * (noise_pred - noise)\n            grad = torch.nan_to_num(grad)\n\n            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8\n            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm\n\n            target = (latents - grad).detach()\n\n        loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]\n\n        return loss\n\n    @torch.no_grad()\n    def produce_latents(\n        self,\n        height=512,\n        width=512,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        latents=None,\n        control_images=None,\n    ):\n        if latents is None:\n            latents = torch.randn(\n                (\n                    1,\n                    self.unet.in_channels,\n                    height // 8,\n                    width // 8,\n                ),\n                device=self.device,\n            )\n        \n        batch_size = latents.shape[0]\n\n        self.scheduler.set_timesteps(num_inference_steps)\n        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])\n\n        for i, t in enumerate(self.scheduler.timesteps):\n            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.\n            latent_model_input = torch.cat([latents] * 2)\n\n            # predict the noise residual\n            if self.control_mode is not None and control_images is not None:\n\n                noise_pred = 0\n\n                for mode, controlnet in self.controlnet.items():\n                    # may omit control mode if input is not provided\n                    if mode not in control_images: continue\n                    \n                    control_image = control_images[mode].to(self.dtype)\n                    weight = 1 / len(self.controlnet)\n\n                    control_image_input = torch.cat([control_image] * 2)\n                    down_samples, mid_sample = controlnet(\n                        latent_model_input, t, encoder_hidden_states=embeddings, \n                        controlnet_cond=control_image_input, \n                        conditioning_scale=self.controlnet_conditioning_scale[mode],\n                        return_dict=False\n                    )\n\n                    # predict the noise residual\n                    noise_pred_cur = self.unet(\n                        latent_model_input, t, encoder_hidden_states=embeddings, \n                        down_block_additional_residuals=down_samples, \n                        mid_block_additional_residual=mid_sample\n                    ).sample\n\n                    # merge after unet\n                    noise_pred = noise_pred + weight * noise_pred_cur\n                \n            else:\n                noise_pred = self.unet(\n                    latent_model_input, t, encoder_hidden_states=embeddings,\n                ).sample\n\n            # perform guidance\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (\n                noise_pred_cond - noise_pred_uncond\n            )\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n        return latents\n\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n\n        imgs = self.vae.decode(latents).sample\n        imgs = (imgs / 2 + 0.5).clamp(0, 1)\n\n        return imgs\n\n    def encode_imgs(self, imgs):\n        # imgs: [B, 3, H, W]\n\n        imgs = 2 * imgs - 1\n\n        posterior = self.vae.encode(imgs).latent_dist\n        latents = posterior.sample() * self.vae.config.scaling_factor\n\n        return latents\n\n    def prompt_to_img(\n        self,\n        prompts,\n        negative_prompts=\"\",\n        height=512,\n        width=512,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        latents=None,\n        control_images=None,\n    ):\n        if isinstance(prompts, str):\n            prompts = [prompts]\n\n        if isinstance(negative_prompts, str):\n            negative_prompts = [negative_prompts]\n\n        # Prompts -> text embeds\n        self.get_text_embeds(prompts, negative_prompts)\n        \n        # Text embeds -> img latents\n        latents = self.produce_latents(\n            height=height,\n            width=width,\n            latents=latents,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            control_images=control_images,\n        )  # [1, 4, 64, 64]\n\n        # Img latents -> imgs\n        imgs = self.decode_latents(latents)  # [1, 3, 512, 512]\n\n        # Img to Numpy\n        imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()\n        imgs = (imgs * 255).round().astype(\"uint8\")\n\n        return imgs\n\n\nif __name__ == \"__main__\":\n    import kiui\n    import argparse\n    import matplotlib.pyplot as plt\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"image\", type=str)\n    parser.add_argument(\"prompt\", default=\"\", type=str)\n    parser.add_argument(\"--control\", default='tile', type=str)\n    parser.add_argument(\"--negative\", default=\"\", type=str)\n    parser.add_argument(\"--fp16\", action=\"store_true\", help=\"use float16 for training\")\n    parser.add_argument(\"--vram_O\", action=\"store_true\", help=\"optimization for low VRAM usage\")\n    parser.add_argument(\"-H\", type=int, default=512)\n    parser.add_argument(\"-W\", type=int, default=512)\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--steps\", type=int, default=50)\n    opt = parser.parse_args()\n\n    kiui.seed_everything(opt.seed)\n\n    device = torch.device(\"cuda\")\n\n    # load image\n    control_image = kiui.read_image(opt.image, mode='tensor').permute(2,0,1).contiguous().unsqueeze(0).to(device)\n    control_image = F.interpolate(control_image, (opt.H, opt.W), mode='bilinear', align_corners=False)\n\n    kiui.lo(control_image)\n\n    control_images = {}\n    control_images[opt.control] = control_image\n\n    sd = StableDiffusionControlNet(device, opt.fp16, opt.vram_O, control_mode=[opt.control])\n\n    while True:\n        imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps, control_images=control_images)\n\n        # visualize image\n        plt.imshow(imgs[0])\n        plt.show()"
  },
  {
    "path": "threefiner/lights/LICENSE.txt",
    "content": "The mud_road_puresky.hdr HDR probe is from https://polyhaven.com/a/mud_road_puresky\nCC0 License.\n"
  },
  {
    "path": "threefiner/nn.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport numpy as np\nimport tinycudann as tcnn\n\nclass HashGridEncoder(nn.Module):\n    def __init__(self, \n                 input_dim=3,\n                 num_levels=16,\n                 level_dim=2,\n                 log2_hashmap_size=18, \n                 base_resolution=16, \n                 desired_resolution=1024, \n                 interpolation='linear'\n                 ):\n        super().__init__()\n        self.encoder = tcnn.Encoding(\n            n_input_dims=input_dim,\n            encoding_config={\n                \"otype\": \"HashGrid\",\n                \"n_levels\": num_levels,\n                \"n_features_per_level\": level_dim,\n                \"log2_hashmap_size\": log2_hashmap_size,\n                \"base_resolution\": base_resolution,\n                \"per_level_scale\": np.exp2(np.log2(desired_resolution / num_levels) / (num_levels - 1)),\n                \"interpolation\": \"Smoothstep\" if interpolation == 'smoothstep' else \"Linear\",\n            },\n            dtype=torch.float32,\n        )\n        self.input_dim = input_dim\n        self.output_dim = self.encoder.n_output_dims # patch\n    \n    def forward(self, x, bound=1):\n        return self.encoder((x + bound) / (2 * bound))\n\nclass FrequencyEncoder(nn.Module):\n    def __init__(self, \n                 input_dim=3,\n                 output_dim=32,\n                 n_frequencies=12,\n                 ):\n        super().__init__()\n        self.encoder = tcnn.Encoding(\n            n_input_dims=input_dim,\n            encoding_config={\n                \"otype\": \"Frequency\",\n                \"n_frequencies\": n_frequencies,\n            },\n            dtype=torch.float32,\n        )\n        self.implicit_mlp = MLP(self.encoder.n_output_dims, output_dim, 128, 5, bias=True)\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n    \n    def forward(self, x, **kwargs):\n        return self.implicit_mlp(self.encoder(x))\n    \n\nclass TriplaneEncoder(nn.Module):\n    def __init__(self, \n                 input_dim=3,\n                 output_dim=32,\n                 resolution=256,\n                 ):\n        super().__init__()\n\n        self.C_mat = nn.Parameter(torch.randn(3, output_dim, resolution, resolution))\n        torch.nn.init.kaiming_normal_(self.C_mat)\n        \n        self.mat_ids = [[0, 1], [0, 2], [1, 2]]\n\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n    \n    def forward(self, x, bound=1):\n\n        N = x.shape[0]\n        x = x / bound # to [-1, 1]\n\n        mat_coord = torch.stack((x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]])).view(3, -1, 1, 2) # [3, N, 1, 2]\n\n        feat = F.grid_sample(self.C_mat[[0]], mat_coord[[0]], align_corners=False).view(-1, N) + \\\n               F.grid_sample(self.C_mat[[1]], mat_coord[[1]], align_corners=False).view(-1, N) + \\\n               F.grid_sample(self.C_mat[[2]], mat_coord[[2]], align_corners=False).view(-1, N) # [r, N]\n\n        # density\n        feat = feat.transpose(0, 1).contiguous() # [N, C]\n        return feat\n\nclass MLP(nn.Module):\n    def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):\n        super().__init__()\n        self.dim_in = dim_in\n        self.dim_out = dim_out\n        self.dim_hidden = dim_hidden\n        self.num_layers = num_layers\n\n        net = []\n        for l in range(num_layers):\n            net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))\n\n        self.net = nn.ModuleList(net)\n    \n    def forward(self, x):\n        for l in range(self.num_layers):\n            x = self.net[l](x)\n            if l != self.num_layers - 1:\n                x = F.relu(x, inplace=True)\n        return x"
  },
  {
    "path": "threefiner/opt.py",
    "content": "import os\nfrom dataclasses import dataclass\nfrom typing import Tuple, Literal, Dict, Optional\n\n@dataclass\nclass Options:\n    # path to input mesh\n    mesh: Optional[str] = None\n    # input text prompt\n    prompt: Optional[str] = None\n    # additional positive prompt\n    positive_prompt: str = \"best quality, extremely detailed, masterpiece, high resolution, high quality\"\n    # additional negative prompt\n    negative_prompt: str = \"blur, lowres, cropped, low quality, worst quality, ugly, dark, shadow, oversaturated\"\n    # whether to append directional text prompt\n    text_dir: bool = False\n    # set mesh front-facing direction (camera front=+z, right=+x, up=+y, clock-wise rotation 90=1, 180=2, 270=3, e.g., +z, -y1)\n    front_dir: str = \"+z\"\n\n    # training iterations\n    iters: int = 500\n    # training resolution\n    render_resolution: int = 512\n    # training camera radius\n    radius: float = 2.5\n    # training camera fovy in degree\n    fovy: float = 49.1\n    # whether to allow geom training\n    fix_geo: bool = False\n    # whether to mix normal with rgb for geometry training\n    mix_normal: bool = True\n    # whether to pretrain texture first\n    fit_tex: bool = True\n    # pretrain texture iterations\n    fit_tex_iters: int = 512\n\n    # output folder\n    outdir: str = '.'\n    # output filename, default to {name}_fine.{ext}\n    save: Optional[str] = None\n\n    # guidance mode\n    mode: Literal['SD', 'IF', 'IF2', 'SDCN', 'SD_NFSD', 'IF2_NFSD', 'SD_ISM', 'IF2_ISM'] = 'IF2'\n    # renderer geometry mode\n    geom_mode: Literal['mesh', 'diffmc', 'pbr_mesh', 'pbr_diffmc'] = 'diffmc'\n    # renderer texture mode\n    tex_mode: Literal['hashgrid', 'mlp', 'triplane'] = 'hashgrid'\n    \n    # training batch size per iter\n    batch_size: int = 1\n    # environmental texture\n    env_texture: Optional[str] = None\n    # environmental light scale\n    env_scale: float = 2\n    \n    # DiffMC grid size\n    mc_grid_size: int = 128\n    # Mesh remeshing interval\n    remesh_interval: int = 200\n    # mesh decimation target face number\n    decimate_target: int = 5e4\n    # remesh target edge length (smaller value lead to finer mesh)\n    remesh_size: float = 0.015\n    # texture resolution\n    texture_resolution: int = 1024\n    # learning rate for hashgrid\n    hashgrid_lr: float = 0.01\n    # learning rate for feature MLP\n    mlp_lr: float = 0.001\n    # learning rate for SDF\n    sdf_lr: float = 0.0001\n    # learning rate for deformation\n    deform_lr: float = 0.0001\n    # learning rate for mesh geometry\n    geom_lr: float = 0.0001\n\n    # guidance loss weights\n    lambda_sd: float = 1\n    # mesh laplacian regularization weight\n    lambda_lap: float = 0\n    # mesh normal consistency weight (should be large enough)\n    lambda_normal: float = 10000\n    # mesh vertices offset penalty weight\n    lambda_offsets: float = 100\n\n    # whether to open a GUI\n    gui: bool = False\n    # GUI height\n    H: int = 800\n    # GUI width\n    W: int = 800\n    # whether to use CUDA rasterizer (in case OpenGL fails)\n    force_cuda_rast: bool = False\n    # whether to use GPU memory-optimized mode (slower, but uses less GPU memory)\n    vram_O: bool = False\n\n\n# all the default settings\nconfig_defaults: Dict[str, Options] = {}\nconfig_doc: Dict[str, str] = {}\n\nconfig_doc['sd'] = 'coarse-level generation with stable-diffusion 2.'\nconfig_defaults['sd'] = Options(\n    mode='SD',\n    iters=800,\n)\n\nconfig_doc['if'] = 'coarse-level generation with deepfloyd-if I.'\nconfig_defaults['if'] = Options(\n    mode='IF',\n    iters=400,\n)\n\nconfig_doc['if2'] = 'fine-level refinement with deepfloyd-if II.'\nconfig_defaults['if2'] = Options(\n    mode='IF2',\n    iters=400,\n)\n\nconfig_doc['sd_fixgeo'] = 'coarse-level generation with stable-diffusion 2, fixed goemetry.'\nconfig_defaults['sd_fixgeo'] = Options(\n    mode='SD',\n    iters=800,\n    fix_geo=True,\n    geom_mode='mesh',\n)\n\nconfig_doc['if_fixgeo'] = 'coarse-level generation with deepfloyd-if I, fixed goemetry.'\nconfig_defaults['if_fixgeo'] = Options(\n    mode='IF',\n    iters=400,\n    fix_geo=True,\n    geom_mode='mesh',\n)\n\nconfig_doc['if2_fixgeo'] = 'fine-level refinement with deepfloyd-if II, fixed goemetry.'\nconfig_defaults['if2_fixgeo'] = Options(\n    mode='IF2',\n    iters=400,\n    fix_geo=True,\n    geom_mode='mesh',\n)\n\ndef check_options(opt: Options):\n    assert opt.mesh is not None, 'mesh path must be specified!'\n    assert opt.prompt is not None, 'prompt must be specified!'\n\n    if opt.save is None:\n        input_name, input_ext = os.path.splitext(os.path.basename(opt.mesh))\n        opt.save = input_name + '_fine' + '.glb'\n        print(f'[INFO] save to default output path: {os.path.join(opt.outdir, opt.save)}.')\n\n    return opt"
  },
  {
    "path": "threefiner/renderer/__init__.py",
    "content": "\n"
  },
  {
    "path": "threefiner/renderer/diffmc_renderer.py",
    "content": "import os\nimport tqdm\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport nvdiffrast.torch as dr\n\nimport kiui\nfrom kiui.mesh import Mesh\nfrom kiui.mesh_utils import clean_mesh, decimate_mesh\nfrom kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding\nfrom kiui.cam import orbit_camera, get_perspective\n\nfrom threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder\nfrom threefiner.renderer.mesh_renderer import render_mesh\n\nfrom diso import DiffMC, DiffDMC\n\nclass Renderer(nn.Module):\n    def __init__(self, opt, device):\n        \n        super().__init__()\n\n        self.opt = opt\n        self.device = device\n\n        if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):\n            self.glctx = dr.RasterizeGLContext()\n        else:\n            self.glctx = dr.RasterizeCudaContext()\n\n        # diffmc\n        self.verts = torch.stack(\n            torch.meshgrid(\n                torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),\n                torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),\n                torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),\n                indexing=\"ij\",\n            ), dim=-1,\n        ) # [N, N, N, 3]\n        self.grid_scale = 1\n        self.diffmc = DiffMC(dtype=torch.float32).to(device)\n        \n        # vert sdf and deform\n        self.sdf = nn.Parameter(torch.zeros_like(self.verts[..., 0]))\n        self.deform = nn.Parameter(torch.zeros_like(self.verts))\n        \n        # init diffmc from mesh\n        self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir)\n\n        vertices = self.mesh.v.detach().cpu().numpy()\n        triangles = self.mesh.f.detach().cpu().numpy()\n        vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False)\n        self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)\n        self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)\n\n        self.grid_scale = self.mesh.v.abs().max() + 1e-1\n        self.verts = self.verts * self.grid_scale\n        \n        try:\n            import cubvh\n            BVH = cubvh.cuBVH(self.mesh.v, self.mesh.f)\n            sdf, _, _ = BVH.signed_distance(self.verts.reshape(-1, 3), return_uvw=False, mode='raystab') # some mesh may not be watertight...\n        except:\n            from pysdf import SDF\n            sdf_func = SDF(self.mesh.v.detach().cpu().numpy(), self.mesh.f.detach().cpu().numpy())\n            sdf = sdf_func(self.verts.detach().cpu().numpy().reshape(-1, 3))\n            sdf = torch.from_numpy(sdf).to(self.device)\n            sdf *= -1\n\n        # OUTER is POSITIVE\n        self.sdf.data += sdf.reshape(*self.sdf.data.shape).to(self.sdf.data.dtype)\n\n        # texture\n        if self.opt.tex_mode == 'hashgrid':\n            self.encoder = HashGridEncoder().to(self.device)\n        elif self.opt.tex_mode == 'mlp':\n            self.encoder = FrequencyEncoder().to(self.device)\n        elif self.opt.tex_mode == 'triplane':\n            self.encoder = TriplaneEncoder().to(self.device)\n        else:\n            raise NotImplementedError(f\"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}\")\n        \n        self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=True).to(self.device)\n\n        self.v, self.f = None, None # placeholder\n\n        # init hashgrid texture from mesh\n        if self.opt.fit_tex:\n            self.fit_texture_from_mesh(self.opt.fit_tex_iters)\n    \n    def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):\n        return render_mesh(\n            self.glctx, \n            self.mesh.v, self.mesh.f, self.mesh.vt, \n            self.mesh.ft, self.mesh.albedo, \n            self.mesh.vc, self.mesh.vn, self.mesh.fn, \n            pose, proj, h, w, \n            ssaa=ssaa, bg_color=bg_color,\n        )\n\n    def fit_texture_from_mesh(self, iters=512):\n        # a small training loop...\n\n        loss_fn = torch.nn.MSELoss()\n        optimizer = torch.optim.Adam([\n            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},\n            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},\n        ])\n\n        resolution = 512\n\n        print(f\"[INFO] fitting texture...\")\n        pbar = tqdm.trange(iters)\n        for i in pbar:\n\n            ver = np.random.randint(-45, 45)\n            hor = np.random.randint(-180, 180)\n            \n            pose = orbit_camera(ver, hor, self.opt.radius)\n            proj = get_perspective(self.opt.fovy)\n\n            image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image']\n            image_pred = self.render(pose, proj, resolution, resolution)['image']\n\n            loss = loss_fn(image_pred, image_mesh)\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            pbar.set_description(f\"MSE = {loss.item():.6f}\")\n        \n        print(f\"[INFO] finished fitting texture!\")\n\n    def get_params(self):\n\n        params = [\n            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},\n            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},\n        ]\n\n        if not self.opt.fix_geo:\n            params.append({'params': self.sdf, 'lr': self.opt.sdf_lr})\n            params.append({'params': self.deform, 'lr': self.opt.deform_lr})\n\n        return params\n\n    @torch.no_grad()\n    def export_mesh(self, save_path, texture_resolution=2048, padding=16):\n\n        # get v\n        sdf = self.sdf\n        deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5]\n\n        v, f = self.diffmc(sdf, deform)\n        v = (2 * v - 1) * self.grid_scale\n        f = f.int()\n        self.v, self.f = v, f\n        \n        vertices = v.detach().cpu().numpy()\n        triangles = f.detach().cpu().numpy()\n\n        # clean\n        vertices = vertices.astype(np.float32)\n        triangles = triangles.astype(np.int32)\n        vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=self.opt.remesh_size)\n        \n        # decimation\n        if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target:\n            vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target)\n        \n        v = torch.from_numpy(vertices).contiguous().float().to(self.device)\n        f = torch.from_numpy(triangles).contiguous().int().to(self.device)\n        \n        mesh = Mesh(v=v, f=f, albedo=None, device=self.device)\n        print(f\"[INFO] uv unwrapping...\")\n        mesh.auto_normal()\n        mesh.auto_uv()\n\n        # render uv maps\n        h = w = texture_resolution\n        uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]\n        uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]\n\n        rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]\n        xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]\n        mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]\n\n        # masked query \n        xyzs = xyzs.view(-1, 3)\n        mask = (mask > 0).view(-1)\n        \n        albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)\n\n        if mask.any():\n            print(f\"[INFO] querying texture...\")\n\n            xyzs = xyzs[mask] # [M, 3]\n\n            # batched inference to avoid OOM\n            batch = []\n            head = 0\n            while head < xyzs.shape[0]:\n                tail = min(head + 640000, xyzs.shape[0])\n                batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())\n                head += 640000\n\n            albedo[mask] = torch.cat(batch, dim=0)\n        \n        albedo = albedo.view(h, w, -1)\n        mask = mask.view(h, w)\n\n        print(f\"[INFO] uv padding...\")\n        albedo = uv_padding(albedo, mask, padding)\n\n        mesh.albedo = albedo\n        mesh.write(save_path)\n\n    def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):\n        \n        # do super-sampling\n        if ssaa != 1:\n            h = make_divisible(h0 * ssaa, 8)\n            w = make_divisible(w0 * ssaa, 8)\n        else:\n            h, w = h0, w0\n        \n        results = {}\n\n        # get v\n        sdf = self.sdf\n        deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5]\n\n        v, f = self.diffmc(sdf, deform)\n        v = (2 * v - 1) * self.grid_scale\n        f = f.int()\n        self.v, self.f = v, f\n\n        pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)\n        proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)\n\n        # get v_clip and render rgb\n        v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)\n        v_clip = v_cam @ proj.T\n\n        rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))\n\n        alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [V, H, W, 1]\n        alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients!\n        depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]\n        depth = depth.squeeze(0) # [H, W, 1]\n\n        xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]\n        xyzs = xyzs.view(-1, 3)\n        mask = (alpha > 0).view(-1)\n        color = torch.zeros_like(xyzs, dtype=torch.float32)\n        if mask.any():\n            masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1)))\n            color[mask] = masked_albedo.float()\n        color = color.view(1, h, w, 3)\n\n        # antialias\n        color = dr.antialias(color, rast, v_clip, f).squeeze(0) # [H, W, 3]\n        color = alpha * color + (1 - alpha) * bg_color\n\n        # get vn and render normal\n        i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()\n        v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]\n\n        face_normals = torch.cross(v1 - v0, v2 - v0)\n        face_normals = safe_normalize(face_normals)\n        \n        vn = torch.zeros_like(v)\n        vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)\n        vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)\n        vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)\n\n        vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))\n\n        normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f)\n        normal = safe_normalize(normal[0])\n\n        # rotated normal (where [0, 0, 1] always faces camera)\n        rot_normal = normal @ pose[:3, :3]\n        viewcos = rot_normal[..., [2]]\n\n        # ssaa\n        if ssaa != 1:\n            color = scale_img_hwc(color, (h0, w0))\n            alpha = scale_img_hwc(alpha, (h0, w0))\n            depth = scale_img_hwc(depth, (h0, w0))\n            normal = scale_img_hwc(normal, (h0, w0))\n            viewcos = scale_img_hwc(viewcos, (h0, w0))\n\n        results['image'] = color.clamp(0, 1)\n        results['alpha'] = alpha\n        results['depth'] = depth\n        results['normal'] = (normal + 1) / 2\n        results['viewcos'] = viewcos\n\n        return results"
  },
  {
    "path": "threefiner/renderer/mesh_renderer.py",
    "content": "import os\nimport tqdm\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport nvdiffrast.torch as dr\n\nfrom kiui.mesh import Mesh\nfrom kiui.mesh_utils import clean_mesh, decimate_mesh\nfrom kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding\nfrom kiui.cam import orbit_camera, get_perspective\nfrom threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder\n\ndef render_mesh(\n        glctx, \n        v, f, \n        vt, ft, albedo, vc,\n        vn, fn,\n        pose, proj, \n        h0, w0, \n        ssaa=1, bg_color=1, \n        texture_filter='linear-mipmap-linear', \n        color_activation=None,\n    ):\n    \n    # do super-sampling\n    if ssaa != 1:\n        h = make_divisible(h0 * ssaa, 8)\n        w = make_divisible(w0 * ssaa, 8)\n    else:\n        h, w = h0, w0\n    \n    results = {}\n\n    pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)\n    proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)\n\n    # get v_clip and render rgb\n    v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)\n    v_clip = v_cam @ proj.T\n\n    rast, rast_db = dr.rasterize(glctx, v_clip, f, (h, w))\n\n    alpha = (rast[0, ..., 3:] > 0).float()\n    depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]\n    depth = depth.squeeze(0) # [H, W, 1]\n\n    if vc is not None:\n        # use vertex color\n        color, _ = dr.interpolate(vc.unsqueeze(0).contiguous(), rast, f)\n    else:\n        # use texture image\n        texc, texc_db = dr.interpolate(vt.unsqueeze(0).contiguous(), rast, ft, rast_db=rast_db, diff_attrs='all')\n        color = dr.texture(albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3]\n\n    if color_activation is not None:\n        color = color_activation(color)\n\n    # antialias\n    color = dr.antialias(color, rast, v_clip, f).squeeze(0) # [H, W, 3]\n    color = alpha * color + (1 - alpha) * bg_color\n\n    # get vn and render normal\n    if vn is None:\n        i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()\n        v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]\n\n        face_normals = torch.cross(v1 - v0, v2 - v0)\n        face_normals = safe_normalize(face_normals)\n        \n        vn = torch.zeros_like(v)\n        vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)\n        vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)\n        vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)\n\n        vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))\n    \n    normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, fn)\n    normal = safe_normalize(normal[0])\n\n    # rotated normal (where [0, 0, 1] always faces camera)\n    rot_normal = normal @ pose[:3, :3]\n    viewcos = rot_normal[..., [2]]\n\n    # ssaa\n    if ssaa != 1:\n        color = scale_img_hwc(color, (h0, w0))\n        alpha = scale_img_hwc(alpha, (h0, w0))\n        depth = scale_img_hwc(depth, (h0, w0))\n        normal = scale_img_hwc(normal, (h0, w0))\n        viewcos = scale_img_hwc(viewcos, (h0, w0))\n\n    results['image'] = color.clamp(0, 1)\n    results['alpha'] = alpha\n    results['depth'] = depth\n    results['normal'] = (normal + 1) / 2\n    results['viewcos'] = viewcos\n\n    return results\n\nclass Renderer(nn.Module):\n    def __init__(self, opt, device):\n        \n        super().__init__()\n\n        self.opt = opt\n        self.device = device\n\n        self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir)\n\n        # it's necessary to clean the mesh to facilitate later remeshing!\n        vertices = self.mesh.v.detach().cpu().numpy()\n        triangles = self.mesh.f.detach().cpu().numpy()\n        vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False)\n        self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)\n        self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)\n\n        if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):\n            self.glctx = dr.RasterizeGLContext()\n        else:\n            self.glctx = dr.RasterizeCudaContext()\n        \n        # extract trainable parameters\n        self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v))\n        \n        # texture\n        if self.opt.tex_mode == 'hashgrid':\n            self.encoder = HashGridEncoder().to(self.device)\n        elif self.opt.tex_mode == 'mlp':\n            self.encoder = FrequencyEncoder().to(self.device)\n        elif self.opt.tex_mode == 'triplane':\n            self.encoder = TriplaneEncoder().to(self.device)\n        else:\n            raise NotImplementedError(f\"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}\")\n        \n        self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=True).to(self.device)\n\n        # init hashgrid texture from mesh\n        if self.opt.fit_tex:\n            self.fit_texture_from_mesh(self.opt.fit_tex_iters)\n\n    def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):\n        return render_mesh(\n            self.glctx, \n            self.mesh.v, self.mesh.f, self.mesh.vt, \n            self.mesh.ft, self.mesh.albedo, \n            self.mesh.vc, self.mesh.vn, self.mesh.fn, \n            pose, proj, h, w, \n            ssaa=ssaa, bg_color=bg_color,\n        )\n    \n    def fit_texture_from_mesh(self, iters=512):\n        # a small training loop...\n\n        loss_fn = torch.nn.MSELoss()\n        optimizer = torch.optim.Adam([\n            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},\n            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},\n        ])\n\n        resolution = 512\n\n        print(f\"[INFO] fitting texture...\")\n        pbar = tqdm.trange(iters)\n        for i in pbar:\n\n            ver = np.random.randint(-45, 45)\n            hor = np.random.randint(-180, 180)\n            \n            pose = orbit_camera(ver, hor, self.opt.radius)\n            proj = get_perspective(self.opt.fovy)\n\n            image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image']\n            image_pred = self.render(pose, proj, resolution, resolution)['image']\n\n            loss = loss_fn(image_pred, image_mesh)\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            pbar.set_description(f\"MSE = {loss.item():.6f}\")\n        \n        print(f\"[INFO] finished fitting texture!\")\n\n    def get_params(self):\n\n        params = [\n            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},\n            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},\n        ]\n\n        if not self.opt.fix_geo:\n            params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr})\n\n        return params\n\n    @torch.no_grad()\n    def export_mesh(self, save_path, texture_resolution=2048, padding=16):\n\n        mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)\n        print(f\"[INFO] uv unwrapping...\")\n        mesh.auto_normal()\n        mesh.auto_uv()\n\n        # render uv maps\n        h = w = texture_resolution\n        uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]\n        uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]\n\n        rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]\n\n        # masked query \n        xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]\n        mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1]\n        xyzs = xyzs.view(-1, 3)\n        mask = (mask > 0).view(-1)\n        \n        albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)\n\n        if mask.any():\n            print(f\"[INFO] querying texture...\")\n\n            xyzs = xyzs[mask] # [M, 3]\n\n            # batched inference to avoid OOM\n            batch = []\n            head = 0\n            while head < xyzs.shape[0]:\n                tail = min(head + 640000, xyzs.shape[0])\n                batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())\n                head += 640000\n\n            albedo[mask] = torch.cat(batch, dim=0)\n        \n        albedo = albedo.view(h, w, -1)\n        mask = mask.view(h, w)\n\n        print(f\"[INFO] uv padding...\")\n        albedo = uv_padding(albedo, mask, padding)\n\n        mesh.albedo = albedo\n        mesh.write(save_path)\n\n    @property\n    def v(self):\n        if self.opt.fix_geo:\n            return self.mesh.v\n        else:\n            return self.mesh.v + self.v_offsets\n    \n    @property\n    def f(self):\n        return self.mesh.f\n\n    @torch.no_grad()\n    def remesh(self):\n        vertices = self.v.detach().cpu().numpy()\n        triangles = self.f.detach().cpu().numpy()\n        vertices, triangles = clean_mesh(vertices, triangles, repair=False, remesh=True, remesh_size=self.opt.remesh_size)\n        if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target:\n            vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target, optimalplacement=False)\n        self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)\n        self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)\n        self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v)).to(self.device)\n        \n    \n    def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):\n\n        # do super-sampling\n        if ssaa != 1:\n            h = make_divisible(h0 * ssaa, 8)\n            w = make_divisible(w0 * ssaa, 8)\n        else:\n            h, w = h0, w0\n        \n        results = {}\n\n        # get v\n        v = self.v\n        f = self.f\n\n        pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)\n        proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)\n\n        # get v_clip and render rgb\n        v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)\n        v_clip = v_cam @ proj.T\n\n        rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))\n\n        alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]\n        alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients!\n\n        depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]\n        depth = depth.squeeze(0) # [H, W, 1]\n\n        xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]\n        xyzs = xyzs.view(-1, 3)\n        mask = (alpha > 0).view(-1)\n        color = torch.zeros_like(xyzs, dtype=torch.float32)\n        if mask.any():\n            masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1)))\n            color[mask] = masked_albedo.float()\n        color = color.view(1, h, w, 3)\n\n        # antialias\n        color = dr.antialias(color, rast, v_clip, f).clamp(0, 1).squeeze(0) # [H, W, 3]\n        color = alpha * color + (1 - alpha) * bg_color\n\n        # get vn and render normal\n        if self.opt.fix_geo:\n            vn = self.mesh.vn\n        else:\n            i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()\n            v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]\n\n            face_normals = torch.cross(v1 - v0, v2 - v0)\n            face_normals = safe_normalize(face_normals)\n\n            vn = torch.zeros_like(v)\n            vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)\n            vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)\n            vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)\n\n            vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))\n\n        normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f)\n        normal = safe_normalize(normal[0])\n\n        # rotated normal (where [0, 0, 1] always faces camera)\n        rot_normal = normal @ pose[:3, :3]\n        viewcos = rot_normal[..., [2]]\n\n        # ssaa\n        if ssaa != 1:\n            color = scale_img_hwc(color, (h0, w0))\n            alpha = scale_img_hwc(alpha, (h0, w0))\n            depth = scale_img_hwc(depth, (h0, w0))\n            normal = scale_img_hwc(normal, (h0, w0))\n            viewcos = scale_img_hwc(viewcos, (h0, w0))\n\n        results['image'] = color\n        results['alpha'] = alpha\n        results['depth'] = depth\n        results['normal'] = (normal + 1) / 2\n        results['viewcos'] = viewcos\n\n        return results"
  },
  {
    "path": "threefiner/renderer/pbr_diffmc_renderer.py",
    "content": "import os\nimport tqdm\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport envlight\n\nimport nvdiffrast.torch as dr\n\nimport kiui\nfrom kiui.mesh import Mesh\nfrom kiui.mesh_utils import clean_mesh, decimate_mesh\nfrom kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding\nfrom kiui.cam import orbit_camera, get_perspective\n\nfrom threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder\nfrom threefiner.renderer.mesh_renderer import render_mesh\n\nfrom diso import DiffMC, DiffDMC\n\nclass Renderer(nn.Module):\n    def __init__(self, opt, device):\n        \n        super().__init__()\n\n        self.opt = opt\n        self.device = device\n\n        if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):\n            self.glctx = dr.RasterizeGLContext()\n        else:\n            self.glctx = dr.RasterizeCudaContext()\n\n        # diffmc\n        self.verts = torch.stack(\n            torch.meshgrid(\n                torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),\n                torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),\n                torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),\n                indexing=\"ij\",\n            ), dim=-1,\n        ) # [N, N, N, 3]\n        self.grid_scale = 1\n        self.diffmc = DiffMC(dtype=torch.float32).to(device)\n        \n        # vert sdf and deform\n        self.sdf = nn.Parameter(torch.zeros_like(self.verts[..., 0]))\n        self.deform = nn.Parameter(torch.zeros_like(self.verts))\n        \n        # init diffmc from mesh\n        self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir)\n\n        vertices = self.mesh.v.detach().cpu().numpy()\n        triangles = self.mesh.f.detach().cpu().numpy()\n        vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False)\n        self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)\n        self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)\n\n        self.grid_scale = self.mesh.v.abs().max() + 1e-1\n        self.verts = self.verts * self.grid_scale\n        \n        try:\n            import cubvh\n            BVH = cubvh.cuBVH(self.mesh.v, self.mesh.f)\n            sdf, _, _ = BVH.signed_distance(self.verts.reshape(-1, 3), return_uvw=False, mode='raystab') # some mesh may not be watertight...\n        except:\n            from pysdf import SDF\n            sdf_func = SDF(self.mesh.v.detach().cpu().numpy(), self.mesh.f.detach().cpu().numpy())\n            sdf = sdf_func(self.verts.detach().cpu().numpy().reshape(-1, 3))\n            sdf = torch.from_numpy(sdf).to(self.device)\n            sdf *= -1\n\n        # OUTER is POSITIVE\n        self.sdf.data += sdf.reshape(*self.sdf.data.shape).to(self.sdf.data.dtype)\n\n        # texture\n        if self.opt.tex_mode == 'hashgrid':\n            self.encoder = HashGridEncoder().to(self.device)\n        elif self.opt.tex_mode == 'mlp':\n            self.encoder = FrequencyEncoder().to(self.device)\n        elif self.opt.tex_mode == 'triplane':\n            self.encoder = TriplaneEncoder().to(self.device)\n        else:\n            raise NotImplementedError(f\"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}\")\n        \n        self.mlp = MLP(self.encoder.output_dim, 3+2, 32, 2, bias=True).to(self.device)\n\n        self.v, self.f = None, None # placeholder\n\n        # env light\n        if self.opt.env_texture is None:\n            hdr_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../lights/mud_road_puresky_1k.hdr')\n        else:\n            hdr_path = self.opt.env_texture\n        self.light = envlight.EnvLight(hdr_path, scale=self.opt.env_scale, device=self.device)\n\n        FG_LUT = torch.from_numpy(np.fromfile(os.path.join(os.path.dirname(os.path.abspath(__file__)), \"../lights/bsdf_256_256.bin\"), dtype=np.float32).reshape(1, 256, 256, 2)).to(self.device)\n        self.register_buffer(\"FG_LUT\", FG_LUT)\n\n        # init hashgrid texture from mesh\n        if self.opt.fit_tex:\n            self.fit_texture_from_mesh(self.opt.fit_tex_iters)\n    \n    def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):\n        return render_mesh(\n            self.glctx, \n            self.mesh.v, self.mesh.f, self.mesh.vt, \n            self.mesh.ft, self.mesh.albedo, \n            self.mesh.vc, self.mesh.vn, self.mesh.fn, \n            pose, proj, h, w, \n            ssaa=ssaa, bg_color=bg_color,\n        )\n\n    def fit_texture_from_mesh(self, iters=512):\n        # a small training loop...\n\n        loss_fn = torch.nn.MSELoss()\n        optimizer = torch.optim.Adam([\n            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},\n            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},\n        ])\n\n        resolution = 512\n\n        print(f\"[INFO] fitting texture...\")\n        pbar = tqdm.trange(iters)\n        for i in pbar:\n\n            ver = np.random.randint(-45, 45)\n            hor = np.random.randint(-180, 180)\n            \n            pose = orbit_camera(ver, hor, self.opt.radius)\n            proj = get_perspective(self.opt.fovy)\n\n            image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image']\n            image_pred = self.render(pose, proj, resolution, resolution)['image']\n\n            loss = loss_fn(image_pred, image_mesh)\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            pbar.set_description(f\"MSE = {loss.item():.6f}\")\n        \n        print(f\"[INFO] finished fitting texture!\")\n\n    def get_params(self):\n\n        params = [\n            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},\n            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},\n        ]\n\n        if not self.opt.fix_geo:\n            params.append({'params': self.sdf, 'lr': self.opt.sdf_lr})\n            params.append({'params': self.deform, 'lr': self.opt.deform_lr})\n\n        return params\n\n    @torch.no_grad()\n    def export_mesh(self, save_path, texture_resolution=2048, padding=16):\n\n        # get v\n        sdf = self.sdf\n        deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5]\n\n        v, f = self.diffmc(sdf, deform)\n        v = (2 * v - 1) * self.grid_scale\n        f = f.int()\n        self.v, self.f = v, f\n        \n        vertices = v.detach().cpu().numpy()\n        triangles = f.detach().cpu().numpy()\n\n        # clean\n        vertices = vertices.astype(np.float32)\n        triangles = triangles.astype(np.int32)\n        vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=self.opt.remesh_size)\n        \n        # decimation\n        if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target:\n            vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target)\n        \n        v = torch.from_numpy(vertices).contiguous().float().to(self.device)\n        f = torch.from_numpy(triangles).contiguous().int().to(self.device)\n        \n        mesh = Mesh(v=v, f=f, albedo=None, device=self.device)\n        print(f\"[INFO] uv unwrapping...\")\n        mesh.auto_normal()\n        mesh.auto_uv()\n\n        # render uv maps\n        h = w = texture_resolution\n        uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]\n        uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]\n\n        rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]\n        xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]\n        mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]\n\n        # masked query \n        xyzs = xyzs.view(-1, 3)\n        mask = (mask > 0).view(-1)\n        \n        material = torch.zeros(h * w, 5, device=self.device, dtype=torch.float32)\n\n        if mask.any():\n            print(f\"[INFO] querying texture...\")\n\n            xyzs = xyzs[mask] # [M, 3]\n\n            # batched inference to avoid OOM\n            batch = []\n            head = 0\n            while head < xyzs.shape[0]:\n                tail = min(head + 640000, xyzs.shape[0])\n                batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())\n                head += 640000\n\n            material[mask] = torch.cat(batch, dim=0)\n        \n        material = material.view(h, w, -1)\n        mask = mask.view(h, w)\n\n        print(f\"[INFO] uv padding...\")\n        material = uv_padding(material, mask, padding)\n\n        mesh.albedo = material[..., :3]\n        mesh.metallicRoughness = torch.cat([torch.zeros_like(material[..., 3:4]), material[..., 4:5], material[..., 3:4]], dim=-1)\n\n        mesh.write(save_path)\n\n    def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):\n        \n        # do super-sampling\n        if ssaa != 1:\n            h = make_divisible(h0 * ssaa, 8)\n            w = make_divisible(w0 * ssaa, 8)\n        else:\n            h, w = h0, w0\n        \n        results = {}\n\n        # get v\n        sdf = self.sdf\n        deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5]\n\n        v, f = self.diffmc(sdf, deform)\n        v = (2 * v - 1) * self.grid_scale\n        f = f.int()\n        self.v, self.f = v, f\n\n        pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)\n        proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)\n\n        # get v_clip and render rgb\n        v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)\n        v_clip = v_cam @ proj.T\n\n        rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))\n\n        alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [V, H, W, 1]\n        alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients!\n        depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]\n        depth = depth.squeeze(0) # [H, W, 1]\n\n        xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]\n        viewdir = safe_normalize(xyzs - pose[:3, 3]).squeeze(0)\n        xyzs = xyzs.view(-1, 3)\n        mask = (alpha > 0).view(-1)\n        material = torch.zeros(xyzs.shape[0], 5, dtype=torch.float32, device=xyzs.device)\n        if mask.any():\n            masked_material = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1)))\n            material[mask] = masked_material.float()\n        material = material.view(h, w, -1)\n\n        # get vn and render normal\n        i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()\n        v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]\n\n        face_normals = torch.cross(v1 - v0, v2 - v0)\n        face_normals = safe_normalize(face_normals)\n\n        vn = torch.zeros_like(v)\n        vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)\n        vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)\n        vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)\n\n        vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))\n\n        normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f)\n        normal = safe_normalize(normal[0])\n\n        # rotated normal (where [0, 0, 1] always faces camera)\n        rot_normal = normal @ pose[:3, :3]\n        viewcos = rot_normal[..., [2]]\n\n        # shading\n        albedo = material[..., :3]\n        metallic = material[..., 3:4]\n        roughness = material[..., 4:5]\n\n        n_dot_v = (normal * viewdir).sum(-1, keepdim=True) # [H, W, 1]\n        reflective = n_dot_v * normal * 2 - viewdir\n\n        diffuse_albedo = (1 - metallic) * albedo\n\n        fg_uv = torch.cat([n_dot_v, roughness], -1).clamp(0, 1) # [H, W, 2]\n        fg = dr.texture(\n            self.FG_LUT,\n            fg_uv.reshape(1, -1, 1, 2).contiguous(),\n            filter_mode=\"linear\",\n            boundary_mode=\"clamp\",\n        ).reshape(h, w, 2)\n        F0 = (1 - metallic) * 0.04 + metallic * albedo\n        specular_albedo = F0 * fg[..., 0:1] + fg[..., 1:2]\n\n        diffuse_light = self.light(normal)\n        specular_light = self.light(reflective, roughness)\n\n        color = diffuse_albedo * diffuse_light + specular_albedo * specular_light # [H, W, 3]\n\n        # antialias\n        color = dr.antialias(color.unsqueeze(0), rast, v_clip, f).clamp(0, 1).squeeze(0) # [H, W, 3]\n        color = alpha * color + (1 - alpha) * bg_color\n\n        # ssaa\n        if ssaa != 1:\n            color = scale_img_hwc(color, (h0, w0))\n            alpha = scale_img_hwc(alpha, (h0, w0))\n            depth = scale_img_hwc(depth, (h0, w0))\n            normal = scale_img_hwc(normal, (h0, w0))\n            viewcos = scale_img_hwc(viewcos, (h0, w0))\n\n        results['image'] = color\n        results['alpha'] = alpha\n        results['depth'] = depth\n        results['normal'] = (normal + 1) / 2\n        results['viewcos'] = viewcos\n\n        return results"
  },
  {
    "path": "threefiner/renderer/pbr_mesh_renderer.py",
    "content": "import os\nimport tqdm\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport envlight\n\nimport nvdiffrast.torch as dr\n\nfrom kiui.mesh import Mesh\nfrom kiui.mesh_utils import clean_mesh, decimate_mesh\nfrom kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding\nfrom kiui.cam import orbit_camera, get_perspective\nfrom threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder\nfrom threefiner.renderer.mesh_renderer import render_mesh\n\n\nclass Renderer(nn.Module):\n    def __init__(self, opt, device):\n        \n        super().__init__()\n\n        self.opt = opt\n        self.device = device\n\n        self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir)\n\n        # it's necessary to clean the mesh to facilitate later remeshing!\n        vertices = self.mesh.v.detach().cpu().numpy()\n        triangles = self.mesh.f.detach().cpu().numpy()\n        vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False)\n        self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)\n        self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)\n\n        if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):\n            self.glctx = dr.RasterizeGLContext()\n        else:\n            self.glctx = dr.RasterizeCudaContext()\n        \n        # extract trainable parameters\n        self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v))\n        \n        # texture\n        if self.opt.tex_mode == 'hashgrid':\n            self.encoder = HashGridEncoder().to(self.device)\n        elif self.opt.tex_mode == 'mlp':\n            self.encoder = FrequencyEncoder().to(self.device)\n        elif self.opt.tex_mode == 'triplane':\n            self.encoder = TriplaneEncoder().to(self.device)\n        else:\n            raise NotImplementedError(f\"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}\")\n        \n        self.mlp = MLP(self.encoder.output_dim, 3+2, 32, 2, bias=True).to(self.device)\n\n        # env light\n        if self.opt.env_texture is None:\n            hdr_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../lights/mud_road_puresky_1k.hdr')\n        else:\n            hdr_path = self.opt.env_texture\n        self.light = envlight.EnvLight(hdr_path, scale=self.opt.env_scale, device=self.device)\n\n        FG_LUT = torch.from_numpy(np.fromfile(os.path.join(os.path.dirname(os.path.abspath(__file__)), \"../lights/bsdf_256_256.bin\"), dtype=np.float32).reshape(1, 256, 256, 2)).to(self.device)\n        self.register_buffer(\"FG_LUT\", FG_LUT)\n\n        # init hashgrid texture from mesh\n        if self.opt.fit_tex:\n            self.fit_texture_from_mesh(self.opt.fit_tex_iters)\n\n    def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):\n        return render_mesh(\n            self.glctx, \n            self.mesh.v, self.mesh.f, self.mesh.vt, \n            self.mesh.ft, self.mesh.albedo, \n            self.mesh.vc, self.mesh.vn, self.mesh.fn, \n            pose, proj, h, w, \n            ssaa=ssaa, bg_color=bg_color,\n        )\n    \n    def fit_texture_from_mesh(self, iters=512):\n        # a small training loop...\n\n        loss_fn = torch.nn.MSELoss()\n        optimizer = torch.optim.Adam([\n            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},\n            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},\n        ])\n\n        resolution = 512\n\n        print(f\"[INFO] fitting texture...\")\n        pbar = tqdm.trange(iters)\n        for i in pbar:\n\n            ver = np.random.randint(-45, 45)\n            hor = np.random.randint(-180, 180)\n            \n            pose = orbit_camera(ver, hor, self.opt.radius)\n            proj = get_perspective(self.opt.fovy)\n\n            image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image']\n            image_pred = self.render(pose, proj, resolution, resolution)['image']\n\n            loss = loss_fn(image_pred, image_mesh)\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            pbar.set_description(f\"MSE = {loss.item():.6f}\")\n        \n        print(f\"[INFO] finished fitting texture!\")\n\n    def get_params(self):\n\n        params = [\n            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},\n            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},\n        ]\n\n        if not self.opt.fix_geo:\n            params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr})\n\n        return params\n\n    @torch.no_grad()\n    def export_mesh(self, save_path, texture_resolution=2048, padding=16):\n\n        mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)\n        print(f\"[INFO] uv unwrapping...\")\n        mesh.auto_normal()\n        mesh.auto_uv()\n\n        # render uv maps\n        h = w = texture_resolution\n        uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]\n        uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]\n\n        rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]\n\n        # masked query \n        xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]\n        mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1]\n        xyzs = xyzs.view(-1, 3)\n        mask = (mask > 0).view(-1)\n        \n        material = torch.zeros(h * w, 5, device=self.device, dtype=torch.float32)\n\n        if mask.any():\n            print(f\"[INFO] querying texture...\")\n\n            xyzs = xyzs[mask] # [M, 3]\n\n            # batched inference to avoid OOM\n            batch = []\n            head = 0\n            while head < xyzs.shape[0]:\n                tail = min(head + 640000, xyzs.shape[0])\n                batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())\n                head += 640000\n\n            material[mask] = torch.cat(batch, dim=0)\n        \n        material = material.view(h, w, -1)\n        mask = mask.view(h, w)\n\n        print(f\"[INFO] uv padding...\")\n        material = uv_padding(material, mask, padding)\n\n        mesh.albedo = material[..., :3]\n        mesh.metallicRoughness = torch.cat([torch.zeros_like(material[..., 3:4]), material[..., 4:5], material[..., 3:4]], dim=-1)\n        mesh.write(save_path)\n\n    @property\n    def v(self):\n        if self.opt.fix_geo:\n            return self.mesh.v\n        else:\n            return self.mesh.v + self.v_offsets\n    \n    @property\n    def f(self):\n        return self.mesh.f\n    \n    @torch.no_grad()\n    def remesh(self):\n        vertices = self.v.detach().cpu().numpy()\n        triangles = self.f.detach().cpu().numpy()\n        vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=self.opt.remesh_size)\n        if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target:\n            vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target, optimalplacement=False)\n        self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)\n        self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)\n        self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v)).to(self.device)\n    \n    def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):\n\n        # do super-sampling\n        if ssaa != 1:\n            h = make_divisible(h0 * ssaa, 8)\n            w = make_divisible(w0 * ssaa, 8)\n        else:\n            h, w = h0, w0\n        \n        results = {}\n\n        # get v\n        v = self.v\n        f = self.f\n\n        pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)\n        proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)\n\n        # get v_clip and render rgb\n        v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)\n        v_clip = v_cam @ proj.T\n\n        rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))\n\n        alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]\n        alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients!\n\n        depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]\n        depth = depth.squeeze(0) # [H, W, 1]\n\n        xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]\n        viewdir = safe_normalize(xyzs - pose[:3, 3]).squeeze(0)\n\n        xyzs = xyzs.view(-1, 3)\n        mask = (alpha > 0).view(-1)\n        material = torch.zeros(xyzs.shape[0], 5, dtype=torch.float32, device=xyzs.device)\n        if mask.any():\n            masked_material = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1)))\n            material[mask] = masked_material.float()\n        material = material.view(h, w, -1)\n        \n        # get vn and render normal\n        if self.opt.fix_geo:\n            vn = self.mesh.vn\n        else:\n            i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()\n            v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]\n\n            face_normals = torch.cross(v1 - v0, v2 - v0)\n            face_normals = safe_normalize(face_normals)\n\n            vn = torch.zeros_like(v)\n            vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)\n            vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)\n            vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)\n\n            vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))\n\n        normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f)\n        normal = safe_normalize(normal[0])\n\n        # rotated normal (where [0, 0, 1] always faces camera)\n        rot_normal = normal @ pose[:3, :3]\n        viewcos = rot_normal[..., [2]]\n\n        # shading\n        albedo = material[..., :3]\n        metallic = material[..., 3:4]\n        roughness = material[..., 4:5]\n\n        n_dot_v = (normal * viewdir).sum(-1, keepdim=True) # [H, W, 1]\n        reflective = n_dot_v * normal * 2 - viewdir\n\n        diffuse_albedo = (1 - metallic) * albedo\n\n        fg_uv = torch.cat([n_dot_v, roughness], -1).clamp(0, 1) # [H, W, 2]\n        fg = dr.texture(\n            self.FG_LUT,\n            fg_uv.reshape(1, -1, 1, 2).contiguous(),\n            filter_mode=\"linear\",\n            boundary_mode=\"clamp\",\n        ).reshape(h, w, 2)\n        F0 = (1 - metallic) * 0.04 + metallic * albedo\n        specular_albedo = F0 * fg[..., 0:1] + fg[..., 1:2]\n\n        diffuse_light = self.light(normal)\n        specular_light = self.light(reflective, roughness)\n\n        color = diffuse_albedo * diffuse_light + specular_albedo * specular_light # [H, W, 3]\n\n        # antialias\n        color = dr.antialias(color.unsqueeze(0), rast, v_clip, f).clamp(0, 1).squeeze(0) # [H, W, 3]\n        color = alpha * color + (1 - alpha) * bg_color\n\n        # ssaa\n        if ssaa != 1:\n            color = scale_img_hwc(color, (h0, w0))\n            alpha = scale_img_hwc(alpha, (h0, w0))\n            depth = scale_img_hwc(depth, (h0, w0))\n            normal = scale_img_hwc(normal, (h0, w0))\n            viewcos = scale_img_hwc(viewcos, (h0, w0))\n\n        results['image'] = color\n        results['alpha'] = alpha\n        results['depth'] = depth\n        results['normal'] = (normal + 1) / 2\n        results['viewcos'] = viewcos\n\n        return results"
  }
]