Full Code of 3DTopia/threefiner for AI

main 6c34f089e61a cached
32 files
196.8 KB
50.2k tokens
133 symbols
1 requests
Download .txt
Showing preview only (207K chars total). Download the full file or copy to clipboard to get everything.
Repository: 3DTopia/threefiner
Branch: main
Commit: 6c34f089e61a
Files: 32
Total size: 196.8 KB

Directory structure:
gitextract_wp6txn7z/

├── .github/
│   └── workflows/
│       └── pypi-publish.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── data/
│   ├── car.glb
│   └── chair.ply
├── gradio_app.py
├── readme.md
├── scripts/
│   ├── run.sh
│   └── test_all.sh
├── setup.py
└── threefiner/
    ├── __init__.py
    ├── cli.py
    ├── gui.py
    ├── guidance/
    │   ├── __init__.py
    │   ├── if2_ism_utils.py
    │   ├── if2_nfsd_utils.py
    │   ├── if2_utils.py
    │   ├── if_utils.py
    │   ├── sd_ism_utils.py
    │   ├── sd_nfsd_utils.py
    │   ├── sd_utils.py
    │   └── sdcn_utils.py
    ├── lights/
    │   ├── LICENSE.txt
    │   └── mud_road_puresky_1k.hdr
    ├── nn.py
    ├── opt.py
    └── renderer/
        ├── __init__.py
        ├── diffmc_renderer.py
        ├── mesh_renderer.py
        ├── pbr_diffmc_renderer.py
        └── pbr_mesh_renderer.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/pypi-publish.yml
================================================
name: Upload Python Package

on:
  release:
    types: [created]
  workflow_dispatch:

jobs:
  deploy:

    runs-on: ubuntu-latest

    environment:
      name: pypi
      url: https://pypi.org/project/threefiner/
    permissions:
      id-token: write  # IMPORTANT: this permission is mandatory for trusted publishing

    steps:
    - uses: actions/checkout@v3
    - name: Set up Python
      uses: actions/setup-python@v3
      with:
        python-version: '3.10'
    # prepare distributions in dist/
    - name: Install dependencies and Build
      run: |
        python -m pip install --upgrade pip
        pip install setuptools wheel
        python setup.py sdist bdist_wheel
    # publish by trusted publishers (need to first setup in pypi.org projects-manage-publishing!)
    # ref: https://github.com/marketplace/actions/pypi-publish
    - name: Publish package distributions to PyPI
      uses: pypa/gh-action-pypi-publish@release/v1

================================================
FILE: .gitignore
================================================
__pycache__
tmp*
data_*
logs
logs*
videos*


*.egg-info
build/

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: MANIFEST.in
================================================
recursive-include threefiner/lights *

================================================
FILE: gradio_app.py
================================================
import os
import tyro
import tqdm
import torch
import gradio as gr

import kiui

from threefiner.opt import config_defaults, config_doc, check_options
from threefiner.gui import GUI

GRADIO_SAVE_PATH_MESH = 'gradio_output.glb'
GRADIO_SAVE_PATH_VIDEO = 'gradio_output.mp4'

opt = tyro.cli(tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc))

# hacks for not loading mesh at initialization
opt.save = GRADIO_SAVE_PATH_MESH
opt.prompt = ''
opt.text_dir = True
opt.front_dir = '+z'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gui = GUI(opt)

# process function
def process(input_model, input_text, input_dir, iters):

    # set front facing direction (map from gradio model3D's mysterious coordinate system to OpenGL...)
    opt.text_dir = True
    if input_dir == 'front':
        opt.front_dir = '-z'
    elif input_dir == 'back':
        opt.front_dir = '+z'
    elif input_dir == 'left':
        opt.front_dir = '+x'
    elif input_dir == 'right':
        opt.front_dir = '-x'
    elif input_dir == 'up':
        opt.front_dir = '+y'
    elif input_dir == 'down':
        opt.front_dir = '-y'
    else:
        # turn off text_dir
        opt.text_dir = False
        opt.front_dir = '+z'
    
    # set mesh path
    opt.mesh = input_model

    # load mesh!
    gui.renderer = gui.renderer_class(opt, device).to(device)

    # set prompt
    gui.prompt = opt.positive_prompt + ', ' + input_text

    # train
    gui.prepare_train() # update optimizer and prompt embeddings
    for i in tqdm.trange(iters):
        gui.train_step()

    # save mesh & video
    gui.save_model(GRADIO_SAVE_PATH_MESH)
    gui.save_model(GRADIO_SAVE_PATH_VIDEO)
    
    # return 3d model & video
    return GRADIO_SAVE_PATH_MESH, GRADIO_SAVE_PATH_VIDEO

# gradio UI
block = gr.Blocks().queue()
with block:
    gr.Markdown("""
    ## Threefiner: Text-guided mesh refinement.
    """)

    with gr.Row(variant='panel'):
        with gr.Column(scale=1):
            input_model = gr.Model3D(label="input mesh")
            input_text = gr.Text(label="prompt")
            input_dir = gr.Radio(['front', 'back', 'left', 'right', 'up', 'down'], label="front-facing direction")
            iters = gr.Slider(minimum=100, maximum=1000, step=100, value=400, label="training iterations")
            button_gen = gr.Button("Refine!")
        
        with gr.Column(scale=1):
            output_model = gr.Model3D(label="output mesh")
            output_video = gr.Video(label="output video")

        button_gen.click(process, inputs=[input_model, input_text, input_dir, iters], outputs=[output_model, output_video])
    
block.launch(server_name="0.0.0.0", share=True)

================================================
FILE: readme.md
================================================
<p align="center">
    <picture>
    <img alt="logo" src="assets/threefiner_icon.png" width="20%">
    </picture>
    </br>
    <b>Threefiner</b>
</p>

An interface for text-guided mesh refinement.

https://github.com/3DTopia/threefiner/assets/25863658/a4abe725-b542-4a4a-a6d4-e4c4821f7d96

### Features
* **Mesh in, mesh out**: we support `ply` with vertex colors, `obj`, and single object `glb/gltf` with textures!
* **Easy to use**: both a CLI and a GUI is available.
* **Performant**: Refine your texture in 1 minute with Deepfloyd-IF-II.

### Install

We rely on `torch` and several CUDA extensions, please make sure you install them correctly first!
```bash
# tiny-cuda-nn
pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch

# nvdiffrast
pip install git+https://github.com/NVlabs/nvdiffrast

# [optional, will use pysdf if unavailable] cubvh:
pip install git+https://github.com/ashawkey/cubvh
```

To use [Deepfloyd-IF](https://github.com/deep-floyd/IF) models, please log in to your huggingface and accept the [license](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0).

To install this package:
```bash
# install from pypi
pip install threefiner

# install from github
pip install git+https://github.com/3DTopia/threefiner

# local install
git clone https://github.com/3DTopia/threefiner
cd threefiner
pip install .
```

### Usage

```bash
### command line interface
threefiner --help
# this is short for
python -m threefiner.cli --help

### refine a coarse mesh ('input.obj') using Stable-diffusion and save to 'logs/hamburger.glb'
threefiner sd --mesh input.obj --prompt 'a hamburger' --outdir logs --save hamburger.glb

### if the initial texture is good, we recommend using IF2 for refinement.
# by default, it will save to './name_fine.glb'
threefiner if2 --mesh name.glb --prompt 'description'

### if the initial texture is not good, we recommend using SD or IF first.
threefiner sd --mesh name.glb --prompt 'description'
threefiner if --mesh name.glb --prompt 'description'

### if the initial geometry is good, you can fix the geometry.
threefiner sd_fixgeo --mesh name.glb --prompt 'description'
threefiner if_fixgeo --mesh name.glb --prompt 'description'
threefiner if2_fixgeo --mesh name.glb --prompt 'description'

### advanced
# directional text prompt (append front/side/back view in text prompt)
# you need to know the mesh's front facing direction and specify it by '--front_dir'
# we use the OpenGL coordinate system, i.e., +x is right, +y is up, +z is front (more details: https://kit.kiui.moe/camera/)
# clock-wise rotation can be specified per 90 degree, e.g., +z1, -y2
threefiner if2 --mesh input.glb --prompt 'description' --text_dir --front_dir='+z'

# adjust training iterations
threefiner if2 --mesh input.glb --prompt 'description' --iters 1000

# explicitly fix the geometry and only refine texture
threefiner if2 --fix-geo --geom_mode mesh --mesh input.glb --prompt 'description' # equals if2_fixgeo

# open a GUI to visualize the training progress (needs a desktop)
threefiner if2 --mesh input.glb --prompt 'description' --gui
```

Gradio demo:
```bash
# requires gradio 4
python gradio_app.py if2
```

For more examples, please see [scripts](./scripts/).

### Q&A

* **How to make sure `--front_dir` for your model?**
    
    You may first visualize it in a 3D viewer that follows OpenGL coordinate system:
    <p align="center">
        <picture>
        <img alt="example_front_dir" src="assets/coord.jpg" width="50%">
        </picture>
    </p>
    The chair is facing down the Y axis (Green), so we can use `--front_dir="-y"` to rectify it to face +Z axis (Blue).

* **fatal error: EGL/egl.h: No such file or directory**

    By default, we use the OpenGL rasterizer. This error means there is no OpenGL installation, which is often the case for headless servers.
    It's recommended to install OpenGL (along with NVIDIA driver) as it brings better performance.
    Otherwise, you can append `--force_cuda_rast` to use the CUDA rasterizer instead.

## Acknowledgement

This work is built on many amazing research works and open-source projects, thanks a lot to all the authors for sharing!

- SDS `guidance` classes are based on [diffusers](https://github.com/huggingface/diffusers).
- `diffmc` geometry is based on [diso](https://github.com/SarahWeiii/diso).
- `mesh` geometry is based on [nerf2mesh](https://github.com/ashawkey/nerf2mesh).
- Texture encoding is based on [tinycudann](https://github.com/NVlabs/tiny-cuda-nn).
- Mesh renderer is based on [nvdiffrast](https://github.com/NVlabs/nvdiffrast).
- GUI is based on [dearpygui](https://github.com/hoffstadt/DearPyGui).
- The coarse models used in demo are generated by [Genie](https://lumalabs.ai/genie?view=create) and [3DTopia](https://github.com/3DTopia/3DTopia).


================================================
FILE: scripts/run.sh
================================================
export CUDA_VISIBLE_DEVICES=0

# the mesh is already with good initial texture, just refine it using IF2
threefiner if2 --mesh data/car.glb --prompt 'a red car' --outdir logs --save car_fine.glb --text_dir --front_dir='+x'

# the mesh is coarse, using SD for diverse texture generation and IF2 for refinement
threefiner sd --mesh data/chair.ply --prompt 'a swivel chair' --outdir logs --save chair_coarse.glb --text_dir --front_dir='-y'
threefiner if2 --mesh logs/chair_coarse.glb --prompt 'a swivel chair' --outdir logs --save chair_fine.glb --text_dir --front_dir='+z'


================================================
FILE: scripts/test_all.sh
================================================
export CUDA_VISIBLE_DEVICES=1

# geom_mode
threefiner if2 --geom_mode diffmc --save car_diffmc.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
threefiner if2 --geom_mode mesh --save car_mesh.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
threefiner if2 --geom_mode pbr_diffmc --save car_pbr_diffmc.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
threefiner if2 --geom_mode pbr_mesh --save car_pbr_mesh.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'

# tex_mode
threefiner if2 --tex_mode mlp --save car_mlp.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
threefiner if2 --tex_mode triplane --save car_triplane.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'

# guidance mode
threefiner sd --save car_SD.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
threefiner if --save car_IF.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'

================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages


setup(
  name = 'threefiner',
  packages = find_packages(exclude=[]),
  include_package_data = True,
  entry_points={
    # CLI tools
    'console_scripts': [
      'threefiner = threefiner.cli:main'
    ],
  },
  version = '0.1.2',
  license='MIT',
  description = 'Threefiner: a text-guided mesh refiner',
  author = 'kiui',
  author_email = 'ashawkey1999@gmail.com',
  long_description=open("readme.md", encoding="utf-8").read(),
  long_description_content_type = 'text/markdown',
  url = 'https://github.com/3DTopia/threefiner',
  keywords = [
    'generative mesh refinement',
  ],
  install_requires=[
    'tyro',
    'tqdm',
    'rich',
    'ninja',
    'numpy',
    'pandas',
    'matplotlib',
    'opencv-python',
    'imageio',
    'imageio-ffmpeg',
    'scipy',
    'scikit-learn',
    'torch',
    'einops',
    'huggingface_hub',
    'diffusers',
    'accelerate',
    'transformers',
    "sentencepiece", # required by deepfloyd-if T5 encoder
    'plyfile',
    'pygltflib',
    'xatlas',
    'trimesh',
    'PyMCubes',
    'pymeshlab',
    "pysdf",
    "diso",
    "envlight",
    'dearpygui',
    'kiui >= 0.2.1',
  ],
  classifiers=[
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3',
  ],
)


================================================
FILE: threefiner/__init__.py
================================================


================================================
FILE: threefiner/cli.py
================================================
import os
import tyro
from threefiner.opt import config_defaults, config_doc, check_options
from threefiner.gui import GUI


def main():    
    opt = tyro.cli(tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc))
    opt = check_options(opt)
    gui = GUI(opt)
    if gui.gui:
        gui.render()
    else:
        gui.train(opt.iters)

if __name__ == "__main__":
    main()


================================================
FILE: threefiner/gui.py
================================================
import os
import tqdm
import random
import imageio
import numpy as np

import torch
import torch.nn.functional as F

GUI_AVAILABLE = True
try:
    import dearpygui.dearpygui as dpg
except Exception as e:
    GUI_AVAILABLE = False

import kiui
from kiui.cam import orbit_camera, OrbitCamera
from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency

from threefiner.opt import Options

class GUI:
    def __init__(self, opt: Options):
        self.opt = opt  # shared with the trainer's opt to support in-place modification of rendering parameters.
        if not GUI_AVAILABLE and opt.gui:
            print(f'[WARN] cannot import dearpygui, assume without --gui')
        self.gui = opt.gui and GUI_AVAILABLE # enable gui
        self.W = opt.W
        self.H = opt.H
        self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)

        self.mode = "image"
        self.seed = "random"

        self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
        self.need_update = True  # update buffer_image

        self.save_path = os.path.join(self.opt.outdir, self.opt.save)
        os.makedirs(self.opt.outdir, exist_ok=True)

        # models
        self.device = torch.device("cuda")

        self.guidance = None

        # renderer
        if self.opt.geom_mode == 'mesh':
            from threefiner.renderer.mesh_renderer import Renderer
        elif self.opt.geom_mode == 'diffmc':
            from threefiner.renderer.diffmc_renderer import Renderer
        elif self.opt.geom_mode == 'pbr_mesh':
            from threefiner.renderer.pbr_mesh_renderer import Renderer
        elif self.opt.geom_mode == 'pbr_diffmc':
            from threefiner.renderer.pbr_diffmc_renderer import Renderer
        else:
            raise NotImplementedError(f"unknown geometry mode: {self.opt.geom_mode}")

        self.renderer_class = Renderer
        
        if self.opt.mesh is None:
            self.renderer = None
        else:
            self.renderer = Renderer(opt, self.device).to(self.device)

        # input prompt
        self.prompt = self.opt.prompt
        self.negative_prompt = ""

        if self.opt.positive_prompt is not None:
            self.prompt = self.opt.positive_prompt + ', ' + self.prompt
        if self.opt.negative_prompt is not None:
            self.negative_prompt = self.opt.negative_prompt
        
        # training stuff
        self.training = False
        self.optimizer = None
        self.step = 0
        self.train_steps = 1  # steps per rendering loop

        if self.gui:
            dpg.create_context()
            self.register_dpg()
            self.test_step()

    def __del__(self):
        if self.gui:
            dpg.destroy_context()

    def seed_everything(self):
        try:
            seed = int(self.seed)
        except:
            seed = np.random.randint(0, 1000000)

        os.environ["PYTHONHASHSEED"] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        
        self.last_seed = seed

    def prepare_train(self):

        assert self.renderer is not None, 'no mesh loaded!'

        self.step = 0

        # setup training
        self.optimizer = torch.optim.Adam(self.renderer.get_params())

        # lazy load guidance model
        if self.guidance is None:
            print(f"[INFO] loading guidance...")
            if self.opt.mode == 'SD':
                from threefiner.guidance.sd_utils import StableDiffusion
                self.guidance = StableDiffusion(self.device, vram_O=self.opt.vram_O)
            elif self.opt.mode == 'SD_NFSD':
                from threefiner.guidance.sd_nfsd_utils import StableDiffusion
                self.guidance = StableDiffusion(self.device, vram_O=self.opt.vram_O)
            elif self.opt.mode == 'SDCN':
                from threefiner.guidance.sdcn_utils import StableDiffusionControlNet
                self.guidance = StableDiffusionControlNet(self.device, vram_O=self.opt.vram_O)
            elif self.opt.mode == 'IF':
                from threefiner.guidance.if_utils import IF
                self.guidance = IF(self.device, vram_O=self.opt.vram_O)
            elif self.opt.mode == 'IF2':
                from threefiner.guidance.if2_utils import IF2
                self.guidance = IF2(self.device, vram_O=self.opt.vram_O)
            elif self.opt.mode == 'IF2_NFSD':
                from threefiner.guidance.if2_nfsd_utils import IF2
                self.guidance = IF2(self.device, vram_O=self.opt.vram_O)
            elif self.opt.mode == 'SD_ISM':
                from threefiner.guidance.sd_ism_utils import StableDiffusion
                self.guidance = StableDiffusion(self.device, vram_O=self.opt.vram_O)
            elif self.opt.mode == 'IF2_ISM':
                from threefiner.guidance.if2_ism_utils import IF2
                self.guidance = IF2(self.device, vram_O=self.opt.vram_O)
            else:
                raise NotImplementedError(f"unknown guidance mode {self.opt.mode}!")
            print(f"[INFO] loaded guidance!")

        # prepare embeddings
        with torch.no_grad():
            self.guidance.get_text_embeds([self.prompt], [self.negative_prompt])

           
    def train_step(self):
        starter = torch.cuda.Event(enable_timing=True)
        ender = torch.cuda.Event(enable_timing=True)
        starter.record()

        self.renderer.train()

        for _ in range(self.train_steps):

            self.step += 1
            step_ratio = min(1, self.step / self.opt.iters)

            loss = 0

            ### novel view (manual batch)
            images = []
            poses = []
            normals = []
            ori_images = []
            vers, hors, radii = [], [], []
            
            for _ in range(self.opt.batch_size):

                # render random view
                ver = np.random.randint(-60, 30)
                hor = np.random.randint(-180, 180)
                radius = np.random.uniform() - 0.5 # [-0.5, 0.5]
                pose = orbit_camera(ver, hor, self.opt.radius + radius)

                vers.append(ver)
                hors.append(hor)
                radii.append(radius)
                poses.append(pose)

                # random render resolution
                ssaa = min(2.0, max(0.125, 2 * np.random.random()))
                out = self.renderer.render(pose, self.cam.perspective, self.opt.render_resolution, self.opt.render_resolution, ssaa=ssaa)

                image = out["image"] # [H, W, 3] in [0, 1]
                image = image.permute(2,0,1).contiguous().unsqueeze(0) # [1, 3, H, W] in [0, 1]
                images.append(image)

                # mix_normal
                if not self.opt.fix_geo and self.opt.mix_normal:
                    normal = out['normal']
                    normal = normal.permute(2,0,1).contiguous().unsqueeze(0)
                    normals.append(normal)

                # IF SR model requires the original rendering
                if self.opt.mode in ['IF2', 'IF2_NFSD', 'IF2_ISM', 'SDCN']:
                    out_mesh = self.renderer.render_mesh(pose, self.cam.perspective, self.opt.render_resolution, self.opt.render_resolution, ssaa=1)
                    ori_image = out_mesh["image"] # [H, W, 3] in [0, 1]
                    ori_image = ori_image.permute(2,0,1).contiguous().unsqueeze(0)
                    ori_images.append(ori_image)
                    # ori_images.append(image.clone())

            # guidance loss
            guidance_input = {'pred_rgb': torch.cat(images, dim=0)}

            if not self.opt.fix_geo and self.opt.mix_normal:
                if random.random() > 0.5:
                    ratio = random.random()
                    guidance_input['pred_rgb'] = guidance_input['pred_rgb'] * ratio + torch.cat(normals, dim=0) * (1 - ratio)
            
            # guidance_input['step_ratio'] = step_ratio
            if self.opt.mode in ['IF2', 'IF2_NFSD', 'IF2_ISM']:
                guidance_input['ori_rgb'] = torch.cat(ori_images, dim=0)
            if self.opt.mode == 'SDCN':
                guidance_input['control_images'] = {'tile': torch.cat(ori_images, dim=0)}
            if self.opt.text_dir:
                guidance_input['vers'] = vers
                guidance_input['hors'] = hors
            
            loss = loss + self.opt.lambda_sd * self.guidance.train_step(**guidance_input)
            
            # geom regularizations
            if self.opt.geom_mode in ['diffmc', 'pbr_diffmc', 'mesh', 'pbr_mesh'] and not self.opt.fix_geo:
                if self.opt.lambda_lap > 0:
                    lap_loss = laplacian_smooth_loss(self.renderer.v, self.renderer.f)
                    loss = loss + self.opt.lambda_lap * lap_loss
                if self.opt.lambda_normal > 0:
                    normal_loss = normal_consistency(self.renderer.v, self.renderer.f)
                    loss = loss + self.opt.lambda_normal * normal_loss
                if self.opt.geom_mode in ['mesh', 'pbr_mesh'] and self.opt.lambda_offsets > 0:
                    offset_loss = (self.renderer.v_offsets ** 2).sum(-1).mean()
                    loss = loss + self.opt.lambda_offsets * offset_loss
            
            # optimize step
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            # for mesh geom_mode: peoriodically remesh
            if self.opt.geom_mode in ['mesh', 'pbr_mesh'] and not self.opt.fix_geo:
                if self.step > 0 and self.step % self.opt.remesh_interval == 0:
                    self.renderer.remesh()
                    # reset optimizer
                    self.optimizer = torch.optim.Adam(self.renderer.get_params())

        ender.record()
        torch.cuda.synchronize()
        t = starter.elapsed_time(ender)

        self.need_update = True

        if self.gui:
            dpg.set_value("_log_train_time", f"{t:.4f}ms")
            dpg.set_value(
                "_log_train_log",
                f"step = {self.step: 5d} (+{self.train_steps: 2d}) loss = {loss.item():.4f}",
            )

    @torch.no_grad()
    def test_step(self):
        # ignore if no need to update
        if not self.need_update:
            return

        starter = torch.cuda.Event(enable_timing=True)
        ender = torch.cuda.Event(enable_timing=True)
        starter.record()

        # should update image
        if self.need_update:
            # render image
            self.renderer.eval()

            out = self.renderer.render(self.cam.pose, self.cam.perspective, self.H, self.W)

            buffer_image = out[self.mode]  # [H, W, 3]

            if self.mode in ['depth', 'alpha']:
                buffer_image = buffer_image.repeat(1, 1, 3)
                if self.mode == 'depth':
                    buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20)

            self.buffer_image = buffer_image.contiguous().clamp(0, 1).detach().cpu().numpy()

            self.need_update = False

        ender.record()
        torch.cuda.synchronize()
        t = starter.elapsed_time(ender)

        if self.gui:
            dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)")
            dpg.set_value(
                "_texture", self.buffer_image
            )  # buffer must be contiguous, else seg fault!

    def save_model(self, save_path=None):

        if save_path is None:
            save_path = self.save_path

        # export video
        if save_path.endswith(".mp4"):
            images = []
            elevation = 0
            azimuth = np.arange(0, 360, 3, dtype=np.int32) # front-->back-->front
            for azi in tqdm.tqdm(azimuth):
                pose = orbit_camera(elevation, azi, self.opt.radius)
                out = self.renderer.render(pose, self.cam.perspective, self.opt.render_resolution, self.opt.render_resolution, ssaa=1)    
                image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8)
                images.append(image)
            images = np.stack(images, axis=0)
            # ~4 seconds, 120 frames at 30 fps
            imageio.mimwrite(save_path, images, fps=30, quality=8, macro_block_size=1)
        # export mesh
        else:
            self.renderer.export_mesh(save_path, texture_resolution=self.opt.texture_resolution)

        print(f"[INFO] save model to {save_path}.")

    def register_dpg(self):
        ### register texture

        with dpg.texture_registry(show=False):
            dpg.add_raw_texture(
                self.W,
                self.H,
                self.buffer_image,
                format=dpg.mvFormat_Float_rgb,
                tag="_texture",
            )

        ### register window

        # the rendered image, as the primary window
        with dpg.window(
            tag="_primary_window",
            width=self.W,
            height=self.H,
            pos=[0, 0],
            no_move=True,
            no_title_bar=True,
            no_scrollbar=True,
        ):
            # add the texture
            dpg.add_image("_texture")

        # dpg.set_primary_window("_primary_window", True)

        # control window
        with dpg.window(
            label="Control",
            tag="_control_window",
            width=600,
            height=self.H,
            pos=[self.W, 0],
            no_move=True,
            no_title_bar=True,
        ):
            # button theme
            with dpg.theme() as theme_button:
                with dpg.theme_component(dpg.mvButton):
                    dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
                    dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
                    dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
                    dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
                    dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)

            # timer stuff
            with dpg.group(horizontal=True):
                dpg.add_text("Infer time: ")
                dpg.add_text("no data", tag="_log_infer_time")

            def callback_setattr(sender, app_data, user_data):
                setattr(self, user_data, app_data)

            # init stuff
            with dpg.collapsing_header(label="Initialize", default_open=True):

                # seed stuff
                def callback_set_seed(sender, app_data):
                    self.seed = app_data
                    self.seed_everything()

                dpg.add_input_text(
                    label="seed",
                    default_value=self.seed,
                    on_enter=True,
                    callback=callback_set_seed,
                )

                # input stuff
                def callback_select_input(sender, app_data):
                    # only one item
                    for k, v in app_data["selections"].items():
                        dpg.set_value("_log_input", k)
                        self.load_input(v)

                    self.need_update = True

                with dpg.file_dialog(
                    directory_selector=False,
                    show=False,
                    callback=callback_select_input,
                    file_count=1,
                    tag="file_dialog_tag",
                    width=700,
                    height=400,
                ):
                    dpg.add_file_extension("Images{.jpg,.jpeg,.png}")

                with dpg.group(horizontal=True):
                    dpg.add_button(
                        label="input",
                        callback=lambda: dpg.show_item("file_dialog_tag"),
                    )
                    dpg.add_text("", tag="_log_input")
                
                # prompt stuff
            
                dpg.add_input_text(
                    label="prompt",
                    default_value=self.prompt,
                    callback=callback_setattr,
                    user_data="prompt",
                )

                dpg.add_input_text(
                    label="negative",
                    default_value=self.negative_prompt,
                    callback=callback_setattr,
                    user_data="negative_prompt",
                )

                # save current model
                with dpg.group(horizontal=True):
                    dpg.add_text("Save: ")

                    dpg.add_button(
                        label="model",
                        tag="_button_save_model",
                        callback=self.save_model,
                    )
                    dpg.bind_item_theme("_button_save_model", theme_button)

                    dpg.add_input_text(
                        label="",
                        default_value=self.save_path,
                        callback=callback_setattr,
                        user_data="save_path",
                    )

            # training stuff
            with dpg.collapsing_header(label="Train", default_open=True):
                # lr and train button
                with dpg.group(horizontal=True):
                    dpg.add_text("Train: ")

                    def callback_train(sender, app_data):
                        if self.training:
                            self.training = False
                            dpg.configure_item("_button_train", label="start")
                        else:
                            self.prepare_train()
                            self.training = True
                            dpg.configure_item("_button_train", label="stop")

                    # dpg.add_button(
                    #     label="init", tag="_button_init", callback=self.prepare_train
                    # )
                    # dpg.bind_item_theme("_button_init", theme_button)

                    dpg.add_button(
                        label="start", tag="_button_train", callback=callback_train
                    )
                    dpg.bind_item_theme("_button_train", theme_button)

                with dpg.group(horizontal=True):
                    dpg.add_text("", tag="_log_train_time")
                    dpg.add_text("", tag="_log_train_log")

            # rendering options
            with dpg.collapsing_header(label="Rendering", default_open=True):
                # mode combo
                def callback_change_mode(sender, app_data):
                    self.mode = app_data
                    self.need_update = True

                dpg.add_combo(
                    ("image", "depth", "alpha", "normal"),
                    label="mode",
                    default_value=self.mode,
                    callback=callback_change_mode,
                )

                # fov slider
                def callback_set_fovy(sender, app_data):
                    self.cam.fovy = np.deg2rad(app_data)
                    self.need_update = True

                dpg.add_slider_int(
                    label="FoV (vertical)",
                    min_value=1,
                    max_value=120,
                    format="%d deg",
                    default_value=np.rad2deg(self.cam.fovy),
                    callback=callback_set_fovy,
                )

        ### register camera handler

        def callback_camera_drag_rotate_or_draw_mask(sender, app_data):
            if not dpg.is_item_focused("_primary_window"):
                return

            dx = app_data[1]
            dy = app_data[2]

            self.cam.orbit(dx, dy)
            self.need_update = True

        def callback_camera_wheel_scale(sender, app_data):
            if not dpg.is_item_focused("_primary_window"):
                return

            delta = app_data

            self.cam.scale(delta)
            self.need_update = True

        def callback_camera_drag_pan(sender, app_data):
            if not dpg.is_item_focused("_primary_window"):
                return

            dx = app_data[1]
            dy = app_data[2]

            self.cam.pan(dx, dy)
            self.need_update = True

        with dpg.handler_registry():
            # for camera moving
            dpg.add_mouse_drag_handler(
                button=dpg.mvMouseButton_Left,
                callback=callback_camera_drag_rotate_or_draw_mask,
            )
            dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
            dpg.add_mouse_drag_handler(
                button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan
            )

        dpg.create_viewport(
            title="Threefiner",
            width=self.W + 600,
            height=self.H + (45 if os.name == "nt" else 0),
            resizable=False,
        )

        ### global theme
        with dpg.theme() as theme_no_padding:
            with dpg.theme_component(dpg.mvAll):
                # set all padding to 0 to avoid scroll bar
                dpg.add_theme_style(
                    dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core
                )
                dpg.add_theme_style(
                    dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core
                )
                dpg.add_theme_style(
                    dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core
                )

        dpg.bind_item_theme("_primary_window", theme_no_padding)

        dpg.setup_dearpygui()

        ### register a larger font
        # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf
        if os.path.exists("LXGWWenKai-Regular.ttf"):
            with dpg.font_registry():
                with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font:
                    dpg.bind_font(default_font)

        # dpg.show_metrics()

        dpg.show_viewport()

    def render(self):
        assert self.gui
        while dpg.is_dearpygui_running():
            # update texture every frame
            if self.training:
                self.train_step()
            self.test_step()
            dpg.render_dearpygui_frame()
    
    # no gui mode
    def train(self, iters=500):
        if iters > 0:
            self.prepare_train()
            for i in tqdm.trange(iters):
                self.train_step()
        # save
        self.save_model()

================================================
FILE: threefiner/guidance/__init__.py
================================================


================================================
FILE: threefiner/guidance/if2_ism_utils.py
================================================
from diffusers import (
    PNDMScheduler,
    DDIMScheduler,
    IFPipeline,
    IFSuperResolutionPipeline,
)



import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def invert_noise(scheduler, noisy_samples, noise, timesteps):
    alphas_cumprod = scheduler.alphas_cumprod.to(device=noisy_samples.device, dtype=noisy_samples.dtype)
    timesteps = timesteps.to(noisy_samples.device)

    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
    while len(sqrt_alpha_prod.shape) < len(noisy_samples.shape):
        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
    while len(sqrt_one_minus_alpha_prod.shape) < len(noisy_samples.shape):
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
    
    original_samples = 1 / sqrt_alpha_prod * (noisy_samples - sqrt_one_minus_alpha_prod * noise)
    return original_samples


class IF2(nn.Module):
    def __init__(
        self,
        device,
        fp16=True,
        vram_O=False,
        model_key = "DeepFloyd/IF-II-M-v1.0",
        t_range=[0.02, 0.50],
    ):
        super().__init__()

        self.device = device
        self.model_key = model_key
        self.dtype = torch.float16 if fp16 else torch.float32

        # Create model
        pipe = IFSuperResolutionPipeline.from_pretrained(
            model_key, variant="fp16", torch_dtype=torch.float16, 
            watermarker=None, safety_checker=None, requires_safety_checker=False,
        )

        if vram_O:
            pipe.unet.to(memory_format=torch.channels_last)
            pipe.enable_attention_slicing(1)
            # pipe.enable_model_cpu_offload()
        else:
            pipe.to(device)

        self.unet = pipe.unet
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder

        self.scheduler = pipe.scheduler
        self.image_noising_scheduler = pipe.image_noising_scheduler

        self.pipe = pipe

        self.num_train_timesteps = self.scheduler.config.num_train_timesteps
        self.min_step = int(self.num_train_timesteps * t_range[0])
        self.max_step = int(self.num_train_timesteps * t_range[1])
        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience

        self.embeddings = {}
        

    @torch.no_grad()
    def get_text_embeds(self, prompts, negative_prompts):
        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]
        neg_embeds = self.encode_text(negative_prompts)
        null_embeds = self.encode_text([""])
        self.embeddings['pos'] = pos_embeds
        self.embeddings['neg'] = neg_embeds
        self.embeddings['null'] = null_embeds

        # directional embeddings
        for d in ['front', 'side', 'back']:
            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
            self.embeddings[d] = embeds
        
    
    def encode_text(self, prompt):
        # prompt: [str]
        prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
        inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
        return embeddings

    def train_step(
        self,
        pred_rgb,
        ori_rgb,
        step_ratio=None,
        guidance_scale=5,
        vers=None, hors=None,
        delta_t=50, delta_s=200,
    ):
        
        batch_size = pred_rgb.shape[0]
        pred_rgb = pred_rgb.to(self.dtype)
        ori_rgb = ori_rgb.to(self.dtype)

        images = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) * 2 - 1

        with torch.no_grad():
            max_t = torch.full((batch_size,), self.max_step, dtype=torch.long, device=self.device)

            # images_upscaled = images.clone()
            images_upscaled = F.interpolate(ori_rgb, (256, 256), mode="bilinear", align_corners=False).clamp(0, 1) * 2 - 1
            noise = torch.randn_like(images_upscaled)
            images_upscaled = self.image_noising_scheduler.add_noise(images_upscaled, noise, max_t)
            
            if step_ratio is not None:
                # dreamtime-like
                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
            else:
                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)

            # w(t), sigma_t^2
            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1).to(self.dtype)

            ######### debug
            # imagesx = self.produce_imgs(
            #     images_upscaled=images_upscaled,
            #     max_t=max_t,
            #     images=torch.randn_like(images),
            #     num_inference_steps=50,
            #     guidance_scale=4.0,
            # )  # [1, 3, 64, 64]
            # import kiui
            # kiui.vis.plot_image(images_upscaled * 0.5 + 0.5)
            # kiui.vis.plot_image(imagesx * 0.5 + 0.5)
            #########

            null_embeddings = self.embeddings['null'].expand(batch_size, -1, -1)

            if hors is None:
                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
            else:
                def _get_dir_ind(h):
                    if abs(h) < 60: return 'front'
                    elif abs(h) < 120: return 'side'
                    else: return 'back'

                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])

            ########### ISM
            # steps
            t = t.clamp(min=delta_t)
            s = t - delta_t
            n = s // delta_s
            r = s % delta_s

            # construct trajectory
            images_noisy = images.clone()
            cur_t = torch.full((batch_size,), 0, dtype=torch.long, device=self.device)

            noise = self.unet(torch.cat([images_noisy, images_upscaled], dim=1), cur_t, encoder_hidden_states=null_embeddings, class_labels=max_t).sample.split(images_noisy.shape[1], dim=1)[0]
            images_original = invert_noise(self.scheduler, images_noisy, noise, cur_t)
            cur_t += r
            images_noisy = self.scheduler.add_noise(images_original, noise, cur_t)

            for i in range(n):
                noise = self.unet(torch.cat([images_noisy, images_upscaled], dim=1), cur_t, encoder_hidden_states=null_embeddings, class_labels=max_t).sample.split(images_noisy.shape[1], dim=1)[0]
                images_original = invert_noise(self.scheduler, images_noisy, noise, cur_t)
                cur_t += delta_s
                images_noisy = self.scheduler.add_noise(images_original, noise, cur_t) # x_s

            # construct last step
            noise = self.unet(torch.cat([images_noisy, images_upscaled], dim=1), cur_t, encoder_hidden_states=null_embeddings, class_labels=max_t).sample.split(images_noisy.shape[1], dim=1)[0]
            images_original = invert_noise(self.scheduler, images_noisy, noise, cur_t) # \hat x_0^s

            # perform guidance
            images_noisy = self.scheduler.add_noise(images_original, noise, t)
            model_input = torch.cat([images_noisy, images_upscaled], dim=1)
            model_input = torch.cat([model_input] * 2)
            model_input = self.scheduler.scale_model_input(model_input, t)
            tt = torch.cat([t] * 2)
            max_tt = torch.cat([max_t] * 2)

            noise_pred = self.unet(
                model_input, tt, encoder_hidden_states=embeddings, class_labels=max_tt,
            ).sample

            # perform guidance (high scale from paper!)
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

            grad = w * (noise_pred - noise)
            grad = torch.nan_to_num(grad)

            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm
        
            target = (images - grad).detach()

        loss = 0.5 * F.mse_loss(images, target, reduction='sum') / images.shape[0]

        return loss

    @torch.no_grad()
    def produce_imgs(
        self,
        images_upscaled,
        max_t,
        height=256,
        width=256,
        num_inference_steps=50,
        guidance_scale=4.0,
        images=None,
    ):
        if images is None:
            images = torch.randn(
                (
                    1,
                    self.unet.in_channels,
                    height,
                    width,
                ),
                device=self.device,
            )
        
        batch_size = images.shape[0]

        self.scheduler.set_timesteps(num_inference_steps)
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps):
            # expand the images if we are doing classifier-free guidance to avoid doing two forward passes.
            model_input = torch.cat([images, images_upscaled], dim=1)
            model_input = torch.cat([model_input] * 2)
            model_input = self.scheduler.scale_model_input(model_input, t)
            max_tt = torch.cat([max_t] * 2)
            
            # predict the noise residual
            noise_pred = self.unet(
                model_input, t, encoder_hidden_states=embeddings, class_labels=max_tt,
            ).sample

            # perform guidance
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )
            noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)

            # compute the previous noisy sample x_t -> x_t-1
            images = self.scheduler.step(noise_pred, t, images).prev_sample

        return images

    
    def prompt_to_img(
        self,
        images_upscaled,
        max_t,
        prompts,
        negative_prompts="",
        height=256,
        width=256,
        num_inference_steps=50,
        guidance_scale=4.0,
        images=None,
    ):
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        self.get_text_embeds(prompts, negative_prompts)
        
        # Text embeds -> img images
        images = self.produce_imgs(
            images_upscaled=images_upscaled,
            max_t=max_t,
            height=height,
            width=width,
            images=images,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        )  # [1, 4, 64, 64]

        
        # Img to Numpy
        images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (images * 255).round().astype("uint8")

        return images


if __name__ == "__main__":
    import kiui
    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser()
    parser.add_argument("prompt", type=str)
    parser.add_argument("--negative", default="", type=str)
    parser.add_argument("--fp16", action="store_true", help="use float16 for training")
    parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
    parser.add_argument("-H", type=int, default=512)
    parser.add_argument("-W", type=int, default=512)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--steps", type=int, default=50)
    opt = parser.parse_args()

    kiui.seed_everything(opt.seed)

    device = torch.device("cuda")

    sd = IF2(device, opt.fp16, opt.vram_O)

    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)

    # visualize image
    plt.imshow(imgs[0])
    plt.show()

================================================
FILE: threefiner/guidance/if2_nfsd_utils.py
================================================
from diffusers import (
    PNDMScheduler,
    DDIMScheduler,
    IFPipeline,
    IFSuperResolutionPipeline,
)



import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class IF2(nn.Module):
    def __init__(
        self,
        device,
        fp16=True,
        vram_O=False,
        model_key = "DeepFloyd/IF-II-L-v1.0",
        t_range=[0.02, 0.50],
    ):
        super().__init__()

        self.device = device
        self.model_key = model_key
        self.dtype = torch.float16 if fp16 else torch.float32

        # Create model
        pipe = IFSuperResolutionPipeline.from_pretrained(
            model_key, variant="fp16", torch_dtype=torch.float16, 
            watermarker=None, safety_checker=None, requires_safety_checker=False,
        )

        if vram_O:
            pipe.unet.to(memory_format=torch.channels_last)
            pipe.enable_attention_slicing(1)
            # pipe.enable_model_cpu_offload()
        else:
            pipe.to(device)

        self.unet = pipe.unet
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder

        self.scheduler = pipe.scheduler
        self.image_noising_scheduler = pipe.image_noising_scheduler

        self.pipe = pipe

        self.num_train_timesteps = self.scheduler.config.num_train_timesteps
        self.min_step = int(self.num_train_timesteps * t_range[0])
        self.max_step = int(self.num_train_timesteps * t_range[1])
        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience

        self.embeddings = {}
        

    @torch.no_grad()
    def get_text_embeds(self, prompts, negative_prompts):
        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]
        neg_embeds = self.encode_text(negative_prompts)
        null_embeds = self.encode_text([''])
        self.embeddings['pos'] = pos_embeds
        self.embeddings['neg'] = neg_embeds
        self.embeddings['null'] = null_embeds

        # directional embeddings
        for d in ['front', 'side', 'back']:
            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
            self.embeddings[d] = embeds
        
    
    def encode_text(self, prompt):
        # prompt: [str]
        prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
        inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
        return embeddings

    def train_step(
        self,
        pred_rgb,
        ori_rgb,
        step_ratio=None,
        guidance_scale=5,
        vers=None, hors=None,
    ):
        
        batch_size = pred_rgb.shape[0]
        pred_rgb = pred_rgb.to(self.dtype)
        ori_rgb = ori_rgb.to(self.dtype)

        images = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) * 2 - 1

        with torch.no_grad():
            max_t = torch.full((batch_size,), self.max_step, dtype=torch.long, device=self.device)

            # images_upscaled = images.clone()
            images_upscaled = F.interpolate(ori_rgb, (256, 256), mode="bilinear", align_corners=False).clamp(0, 1) * 2 - 1
            noise = torch.randn_like(images_upscaled)
            images_upscaled = self.image_noising_scheduler.add_noise(images_upscaled, noise, max_t)
            
            if step_ratio is not None:
                # dreamtime-like
                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
            else:
                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)

            # w(t), sigma_t^2
            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1).to(self.dtype)

            ######### debug
            # imagesx = self.produce_imgs(
            #     images_upscaled=images_upscaled,
            #     max_t=max_t,
            #     images=torch.randn_like(images),
            #     num_inference_steps=50,
            #     guidance_scale=4.0,
            # )  # [1, 3, 64, 64]
            # import kiui
            # kiui.vis.plot_image(images_upscaled * 0.5 + 0.5)
            # kiui.vis.plot_image(imagesx * 0.5 + 0.5)
            #########
            
            # add noise
            noise = torch.randn_like(images)
            images_noisy = self.scheduler.add_noise(images, noise, t)
            # pred noise
            model_input = torch.cat([images_noisy, images_upscaled], dim=1)
            model_input = torch.cat([model_input] * 3)
            model_input = self.scheduler.scale_model_input(model_input, t)
            tt = torch.cat([t] * 3)
            max_tt = torch.cat([max_t] * 3)

            if hors is None:
                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)])
            else:
                def _get_dir_ind(h):
                    if abs(h) < 60: return 'front'
                    elif abs(h) < 120: return 'side'
                    else: return 'back'

                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)])

            noise_pred = self.unet(
                model_input, tt, encoder_hidden_states=embeddings, class_labels=max_tt,
            ).sample

            # perform guidance (high scale from paper!)
            noise_pred_cond, noise_pred_uncond, noise_pred_null = noise_pred.chunk(3)
            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
            noise_pred_cond, _ = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)
            noise_pred_null, _ = noise_pred_null.split(model_input.shape[1] // 2, dim=1)

            delta_c = guidance_scale * (noise_pred_cond - noise_pred_null)
            mask = (t < 200).int().view(batch_size, 1, 1, 1)
            delta_d = mask * noise_pred_null + (1 - mask) * (noise_pred_null - noise_pred_uncond)
            grad = w * (delta_c + delta_d)
            
            grad = torch.nan_to_num(grad)

            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm
        
            target = (images - grad).detach()

        loss = 0.5 * F.mse_loss(images, target, reduction='sum') / images.shape[0]

        return loss

    @torch.no_grad()
    def produce_imgs(
        self,
        images_upscaled,
        max_t,
        height=256,
        width=256,
        num_inference_steps=50,
        guidance_scale=4.0,
        images=None,
    ):
        if images is None:
            images = torch.randn(
                (
                    1,
                    self.unet.in_channels,
                    height,
                    width,
                ),
                device=self.device,
            )
        
        batch_size = images.shape[0]

        self.scheduler.set_timesteps(num_inference_steps)
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps):
            # expand the images if we are doing classifier-free guidance to avoid doing two forward passes.
            model_input = torch.cat([images, images_upscaled], dim=1)
            model_input = torch.cat([model_input] * 2)
            model_input = self.scheduler.scale_model_input(model_input, t)
            max_tt = torch.cat([max_t] * 2)
            
            # predict the noise residual
            noise_pred = self.unet(
                model_input, t, encoder_hidden_states=embeddings, class_labels=max_tt,
            ).sample

            # perform guidance
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )
            noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)

            # compute the previous noisy sample x_t -> x_t-1
            images = self.scheduler.step(noise_pred, t, images).prev_sample

        return images

    
    def prompt_to_img(
        self,
        images_upscaled,
        max_t,
        prompts,
        negative_prompts="",
        height=256,
        width=256,
        num_inference_steps=50,
        guidance_scale=4.0,
        images=None,
    ):
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        self.get_text_embeds(prompts, negative_prompts)
        
        # Text embeds -> img images
        images = self.produce_imgs(
            images_upscaled=images_upscaled,
            max_t=max_t,
            height=height,
            width=width,
            images=images,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        )  # [1, 4, 64, 64]

        
        # Img to Numpy
        images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (images * 255).round().astype("uint8")

        return images


if __name__ == "__main__":
    import kiui
    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser()
    parser.add_argument("prompt", type=str)
    parser.add_argument("--negative", default="", type=str)
    parser.add_argument("--fp16", action="store_true", help="use float16 for training")
    parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
    parser.add_argument("-H", type=int, default=512)
    parser.add_argument("-W", type=int, default=512)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--steps", type=int, default=50)
    opt = parser.parse_args()

    kiui.seed_everything(opt.seed)

    device = torch.device("cuda")

    sd = IF2(device, opt.fp16, opt.vram_O)

    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)

    # visualize image
    plt.imshow(imgs[0])
    plt.show()

================================================
FILE: threefiner/guidance/if2_utils.py
================================================
from diffusers import (
    PNDMScheduler,
    DDIMScheduler,
    IFPipeline,
    IFSuperResolutionPipeline,
)



import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class IF2(nn.Module):
    def __init__(
        self,
        device,
        fp16=True,
        vram_O=False,
        # model_key = "DeepFloyd/IF-II-L-v1.0",
        model_key = "DeepFloyd/IF-II-M-v1.0",
        t_range=[0.02, 0.50],
    ):
        super().__init__()

        self.device = device
        self.model_key = model_key
        self.dtype = torch.float16 if fp16 else torch.float32

        # Create model
        pipe = IFSuperResolutionPipeline.from_pretrained(
            model_key, variant="fp16", torch_dtype=torch.float16, 
            watermarker=None, safety_checker=None, requires_safety_checker=False,
        )

        if vram_O:
            pipe.unet.to(memory_format=torch.channels_last)
            pipe.enable_attention_slicing(1)
            # pipe.enable_model_cpu_offload()
            pipe.enable_sequential_cpu_offload()
        else:
            pipe.to(device)

        self.unet = pipe.unet
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder

        self.scheduler = pipe.scheduler
        self.image_noising_scheduler = pipe.image_noising_scheduler

        self.pipe = pipe

        self.num_train_timesteps = self.scheduler.config.num_train_timesteps
        self.min_step = int(self.num_train_timesteps * t_range[0])
        self.max_step = int(self.num_train_timesteps * t_range[1])
        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience

        self.embeddings = {}
        

    @torch.no_grad()
    def get_text_embeds(self, prompts, negative_prompts):
        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]
        neg_embeds = self.encode_text(negative_prompts)
        self.embeddings['pos'] = pos_embeds
        self.embeddings['neg'] = neg_embeds

        # directional embeddings
        for d in ['front', 'side', 'back']:
            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
            self.embeddings[d] = embeds
        
    
    def encode_text(self, prompt):
        # prompt: [str]
        prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
        inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
        return embeddings

    def train_step(
        self,
        pred_rgb,
        ori_rgb,
        step_ratio=None,
        guidance_scale=50,
        vers=None, hors=None,
    ):
        
        batch_size = pred_rgb.shape[0]
        pred_rgb = pred_rgb.to(self.dtype)
        ori_rgb = ori_rgb.to(self.dtype)

        images = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) * 2 - 1

        with torch.no_grad():
            max_t = torch.full((batch_size,), self.max_step, dtype=torch.long, device=self.device)

            # images_upscaled = images.clone()
            images_upscaled = F.interpolate(ori_rgb, (256, 256), mode="bilinear", align_corners=False).clamp(0, 1) * 2 - 1
            noise = torch.randn_like(images_upscaled)
            images_upscaled = self.image_noising_scheduler.add_noise(images_upscaled, noise, max_t)
            
            if step_ratio is not None:
                # dreamtime-like
                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
            else:
                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)

            # w(t), sigma_t^2
            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1).to(self.dtype)

            ######### debug
            # imagesx = self.produce_imgs(
            #     images_upscaled=images_upscaled,
            #     max_t=max_t,
            #     images=torch.randn_like(images),
            #     num_inference_steps=50,
            #     guidance_scale=4.0,
            # )  # [1, 3, 64, 64]
            # import kiui
            # kiui.vis.plot_image(images_upscaled * 0.5 + 0.5)
            # kiui.vis.plot_image(imagesx * 0.5 + 0.5)
            #########
            
            # add noise
            noise = torch.randn_like(images)
            images_noisy = self.scheduler.add_noise(images, noise, t)
            # pred noise
            model_input = torch.cat([images_noisy, images_upscaled], dim=1)
            model_input = torch.cat([model_input] * 2)
            model_input = self.scheduler.scale_model_input(model_input, t)
            tt = torch.cat([t] * 2)
            max_tt = torch.cat([max_t] * 2)

            if hors is None:
                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
            else:
                def _get_dir_ind(h):
                    if abs(h) < 60: return 'front'
                    elif abs(h) < 120: return 'side'
                    else: return 'back'

                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])

            noise_pred = self.unet(
                model_input, tt, encoder_hidden_states=embeddings, class_labels=max_tt,
            ).sample

            # perform guidance (high scale from paper!)
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )

            grad = w * (noise_pred - noise)
            grad = torch.nan_to_num(grad)

            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm
        
            target = (images - grad).detach()

        loss = 0.5 * F.mse_loss(images, target, reduction='sum') / images.shape[0]

        return loss

    @torch.no_grad()
    def produce_imgs(
        self,
        images_upscaled,
        max_t,
        height=256,
        width=256,
        num_inference_steps=50,
        guidance_scale=4.0,
        images=None,
    ):
        if images is None:
            images = torch.randn(
                (
                    1,
                    self.unet.in_channels,
                    height,
                    width,
                ),
                device=self.device,
            )
        
        batch_size = images.shape[0]

        self.scheduler.set_timesteps(num_inference_steps)
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps):
            # expand the images if we are doing classifier-free guidance to avoid doing two forward passes.
            model_input = torch.cat([images, images_upscaled], dim=1)
            model_input = torch.cat([model_input] * 2)
            model_input = self.scheduler.scale_model_input(model_input, t)
            max_tt = torch.cat([max_t] * 2)
            
            # predict the noise residual
            noise_pred = self.unet(
                model_input, t, encoder_hidden_states=embeddings, class_labels=max_tt,
            ).sample

            # perform guidance
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )
            noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)

            # compute the previous noisy sample x_t -> x_t-1
            images = self.scheduler.step(noise_pred, t, images).prev_sample

        return images

    
    def prompt_to_img(
        self,
        images_upscaled,
        max_t,
        prompts,
        negative_prompts="",
        height=256,
        width=256,
        num_inference_steps=50,
        guidance_scale=4.0,
        images=None,
    ):
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        self.get_text_embeds(prompts, negative_prompts)
        
        # Text embeds -> img images
        images = self.produce_imgs(
            images_upscaled=images_upscaled,
            max_t=max_t,
            height=height,
            width=width,
            images=images,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        )  # [1, 4, 64, 64]

        
        # Img to Numpy
        images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (images * 255).round().astype("uint8")

        return images


if __name__ == "__main__":
    import kiui
    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser()
    parser.add_argument("prompt", type=str)
    parser.add_argument("--negative", default="", type=str)
    parser.add_argument("--fp16", action="store_true", help="use float16 for training")
    parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
    parser.add_argument("-H", type=int, default=512)
    parser.add_argument("-W", type=int, default=512)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--steps", type=int, default=50)
    opt = parser.parse_args()

    kiui.seed_everything(opt.seed)

    device = torch.device("cuda")

    sd = IF2(device, opt.fp16, opt.vram_O)

    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)

    # visualize image
    plt.imshow(imgs[0])
    plt.show()

================================================
FILE: threefiner/guidance/if_utils.py
================================================
from diffusers import (
    PNDMScheduler,
    DDIMScheduler,
    IFPipeline,
)



import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class IF(nn.Module):
    def __init__(
        self,
        device,
        fp16=True,
        vram_O=False,
        model_key = "DeepFloyd/IF-I-XL-v1.0",
        # model_key = "DeepFloyd/IF-I-M-v1.0",
        t_range=[0.02, 0.98],
    ):
        super().__init__()

        self.device = device
        self.model_key = model_key
        self.dtype = torch.float16 if fp16 else torch.float32

        # Create model
        pipe = IFPipeline.from_pretrained(
            model_key, variant="fp16", torch_dtype=torch.float16, 
            watermarker=None, safety_checker=None, requires_safety_checker=False,
        )

        if vram_O:
            pipe.unet.to(memory_format=torch.channels_last)
            pipe.enable_attention_slicing(1)
            # pipe.enable_model_cpu_offload()
            pipe.enable_sequential_cpu_offload()
        else:
            pipe.to(device)

        self.unet = pipe.unet
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder

        self.scheduler = pipe.scheduler

        self.pipe = pipe

        self.num_train_timesteps = self.scheduler.config.num_train_timesteps
        self.min_step = int(self.num_train_timesteps * t_range[0])
        self.max_step = int(self.num_train_timesteps * t_range[1])
        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience

        self.embeddings = {}
        

    @torch.no_grad()
    def get_text_embeds(self, prompts, negative_prompts):
        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]
        neg_embeds = self.encode_text(negative_prompts)
        self.embeddings['pos'] = pos_embeds
        self.embeddings['neg'] = neg_embeds

        # directional embeddings
        for d in ['front', 'side', 'back']:
            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
            self.embeddings[d] = embeds
        
    
    def encode_text(self, prompt):
        # prompt: [str]
        prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
        inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
        return embeddings

    @torch.no_grad()
    def refine(self, pred_rgb,
               guidance_scale=100, steps=50, strength=0.8,
        ):

        batch_size = pred_rgb.shape[0]
        images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False)
        
        self.scheduler.set_timesteps(steps)
        init_step = int(steps * strength)
        images = self.scheduler.add_noise(images, torch.randn_like(images), self.scheduler.timesteps[init_step])
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps[init_step:]):
    
            model_input = torch.cat([images] * 2)

            noise_pred = self.unet(
                model_input, t, encoder_hidden_states=embeddings,
            ).sample

            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1], dim=1)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
            
            images = self.scheduler.step(noise_pred, t, images).prev_sample

        return images

    def train_step(
        self,
        pred_rgb,
        step_ratio=None,
        guidance_scale=50,
        vers=None, hors=None,
    ):
        
        batch_size = pred_rgb.shape[0]
        pred_rgb = pred_rgb.to(self.dtype)

        images = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1

        with torch.no_grad():
            
            if step_ratio is not None:
                # dreamtime-like
                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
            else:
                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)

            # w(t), sigma_t^2
            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)

            ######### debug
            # # Text embeds -> img latents
            # imagesx = self.produce_imgs(
            #     height=64,
            #     width=64,
            #     images=torch.randn_like(images),
            #     num_inference_steps=50,
            #     guidance_scale=7.5,
            # )  # [1, 3, 64, 64]
            # import kiui
            # kiui.vis.plot_image(imagesx)
            #########

            # add noise
            noise = torch.randn_like(images)
            images_noisy = self.scheduler.add_noise(images, noise, t)
            # pred noise
            model_input = torch.cat([images_noisy] * 2)
            model_input = self.scheduler.scale_model_input(model_input, t)
            tt = torch.cat([t] * 2)

            if hors is None:
                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
            else:
                def _get_dir_ind(h):
                    if abs(h) < 60: return 'front'
                    elif abs(h) < 120: return 'side'
                    else: return 'back'

                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])

            noise_pred = self.unet(
                model_input, tt, encoder_hidden_states=embeddings
            ).sample

            # perform guidance (high scale from paper!)
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1], dim=1)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )

            grad = w * (noise_pred - noise)
            grad = torch.nan_to_num(grad)

            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm

            target = (images - grad).detach()
            
        loss = 0.5 * F.mse_loss(images.float(), target, reduction='sum') / images.shape[0]

        return loss

    @torch.no_grad()
    def produce_imgs(
        self,
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=7.5,
        images=None,
    ):
        if images is None:
            images = torch.randn(
                (
                    1,
                    self.unet.in_channels,
                    height,
                    width,
                ),
                device=self.device,
            )
        
        batch_size = images.shape[0]

        self.scheduler.set_timesteps(num_inference_steps)
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps):
            # expand the images if we are doing classifier-free guidance to avoid doing two forward passes.
            model_input = torch.cat([images] * 2)
            model_input = self.scheduler.scale_model_input(model_input, t)
            # predict the noise residual
            noise_pred = self.unet(
                model_input, t, encoder_hidden_states=embeddings
            ).sample

            # perform guidance
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
            noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1], dim=1)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )
            noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)

            # compute the previous noisy sample x_t -> x_t-1
            images = self.scheduler.step(noise_pred, t, images).prev_sample

        return images

    
    def prompt_to_img(
        self,
        prompts,
        negative_prompts="",
        height=64,
        width=64,
        num_inference_steps=50,
        guidance_scale=7.5,
        images=None,
    ):
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        self.get_text_embeds(prompts, negative_prompts)
        
        # Text embeds -> img images
        images = self.produce_imgs(
            height=height,
            width=width,
            images=images,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        )  # [1, 4, 64, 64]

        
        # Img to Numpy
        images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (images * 255).round().astype("uint8")

        return images


if __name__ == "__main__":
    import kiui
    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser()
    parser.add_argument("prompt", type=str)
    parser.add_argument("--negative", default="", type=str)
    parser.add_argument("--fp16", action="store_true", help="use float16 for training")
    parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
    parser.add_argument("-H", type=int, default=512)
    parser.add_argument("-W", type=int, default=512)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--steps", type=int, default=50)
    opt = parser.parse_args()

    kiui.seed_everything(opt.seed)

    device = torch.device("cuda")

    sd = IF(device, opt.fp16, opt.vram_O)

    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)

    # visualize image
    plt.imshow(imgs[0])
    plt.show()

================================================
FILE: threefiner/guidance/sd_ism_utils.py
================================================
from diffusers import (
    PNDMScheduler,
    DDIMScheduler,
    StableDiffusionPipeline,
)



import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def invert_noise(scheduler, noisy_samples, noise, timesteps):
    alphas_cumprod = scheduler.alphas_cumprod.to(device=noisy_samples.device, dtype=noisy_samples.dtype)
    timesteps = timesteps.to(noisy_samples.device)

    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
    while len(sqrt_alpha_prod.shape) < len(noisy_samples.shape):
        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
    while len(sqrt_one_minus_alpha_prod.shape) < len(noisy_samples.shape):
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
    
    original_samples = 1 / sqrt_alpha_prod * (noisy_samples - sqrt_one_minus_alpha_prod * noise)
    return original_samples


class StableDiffusion(nn.Module):
    def __init__(
        self,
        device,
        fp16=True,
        vram_O=False,
        model_key="stabilityai/stable-diffusion-2-1-base",
        # model_key="philz1337/revanimated",
        t_range=[0.02, 0.98],
    ):
        super().__init__()

        self.device = device
        self.model_key = model_key
        self.dtype = torch.float16 if fp16 else torch.float32

        # Create model
        pipe = StableDiffusionPipeline.from_pretrained(
            model_key, torch_dtype=self.dtype
        )

        if vram_O:
            pipe.enable_sequential_cpu_offload()
            pipe.enable_vae_slicing()
            pipe.unet.to(memory_format=torch.channels_last)
            pipe.enable_attention_slicing(1)
            # pipe.enable_model_cpu_offload()
        else:
            pipe.to(device)

        self.vae = pipe.vae
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder
        self.unet = pipe.unet

        self.scheduler = DDIMScheduler.from_pretrained(
            model_key, subfolder="scheduler", torch_dtype=self.dtype
        )

        del pipe

        self.num_train_timesteps = self.scheduler.config.num_train_timesteps
        self.min_step = int(self.num_train_timesteps * t_range[0])
        self.max_step = int(self.num_train_timesteps * t_range[1])
        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience

        self.embeddings = {}
        

    @torch.no_grad()
    def get_text_embeds(self, prompts, negative_prompts):
        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]
        neg_embeds = self.encode_text(negative_prompts)
        null_embeds = self.encode_text([""])
        self.embeddings['pos'] = pos_embeds
        self.embeddings['neg'] = neg_embeds
        self.embeddings['null'] = null_embeds

        # directional embeddings
        for d in ['front', 'side', 'back']:
            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
            self.embeddings[d] = embeds
        
    
    def encode_text(self, prompt):
        # prompt: [str]
        inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
        return embeddings

    @torch.no_grad()
    def refine(self, pred_rgb,
               guidance_scale=100, steps=50, strength=0.8,
        ):

        batch_size = pred_rgb.shape[0]
        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
        latents = self.encode_imgs(pred_rgb_512.to(self.dtype))
        # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)

        self.scheduler.set_timesteps(steps)
        init_step = int(steps * strength)
        latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps[init_step:]):
    
            latent_model_input = torch.cat([latents] * 2)

            noise_pred = self.unet(
                latent_model_input, t, encoder_hidden_states=embeddings,
            ).sample

            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        imgs = self.decode_latents(latents) # [1, 3, 512, 512]
        return imgs

    def train_step(
        self,
        pred_rgb,
        step_ratio=None,
        guidance_scale=7.5,
        as_latent=False,
        vers=None, hors=None,
        delta_t=50, delta_s=200,
    ):
        
        batch_size = pred_rgb.shape[0]
        pred_rgb = pred_rgb.to(self.dtype)

        if as_latent:
            latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1
        else:
            # interp to 512x512 to be fed into vae.
            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
            # encode image into latents with vae, requires grad!
            latents = self.encode_imgs(pred_rgb_512)

        with torch.no_grad():

            if step_ratio is not None:
                # dreamtime-like
                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
            else:
                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)

            # w(t), sigma_t^2
            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)

            ######### debug
            # # Text embeds -> img latents
            # latentsx = self.produce_latents(
            #     height=512,
            #     width=512,
            #     latents=torch.randn_like(latents),
            #     num_inference_steps=50,
            #     guidance_scale=7.5,
            # )  # [1, 4, 64, 64]

            # # Img latents -> imgs
            # imgs = self.decode_latents(latentsx)  # [1, 3, 512, 512]
            # import kiui
            # kiui.vis.plot_image(imgs)
            #########

            null_embeddings = self.embeddings['null'].expand(batch_size, -1, -1)

            if hors is None:
                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
            else:
                def _get_dir_ind(h):
                    if abs(h) < 60: return 'front'
                    elif abs(h) < 120: return 'side'
                    else: return 'back'

                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])

            ########### ISM
            # steps
            t = t.clamp(min=delta_t)
            s = t - delta_t
            n = s // delta_s
            r = s % delta_s

            # construct trajectory
            latents_noisy = latents.clone()
            cur_t = torch.full((batch_size,), 0, dtype=torch.long, device=self.device)

            noise = self.unet(latents_noisy, cur_t, encoder_hidden_states=null_embeddings).sample
            latents_original = invert_noise(self.scheduler, latents_noisy, noise, cur_t)
            cur_t += r
            latents_noisy = self.scheduler.add_noise(latents_original, noise, cur_t)

            for i in range(n):
                noise = self.unet(latents_noisy, cur_t, encoder_hidden_states=null_embeddings).sample
                latents_original = invert_noise(self.scheduler, latents_noisy, noise, cur_t)
                cur_t += delta_s
                latents_noisy = self.scheduler.add_noise(latents_original, noise, cur_t) # x_s

            # construct last step
            noise = self.unet(latents_noisy, cur_t, encoder_hidden_states=null_embeddings).sample # \epsilon_s
            latents_original = invert_noise(self.scheduler, latents_noisy, noise, cur_t) # \hat x_0^s

            # perform guidance
            latents_noisy = self.scheduler.add_noise(latents_original, noise, t) # x_t
            latent_model_input = torch.cat([latents_noisy] * 2)
            tt = torch.cat([t] * 2)
            noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=embeddings).sample
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

            grad = w * (noise_pred - noise)
            grad = torch.nan_to_num(grad)
             
            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm

            target = (latents - grad).detach()

        loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]

        return loss

    @torch.no_grad()
    def produce_latents(
        self,
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=7.5,
        latents=None,
    ):
        if latents is None:
            latents = torch.randn(
                (
                    1,
                    self.unet.in_channels,
                    height // 8,
                    width // 8,
                ),
                device=self.device,
            )
        
        batch_size = latents.shape[0]

        self.scheduler.set_timesteps(num_inference_steps)
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)
            # predict the noise residual
            noise_pred = self.unet(
                latent_model_input, t, encoder_hidden_states=embeddings
            ).sample

            # perform guidance
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        return latents

    def decode_latents(self, latents):
        latents = 1 / self.vae.config.scaling_factor * latents

        imgs = self.vae.decode(latents).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)

        return imgs

    def encode_imgs(self, imgs):
        # imgs: [B, 3, H, W]

        imgs = 2 * imgs - 1

        posterior = self.vae.encode(imgs).latent_dist
        latents = posterior.sample() * self.vae.config.scaling_factor

        return latents

    def prompt_to_img(
        self,
        prompts,
        negative_prompts="",
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=7.5,
        latents=None,
    ):
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        self.get_text_embeds(prompts, negative_prompts)
        
        # Text embeds -> img latents
        latents = self.produce_latents(
            height=height,
            width=width,
            latents=latents,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        )  # [1, 4, 64, 64]

        # Img latents -> imgs
        imgs = self.decode_latents(latents)  # [1, 3, 512, 512]

        # Img to Numpy
        imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
        imgs = (imgs * 255).round().astype("uint8")

        return imgs


if __name__ == "__main__":
    import kiui
    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser()
    parser.add_argument("prompt", type=str)
    parser.add_argument("--negative", default="", type=str)
    parser.add_argument("--fp16", action="store_true", help="use float16 for training")
    parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
    parser.add_argument("-H", type=int, default=512)
    parser.add_argument("-W", type=int, default=512)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--steps", type=int, default=50)
    opt = parser.parse_args()

    kiui.seed_everything(opt.seed)

    device = torch.device("cuda")

    sd = StableDiffusion(device, opt.fp16, opt.vram_O)

    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)

    # visualize image
    plt.imshow(imgs[0])
    plt.show()


================================================
FILE: threefiner/guidance/sd_nfsd_utils.py
================================================
from diffusers import (
    PNDMScheduler,
    DDIMScheduler,
    StableDiffusionPipeline,
)



import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class StableDiffusion(nn.Module):
    def __init__(
        self,
        device,
        fp16=True,
        vram_O=False,
        model_key="stabilityai/stable-diffusion-2-1-base",
        # model_key="philz1337/revanimated",
        t_range=[0.02, 0.50],
    ):
        super().__init__()

        self.device = device
        self.model_key = model_key
        self.dtype = torch.float16 if fp16 else torch.float32

        # Create model
        pipe = StableDiffusionPipeline.from_pretrained(
            model_key, torch_dtype=self.dtype
        )

        if vram_O:
            pipe.enable_sequential_cpu_offload()
            pipe.enable_vae_slicing()
            pipe.unet.to(memory_format=torch.channels_last)
            pipe.enable_attention_slicing(1)
            # pipe.enable_model_cpu_offload()
        else:
            pipe.to(device)

        self.vae = pipe.vae
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder
        self.unet = pipe.unet

        self.scheduler = DDIMScheduler.from_pretrained(
            model_key, subfolder="scheduler", torch_dtype=self.dtype
        )

        del pipe

        self.num_train_timesteps = self.scheduler.config.num_train_timesteps
        self.min_step = int(self.num_train_timesteps * t_range[0])
        self.max_step = int(self.num_train_timesteps * t_range[1])
        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience

        self.embeddings = {}
        

    @torch.no_grad()
    def get_text_embeds(self, prompts, negative_prompts):
        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]
        neg_embeds = self.encode_text(negative_prompts)
        null_embeds = self.encode_text([''])
        self.embeddings['pos'] = pos_embeds
        self.embeddings['neg'] = neg_embeds
        self.embeddings['null'] = null_embeds

        # directional embeddings
        for d in ['front', 'side', 'back']:
            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
            self.embeddings[d] = embeds
        
    
    def encode_text(self, prompt):
        # prompt: [str]
        inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
        return embeddings

    @torch.no_grad()
    def refine(self, pred_rgb,
               guidance_scale=100, steps=50, strength=0.8,
        ):

        batch_size = pred_rgb.shape[0]
        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
        latents = self.encode_imgs(pred_rgb_512.to(self.dtype))
        # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)

        self.scheduler.set_timesteps(steps)
        init_step = int(steps * strength)
        latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps[init_step:]):
    
            latent_model_input = torch.cat([latents] * 2)

            noise_pred = self.unet(
                latent_model_input, t, encoder_hidden_states=embeddings,
            ).sample

            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        imgs = self.decode_latents(latents) # [1, 3, 512, 512]
        return imgs

    def train_step(
        self,
        pred_rgb,
        step_ratio=None,
        guidance_scale=7.5,
        as_latent=False,
        vers=None, hors=None,
    ):
        
        batch_size = pred_rgb.shape[0]
        pred_rgb = pred_rgb.to(self.dtype)

        if as_latent:
            latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1
        else:
            # interp to 512x512 to be fed into vae.
            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
            # encode image into latents with vae, requires grad!
            latents = self.encode_imgs(pred_rgb_512)

        with torch.no_grad():

            if step_ratio is not None:
                # dreamtime-like
                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
            else:
                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)

            # w(t), sigma_t^2
            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)

            ######### debug
            # # Text embeds -> img latents
            # latentsx = self.produce_latents(
            #     height=512,
            #     width=512,
            #     latents=torch.randn_like(latents),
            #     num_inference_steps=50,
            #     guidance_scale=7.5,
            # )  # [1, 4, 64, 64]

            # # Img latents -> imgs
            # imgs = self.decode_latents(latentsx)  # [1, 3, 512, 512]
            # import kiui
            # kiui.vis.plot_image(imgs)
            #########

            # add noise
            noise = torch.randn_like(latents)
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = torch.cat([latents_noisy] * 3)
            tt = torch.cat([t] * 3)

            if hors is None:
                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)])
            else:
                def _get_dir_ind(h):
                    if abs(h) < 60: return 'front'
                    elif abs(h) < 120: return 'side'
                    else: return 'back'

                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)])

            noise_pred = self.unet(
                latent_model_input, tt, encoder_hidden_states=embeddings
            ).sample

            # perform guidance (high scale from paper!)
            noise_pred_cond, noise_pred_uncond, noise_pred_null = noise_pred.chunk(3)
            delta_c = guidance_scale * (noise_pred_cond - noise_pred_null)
            mask = (t < 200).int().view(batch_size, 1, 1, 1)
            delta_d = mask * noise_pred_null + (1 - mask) * (noise_pred_null - noise_pred_uncond) * 3
            grad = w * (delta_c + delta_d)
            grad = torch.nan_to_num(grad)

            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm

            target = (latents - grad).detach()

        loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]

        return loss

    @torch.no_grad()
    def produce_latents(
        self,
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=7.5,
        latents=None,
    ):
        if latents is None:
            latents = torch.randn(
                (
                    1,
                    self.unet.in_channels,
                    height // 8,
                    width // 8,
                ),
                device=self.device,
            )
        
        batch_size = latents.shape[0]

        self.scheduler.set_timesteps(num_inference_steps)
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)
            # predict the noise residual
            noise_pred = self.unet(
                latent_model_input, t, encoder_hidden_states=embeddings
            ).sample

            # perform guidance
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        return latents

    def decode_latents(self, latents):
        latents = 1 / self.vae.config.scaling_factor * latents

        imgs = self.vae.decode(latents).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)

        return imgs

    def encode_imgs(self, imgs):
        # imgs: [B, 3, H, W]

        imgs = 2 * imgs - 1

        posterior = self.vae.encode(imgs).latent_dist
        latents = posterior.sample() * self.vae.config.scaling_factor

        return latents

    def prompt_to_img(
        self,
        prompts,
        negative_prompts="",
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=7.5,
        latents=None,
    ):
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        self.get_text_embeds(prompts, negative_prompts)
        
        # Text embeds -> img latents
        latents = self.produce_latents(
            height=height,
            width=width,
            latents=latents,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        )  # [1, 4, 64, 64]

        # Img latents -> imgs
        imgs = self.decode_latents(latents)  # [1, 3, 512, 512]

        # Img to Numpy
        imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
        imgs = (imgs * 255).round().astype("uint8")

        return imgs


if __name__ == "__main__":
    import kiui
    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser()
    parser.add_argument("prompt", type=str)
    parser.add_argument("--negative", default="", type=str)
    parser.add_argument("--fp16", action="store_true", help="use float16 for training")
    parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
    parser.add_argument("-H", type=int, default=512)
    parser.add_argument("-W", type=int, default=512)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--steps", type=int, default=50)
    opt = parser.parse_args()

    kiui.seed_everything(opt.seed)

    device = torch.device("cuda")

    sd = StableDiffusion(device, opt.fp16, opt.vram_O)

    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)

    # visualize image
    plt.imshow(imgs[0])
    plt.show()


================================================
FILE: threefiner/guidance/sd_utils.py
================================================
from diffusers import (
    PNDMScheduler,
    DDIMScheduler,
    StableDiffusionPipeline,
)



import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class StableDiffusion(nn.Module):
    def __init__(
        self,
        device,
        fp16=True,
        vram_O=False,
        model_key="stabilityai/stable-diffusion-2-1-base",
        # model_key="philz1337/revanimated",
        t_range=[0.02, 0.50],
    ):
        super().__init__()

        self.device = device
        self.model_key = model_key
        self.dtype = torch.float16 if fp16 else torch.float32

        # Create model
        pipe = StableDiffusionPipeline.from_pretrained(
            model_key, torch_dtype=self.dtype
        )

        if vram_O:
            pipe.enable_sequential_cpu_offload()
            pipe.enable_vae_slicing()
            pipe.unet.to(memory_format=torch.channels_last)
            pipe.enable_attention_slicing(1)
            # pipe.enable_model_cpu_offload()
        else:
            pipe.to(device)

        self.vae = pipe.vae
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder
        self.unet = pipe.unet

        self.scheduler = DDIMScheduler.from_pretrained(
            model_key, subfolder="scheduler", torch_dtype=self.dtype
        )

        del pipe

        self.num_train_timesteps = self.scheduler.config.num_train_timesteps
        self.min_step = int(self.num_train_timesteps * t_range[0])
        self.max_step = int(self.num_train_timesteps * t_range[1])
        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience

        self.embeddings = {}
        

    @torch.no_grad()
    def get_text_embeds(self, prompts, negative_prompts):
        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]
        neg_embeds = self.encode_text(negative_prompts)
        self.embeddings['pos'] = pos_embeds
        self.embeddings['neg'] = neg_embeds

        # directional embeddings
        for d in ['front', 'side', 'back']:
            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
            self.embeddings[d] = embeds
        
    
    def encode_text(self, prompt):
        # prompt: [str]
        inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
        return embeddings

    @torch.no_grad()
    def refine(self, pred_rgb,
               guidance_scale=100, steps=50, strength=0.8,
        ):

        batch_size = pred_rgb.shape[0]
        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
        latents = self.encode_imgs(pred_rgb_512.to(self.dtype))
        # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)

        self.scheduler.set_timesteps(steps)
        init_step = int(steps * strength)
        latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps[init_step:]):
    
            latent_model_input = torch.cat([latents] * 2)

            noise_pred = self.unet(
                latent_model_input, t, encoder_hidden_states=embeddings,
            ).sample

            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        imgs = self.decode_latents(latents) # [1, 3, 512, 512]
        return imgs

    def train_step(
        self,
        pred_rgb,
        step_ratio=None,
        guidance_scale=100,
        as_latent=False,
        vers=None, hors=None,
    ):
        
        batch_size = pred_rgb.shape[0]
        pred_rgb = pred_rgb.to(self.dtype)

        if as_latent:
            latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1
        else:
            # interp to 512x512 to be fed into vae.
            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
            # encode image into latents with vae, requires grad!
            latents = self.encode_imgs(pred_rgb_512)

        with torch.no_grad():

            if step_ratio is not None:
                # dreamtime-like
                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
            else:
                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)

            # w(t), sigma_t^2
            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)

            ######### debug
            # # Text embeds -> img latents
            # latentsx = self.produce_latents(
            #     height=512,
            #     width=512,
            #     latents=torch.randn_like(latents),
            #     num_inference_steps=50,
            #     guidance_scale=7.5,
            # )  # [1, 4, 64, 64]

            # # Img latents -> imgs
            # imgs = self.decode_latents(latentsx)  # [1, 3, 512, 512]
            # import kiui
            # kiui.vis.plot_image(imgs)
            #########

            # add noise
            noise = torch.randn_like(latents)
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = torch.cat([latents_noisy] * 2)
            tt = torch.cat([t] * 2)

            if hors is None:
                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
            else:
                def _get_dir_ind(h):
                    if abs(h) < 60: return 'front'
                    elif abs(h) < 120: return 'side'
                    else: return 'back'

                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])

            noise_pred = self.unet(
                latent_model_input, tt, encoder_hidden_states=embeddings
            ).sample

            # perform guidance (high scale from paper!)
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )

            grad = w * (noise_pred - noise)
            grad = torch.nan_to_num(grad)

            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm

            target = (latents - grad).detach()

        loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]

        return loss

    @torch.no_grad()
    def produce_latents(
        self,
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=7.5,
        latents=None,
    ):
        if latents is None:
            latents = torch.randn(
                (
                    1,
                    self.unet.in_channels,
                    height // 8,
                    width // 8,
                ),
                device=self.device,
            )
        
        batch_size = latents.shape[0]

        self.scheduler.set_timesteps(num_inference_steps)
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)
            # predict the noise residual
            noise_pred = self.unet(
                latent_model_input, t, encoder_hidden_states=embeddings
            ).sample

            # perform guidance
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        return latents

    def decode_latents(self, latents):
        latents = 1 / self.vae.config.scaling_factor * latents

        imgs = self.vae.decode(latents).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)

        return imgs

    def encode_imgs(self, imgs):
        # imgs: [B, 3, H, W]

        imgs = 2 * imgs - 1

        posterior = self.vae.encode(imgs).latent_dist
        latents = posterior.sample() * self.vae.config.scaling_factor

        return latents

    def prompt_to_img(
        self,
        prompts,
        negative_prompts="",
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=7.5,
        latents=None,
    ):
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        self.get_text_embeds(prompts, negative_prompts)
        
        # Text embeds -> img latents
        latents = self.produce_latents(
            height=height,
            width=width,
            latents=latents,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        )  # [1, 4, 64, 64]

        # Img latents -> imgs
        imgs = self.decode_latents(latents)  # [1, 3, 512, 512]

        # Img to Numpy
        imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
        imgs = (imgs * 255).round().astype("uint8")

        return imgs


if __name__ == "__main__":
    import kiui
    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser()
    parser.add_argument("prompt", type=str)
    parser.add_argument("--negative", default="", type=str)
    parser.add_argument("--fp16", action="store_true", help="use float16 for training")
    parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
    parser.add_argument("-H", type=int, default=512)
    parser.add_argument("-W", type=int, default=512)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--steps", type=int, default=50)
    opt = parser.parse_args()

    kiui.seed_everything(opt.seed)

    device = torch.device("cuda")

    sd = StableDiffusion(device, opt.fp16, opt.vram_O)

    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)

    # visualize image
    plt.imshow(imgs[0])
    plt.show()


================================================
FILE: threefiner/guidance/sdcn_utils.py
================================================
from diffusers import (
    PNDMScheduler,
    DDIMScheduler,
    StableDiffusionPipeline,
    ControlNetModel,
)

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class StableDiffusionControlNet(nn.Module):
    def __init__(
        self,
        device,
        fp16=True,
        vram_O=False,
        control_mode=["tile"],
        model_key="runwayml/stable-diffusion-v1-5",
        # model_key="philz1337/revanimated",
        t_range=[0.02, 0.50],
    ):
        super().__init__()

        self.device = device
        self.control_mode = control_mode
        self.dtype = torch.float16 if fp16 else torch.float32

        # Create model
        pipe = StableDiffusionPipeline.from_pretrained(
            model_key, torch_dtype=self.dtype
        )

        if vram_O:
            pipe.enable_sequential_cpu_offload()
            pipe.enable_vae_slicing()
            pipe.unet.to(memory_format=torch.channels_last)
            pipe.enable_attention_slicing(1)
            # pipe.enable_model_cpu_offload()
        else:
            pipe.to(device)

        self.vae = pipe.vae
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder
        self.unet = pipe.unet

        # controlnet
        if self.control_mode is not None:
            self.controlnet = {}
            self.controlnet_conditioning_scale = {}
            
            if "normal" in self.control_mode:
                self.controlnet['normal'] = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_normalbae",torch_dtype=self.dtype).to(self.device)
                self.controlnet_conditioning_scale['normal'] = 1.0
            if "depth" in self.control_mode:
                self.controlnet['depth'] = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth",torch_dtype=self.dtype).to(self.device)
                self.controlnet_conditioning_scale['depth'] = 1.0
            if "ip2p" in self.control_mode:
                self.controlnet['ip2p'] = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_ip2p",torch_dtype=self.dtype).to(self.device)
                self.controlnet_conditioning_scale['ip2p'] = 1.0
            if "inpaint" in self.control_mode:
                self.controlnet['inpaint'] = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint",torch_dtype=self.dtype).to(self.device)
                self.controlnet_conditioning_scale['inpaint'] = 1.0
            if "depth_inpaint" in self.control_mode:
                self.controlnet['depth_inpaint'] = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_depth_aware_inpaint",torch_dtype=self.dtype).to(self.device)
                self.controlnet_conditioning_scale['depth_inpaint'] = 1.0
            if "pose" in self.control_mode:
                self.controlnet['pose'] = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose",torch_dtype=self.dtype).to(self.device)
                self.controlnet_conditioning_scale['pose'] = 1.0
            if "tile" in self.control_mode:
                self.controlnet['tile'] = ControlNetModel.from_pretrained("lllyasviel/control_v11f1e_sd15_tile",torch_dtype=self.dtype).to(self.device)
                self.controlnet_conditioning_scale['tile'] = 1.0

        self.scheduler = DDIMScheduler.from_pretrained(
            model_key, subfolder="scheduler", torch_dtype=self.dtype
        )

        del pipe

        self.num_train_timesteps = self.scheduler.config.num_train_timesteps
        self.min_step = int(self.num_train_timesteps * t_range[0])
        self.max_step = int(self.num_train_timesteps * t_range[1])
        self.alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience

        self.embeddings = {}
        

    @torch.no_grad()
    def get_text_embeds(self, prompts, negative_prompts):
        pos_embeds = self.encode_text(prompts)  # [1, 77, 768]
        neg_embeds = self.encode_text(negative_prompts)
        self.embeddings['pos'] = pos_embeds
        self.embeddings['neg'] = neg_embeds

        # directional embeddings
        for d in ['front', 'side', 'back']:
            embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
            self.embeddings[d] = embeds
        
    
    def encode_text(self, prompt):
        # prompt: [str]
        inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
        return embeddings

    @torch.no_grad()
    def refine(self, pred_rgb,
               guidance_scale=100, steps=50, strength=0.8,
               control_images=None
        ):

        batch_size = pred_rgb.shape[0]
        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
        latents = self.encode_imgs(pred_rgb_512.to(self.dtype))
        # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)

        self.scheduler.set_timesteps(steps)
        init_step = int(steps * strength)
        latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps[init_step:]):
    
            latent_model_input = torch.cat([latents] * 2)

            if self.control_mode is not None and control_images is not None:

                noise_pred = 0

                for mode, controlnet in self.controlnet.items():
                    # may omit control mode if input is not provided
                    if mode not in control_images: continue
                    
                    control_image = control_images[mode].to(self.dtype)
                    weight = 1 / len(self.controlnet)

                    control_image_input = torch.cat([control_image] * 2)
                    down_samples, mid_sample = controlnet(
                        latent_model_input, t, encoder_hidden_states=embeddings, 
                        controlnet_cond=control_image_input, 
                        conditioning_scale=self.controlnet_conditioning_scale[mode],
                        return_dict=False
                    )

                    # predict the noise residual
                    noise_pred_cur = self.unet(
                        latent_model_input, t, encoder_hidden_states=embeddings, 
                        down_block_additional_residuals=down_samples, 
                        mid_block_additional_residual=mid_sample
                    ).sample

                    # merge after unet
                    noise_pred = noise_pred + weight * noise_pred_cur
                
            else:
                noise_pred = self.unet(
                    latent_model_input, t, encoder_hidden_states=self.embeddings,
                ).sample

            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        imgs = self.decode_latents(latents) # [1, 3, 512, 512]
        return imgs

    def train_step(
        self,
        pred_rgb,
        step_ratio=None,
        guidance_scale=100,
        as_latent=False,
        control_images=None,
        vers=None, hors=None,
    ):
        
        batch_size = pred_rgb.shape[0]
        pred_rgb = pred_rgb.to(self.dtype)

        if as_latent:
            latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1
        else:
            # interp to 512x512 to be fed into vae.
            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
            # encode image into latents with vae, requires grad!
            latents = self.encode_imgs(pred_rgb_512)

        with torch.no_grad():
            if step_ratio is not None:
                # dreamtime-like
                # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
                t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
                t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
            else:
                t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)

            # w(t), sigma_t^2
            w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)

            ######### debug
            # # Text embeds -> img latents
            # latentsx = self.produce_latents(
            #     height=512,
            #     width=512,
            #     latents=torch.randn_like(latents),
            #     num_inference_steps=50,
            #     guidance_scale=7.5,
            #     control_images=control_images,
            # )  # [1, 4, 64, 64]
            # # Img latents -> imgs
            # imgs = self.decode_latents(latentsx)  # [1, 3, 512, 512]
            # import kiui
            # kiui.vis.plot_image(control_images['tile'], imgs)
            #########

            # add noise
            noise = torch.randn_like(latents)
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = torch.cat([latents_noisy] * 2)
            tt = torch.cat([t] * 2)

            if hors is None:
                embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
            else:
                def _get_dir_ind(h):
                    if abs(h) < 60: return 'front'
                    elif abs(h) < 120: return 'side'
                    else: return 'back'

                embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])
                

            if self.control_mode is not None and control_images is not None:

                noise_pred = 0

                for mode, controlnet in self.controlnet.items():
                    # may omit control mode if input is not provided
                    if mode not in control_images: continue
                    
                    control_image = control_images[mode].to(self.dtype)
                    weight = 1 / len(self.controlnet)

                    control_image_input = torch.cat([control_image] * 2)
                    down_samples, mid_sample = controlnet(
                        latent_model_input, tt, encoder_hidden_states=embeddings, 
                        controlnet_cond=control_image_input, 
                        conditioning_scale=self.controlnet_conditioning_scale[mode],
                        return_dict=False
                    )

                    # predict the noise residual
                    noise_pred_cur = self.unet(
                        latent_model_input, tt, encoder_hidden_states=embeddings, 
                        down_block_additional_residuals=down_samples, 
                        mid_block_additional_residual=mid_sample
                    ).sample

                    # merge after unet
                    noise_pred = noise_pred + weight * noise_pred_cur
                
            else:
                noise_pred = self.unet(
                    latent_model_input, tt, encoder_hidden_states=embeddings,
                ).sample

            # perform guidance (high scale from paper!)
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )

            grad = w * (noise_pred - noise)
            grad = torch.nan_to_num(grad)

            # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
            # grad = grad_norm.clamp(max=0.1) * grad / grad_norm

            target = (latents - grad).detach()

        loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]

        return loss

    @torch.no_grad()
    def produce_latents(
        self,
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=7.5,
        latents=None,
        control_images=None,
    ):
        if latents is None:
            latents = torch.randn(
                (
                    1,
                    self.unet.in_channels,
                    height // 8,
                    width // 8,
                ),
                device=self.device,
            )
        
        batch_size = latents.shape[0]

        self.scheduler.set_timesteps(num_inference_steps)
        embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])

        for i, t in enumerate(self.scheduler.timesteps):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)

            # predict the noise residual
            if self.control_mode is not None and control_images is not None:

                noise_pred = 0

                for mode, controlnet in self.controlnet.items():
                    # may omit control mode if input is not provided
                    if mode not in control_images: continue
                    
                    control_image = control_images[mode].to(self.dtype)
                    weight = 1 / len(self.controlnet)

                    control_image_input = torch.cat([control_image] * 2)
                    down_samples, mid_sample = controlnet(
                        latent_model_input, t, encoder_hidden_states=embeddings, 
                        controlnet_cond=control_image_input, 
                        conditioning_scale=self.controlnet_conditioning_scale[mode],
                        return_dict=False
                    )

                    # predict the noise residual
                    noise_pred_cur = self.unet(
                        latent_model_input, t, encoder_hidden_states=embeddings, 
                        down_block_additional_residuals=down_samples, 
                        mid_block_additional_residual=mid_sample
                    ).sample

                    # merge after unet
                    noise_pred = noise_pred + weight * noise_pred_cur
                
            else:
                noise_pred = self.unet(
                    latent_model_input, t, encoder_hidden_states=embeddings,
                ).sample

            # perform guidance
            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        return latents

    def decode_latents(self, latents):
        latents = 1 / self.vae.config.scaling_factor * latents

        imgs = self.vae.decode(latents).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)

        return imgs

    def encode_imgs(self, imgs):
        # imgs: [B, 3, H, W]

        imgs = 2 * imgs - 1

        posterior = self.vae.encode(imgs).latent_dist
        latents = posterior.sample() * self.vae.config.scaling_factor

        return latents

    def prompt_to_img(
        self,
        prompts,
        negative_prompts="",
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=7.5,
        latents=None,
        control_images=None,
    ):
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        self.get_text_embeds(prompts, negative_prompts)
        
        # Text embeds -> img latents
        latents = self.produce_latents(
            height=height,
            width=width,
            latents=latents,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            control_images=control_images,
        )  # [1, 4, 64, 64]

        # Img latents -> imgs
        imgs = self.decode_latents(latents)  # [1, 3, 512, 512]

        # Img to Numpy
        imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
        imgs = (imgs * 255).round().astype("uint8")

        return imgs


if __name__ == "__main__":
    import kiui
    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser()
    parser.add_argument("image", type=str)
    parser.add_argument("prompt", default="", type=str)
    parser.add_argument("--control", default='tile', type=str)
    parser.add_argument("--negative", default="", type=str)
    parser.add_argument("--fp16", action="store_true", help="use float16 for training")
    parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
    parser.add_argument("-H", type=int, default=512)
    parser.add_argument("-W", type=int, default=512)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--steps", type=int, default=50)
    opt = parser.parse_args()

    kiui.seed_everything(opt.seed)

    device = torch.device("cuda")

    # load image
    control_image = kiui.read_image(opt.image, mode='tensor').permute(2,0,1).contiguous().unsqueeze(0).to(device)
    control_image = F.interpolate(control_image, (opt.H, opt.W), mode='bilinear', align_corners=False)

    kiui.lo(control_image)

    control_images = {}
    control_images[opt.control] = control_image

    sd = StableDiffusionControlNet(device, opt.fp16, opt.vram_O, control_mode=[opt.control])

    while True:
        imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps, control_images=control_images)

        # visualize image
        plt.imshow(imgs[0])
        plt.show()

================================================
FILE: threefiner/lights/LICENSE.txt
================================================
The mud_road_puresky.hdr HDR probe is from https://polyhaven.com/a/mud_road_puresky
CC0 License.


================================================
FILE: threefiner/nn.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import tinycudann as tcnn

class HashGridEncoder(nn.Module):
    def __init__(self, 
                 input_dim=3,
                 num_levels=16,
                 level_dim=2,
                 log2_hashmap_size=18, 
                 base_resolution=16, 
                 desired_resolution=1024, 
                 interpolation='linear'
                 ):
        super().__init__()
        self.encoder = tcnn.Encoding(
            n_input_dims=input_dim,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": num_levels,
                "n_features_per_level": level_dim,
                "log2_hashmap_size": log2_hashmap_size,
                "base_resolution": base_resolution,
                "per_level_scale": np.exp2(np.log2(desired_resolution / num_levels) / (num_levels - 1)),
                "interpolation": "Smoothstep" if interpolation == 'smoothstep' else "Linear",
            },
            dtype=torch.float32,
        )
        self.input_dim = input_dim
        self.output_dim = self.encoder.n_output_dims # patch
    
    def forward(self, x, bound=1):
        return self.encoder((x + bound) / (2 * bound))

class FrequencyEncoder(nn.Module):
    def __init__(self, 
                 input_dim=3,
                 output_dim=32,
                 n_frequencies=12,
                 ):
        super().__init__()
        self.encoder = tcnn.Encoding(
            n_input_dims=input_dim,
            encoding_config={
                "otype": "Frequency",
                "n_frequencies": n_frequencies,
            },
            dtype=torch.float32,
        )
        self.implicit_mlp = MLP(self.encoder.n_output_dims, output_dim, 128, 5, bias=True)
        self.input_dim = input_dim
        self.output_dim = output_dim
    
    def forward(self, x, **kwargs):
        return self.implicit_mlp(self.encoder(x))
    

class TriplaneEncoder(nn.Module):
    def __init__(self, 
                 input_dim=3,
                 output_dim=32,
                 resolution=256,
                 ):
        super().__init__()

        self.C_mat = nn.Parameter(torch.randn(3, output_dim, resolution, resolution))
        torch.nn.init.kaiming_normal_(self.C_mat)
        
        self.mat_ids = [[0, 1], [0, 2], [1, 2]]

        self.input_dim = input_dim
        self.output_dim = output_dim
    
    def forward(self, x, bound=1):

        N = x.shape[0]
        x = x / bound # to [-1, 1]

        mat_coord = torch.stack((x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]])).view(3, -1, 1, 2) # [3, N, 1, 2]

        feat = F.grid_sample(self.C_mat[[0]], mat_coord[[0]], align_corners=False).view(-1, N) + \
               F.grid_sample(self.C_mat[[1]], mat_coord[[1]], align_corners=False).view(-1, N) + \
               F.grid_sample(self.C_mat[[2]], mat_coord[[2]], align_corners=False).view(-1, N) # [r, N]

        # density
        feat = feat.transpose(0, 1).contiguous() # [N, C]
        return feat

class MLP(nn.Module):
    def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.dim_hidden = dim_hidden
        self.num_layers = num_layers

        net = []
        for l in range(num_layers):
            net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))

        self.net = nn.ModuleList(net)
    
    def forward(self, x):
        for l in range(self.num_layers):
            x = self.net[l](x)
            if l != self.num_layers - 1:
                x = F.relu(x, inplace=True)
        return x

================================================
FILE: threefiner/opt.py
================================================
import os
from dataclasses import dataclass
from typing import Tuple, Literal, Dict, Optional

@dataclass
class Options:
    # path to input mesh
    mesh: Optional[str] = None
    # input text prompt
    prompt: Optional[str] = None
    # additional positive prompt
    positive_prompt: str = "best quality, extremely detailed, masterpiece, high resolution, high quality"
    # additional negative prompt
    negative_prompt: str = "blur, lowres, cropped, low quality, worst quality, ugly, dark, shadow, oversaturated"
    # whether to append directional text prompt
    text_dir: bool = False
    # set mesh front-facing direction (camera front=+z, right=+x, up=+y, clock-wise rotation 90=1, 180=2, 270=3, e.g., +z, -y1)
    front_dir: str = "+z"

    # training iterations
    iters: int = 500
    # training resolution
    render_resolution: int = 512
    # training camera radius
    radius: float = 2.5
    # training camera fovy in degree
    fovy: float = 49.1
    # whether to allow geom training
    fix_geo: bool = False
    # whether to mix normal with rgb for geometry training
    mix_normal: bool = True
    # whether to pretrain texture first
    fit_tex: bool = True
    # pretrain texture iterations
    fit_tex_iters: int = 512

    # output folder
    outdir: str = '.'
    # output filename, default to {name}_fine.{ext}
    save: Optional[str] = None

    # guidance mode
    mode: Literal['SD', 'IF', 'IF2', 'SDCN', 'SD_NFSD', 'IF2_NFSD', 'SD_ISM', 'IF2_ISM'] = 'IF2'
    # renderer geometry mode
    geom_mode: Literal['mesh', 'diffmc', 'pbr_mesh', 'pbr_diffmc'] = 'diffmc'
    # renderer texture mode
    tex_mode: Literal['hashgrid', 'mlp', 'triplane'] = 'hashgrid'
    
    # training batch size per iter
    batch_size: int = 1
    # environmental texture
    env_texture: Optional[str] = None
    # environmental light scale
    env_scale: float = 2
    
    # DiffMC grid size
    mc_grid_size: int = 128
    # Mesh remeshing interval
    remesh_interval: int = 200
    # mesh decimation target face number
    decimate_target: int = 5e4
    # remesh target edge length (smaller value lead to finer mesh)
    remesh_size: float = 0.015
    # texture resolution
    texture_resolution: int = 1024
    # learning rate for hashgrid
    hashgrid_lr: float = 0.01
    # learning rate for feature MLP
    mlp_lr: float = 0.001
    # learning rate for SDF
    sdf_lr: float = 0.0001
    # learning rate for deformation
    deform_lr: float = 0.0001
    # learning rate for mesh geometry
    geom_lr: float = 0.0001

    # guidance loss weights
    lambda_sd: float = 1
    # mesh laplacian regularization weight
    lambda_lap: float = 0
    # mesh normal consistency weight (should be large enough)
    lambda_normal: float = 10000
    # mesh vertices offset penalty weight
    lambda_offsets: float = 100

    # whether to open a GUI
    gui: bool = False
    # GUI height
    H: int = 800
    # GUI width
    W: int = 800
    # whether to use CUDA rasterizer (in case OpenGL fails)
    force_cuda_rast: bool = False
    # whether to use GPU memory-optimized mode (slower, but uses less GPU memory)
    vram_O: bool = False


# all the default settings
config_defaults: Dict[str, Options] = {}
config_doc: Dict[str, str] = {}

config_doc['sd'] = 'coarse-level generation with stable-diffusion 2.'
config_defaults['sd'] = Options(
    mode='SD',
    iters=800,
)

config_doc['if'] = 'coarse-level generation with deepfloyd-if I.'
config_defaults['if'] = Options(
    mode='IF',
    iters=400,
)

config_doc['if2'] = 'fine-level refinement with deepfloyd-if II.'
config_defaults['if2'] = Options(
    mode='IF2',
    iters=400,
)

config_doc['sd_fixgeo'] = 'coarse-level generation with stable-diffusion 2, fixed goemetry.'
config_defaults['sd_fixgeo'] = Options(
    mode='SD',
    iters=800,
    fix_geo=True,
    geom_mode='mesh',
)

config_doc['if_fixgeo'] = 'coarse-level generation with deepfloyd-if I, fixed goemetry.'
config_defaults['if_fixgeo'] = Options(
    mode='IF',
    iters=400,
    fix_geo=True,
    geom_mode='mesh',
)

config_doc['if2_fixgeo'] = 'fine-level refinement with deepfloyd-if II, fixed goemetry.'
config_defaults['if2_fixgeo'] = Options(
    mode='IF2',
    iters=400,
    fix_geo=True,
    geom_mode='mesh',
)

def check_options(opt: Options):
    assert opt.mesh is not None, 'mesh path must be specified!'
    assert opt.prompt is not None, 'prompt must be specified!'

    if opt.save is None:
        input_name, input_ext = os.path.splitext(os.path.basename(opt.mesh))
        opt.save = input_name + '_fine' + '.glb'
        print(f'[INFO] save to default output path: {os.path.join(opt.outdir, opt.save)}.')

    return opt

================================================
FILE: threefiner/renderer/__init__.py
================================================



================================================
FILE: threefiner/renderer/diffmc_renderer.py
================================================
import os
import tqdm
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import nvdiffrast.torch as dr

import kiui
from kiui.mesh import Mesh
from kiui.mesh_utils import clean_mesh, decimate_mesh
from kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding
from kiui.cam import orbit_camera, get_perspective

from threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder
from threefiner.renderer.mesh_renderer import render_mesh

from diso import DiffMC, DiffDMC

class Renderer(nn.Module):
    def __init__(self, opt, device):
        
        super().__init__()

        self.opt = opt
        self.device = device

        if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
            self.glctx = dr.RasterizeGLContext()
        else:
            self.glctx = dr.RasterizeCudaContext()

        # diffmc
        self.verts = torch.stack(
            torch.meshgrid(
                torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),
                torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),
                torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),
                indexing="ij",
            ), dim=-1,
        ) # [N, N, N, 3]
        self.grid_scale = 1
        self.diffmc = DiffMC(dtype=torch.float32).to(device)
        
        # vert sdf and deform
        self.sdf = nn.Parameter(torch.zeros_like(self.verts[..., 0]))
        self.deform = nn.Parameter(torch.zeros_like(self.verts))
        
        # init diffmc from mesh
        self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir)

        vertices = self.mesh.v.detach().cpu().numpy()
        triangles = self.mesh.f.detach().cpu().numpy()
        vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False)
        self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
        self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)

        self.grid_scale = self.mesh.v.abs().max() + 1e-1
        self.verts = self.verts * self.grid_scale
        
        try:
            import cubvh
            BVH = cubvh.cuBVH(self.mesh.v, self.mesh.f)
            sdf, _, _ = BVH.signed_distance(self.verts.reshape(-1, 3), return_uvw=False, mode='raystab') # some mesh may not be watertight...
        except:
            from pysdf import SDF
            sdf_func = SDF(self.mesh.v.detach().cpu().numpy(), self.mesh.f.detach().cpu().numpy())
            sdf = sdf_func(self.verts.detach().cpu().numpy().reshape(-1, 3))
            sdf = torch.from_numpy(sdf).to(self.device)
            sdf *= -1

        # OUTER is POSITIVE
        self.sdf.data += sdf.reshape(*self.sdf.data.shape).to(self.sdf.data.dtype)

        # texture
        if self.opt.tex_mode == 'hashgrid':
            self.encoder = HashGridEncoder().to(self.device)
        elif self.opt.tex_mode == 'mlp':
            self.encoder = FrequencyEncoder().to(self.device)
        elif self.opt.tex_mode == 'triplane':
            self.encoder = TriplaneEncoder().to(self.device)
        else:
            raise NotImplementedError(f"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}")
        
        self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=True).to(self.device)

        self.v, self.f = None, None # placeholder

        # init hashgrid texture from mesh
        if self.opt.fit_tex:
            self.fit_texture_from_mesh(self.opt.fit_tex_iters)
    
    def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):
        return render_mesh(
            self.glctx, 
            self.mesh.v, self.mesh.f, self.mesh.vt, 
            self.mesh.ft, self.mesh.albedo, 
            self.mesh.vc, self.mesh.vn, self.mesh.fn, 
            pose, proj, h, w, 
            ssaa=ssaa, bg_color=bg_color,
        )

    def fit_texture_from_mesh(self, iters=512):
        # a small training loop...

        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.Adam([
            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
        ])

        resolution = 512

        print(f"[INFO] fitting texture...")
        pbar = tqdm.trange(iters)
        for i in pbar:

            ver = np.random.randint(-45, 45)
            hor = np.random.randint(-180, 180)
            
            pose = orbit_camera(ver, hor, self.opt.radius)
            proj = get_perspective(self.opt.fovy)

            image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image']
            image_pred = self.render(pose, proj, resolution, resolution)['image']

            loss = loss_fn(image_pred, image_mesh)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_description(f"MSE = {loss.item():.6f}")
        
        print(f"[INFO] finished fitting texture!")

    def get_params(self):

        params = [
            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
        ]

        if not self.opt.fix_geo:
            params.append({'params': self.sdf, 'lr': self.opt.sdf_lr})
            params.append({'params': self.deform, 'lr': self.opt.deform_lr})

        return params

    @torch.no_grad()
    def export_mesh(self, save_path, texture_resolution=2048, padding=16):

        # get v
        sdf = self.sdf
        deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5]

        v, f = self.diffmc(sdf, deform)
        v = (2 * v - 1) * self.grid_scale
        f = f.int()
        self.v, self.f = v, f
        
        vertices = v.detach().cpu().numpy()
        triangles = f.detach().cpu().numpy()

        # clean
        vertices = vertices.astype(np.float32)
        triangles = triangles.astype(np.int32)
        vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=self.opt.remesh_size)
        
        # decimation
        if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target:
            vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target)
        
        v = torch.from_numpy(vertices).contiguous().float().to(self.device)
        f = torch.from_numpy(triangles).contiguous().int().to(self.device)
        
        mesh = Mesh(v=v, f=f, albedo=None, device=self.device)
        print(f"[INFO] uv unwrapping...")
        mesh.auto_normal()
        mesh.auto_uv()

        # render uv maps
        h = w = texture_resolution
        uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
        uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]

        rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]
        xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
        mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]

        # masked query 
        xyzs = xyzs.view(-1, 3)
        mask = (mask > 0).view(-1)
        
        albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)

        if mask.any():
            print(f"[INFO] querying texture...")

            xyzs = xyzs[mask] # [M, 3]

            # batched inference to avoid OOM
            batch = []
            head = 0
            while head < xyzs.shape[0]:
                tail = min(head + 640000, xyzs.shape[0])
                batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
                head += 640000

            albedo[mask] = torch.cat(batch, dim=0)
        
        albedo = albedo.view(h, w, -1)
        mask = mask.view(h, w)

        print(f"[INFO] uv padding...")
        albedo = uv_padding(albedo, mask, padding)

        mesh.albedo = albedo
        mesh.write(save_path)

    def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):
        
        # do super-sampling
        if ssaa != 1:
            h = make_divisible(h0 * ssaa, 8)
            w = make_divisible(w0 * ssaa, 8)
        else:
            h, w = h0, w0
        
        results = {}

        # get v
        sdf = self.sdf
        deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5]

        v, f = self.diffmc(sdf, deform)
        v = (2 * v - 1) * self.grid_scale
        f = f.int()
        self.v, self.f = v, f

        pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
        proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)

        # get v_clip and render rgb
        v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
        v_clip = v_cam @ proj.T

        rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))

        alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [V, H, W, 1]
        alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients!
        depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]
        depth = depth.squeeze(0) # [H, W, 1]

        xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
        xyzs = xyzs.view(-1, 3)
        mask = (alpha > 0).view(-1)
        color = torch.zeros_like(xyzs, dtype=torch.float32)
        if mask.any():
            masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1)))
            color[mask] = masked_albedo.float()
        color = color.view(1, h, w, 3)

        # antialias
        color = dr.antialias(color, rast, v_clip, f).squeeze(0) # [H, W, 3]
        color = alpha * color + (1 - alpha) * bg_color

        # get vn and render normal
        i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()
        v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]

        face_normals = torch.cross(v1 - v0, v2 - v0)
        face_normals = safe_normalize(face_normals)
        
        vn = torch.zeros_like(v)
        vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
        vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
        vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)

        vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))

        normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f)
        normal = safe_normalize(normal[0])

        # rotated normal (where [0, 0, 1] always faces camera)
        rot_normal = normal @ pose[:3, :3]
        viewcos = rot_normal[..., [2]]

        # ssaa
        if ssaa != 1:
            color = scale_img_hwc(color, (h0, w0))
            alpha = scale_img_hwc(alpha, (h0, w0))
            depth = scale_img_hwc(depth, (h0, w0))
            normal = scale_img_hwc(normal, (h0, w0))
            viewcos = scale_img_hwc(viewcos, (h0, w0))

        results['image'] = color.clamp(0, 1)
        results['alpha'] = alpha
        results['depth'] = depth
        results['normal'] = (normal + 1) / 2
        results['viewcos'] = viewcos

        return results

================================================
FILE: threefiner/renderer/mesh_renderer.py
================================================
import os
import tqdm
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import nvdiffrast.torch as dr

from kiui.mesh import Mesh
from kiui.mesh_utils import clean_mesh, decimate_mesh
from kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding
from kiui.cam import orbit_camera, get_perspective
from threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder

def render_mesh(
        glctx, 
        v, f, 
        vt, ft, albedo, vc,
        vn, fn,
        pose, proj, 
        h0, w0, 
        ssaa=1, bg_color=1, 
        texture_filter='linear-mipmap-linear', 
        color_activation=None,
    ):
    
    # do super-sampling
    if ssaa != 1:
        h = make_divisible(h0 * ssaa, 8)
        w = make_divisible(w0 * ssaa, 8)
    else:
        h, w = h0, w0
    
    results = {}

    pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
    proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)

    # get v_clip and render rgb
    v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
    v_clip = v_cam @ proj.T

    rast, rast_db = dr.rasterize(glctx, v_clip, f, (h, w))

    alpha = (rast[0, ..., 3:] > 0).float()
    depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]
    depth = depth.squeeze(0) # [H, W, 1]

    if vc is not None:
        # use vertex color
        color, _ = dr.interpolate(vc.unsqueeze(0).contiguous(), rast, f)
    else:
        # use texture image
        texc, texc_db = dr.interpolate(vt.unsqueeze(0).contiguous(), rast, ft, rast_db=rast_db, diff_attrs='all')
        color = dr.texture(albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3]

    if color_activation is not None:
        color = color_activation(color)

    # antialias
    color = dr.antialias(color, rast, v_clip, f).squeeze(0) # [H, W, 3]
    color = alpha * color + (1 - alpha) * bg_color

    # get vn and render normal
    if vn is None:
        i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()
        v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]

        face_normals = torch.cross(v1 - v0, v2 - v0)
        face_normals = safe_normalize(face_normals)
        
        vn = torch.zeros_like(v)
        vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
        vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
        vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)

        vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
    
    normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, fn)
    normal = safe_normalize(normal[0])

    # rotated normal (where [0, 0, 1] always faces camera)
    rot_normal = normal @ pose[:3, :3]
    viewcos = rot_normal[..., [2]]

    # ssaa
    if ssaa != 1:
        color = scale_img_hwc(color, (h0, w0))
        alpha = scale_img_hwc(alpha, (h0, w0))
        depth = scale_img_hwc(depth, (h0, w0))
        normal = scale_img_hwc(normal, (h0, w0))
        viewcos = scale_img_hwc(viewcos, (h0, w0))

    results['image'] = color.clamp(0, 1)
    results['alpha'] = alpha
    results['depth'] = depth
    results['normal'] = (normal + 1) / 2
    results['viewcos'] = viewcos

    return results

class Renderer(nn.Module):
    def __init__(self, opt, device):
        
        super().__init__()

        self.opt = opt
        self.device = device

        self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir)

        # it's necessary to clean the mesh to facilitate later remeshing!
        vertices = self.mesh.v.detach().cpu().numpy()
        triangles = self.mesh.f.detach().cpu().numpy()
        vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False)
        self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
        self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)

        if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
            self.glctx = dr.RasterizeGLContext()
        else:
            self.glctx = dr.RasterizeCudaContext()
        
        # extract trainable parameters
        self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v))
        
        # texture
        if self.opt.tex_mode == 'hashgrid':
            self.encoder = HashGridEncoder().to(self.device)
        elif self.opt.tex_mode == 'mlp':
            self.encoder = FrequencyEncoder().to(self.device)
        elif self.opt.tex_mode == 'triplane':
            self.encoder = TriplaneEncoder().to(self.device)
        else:
            raise NotImplementedError(f"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}")
        
        self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=True).to(self.device)

        # init hashgrid texture from mesh
        if self.opt.fit_tex:
            self.fit_texture_from_mesh(self.opt.fit_tex_iters)

    def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):
        return render_mesh(
            self.glctx, 
            self.mesh.v, self.mesh.f, self.mesh.vt, 
            self.mesh.ft, self.mesh.albedo, 
            self.mesh.vc, self.mesh.vn, self.mesh.fn, 
            pose, proj, h, w, 
            ssaa=ssaa, bg_color=bg_color,
        )
    
    def fit_texture_from_mesh(self, iters=512):
        # a small training loop...

        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.Adam([
            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
        ])

        resolution = 512

        print(f"[INFO] fitting texture...")
        pbar = tqdm.trange(iters)
        for i in pbar:

            ver = np.random.randint(-45, 45)
            hor = np.random.randint(-180, 180)
            
            pose = orbit_camera(ver, hor, self.opt.radius)
            proj = get_perspective(self.opt.fovy)

            image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image']
            image_pred = self.render(pose, proj, resolution, resolution)['image']

            loss = loss_fn(image_pred, image_mesh)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_description(f"MSE = {loss.item():.6f}")
        
        print(f"[INFO] finished fitting texture!")

    def get_params(self):

        params = [
            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
        ]

        if not self.opt.fix_geo:
            params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr})

        return params

    @torch.no_grad()
    def export_mesh(self, save_path, texture_resolution=2048, padding=16):

        mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)
        print(f"[INFO] uv unwrapping...")
        mesh.auto_normal()
        mesh.auto_uv()

        # render uv maps
        h = w = texture_resolution
        uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
        uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]

        rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]

        # masked query 
        xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
        mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1]
        xyzs = xyzs.view(-1, 3)
        mask = (mask > 0).view(-1)
        
        albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)

        if mask.any():
            print(f"[INFO] querying texture...")

            xyzs = xyzs[mask] # [M, 3]

            # batched inference to avoid OOM
            batch = []
            head = 0
            while head < xyzs.shape[0]:
                tail = min(head + 640000, xyzs.shape[0])
                batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
                head += 640000

            albedo[mask] = torch.cat(batch, dim=0)
        
        albedo = albedo.view(h, w, -1)
        mask = mask.view(h, w)

        print(f"[INFO] uv padding...")
        albedo = uv_padding(albedo, mask, padding)

        mesh.albedo = albedo
        mesh.write(save_path)

    @property
    def v(self):
        if self.opt.fix_geo:
            return self.mesh.v
        else:
            return self.mesh.v + self.v_offsets
    
    @property
    def f(self):
        return self.mesh.f

    @torch.no_grad()
    def remesh(self):
        vertices = self.v.detach().cpu().numpy()
        triangles = self.f.detach().cpu().numpy()
        vertices, triangles = clean_mesh(vertices, triangles, repair=False, remesh=True, remesh_size=self.opt.remesh_size)
        if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target:
            vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target, optimalplacement=False)
        self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
        self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
        self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v)).to(self.device)
        
    
    def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):

        # do super-sampling
        if ssaa != 1:
            h = make_divisible(h0 * ssaa, 8)
            w = make_divisible(w0 * ssaa, 8)
        else:
            h, w = h0, w0
        
        results = {}

        # get v
        v = self.v
        f = self.f

        pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
        proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)

        # get v_clip and render rgb
        v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
        v_clip = v_cam @ proj.T

        rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))

        alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
        alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients!

        depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]
        depth = depth.squeeze(0) # [H, W, 1]

        xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
        xyzs = xyzs.view(-1, 3)
        mask = (alpha > 0).view(-1)
        color = torch.zeros_like(xyzs, dtype=torch.float32)
        if mask.any():
            masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1)))
            color[mask] = masked_albedo.float()
        color = color.view(1, h, w, 3)

        # antialias
        color = dr.antialias(color, rast, v_clip, f).clamp(0, 1).squeeze(0) # [H, W, 3]
        color = alpha * color + (1 - alpha) * bg_color

        # get vn and render normal
        if self.opt.fix_geo:
            vn = self.mesh.vn
        else:
            i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()
            v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]

            face_normals = torch.cross(v1 - v0, v2 - v0)
            face_normals = safe_normalize(face_normals)

            vn = torch.zeros_like(v)
            vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
            vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
            vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)

            vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))

        normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f)
        normal = safe_normalize(normal[0])

        # rotated normal (where [0, 0, 1] always faces camera)
        rot_normal = normal @ pose[:3, :3]
        viewcos = rot_normal[..., [2]]

        # ssaa
        if ssaa != 1:
            color = scale_img_hwc(color, (h0, w0))
            alpha = scale_img_hwc(alpha, (h0, w0))
            depth = scale_img_hwc(depth, (h0, w0))
            normal = scale_img_hwc(normal, (h0, w0))
            viewcos = scale_img_hwc(viewcos, (h0, w0))

        results['image'] = color
        results['alpha'] = alpha
        results['depth'] = depth
        results['normal'] = (normal + 1) / 2
        results['viewcos'] = viewcos

        return results

================================================
FILE: threefiner/renderer/pbr_diffmc_renderer.py
================================================
import os
import tqdm
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import envlight

import nvdiffrast.torch as dr

import kiui
from kiui.mesh import Mesh
from kiui.mesh_utils import clean_mesh, decimate_mesh
from kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding
from kiui.cam import orbit_camera, get_perspective

from threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder
from threefiner.renderer.mesh_renderer import render_mesh

from diso import DiffMC, DiffDMC

class Renderer(nn.Module):
    def __init__(self, opt, device):
        
        super().__init__()

        self.opt = opt
        self.device = device

        if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
            self.glctx = dr.RasterizeGLContext()
        else:
            self.glctx = dr.RasterizeCudaContext()

        # diffmc
        self.verts = torch.stack(
            torch.meshgrid(
                torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),
                torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),
                torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),
                indexing="ij",
            ), dim=-1,
        ) # [N, N, N, 3]
        self.grid_scale = 1
        self.diffmc = DiffMC(dtype=torch.float32).to(device)
        
        # vert sdf and deform
        self.sdf = nn.Parameter(torch.zeros_like(self.verts[..., 0]))
        self.deform = nn.Parameter(torch.zeros_like(self.verts))
        
        # init diffmc from mesh
        self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir)

        vertices = self.mesh.v.detach().cpu().numpy()
        triangles = self.mesh.f.detach().cpu().numpy()
        vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False)
        self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
        self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)

        self.grid_scale = self.mesh.v.abs().max() + 1e-1
        self.verts = self.verts * self.grid_scale
        
        try:
            import cubvh
            BVH = cubvh.cuBVH(self.mesh.v, self.mesh.f)
            sdf, _, _ = BVH.signed_distance(self.verts.reshape(-1, 3), return_uvw=False, mode='raystab') # some mesh may not be watertight...
        except:
            from pysdf import SDF
            sdf_func = SDF(self.mesh.v.detach().cpu().numpy(), self.mesh.f.detach().cpu().numpy())
            sdf = sdf_func(self.verts.detach().cpu().numpy().reshape(-1, 3))
            sdf = torch.from_numpy(sdf).to(self.device)
            sdf *= -1

        # OUTER is POSITIVE
        self.sdf.data += sdf.reshape(*self.sdf.data.shape).to(self.sdf.data.dtype)

        # texture
        if self.opt.tex_mode == 'hashgrid':
            self.encoder = HashGridEncoder().to(self.device)
        elif self.opt.tex_mode == 'mlp':
            self.encoder = FrequencyEncoder().to(self.device)
        elif self.opt.tex_mode == 'triplane':
            self.encoder = TriplaneEncoder().to(self.device)
        else:
            raise NotImplementedError(f"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}")
        
        self.mlp = MLP(self.encoder.output_dim, 3+2, 32, 2, bias=True).to(self.device)

        self.v, self.f = None, None # placeholder

        # env light
        if self.opt.env_texture is None:
            hdr_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../lights/mud_road_puresky_1k.hdr')
        else:
            hdr_path = self.opt.env_texture
        self.light = envlight.EnvLight(hdr_path, scale=self.opt.env_scale, device=self.device)

        FG_LUT = torch.from_numpy(np.fromfile(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../lights/bsdf_256_256.bin"), dtype=np.float32).reshape(1, 256, 256, 2)).to(self.device)
        self.register_buffer("FG_LUT", FG_LUT)

        # init hashgrid texture from mesh
        if self.opt.fit_tex:
            self.fit_texture_from_mesh(self.opt.fit_tex_iters)
    
    def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):
        return render_mesh(
            self.glctx, 
            self.mesh.v, self.mesh.f, self.mesh.vt, 
            self.mesh.ft, self.mesh.albedo, 
            self.mesh.vc, self.mesh.vn, self.mesh.fn, 
            pose, proj, h, w, 
            ssaa=ssaa, bg_color=bg_color,
        )

    def fit_texture_from_mesh(self, iters=512):
        # a small training loop...

        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.Adam([
            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
        ])

        resolution = 512

        print(f"[INFO] fitting texture...")
        pbar = tqdm.trange(iters)
        for i in pbar:

            ver = np.random.randint(-45, 45)
            hor = np.random.randint(-180, 180)
            
            pose = orbit_camera(ver, hor, self.opt.radius)
            proj = get_perspective(self.opt.fovy)

            image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image']
            image_pred = self.render(pose, proj, resolution, resolution)['image']

            loss = loss_fn(image_pred, image_mesh)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_description(f"MSE = {loss.item():.6f}")
        
        print(f"[INFO] finished fitting texture!")

    def get_params(self):

        params = [
            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
        ]

        if not self.opt.fix_geo:
            params.append({'params': self.sdf, 'lr': self.opt.sdf_lr})
            params.append({'params': self.deform, 'lr': self.opt.deform_lr})

        return params

    @torch.no_grad()
    def export_mesh(self, save_path, texture_resolution=2048, padding=16):

        # get v
        sdf = self.sdf
        deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5]

        v, f = self.diffmc(sdf, deform)
        v = (2 * v - 1) * self.grid_scale
        f = f.int()
        self.v, self.f = v, f
        
        vertices = v.detach().cpu().numpy()
        triangles = f.detach().cpu().numpy()

        # clean
        vertices = vertices.astype(np.float32)
        triangles = triangles.astype(np.int32)
        vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=self.opt.remesh_size)
        
        # decimation
        if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target:
            vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target)
        
        v = torch.from_numpy(vertices).contiguous().float().to(self.device)
        f = torch.from_numpy(triangles).contiguous().int().to(self.device)
        
        mesh = Mesh(v=v, f=f, albedo=None, device=self.device)
        print(f"[INFO] uv unwrapping...")
        mesh.auto_normal()
        mesh.auto_uv()

        # render uv maps
        h = w = texture_resolution
        uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
        uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]

        rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]
        xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
        mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]

        # masked query 
        xyzs = xyzs.view(-1, 3)
        mask = (mask > 0).view(-1)
        
        material = torch.zeros(h * w, 5, device=self.device, dtype=torch.float32)

        if mask.any():
            print(f"[INFO] querying texture...")

            xyzs = xyzs[mask] # [M, 3]

            # batched inference to avoid OOM
            batch = []
            head = 0
            while head < xyzs.shape[0]:
                tail = min(head + 640000, xyzs.shape[0])
                batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
                head += 640000

            material[mask] = torch.cat(batch, dim=0)
        
        material = material.view(h, w, -1)
        mask = mask.view(h, w)

        print(f"[INFO] uv padding...")
        material = uv_padding(material, mask, padding)

        mesh.albedo = material[..., :3]
        mesh.metallicRoughness = torch.cat([torch.zeros_like(material[..., 3:4]), material[..., 4:5], material[..., 3:4]], dim=-1)

        mesh.write(save_path)

    def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):
        
        # do super-sampling
        if ssaa != 1:
            h = make_divisible(h0 * ssaa, 8)
            w = make_divisible(w0 * ssaa, 8)
        else:
            h, w = h0, w0
        
        results = {}

        # get v
        sdf = self.sdf
        deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5]

        v, f = self.diffmc(sdf, deform)
        v = (2 * v - 1) * self.grid_scale
        f = f.int()
        self.v, self.f = v, f

        pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
        proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)

        # get v_clip and render rgb
        v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
        v_clip = v_cam @ proj.T

        rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))

        alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [V, H, W, 1]
        alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients!
        depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]
        depth = depth.squeeze(0) # [H, W, 1]

        xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
        viewdir = safe_normalize(xyzs - pose[:3, 3]).squeeze(0)
        xyzs = xyzs.view(-1, 3)
        mask = (alpha > 0).view(-1)
        material = torch.zeros(xyzs.shape[0], 5, dtype=torch.float32, device=xyzs.device)
        if mask.any():
            masked_material = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1)))
            material[mask] = masked_material.float()
        material = material.view(h, w, -1)

        # get vn and render normal
        i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()
        v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]

        face_normals = torch.cross(v1 - v0, v2 - v0)
        face_normals = safe_normalize(face_normals)

        vn = torch.zeros_like(v)
        vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
        vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
        vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)

        vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))

        normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f)
        normal = safe_normalize(normal[0])

        # rotated normal (where [0, 0, 1] always faces camera)
        rot_normal = normal @ pose[:3, :3]
        viewcos = rot_normal[..., [2]]

        # shading
        albedo = material[..., :3]
        metallic = material[..., 3:4]
        roughness = material[..., 4:5]

        n_dot_v = (normal * viewdir).sum(-1, keepdim=True) # [H, W, 1]
        reflective = n_dot_v * normal * 2 - viewdir

        diffuse_albedo = (1 - metallic) * albedo

        fg_uv = torch.cat([n_dot_v, roughness], -1).clamp(0, 1) # [H, W, 2]
        fg = dr.texture(
            self.FG_LUT,
            fg_uv.reshape(1, -1, 1, 2).contiguous(),
            filter_mode="linear",
            boundary_mode="clamp",
        ).reshape(h, w, 2)
        F0 = (1 - metallic) * 0.04 + metallic * albedo
        specular_albedo = F0 * fg[..., 0:1] + fg[..., 1:2]

        diffuse_light = self.light(normal)
        specular_light = self.light(reflective, roughness)

        color = diffuse_albedo * diffuse_light + specular_albedo * specular_light # [H, W, 3]

        # antialias
        color = dr.antialias(color.unsqueeze(0), rast, v_clip, f).clamp(0, 1).squeeze(0) # [H, W, 3]
        color = alpha * color + (1 - alpha) * bg_color

        # ssaa
        if ssaa != 1:
            color = scale_img_hwc(color, (h0, w0))
            alpha = scale_img_hwc(alpha, (h0, w0))
            depth = scale_img_hwc(depth, (h0, w0))
            normal = scale_img_hwc(normal, (h0, w0))
            viewcos = scale_img_hwc(viewcos, (h0, w0))

        results['image'] = color
        results['alpha'] = alpha
        results['depth'] = depth
        results['normal'] = (normal + 1) / 2
        results['viewcos'] = viewcos

        return results

================================================
FILE: threefiner/renderer/pbr_mesh_renderer.py
================================================
import os
import tqdm
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import envlight

import nvdiffrast.torch as dr

from kiui.mesh import Mesh
from kiui.mesh_utils import clean_mesh, decimate_mesh
from kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding
from kiui.cam import orbit_camera, get_perspective
from threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder
from threefiner.renderer.mesh_renderer import render_mesh


class Renderer(nn.Module):
    def __init__(self, opt, device):
        
        super().__init__()

        self.opt = opt
        self.device = device

        self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir)

        # it's necessary to clean the mesh to facilitate later remeshing!
        vertices = self.mesh.v.detach().cpu().numpy()
        triangles = self.mesh.f.detach().cpu().numpy()
        vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False)
        self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
        self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)

        if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
            self.glctx = dr.RasterizeGLContext()
        else:
            self.glctx = dr.RasterizeCudaContext()
        
        # extract trainable parameters
        self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v))
        
        # texture
        if self.opt.tex_mode == 'hashgrid':
            self.encoder = HashGridEncoder().to(self.device)
        elif self.opt.tex_mode == 'mlp':
            self.encoder = FrequencyEncoder().to(self.device)
        elif self.opt.tex_mode == 'triplane':
            self.encoder = TriplaneEncoder().to(self.device)
        else:
            raise NotImplementedError(f"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}")
        
        self.mlp = MLP(self.encoder.output_dim, 3+2, 32, 2, bias=True).to(self.device)

        # env light
        if self.opt.env_texture is None:
            hdr_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../lights/mud_road_puresky_1k.hdr')
        else:
            hdr_path = self.opt.env_texture
        self.light = envlight.EnvLight(hdr_path, scale=self.opt.env_scale, device=self.device)

        FG_LUT = torch.from_numpy(np.fromfile(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../lights/bsdf_256_256.bin"), dtype=np.float32).reshape(1, 256, 256, 2)).to(self.device)
        self.register_buffer("FG_LUT", FG_LUT)

        # init hashgrid texture from mesh
        if self.opt.fit_tex:
            self.fit_texture_from_mesh(self.opt.fit_tex_iters)

    def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):
        return render_mesh(
            self.glctx, 
            self.mesh.v, self.mesh.f, self.mesh.vt, 
            self.mesh.ft, self.mesh.albedo, 
            self.mesh.vc, self.mesh.vn, self.mesh.fn, 
            pose, proj, h, w, 
            ssaa=ssaa, bg_color=bg_color,
        )
    
    def fit_texture_from_mesh(self, iters=512):
        # a small training loop...

        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.Adam([
            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
        ])

        resolution = 512

        print(f"[INFO] fitting texture...")
        pbar = tqdm.trange(iters)
        for i in pbar:

            ver = np.random.randint(-45, 45)
            hor = np.random.randint(-180, 180)
            
            pose = orbit_camera(ver, hor, self.opt.radius)
            proj = get_perspective(self.opt.fovy)

            image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image']
            image_pred = self.render(pose, proj, resolution, resolution)['image']

            loss = loss_fn(image_pred, image_mesh)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_description(f"MSE = {loss.item():.6f}")
        
        print(f"[INFO] finished fitting texture!")

    def get_params(self):

        params = [
            {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
            {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
        ]

        if not self.opt.fix_geo:
            params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr})

        return params

    @torch.no_grad()
    def export_mesh(self, save_path, texture_resolution=2048, padding=16):

        mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)
        print(f"[INFO] uv unwrapping...")
        mesh.auto_normal()
        mesh.auto_uv()

        # render uv maps
        h = w = texture_resolution
        uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
        uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]

        rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]

        # masked query 
        xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
        mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1]
        xyzs = xyzs.view(-1, 3)
        mask = (mask > 0).view(-1)
        
        material = torch.zeros(h * w, 5, device=self.device, dtype=torch.float32)

        if mask.any():
            print(f"[INFO] querying texture...")

            xyzs = xyzs[mask] # [M, 3]

            # batched inference to avoid OOM
            batch = []
            head = 0
            while head < xyzs.shape[0]:
                tail = min(head + 640000, xyzs.shape[0])
                batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
                head += 640000

            material[mask] = torch.cat(batch, dim=0)
        
        material = material.view(h, w, -1)
        mask = mask.view(h, w)

        print(f"[INFO] uv padding...")
        mater
Download .txt
gitextract_wp6txn7z/

├── .github/
│   └── workflows/
│       └── pypi-publish.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── data/
│   ├── car.glb
│   └── chair.ply
├── gradio_app.py
├── readme.md
├── scripts/
│   ├── run.sh
│   └── test_all.sh
├── setup.py
└── threefiner/
    ├── __init__.py
    ├── cli.py
    ├── gui.py
    ├── guidance/
    │   ├── __init__.py
    │   ├── if2_ism_utils.py
    │   ├── if2_nfsd_utils.py
    │   ├── if2_utils.py
    │   ├── if_utils.py
    │   ├── sd_ism_utils.py
    │   ├── sd_nfsd_utils.py
    │   ├── sd_utils.py
    │   └── sdcn_utils.py
    ├── lights/
    │   ├── LICENSE.txt
    │   └── mud_road_puresky_1k.hdr
    ├── nn.py
    ├── opt.py
    └── renderer/
        ├── __init__.py
        ├── diffmc_renderer.py
        ├── mesh_renderer.py
        ├── pbr_diffmc_renderer.py
        └── pbr_mesh_renderer.py
Download .txt
SYMBOL INDEX (133 symbols across 17 files)

FILE: gradio_app.py
  function process (line 27) | def process(input_model, input_text, input_dir, iters):

FILE: threefiner/cli.py
  function main (line 7) | def main():

FILE: threefiner/gui.py
  class GUI (line 22) | class GUI:
    method __init__ (line 23) | def __init__(self, opt: Options):
    method __del__ (line 85) | def __del__(self):
    method seed_everything (line 89) | def seed_everything(self):
    method prepare_train (line 102) | def prepare_train(self):
    method train_step (line 147) | def train_step(self):
    method test_step (line 260) | def test_step(self):
    method save_model (line 297) | def save_model(self, save_path=None):
    method register_dpg (line 321) | def register_dpg(self):
    method render (line 588) | def render(self):
    method train (line 598) | def train(self, iters=500):

FILE: threefiner/guidance/if2_ism_utils.py
  function invert_noise (line 16) | def invert_noise(scheduler, noisy_samples, noise, timesteps):
  class IF2 (line 34) | class IF2(nn.Module):
    method __init__ (line 35) | def __init__(
    method get_text_embeds (line 80) | def get_text_embeds(self, prompts, negative_prompts):
    method encode_text (line 94) | def encode_text(self, prompt):
    method train_step (line 101) | def train_step(
    method produce_imgs (line 218) | def produce_imgs(
    method prompt_to_img (line 271) | def prompt_to_img(

FILE: threefiner/guidance/if2_nfsd_utils.py
  class IF2 (line 16) | class IF2(nn.Module):
    method __init__ (line 17) | def __init__(
    method get_text_embeds (line 62) | def get_text_embeds(self, prompts, negative_prompts):
    method encode_text (line 76) | def encode_text(self, prompt):
    method train_step (line 83) | def train_step(
    method produce_imgs (line 177) | def produce_imgs(
    method prompt_to_img (line 230) | def prompt_to_img(

FILE: threefiner/guidance/if2_utils.py
  class IF2 (line 16) | class IF2(nn.Module):
    method __init__ (line 17) | def __init__(
    method get_text_embeds (line 64) | def get_text_embeds(self, prompts, negative_prompts):
    method encode_text (line 76) | def encode_text(self, prompt):
    method train_step (line 83) | def train_step(
    method produce_imgs (line 175) | def produce_imgs(
    method prompt_to_img (line 228) | def prompt_to_img(

FILE: threefiner/guidance/if_utils.py
  class IF (line 15) | class IF(nn.Module):
    method __init__ (line 16) | def __init__(
    method get_text_embeds (line 62) | def get_text_embeds(self, prompts, negative_prompts):
    method encode_text (line 74) | def encode_text(self, prompt):
    method refine (line 82) | def refine(self, pred_rgb,
    method train_step (line 112) | def train_step(
    method produce_imgs (line 194) | def produce_imgs(
    method prompt_to_img (line 242) | def prompt_to_img(

FILE: threefiner/guidance/sd_ism_utils.py
  function invert_noise (line 15) | def invert_noise(scheduler, noisy_samples, noise, timesteps):
  class StableDiffusion (line 33) | class StableDiffusion(nn.Module):
    method __init__ (line 34) | def __init__(
    method get_text_embeds (line 83) | def get_text_embeds(self, prompts, negative_prompts):
    method encode_text (line 97) | def encode_text(self, prompt):
    method refine (line 104) | def refine(self, pred_rgb,
    method train_step (line 134) | def train_step(
    method produce_latents (line 243) | def produce_latents(
    method decode_latents (line 286) | def decode_latents(self, latents):
    method encode_imgs (line 294) | def encode_imgs(self, imgs):
    method prompt_to_img (line 304) | def prompt_to_img(

FILE: threefiner/guidance/sd_nfsd_utils.py
  class StableDiffusion (line 15) | class StableDiffusion(nn.Module):
    method __init__ (line 16) | def __init__(
    method get_text_embeds (line 65) | def get_text_embeds(self, prompts, negative_prompts):
    method encode_text (line 79) | def encode_text(self, prompt):
    method refine (line 86) | def refine(self, pred_rgb,
    method train_step (line 116) | def train_step(
    method produce_latents (line 204) | def produce_latents(
    method decode_latents (line 247) | def decode_latents(self, latents):
    method encode_imgs (line 255) | def encode_imgs(self, imgs):
    method prompt_to_img (line 265) | def prompt_to_img(

FILE: threefiner/guidance/sd_utils.py
  class StableDiffusion (line 15) | class StableDiffusion(nn.Module):
    method __init__ (line 16) | def __init__(
    method get_text_embeds (line 65) | def get_text_embeds(self, prompts, negative_prompts):
    method encode_text (line 77) | def encode_text(self, prompt):
    method refine (line 84) | def refine(self, pred_rgb,
    method train_step (line 114) | def train_step(
    method produce_latents (line 203) | def produce_latents(
    method decode_latents (line 246) | def decode_latents(self, latents):
    method encode_imgs (line 254) | def encode_imgs(self, imgs):
    method prompt_to_img (line 264) | def prompt_to_img(

FILE: threefiner/guidance/sdcn_utils.py
  class StableDiffusionControlNet (line 14) | class StableDiffusionControlNet(nn.Module):
    method __init__ (line 15) | def __init__(
    method get_text_embeds (line 92) | def get_text_embeds(self, prompts, negative_prompts):
    method encode_text (line 104) | def encode_text(self, prompt):
    method refine (line 111) | def refine(self, pred_rgb,
    method train_step (line 172) | def train_step(
    method produce_latents (line 292) | def produce_latents(
    method decode_latents (line 367) | def decode_latents(self, latents):
    method encode_imgs (line 375) | def encode_imgs(self, imgs):
    method prompt_to_img (line 385) | def prompt_to_img(

FILE: threefiner/nn.py
  class HashGridEncoder (line 8) | class HashGridEncoder(nn.Module):
    method __init__ (line 9) | def __init__(self,
    method forward (line 35) | def forward(self, x, bound=1):
  class FrequencyEncoder (line 38) | class FrequencyEncoder(nn.Module):
    method __init__ (line 39) | def __init__(self,
    method forward (line 57) | def forward(self, x, **kwargs):
  class TriplaneEncoder (line 61) | class TriplaneEncoder(nn.Module):
    method __init__ (line 62) | def __init__(self,
    method forward (line 77) | def forward(self, x, bound=1):
  class MLP (line 92) | class MLP(nn.Module):
    method __init__ (line 93) | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
    method forward (line 106) | def forward(self, x):

FILE: threefiner/opt.py
  class Options (line 6) | class Options:
  function check_options (line 144) | def check_options(opt: Options):

FILE: threefiner/renderer/diffmc_renderer.py
  class Renderer (line 22) | class Renderer(nn.Module):
    method __init__ (line 23) | def __init__(self, opt, device):
    method render_mesh (line 95) | def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):
    method fit_texture_from_mesh (line 105) | def fit_texture_from_mesh(self, iters=512):
    method get_params (line 139) | def get_params(self):
    method export_mesh (line 153) | def export_mesh(self, save_path, texture_resolution=2048, padding=16):
    method render (line 223) | def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):

FILE: threefiner/renderer/mesh_renderer.py
  function render_mesh (line 17) | def render_mesh(
  class Renderer (line 104) | class Renderer(nn.Module):
    method __init__ (line 105) | def __init__(self, opt, device):
    method render_mesh (line 145) | def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):
    method fit_texture_from_mesh (line 155) | def fit_texture_from_mesh(self, iters=512):
    method get_params (line 189) | def get_params(self):
    method export_mesh (line 202) | def export_mesh(self, save_path, texture_resolution=2048, padding=16):
    method v (line 249) | def v(self):
    method f (line 256) | def f(self):
    method remesh (line 260) | def remesh(self):
    method render (line 271) | def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):

FILE: threefiner/renderer/pbr_diffmc_renderer.py
  class Renderer (line 24) | class Renderer(nn.Module):
    method __init__ (line 25) | def __init__(self, opt, device):
    method render_mesh (line 107) | def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):
    method fit_texture_from_mesh (line 117) | def fit_texture_from_mesh(self, iters=512):
    method get_params (line 151) | def get_params(self):
    method export_mesh (line 165) | def export_mesh(self, save_path, texture_resolution=2048, padding=16):
    method render (line 237) | def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):

FILE: threefiner/renderer/pbr_mesh_renderer.py
  class Renderer (line 21) | class Renderer(nn.Module):
    method __init__ (line 22) | def __init__(self, opt, device):
    method render_mesh (line 72) | def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):
    method fit_texture_from_mesh (line 82) | def fit_texture_from_mesh(self, iters=512):
    method get_params (line 116) | def get_params(self):
    method export_mesh (line 129) | def export_mesh(self, save_path, texture_resolution=2048, padding=16):
    method v (line 177) | def v(self):
    method f (line 184) | def f(self):
    method remesh (line 188) | def remesh(self):
    method render (line 198) | def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):
Condensed preview — 32 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (209K chars).
[
  {
    "path": ".github/workflows/pypi-publish.yml",
    "chars": 945,
    "preview": "name: Upload Python Package\n\non:\n  release:\n    types: [created]\n  workflow_dispatch:\n\njobs:\n  deploy:\n\n    runs-on: ubu"
  },
  {
    "path": ".gitignore",
    "chars": 62,
    "preview": "__pycache__\ntmp*\ndata_*\nlogs\nlogs*\nvideos*\n\n\n*.egg-info\nbuild/"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "MANIFEST.in",
    "chars": 37,
    "preview": "recursive-include threefiner/lights *"
  },
  {
    "path": "gradio_app.py",
    "chars": 2689,
    "preview": "import os\nimport tyro\nimport tqdm\nimport torch\nimport gradio as gr\n\nimport kiui\n\nfrom threefiner.opt import config_defau"
  },
  {
    "path": "readme.md",
    "chars": 4803,
    "preview": "<p align=\"center\">\n    <picture>\n    <img alt=\"logo\" src=\"assets/threefiner_icon.png\" width=\"20%\">\n    </picture>\n    </"
  },
  {
    "path": "scripts/run.sh",
    "chars": 571,
    "preview": "export CUDA_VISIBLE_DEVICES=0\n\n# the mesh is already with good initial texture, just refine it using IF2\nthreefiner if2 "
  },
  {
    "path": "scripts/test_all.sh",
    "chars": 1179,
    "preview": "export CUDA_VISIBLE_DEVICES=1\n\n# geom_mode\nthreefiner if2 --geom_mode diffmc --save car_diffmc.glb --mesh data/car.glb -"
  },
  {
    "path": "setup.py",
    "chars": 1359,
    "preview": "from setuptools import setup, find_packages\n\n\nsetup(\n  name = 'threefiner',\n  packages = find_packages(exclude=[]),\n  in"
  },
  {
    "path": "threefiner/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "threefiner/cli.py",
    "chars": 395,
    "preview": "import os\nimport tyro\nfrom threefiner.opt import config_defaults, config_doc, check_options\nfrom threefiner.gui import G"
  },
  {
    "path": "threefiner/gui.py",
    "chars": 22374,
    "preview": "import os\nimport tqdm\nimport random\nimport imageio\nimport numpy as np\n\nimport torch\nimport torch.nn.functional as F\n\nGUI"
  },
  {
    "path": "threefiner/guidance/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "threefiner/guidance/if2_ism_utils.py",
    "chars": 12846,
    "preview": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    IFPipeline,\n    IFSuperResolutionPipeline,\n)\n\n\n\nimport"
  },
  {
    "path": "threefiner/guidance/if2_nfsd_utils.py",
    "chars": 10773,
    "preview": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    IFPipeline,\n    IFSuperResolutionPipeline,\n)\n\n\n\nimport"
  },
  {
    "path": "threefiner/guidance/if2_utils.py",
    "chars": 10471,
    "preview": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    IFPipeline,\n    IFSuperResolutionPipeline,\n)\n\n\n\nimport"
  },
  {
    "path": "threefiner/guidance/if_utils.py",
    "chars": 10722,
    "preview": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    IFPipeline,\n)\n\n\n\nimport numpy as np\nimport torch\nimpor"
  },
  {
    "path": "threefiner/guidance/sd_ism_utils.py",
    "chars": 13299,
    "preview": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    StableDiffusionPipeline,\n)\n\n\n\nimport numpy as np\nimpor"
  },
  {
    "path": "threefiner/guidance/sd_nfsd_utils.py",
    "chars": 11462,
    "preview": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    StableDiffusionPipeline,\n)\n\n\n\nimport numpy as np\nimpor"
  },
  {
    "path": "threefiner/guidance/sd_utils.py",
    "chars": 11144,
    "preview": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    StableDiffusionPipeline,\n)\n\n\n\nimport numpy as np\nimpor"
  },
  {
    "path": "threefiner/guidance/sdcn_utils.py",
    "chars": 18209,
    "preview": "from diffusers import (\n    PNDMScheduler,\n    DDIMScheduler,\n    StableDiffusionPipeline,\n    ControlNetModel,\n)\n\nimpor"
  },
  {
    "path": "threefiner/lights/LICENSE.txt",
    "chars": 97,
    "preview": "The mud_road_puresky.hdr HDR probe is from https://polyhaven.com/a/mud_road_puresky\nCC0 License.\n"
  },
  {
    "path": "threefiner/nn.py",
    "chars": 3785,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport numpy as np\nimport tinycudann as tcnn\n\nclass "
  },
  {
    "path": "threefiner/opt.py",
    "chars": 4685,
    "preview": "import os\nfrom dataclasses import dataclass\nfrom typing import Tuple, Literal, Dict, Optional\n\n@dataclass\nclass Options:"
  },
  {
    "path": "threefiner/renderer/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "threefiner/renderer/diffmc_renderer.py",
    "chars": 11190,
    "preview": "import os\nimport tqdm\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport nvd"
  },
  {
    "path": "threefiner/renderer/mesh_renderer.py",
    "chars": 12616,
    "preview": "import os\nimport tqdm\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport nvd"
  },
  {
    "path": "threefiner/renderer/pbr_diffmc_renderer.py",
    "chars": 12940,
    "preview": "import os\nimport tqdm\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport env"
  },
  {
    "path": "threefiner/renderer/pbr_mesh_renderer.py",
    "chars": 11482,
    "preview": "import os\nimport tqdm\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport env"
  }
]

// ... and 3 more files (download for full content)

About this extraction

This page contains the full source code of the 3DTopia/threefiner GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 32 files (196.8 KB), approximately 50.2k tokens, and a symbol index with 133 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!