Repository: 3DTopia/threefiner Branch: main Commit: 6c34f089e61a Files: 32 Total size: 196.8 KB Directory structure: gitextract_wp6txn7z/ ├── .github/ │ └── workflows/ │ └── pypi-publish.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── data/ │ ├── car.glb │ └── chair.ply ├── gradio_app.py ├── readme.md ├── scripts/ │ ├── run.sh │ └── test_all.sh ├── setup.py └── threefiner/ ├── __init__.py ├── cli.py ├── gui.py ├── guidance/ │ ├── __init__.py │ ├── if2_ism_utils.py │ ├── if2_nfsd_utils.py │ ├── if2_utils.py │ ├── if_utils.py │ ├── sd_ism_utils.py │ ├── sd_nfsd_utils.py │ ├── sd_utils.py │ └── sdcn_utils.py ├── lights/ │ ├── LICENSE.txt │ └── mud_road_puresky_1k.hdr ├── nn.py ├── opt.py └── renderer/ ├── __init__.py ├── diffmc_renderer.py ├── mesh_renderer.py ├── pbr_diffmc_renderer.py └── pbr_mesh_renderer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/pypi-publish.yml ================================================ name: Upload Python Package on: release: types: [created] workflow_dispatch: jobs: deploy: runs-on: ubuntu-latest environment: name: pypi url: https://pypi.org/project/threefiner/ permissions: id-token: write # IMPORTANT: this permission is mandatory for trusted publishing steps: - uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v3 with: python-version: '3.10' # prepare distributions in dist/ - name: Install dependencies and Build run: | python -m pip install --upgrade pip pip install setuptools wheel python setup.py sdist bdist_wheel # publish by trusted publishers (need to first setup in pypi.org projects-manage-publishing!) # ref: https://github.com/marketplace/actions/pypi-publish - name: Publish package distributions to PyPI uses: pypa/gh-action-pypi-publish@release/v1 ================================================ FILE: .gitignore ================================================ __pycache__ tmp* data_* logs logs* videos* *.egg-info build/ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ recursive-include threefiner/lights * ================================================ FILE: gradio_app.py ================================================ import os import tyro import tqdm import torch import gradio as gr import kiui from threefiner.opt import config_defaults, config_doc, check_options from threefiner.gui import GUI GRADIO_SAVE_PATH_MESH = 'gradio_output.glb' GRADIO_SAVE_PATH_VIDEO = 'gradio_output.mp4' opt = tyro.cli(tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc)) # hacks for not loading mesh at initialization opt.save = GRADIO_SAVE_PATH_MESH opt.prompt = '' opt.text_dir = True opt.front_dir = '+z' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') gui = GUI(opt) # process function def process(input_model, input_text, input_dir, iters): # set front facing direction (map from gradio model3D's mysterious coordinate system to OpenGL...) opt.text_dir = True if input_dir == 'front': opt.front_dir = '-z' elif input_dir == 'back': opt.front_dir = '+z' elif input_dir == 'left': opt.front_dir = '+x' elif input_dir == 'right': opt.front_dir = '-x' elif input_dir == 'up': opt.front_dir = '+y' elif input_dir == 'down': opt.front_dir = '-y' else: # turn off text_dir opt.text_dir = False opt.front_dir = '+z' # set mesh path opt.mesh = input_model # load mesh! gui.renderer = gui.renderer_class(opt, device).to(device) # set prompt gui.prompt = opt.positive_prompt + ', ' + input_text # train gui.prepare_train() # update optimizer and prompt embeddings for i in tqdm.trange(iters): gui.train_step() # save mesh & video gui.save_model(GRADIO_SAVE_PATH_MESH) gui.save_model(GRADIO_SAVE_PATH_VIDEO) # return 3d model & video return GRADIO_SAVE_PATH_MESH, GRADIO_SAVE_PATH_VIDEO # gradio UI block = gr.Blocks().queue() with block: gr.Markdown(""" ## Threefiner: Text-guided mesh refinement. """) with gr.Row(variant='panel'): with gr.Column(scale=1): input_model = gr.Model3D(label="input mesh") input_text = gr.Text(label="prompt") input_dir = gr.Radio(['front', 'back', 'left', 'right', 'up', 'down'], label="front-facing direction") iters = gr.Slider(minimum=100, maximum=1000, step=100, value=400, label="training iterations") button_gen = gr.Button("Refine!") with gr.Column(scale=1): output_model = gr.Model3D(label="output mesh") output_video = gr.Video(label="output video") button_gen.click(process, inputs=[input_model, input_text, input_dir, iters], outputs=[output_model, output_video]) block.launch(server_name="0.0.0.0", share=True) ================================================ FILE: readme.md ================================================

logo
Threefiner

An interface for text-guided mesh refinement. https://github.com/3DTopia/threefiner/assets/25863658/a4abe725-b542-4a4a-a6d4-e4c4821f7d96 ### Features * **Mesh in, mesh out**: we support `ply` with vertex colors, `obj`, and single object `glb/gltf` with textures! * **Easy to use**: both a CLI and a GUI is available. * **Performant**: Refine your texture in 1 minute with Deepfloyd-IF-II. ### Install We rely on `torch` and several CUDA extensions, please make sure you install them correctly first! ```bash # tiny-cuda-nn pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch # nvdiffrast pip install git+https://github.com/NVlabs/nvdiffrast # [optional, will use pysdf if unavailable] cubvh: pip install git+https://github.com/ashawkey/cubvh ``` To use [Deepfloyd-IF](https://github.com/deep-floyd/IF) models, please log in to your huggingface and accept the [license](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0). To install this package: ```bash # install from pypi pip install threefiner # install from github pip install git+https://github.com/3DTopia/threefiner # local install git clone https://github.com/3DTopia/threefiner cd threefiner pip install . ``` ### Usage ```bash ### command line interface threefiner --help # this is short for python -m threefiner.cli --help ### refine a coarse mesh ('input.obj') using Stable-diffusion and save to 'logs/hamburger.glb' threefiner sd --mesh input.obj --prompt 'a hamburger' --outdir logs --save hamburger.glb ### if the initial texture is good, we recommend using IF2 for refinement. # by default, it will save to './name_fine.glb' threefiner if2 --mesh name.glb --prompt 'description' ### if the initial texture is not good, we recommend using SD or IF first. threefiner sd --mesh name.glb --prompt 'description' threefiner if --mesh name.glb --prompt 'description' ### if the initial geometry is good, you can fix the geometry. threefiner sd_fixgeo --mesh name.glb --prompt 'description' threefiner if_fixgeo --mesh name.glb --prompt 'description' threefiner if2_fixgeo --mesh name.glb --prompt 'description' ### advanced # directional text prompt (append front/side/back view in text prompt) # you need to know the mesh's front facing direction and specify it by '--front_dir' # we use the OpenGL coordinate system, i.e., +x is right, +y is up, +z is front (more details: https://kit.kiui.moe/camera/) # clock-wise rotation can be specified per 90 degree, e.g., +z1, -y2 threefiner if2 --mesh input.glb --prompt 'description' --text_dir --front_dir='+z' # adjust training iterations threefiner if2 --mesh input.glb --prompt 'description' --iters 1000 # explicitly fix the geometry and only refine texture threefiner if2 --fix-geo --geom_mode mesh --mesh input.glb --prompt 'description' # equals if2_fixgeo # open a GUI to visualize the training progress (needs a desktop) threefiner if2 --mesh input.glb --prompt 'description' --gui ``` Gradio demo: ```bash # requires gradio 4 python gradio_app.py if2 ``` For more examples, please see [scripts](./scripts/). ### Q&A * **How to make sure `--front_dir` for your model?** You may first visualize it in a 3D viewer that follows OpenGL coordinate system:

example_front_dir

The chair is facing down the Y axis (Green), so we can use `--front_dir="-y"` to rectify it to face +Z axis (Blue). * **fatal error: EGL/egl.h: No such file or directory** By default, we use the OpenGL rasterizer. This error means there is no OpenGL installation, which is often the case for headless servers. It's recommended to install OpenGL (along with NVIDIA driver) as it brings better performance. Otherwise, you can append `--force_cuda_rast` to use the CUDA rasterizer instead. ## Acknowledgement This work is built on many amazing research works and open-source projects, thanks a lot to all the authors for sharing! - SDS `guidance` classes are based on [diffusers](https://github.com/huggingface/diffusers). - `diffmc` geometry is based on [diso](https://github.com/SarahWeiii/diso). - `mesh` geometry is based on [nerf2mesh](https://github.com/ashawkey/nerf2mesh). - Texture encoding is based on [tinycudann](https://github.com/NVlabs/tiny-cuda-nn). - Mesh renderer is based on [nvdiffrast](https://github.com/NVlabs/nvdiffrast). - GUI is based on [dearpygui](https://github.com/hoffstadt/DearPyGui). - The coarse models used in demo are generated by [Genie](https://lumalabs.ai/genie?view=create) and [3DTopia](https://github.com/3DTopia/3DTopia). ================================================ FILE: scripts/run.sh ================================================ export CUDA_VISIBLE_DEVICES=0 # the mesh is already with good initial texture, just refine it using IF2 threefiner if2 --mesh data/car.glb --prompt 'a red car' --outdir logs --save car_fine.glb --text_dir --front_dir='+x' # the mesh is coarse, using SD for diverse texture generation and IF2 for refinement threefiner sd --mesh data/chair.ply --prompt 'a swivel chair' --outdir logs --save chair_coarse.glb --text_dir --front_dir='-y' threefiner if2 --mesh logs/chair_coarse.glb --prompt 'a swivel chair' --outdir logs --save chair_fine.glb --text_dir --front_dir='+z' ================================================ FILE: scripts/test_all.sh ================================================ export CUDA_VISIBLE_DEVICES=1 # geom_mode threefiner if2 --geom_mode diffmc --save car_diffmc.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x' threefiner if2 --geom_mode mesh --save car_mesh.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x' threefiner if2 --geom_mode pbr_diffmc --save car_pbr_diffmc.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x' threefiner if2 --geom_mode pbr_mesh --save car_pbr_mesh.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x' # tex_mode threefiner if2 --tex_mode mlp --save car_mlp.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x' threefiner if2 --tex_mode triplane --save car_triplane.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x' # guidance mode threefiner sd --save car_SD.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x' threefiner if --save car_IF.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x' ================================================ FILE: setup.py ================================================ from setuptools import setup, find_packages setup( name = 'threefiner', packages = find_packages(exclude=[]), include_package_data = True, entry_points={ # CLI tools 'console_scripts': [ 'threefiner = threefiner.cli:main' ], }, version = '0.1.2', license='MIT', description = 'Threefiner: a text-guided mesh refiner', author = 'kiui', author_email = 'ashawkey1999@gmail.com', long_description=open("readme.md", encoding="utf-8").read(), long_description_content_type = 'text/markdown', url = 'https://github.com/3DTopia/threefiner', keywords = [ 'generative mesh refinement', ], install_requires=[ 'tyro', 'tqdm', 'rich', 'ninja', 'numpy', 'pandas', 'matplotlib', 'opencv-python', 'imageio', 'imageio-ffmpeg', 'scipy', 'scikit-learn', 'torch', 'einops', 'huggingface_hub', 'diffusers', 'accelerate', 'transformers', "sentencepiece", # required by deepfloyd-if T5 encoder 'plyfile', 'pygltflib', 'xatlas', 'trimesh', 'PyMCubes', 'pymeshlab', "pysdf", "diso", "envlight", 'dearpygui', 'kiui >= 0.2.1', ], classifiers=[ 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'License :: OSI Approved :: MIT License', 'Programming Language :: Python :: 3', ], ) ================================================ FILE: threefiner/__init__.py ================================================ ================================================ FILE: threefiner/cli.py ================================================ import os import tyro from threefiner.opt import config_defaults, config_doc, check_options from threefiner.gui import GUI def main(): opt = tyro.cli(tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc)) opt = check_options(opt) gui = GUI(opt) if gui.gui: gui.render() else: gui.train(opt.iters) if __name__ == "__main__": main() ================================================ FILE: threefiner/gui.py ================================================ import os import tqdm import random import imageio import numpy as np import torch import torch.nn.functional as F GUI_AVAILABLE = True try: import dearpygui.dearpygui as dpg except Exception as e: GUI_AVAILABLE = False import kiui from kiui.cam import orbit_camera, OrbitCamera from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency from threefiner.opt import Options class GUI: def __init__(self, opt: Options): self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. if not GUI_AVAILABLE and opt.gui: print(f'[WARN] cannot import dearpygui, assume without --gui') self.gui = opt.gui and GUI_AVAILABLE # enable gui self.W = opt.W self.H = opt.H self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) self.mode = "image" self.seed = "random" self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32) self.need_update = True # update buffer_image self.save_path = os.path.join(self.opt.outdir, self.opt.save) os.makedirs(self.opt.outdir, exist_ok=True) # models self.device = torch.device("cuda") self.guidance = None # renderer if self.opt.geom_mode == 'mesh': from threefiner.renderer.mesh_renderer import Renderer elif self.opt.geom_mode == 'diffmc': from threefiner.renderer.diffmc_renderer import Renderer elif self.opt.geom_mode == 'pbr_mesh': from threefiner.renderer.pbr_mesh_renderer import Renderer elif self.opt.geom_mode == 'pbr_diffmc': from threefiner.renderer.pbr_diffmc_renderer import Renderer else: raise NotImplementedError(f"unknown geometry mode: {self.opt.geom_mode}") self.renderer_class = Renderer if self.opt.mesh is None: self.renderer = None else: self.renderer = Renderer(opt, self.device).to(self.device) # input prompt self.prompt = self.opt.prompt self.negative_prompt = "" if self.opt.positive_prompt is not None: self.prompt = self.opt.positive_prompt + ', ' + self.prompt if self.opt.negative_prompt is not None: self.negative_prompt = self.opt.negative_prompt # training stuff self.training = False self.optimizer = None self.step = 0 self.train_steps = 1 # steps per rendering loop if self.gui: dpg.create_context() self.register_dpg() self.test_step() def __del__(self): if self.gui: dpg.destroy_context() def seed_everything(self): try: seed = int(self.seed) except: seed = np.random.randint(0, 1000000) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) self.last_seed = seed def prepare_train(self): assert self.renderer is not None, 'no mesh loaded!' self.step = 0 # setup training self.optimizer = torch.optim.Adam(self.renderer.get_params()) # lazy load guidance model if self.guidance is None: print(f"[INFO] loading guidance...") if self.opt.mode == 'SD': from threefiner.guidance.sd_utils import StableDiffusion self.guidance = StableDiffusion(self.device, vram_O=self.opt.vram_O) elif self.opt.mode == 'SD_NFSD': from threefiner.guidance.sd_nfsd_utils import StableDiffusion self.guidance = StableDiffusion(self.device, vram_O=self.opt.vram_O) elif self.opt.mode == 'SDCN': from threefiner.guidance.sdcn_utils import StableDiffusionControlNet self.guidance = StableDiffusionControlNet(self.device, vram_O=self.opt.vram_O) elif self.opt.mode == 'IF': from threefiner.guidance.if_utils import IF self.guidance = IF(self.device, vram_O=self.opt.vram_O) elif self.opt.mode == 'IF2': from threefiner.guidance.if2_utils import IF2 self.guidance = IF2(self.device, vram_O=self.opt.vram_O) elif self.opt.mode == 'IF2_NFSD': from threefiner.guidance.if2_nfsd_utils import IF2 self.guidance = IF2(self.device, vram_O=self.opt.vram_O) elif self.opt.mode == 'SD_ISM': from threefiner.guidance.sd_ism_utils import StableDiffusion self.guidance = StableDiffusion(self.device, vram_O=self.opt.vram_O) elif self.opt.mode == 'IF2_ISM': from threefiner.guidance.if2_ism_utils import IF2 self.guidance = IF2(self.device, vram_O=self.opt.vram_O) else: raise NotImplementedError(f"unknown guidance mode {self.opt.mode}!") print(f"[INFO] loaded guidance!") # prepare embeddings with torch.no_grad(): self.guidance.get_text_embeds([self.prompt], [self.negative_prompt]) def train_step(self): starter = torch.cuda.Event(enable_timing=True) ender = torch.cuda.Event(enable_timing=True) starter.record() self.renderer.train() for _ in range(self.train_steps): self.step += 1 step_ratio = min(1, self.step / self.opt.iters) loss = 0 ### novel view (manual batch) images = [] poses = [] normals = [] ori_images = [] vers, hors, radii = [], [], [] for _ in range(self.opt.batch_size): # render random view ver = np.random.randint(-60, 30) hor = np.random.randint(-180, 180) radius = np.random.uniform() - 0.5 # [-0.5, 0.5] pose = orbit_camera(ver, hor, self.opt.radius + radius) vers.append(ver) hors.append(hor) radii.append(radius) poses.append(pose) # random render resolution ssaa = min(2.0, max(0.125, 2 * np.random.random())) out = self.renderer.render(pose, self.cam.perspective, self.opt.render_resolution, self.opt.render_resolution, ssaa=ssaa) image = out["image"] # [H, W, 3] in [0, 1] image = image.permute(2,0,1).contiguous().unsqueeze(0) # [1, 3, H, W] in [0, 1] images.append(image) # mix_normal if not self.opt.fix_geo and self.opt.mix_normal: normal = out['normal'] normal = normal.permute(2,0,1).contiguous().unsqueeze(0) normals.append(normal) # IF SR model requires the original rendering if self.opt.mode in ['IF2', 'IF2_NFSD', 'IF2_ISM', 'SDCN']: out_mesh = self.renderer.render_mesh(pose, self.cam.perspective, self.opt.render_resolution, self.opt.render_resolution, ssaa=1) ori_image = out_mesh["image"] # [H, W, 3] in [0, 1] ori_image = ori_image.permute(2,0,1).contiguous().unsqueeze(0) ori_images.append(ori_image) # ori_images.append(image.clone()) # guidance loss guidance_input = {'pred_rgb': torch.cat(images, dim=0)} if not self.opt.fix_geo and self.opt.mix_normal: if random.random() > 0.5: ratio = random.random() guidance_input['pred_rgb'] = guidance_input['pred_rgb'] * ratio + torch.cat(normals, dim=0) * (1 - ratio) # guidance_input['step_ratio'] = step_ratio if self.opt.mode in ['IF2', 'IF2_NFSD', 'IF2_ISM']: guidance_input['ori_rgb'] = torch.cat(ori_images, dim=0) if self.opt.mode == 'SDCN': guidance_input['control_images'] = {'tile': torch.cat(ori_images, dim=0)} if self.opt.text_dir: guidance_input['vers'] = vers guidance_input['hors'] = hors loss = loss + self.opt.lambda_sd * self.guidance.train_step(**guidance_input) # geom regularizations if self.opt.geom_mode in ['diffmc', 'pbr_diffmc', 'mesh', 'pbr_mesh'] and not self.opt.fix_geo: if self.opt.lambda_lap > 0: lap_loss = laplacian_smooth_loss(self.renderer.v, self.renderer.f) loss = loss + self.opt.lambda_lap * lap_loss if self.opt.lambda_normal > 0: normal_loss = normal_consistency(self.renderer.v, self.renderer.f) loss = loss + self.opt.lambda_normal * normal_loss if self.opt.geom_mode in ['mesh', 'pbr_mesh'] and self.opt.lambda_offsets > 0: offset_loss = (self.renderer.v_offsets ** 2).sum(-1).mean() loss = loss + self.opt.lambda_offsets * offset_loss # optimize step loss.backward() self.optimizer.step() self.optimizer.zero_grad() # for mesh geom_mode: peoriodically remesh if self.opt.geom_mode in ['mesh', 'pbr_mesh'] and not self.opt.fix_geo: if self.step > 0 and self.step % self.opt.remesh_interval == 0: self.renderer.remesh() # reset optimizer self.optimizer = torch.optim.Adam(self.renderer.get_params()) ender.record() torch.cuda.synchronize() t = starter.elapsed_time(ender) self.need_update = True if self.gui: dpg.set_value("_log_train_time", f"{t:.4f}ms") dpg.set_value( "_log_train_log", f"step = {self.step: 5d} (+{self.train_steps: 2d}) loss = {loss.item():.4f}", ) @torch.no_grad() def test_step(self): # ignore if no need to update if not self.need_update: return starter = torch.cuda.Event(enable_timing=True) ender = torch.cuda.Event(enable_timing=True) starter.record() # should update image if self.need_update: # render image self.renderer.eval() out = self.renderer.render(self.cam.pose, self.cam.perspective, self.H, self.W) buffer_image = out[self.mode] # [H, W, 3] if self.mode in ['depth', 'alpha']: buffer_image = buffer_image.repeat(1, 1, 3) if self.mode == 'depth': buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20) self.buffer_image = buffer_image.contiguous().clamp(0, 1).detach().cpu().numpy() self.need_update = False ender.record() torch.cuda.synchronize() t = starter.elapsed_time(ender) if self.gui: dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)") dpg.set_value( "_texture", self.buffer_image ) # buffer must be contiguous, else seg fault! def save_model(self, save_path=None): if save_path is None: save_path = self.save_path # export video if save_path.endswith(".mp4"): images = [] elevation = 0 azimuth = np.arange(0, 360, 3, dtype=np.int32) # front-->back-->front for azi in tqdm.tqdm(azimuth): pose = orbit_camera(elevation, azi, self.opt.radius) out = self.renderer.render(pose, self.cam.perspective, self.opt.render_resolution, self.opt.render_resolution, ssaa=1) image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8) images.append(image) images = np.stack(images, axis=0) # ~4 seconds, 120 frames at 30 fps imageio.mimwrite(save_path, images, fps=30, quality=8, macro_block_size=1) # export mesh else: self.renderer.export_mesh(save_path, texture_resolution=self.opt.texture_resolution) print(f"[INFO] save model to {save_path}.") def register_dpg(self): ### register texture with dpg.texture_registry(show=False): dpg.add_raw_texture( self.W, self.H, self.buffer_image, format=dpg.mvFormat_Float_rgb, tag="_texture", ) ### register window # the rendered image, as the primary window with dpg.window( tag="_primary_window", width=self.W, height=self.H, pos=[0, 0], no_move=True, no_title_bar=True, no_scrollbar=True, ): # add the texture dpg.add_image("_texture") # dpg.set_primary_window("_primary_window", True) # control window with dpg.window( label="Control", tag="_control_window", width=600, height=self.H, pos=[self.W, 0], no_move=True, no_title_bar=True, ): # button theme with dpg.theme() as theme_button: with dpg.theme_component(dpg.mvButton): dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) # timer stuff with dpg.group(horizontal=True): dpg.add_text("Infer time: ") dpg.add_text("no data", tag="_log_infer_time") def callback_setattr(sender, app_data, user_data): setattr(self, user_data, app_data) # init stuff with dpg.collapsing_header(label="Initialize", default_open=True): # seed stuff def callback_set_seed(sender, app_data): self.seed = app_data self.seed_everything() dpg.add_input_text( label="seed", default_value=self.seed, on_enter=True, callback=callback_set_seed, ) # input stuff def callback_select_input(sender, app_data): # only one item for k, v in app_data["selections"].items(): dpg.set_value("_log_input", k) self.load_input(v) self.need_update = True with dpg.file_dialog( directory_selector=False, show=False, callback=callback_select_input, file_count=1, tag="file_dialog_tag", width=700, height=400, ): dpg.add_file_extension("Images{.jpg,.jpeg,.png}") with dpg.group(horizontal=True): dpg.add_button( label="input", callback=lambda: dpg.show_item("file_dialog_tag"), ) dpg.add_text("", tag="_log_input") # prompt stuff dpg.add_input_text( label="prompt", default_value=self.prompt, callback=callback_setattr, user_data="prompt", ) dpg.add_input_text( label="negative", default_value=self.negative_prompt, callback=callback_setattr, user_data="negative_prompt", ) # save current model with dpg.group(horizontal=True): dpg.add_text("Save: ") dpg.add_button( label="model", tag="_button_save_model", callback=self.save_model, ) dpg.bind_item_theme("_button_save_model", theme_button) dpg.add_input_text( label="", default_value=self.save_path, callback=callback_setattr, user_data="save_path", ) # training stuff with dpg.collapsing_header(label="Train", default_open=True): # lr and train button with dpg.group(horizontal=True): dpg.add_text("Train: ") def callback_train(sender, app_data): if self.training: self.training = False dpg.configure_item("_button_train", label="start") else: self.prepare_train() self.training = True dpg.configure_item("_button_train", label="stop") # dpg.add_button( # label="init", tag="_button_init", callback=self.prepare_train # ) # dpg.bind_item_theme("_button_init", theme_button) dpg.add_button( label="start", tag="_button_train", callback=callback_train ) dpg.bind_item_theme("_button_train", theme_button) with dpg.group(horizontal=True): dpg.add_text("", tag="_log_train_time") dpg.add_text("", tag="_log_train_log") # rendering options with dpg.collapsing_header(label="Rendering", default_open=True): # mode combo def callback_change_mode(sender, app_data): self.mode = app_data self.need_update = True dpg.add_combo( ("image", "depth", "alpha", "normal"), label="mode", default_value=self.mode, callback=callback_change_mode, ) # fov slider def callback_set_fovy(sender, app_data): self.cam.fovy = np.deg2rad(app_data) self.need_update = True dpg.add_slider_int( label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=np.rad2deg(self.cam.fovy), callback=callback_set_fovy, ) ### register camera handler def callback_camera_drag_rotate_or_draw_mask(sender, app_data): if not dpg.is_item_focused("_primary_window"): return dx = app_data[1] dy = app_data[2] self.cam.orbit(dx, dy) self.need_update = True def callback_camera_wheel_scale(sender, app_data): if not dpg.is_item_focused("_primary_window"): return delta = app_data self.cam.scale(delta) self.need_update = True def callback_camera_drag_pan(sender, app_data): if not dpg.is_item_focused("_primary_window"): return dx = app_data[1] dy = app_data[2] self.cam.pan(dx, dy) self.need_update = True with dpg.handler_registry(): # for camera moving dpg.add_mouse_drag_handler( button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate_or_draw_mask, ) dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) dpg.add_mouse_drag_handler( button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan ) dpg.create_viewport( title="Threefiner", width=self.W + 600, height=self.H + (45 if os.name == "nt" else 0), resizable=False, ) ### global theme with dpg.theme() as theme_no_padding: with dpg.theme_component(dpg.mvAll): # set all padding to 0 to avoid scroll bar dpg.add_theme_style( dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core ) dpg.add_theme_style( dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core ) dpg.add_theme_style( dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core ) dpg.bind_item_theme("_primary_window", theme_no_padding) dpg.setup_dearpygui() ### register a larger font # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf if os.path.exists("LXGWWenKai-Regular.ttf"): with dpg.font_registry(): with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font: dpg.bind_font(default_font) # dpg.show_metrics() dpg.show_viewport() def render(self): assert self.gui while dpg.is_dearpygui_running(): # update texture every frame if self.training: self.train_step() self.test_step() dpg.render_dearpygui_frame() # no gui mode def train(self, iters=500): if iters > 0: self.prepare_train() for i in tqdm.trange(iters): self.train_step() # save self.save_model() ================================================ FILE: threefiner/guidance/__init__.py ================================================ ================================================ FILE: threefiner/guidance/if2_ism_utils.py ================================================ from diffusers import ( PNDMScheduler, DDIMScheduler, IFPipeline, IFSuperResolutionPipeline, ) import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def invert_noise(scheduler, noisy_samples, noise, timesteps): alphas_cumprod = scheduler.alphas_cumprod.to(device=noisy_samples.device, dtype=noisy_samples.dtype) timesteps = timesteps.to(noisy_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(noisy_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(noisy_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) original_samples = 1 / sqrt_alpha_prod * (noisy_samples - sqrt_one_minus_alpha_prod * noise) return original_samples class IF2(nn.Module): def __init__( self, device, fp16=True, vram_O=False, model_key = "DeepFloyd/IF-II-M-v1.0", t_range=[0.02, 0.50], ): super().__init__() self.device = device self.model_key = model_key self.dtype = torch.float16 if fp16 else torch.float32 # Create model pipe = IFSuperResolutionPipeline.from_pretrained( model_key, variant="fp16", torch_dtype=torch.float16, watermarker=None, safety_checker=None, requires_safety_checker=False, ) if vram_O: pipe.unet.to(memory_format=torch.channels_last) pipe.enable_attention_slicing(1) # pipe.enable_model_cpu_offload() else: pipe.to(device) self.unet = pipe.unet self.tokenizer = pipe.tokenizer self.text_encoder = pipe.text_encoder self.scheduler = pipe.scheduler self.image_noising_scheduler = pipe.image_noising_scheduler self.pipe = pipe self.num_train_timesteps = self.scheduler.config.num_train_timesteps self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience self.embeddings = {} @torch.no_grad() def get_text_embeds(self, prompts, negative_prompts): pos_embeds = self.encode_text(prompts) # [1, 77, 768] neg_embeds = self.encode_text(negative_prompts) null_embeds = self.encode_text([""]) self.embeddings['pos'] = pos_embeds self.embeddings['neg'] = neg_embeds self.embeddings['null'] = null_embeds # directional embeddings for d in ['front', 'side', 'back']: embeds = self.encode_text([f'{p}, {d} view' for p in prompts]) self.embeddings[d] = embeds def encode_text(self, prompt): # prompt: [str] prompt = self.pipe._text_preprocessing(prompt, clean_caption=False) inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] return embeddings def train_step( self, pred_rgb, ori_rgb, step_ratio=None, guidance_scale=5, vers=None, hors=None, delta_t=50, delta_s=200, ): batch_size = pred_rgb.shape[0] pred_rgb = pred_rgb.to(self.dtype) ori_rgb = ori_rgb.to(self.dtype) images = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) * 2 - 1 with torch.no_grad(): max_t = torch.full((batch_size,), self.max_step, dtype=torch.long, device=self.device) # images_upscaled = images.clone() images_upscaled = F.interpolate(ori_rgb, (256, 256), mode="bilinear", align_corners=False).clamp(0, 1) * 2 - 1 noise = torch.randn_like(images_upscaled) images_upscaled = self.image_noising_scheduler.add_noise(images_upscaled, noise, max_t) if step_ratio is not None: # dreamtime-like # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) else: t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) # w(t), sigma_t^2 w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1).to(self.dtype) ######### debug # imagesx = self.produce_imgs( # images_upscaled=images_upscaled, # max_t=max_t, # images=torch.randn_like(images), # num_inference_steps=50, # guidance_scale=4.0, # ) # [1, 3, 64, 64] # import kiui # kiui.vis.plot_image(images_upscaled * 0.5 + 0.5) # kiui.vis.plot_image(imagesx * 0.5 + 0.5) ######### null_embeddings = self.embeddings['null'].expand(batch_size, -1, -1) if hors is None: embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) else: def _get_dir_ind(h): if abs(h) < 60: return 'front' elif abs(h) < 120: return 'side' else: return 'back' embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)]) ########### ISM # steps t = t.clamp(min=delta_t) s = t - delta_t n = s // delta_s r = s % delta_s # construct trajectory images_noisy = images.clone() cur_t = torch.full((batch_size,), 0, dtype=torch.long, device=self.device) noise = self.unet(torch.cat([images_noisy, images_upscaled], dim=1), cur_t, encoder_hidden_states=null_embeddings, class_labels=max_t).sample.split(images_noisy.shape[1], dim=1)[0] images_original = invert_noise(self.scheduler, images_noisy, noise, cur_t) cur_t += r images_noisy = self.scheduler.add_noise(images_original, noise, cur_t) for i in range(n): noise = self.unet(torch.cat([images_noisy, images_upscaled], dim=1), cur_t, encoder_hidden_states=null_embeddings, class_labels=max_t).sample.split(images_noisy.shape[1], dim=1)[0] images_original = invert_noise(self.scheduler, images_noisy, noise, cur_t) cur_t += delta_s images_noisy = self.scheduler.add_noise(images_original, noise, cur_t) # x_s # construct last step noise = self.unet(torch.cat([images_noisy, images_upscaled], dim=1), cur_t, encoder_hidden_states=null_embeddings, class_labels=max_t).sample.split(images_noisy.shape[1], dim=1)[0] images_original = invert_noise(self.scheduler, images_noisy, noise, cur_t) # \hat x_0^s # perform guidance images_noisy = self.scheduler.add_noise(images_original, noise, t) model_input = torch.cat([images_noisy, images_upscaled], dim=1) model_input = torch.cat([model_input] * 2) model_input = self.scheduler.scale_model_input(model_input, t) tt = torch.cat([t] * 2) max_tt = torch.cat([max_t] * 2) noise_pred = self.unet( model_input, tt, encoder_hidden_states=embeddings, class_labels=max_tt, ).sample # perform guidance (high scale from paper!) noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) grad = w * (noise_pred - noise) grad = torch.nan_to_num(grad) # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8 # grad = grad_norm.clamp(max=0.1) * grad / grad_norm target = (images - grad).detach() loss = 0.5 * F.mse_loss(images, target, reduction='sum') / images.shape[0] return loss @torch.no_grad() def produce_imgs( self, images_upscaled, max_t, height=256, width=256, num_inference_steps=50, guidance_scale=4.0, images=None, ): if images is None: images = torch.randn( ( 1, self.unet.in_channels, height, width, ), device=self.device, ) batch_size = images.shape[0] self.scheduler.set_timesteps(num_inference_steps) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps): # expand the images if we are doing classifier-free guidance to avoid doing two forward passes. model_input = torch.cat([images, images_upscaled], dim=1) model_input = torch.cat([model_input] * 2) model_input = self.scheduler.scale_model_input(model_input, t) max_tt = torch.cat([max_t] * 2) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=embeddings, class_labels=max_tt, ).sample # perform guidance noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) # compute the previous noisy sample x_t -> x_t-1 images = self.scheduler.step(noise_pred, t, images).prev_sample return images def prompt_to_img( self, images_upscaled, max_t, prompts, negative_prompts="", height=256, width=256, num_inference_steps=50, guidance_scale=4.0, images=None, ): if isinstance(prompts, str): prompts = [prompts] if isinstance(negative_prompts, str): negative_prompts = [negative_prompts] # Prompts -> text embeds self.get_text_embeds(prompts, negative_prompts) # Text embeds -> img images images = self.produce_imgs( images_upscaled=images_upscaled, max_t=max_t, height=height, width=width, images=images, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) # [1, 4, 64, 64] # Img to Numpy images = images.detach().cpu().permute(0, 2, 3, 1).numpy() images = (images * 255).round().astype("uint8") return images if __name__ == "__main__": import kiui import argparse import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument("prompt", type=str) parser.add_argument("--negative", default="", type=str) parser.add_argument("--fp16", action="store_true", help="use float16 for training") parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage") parser.add_argument("-H", type=int, default=512) parser.add_argument("-W", type=int, default=512) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--steps", type=int, default=50) opt = parser.parse_args() kiui.seed_everything(opt.seed) device = torch.device("cuda") sd = IF2(device, opt.fp16, opt.vram_O) imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) # visualize image plt.imshow(imgs[0]) plt.show() ================================================ FILE: threefiner/guidance/if2_nfsd_utils.py ================================================ from diffusers import ( PNDMScheduler, DDIMScheduler, IFPipeline, IFSuperResolutionPipeline, ) import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class IF2(nn.Module): def __init__( self, device, fp16=True, vram_O=False, model_key = "DeepFloyd/IF-II-L-v1.0", t_range=[0.02, 0.50], ): super().__init__() self.device = device self.model_key = model_key self.dtype = torch.float16 if fp16 else torch.float32 # Create model pipe = IFSuperResolutionPipeline.from_pretrained( model_key, variant="fp16", torch_dtype=torch.float16, watermarker=None, safety_checker=None, requires_safety_checker=False, ) if vram_O: pipe.unet.to(memory_format=torch.channels_last) pipe.enable_attention_slicing(1) # pipe.enable_model_cpu_offload() else: pipe.to(device) self.unet = pipe.unet self.tokenizer = pipe.tokenizer self.text_encoder = pipe.text_encoder self.scheduler = pipe.scheduler self.image_noising_scheduler = pipe.image_noising_scheduler self.pipe = pipe self.num_train_timesteps = self.scheduler.config.num_train_timesteps self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience self.embeddings = {} @torch.no_grad() def get_text_embeds(self, prompts, negative_prompts): pos_embeds = self.encode_text(prompts) # [1, 77, 768] neg_embeds = self.encode_text(negative_prompts) null_embeds = self.encode_text(['']) self.embeddings['pos'] = pos_embeds self.embeddings['neg'] = neg_embeds self.embeddings['null'] = null_embeds # directional embeddings for d in ['front', 'side', 'back']: embeds = self.encode_text([f'{p}, {d} view' for p in prompts]) self.embeddings[d] = embeds def encode_text(self, prompt): # prompt: [str] prompt = self.pipe._text_preprocessing(prompt, clean_caption=False) inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] return embeddings def train_step( self, pred_rgb, ori_rgb, step_ratio=None, guidance_scale=5, vers=None, hors=None, ): batch_size = pred_rgb.shape[0] pred_rgb = pred_rgb.to(self.dtype) ori_rgb = ori_rgb.to(self.dtype) images = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) * 2 - 1 with torch.no_grad(): max_t = torch.full((batch_size,), self.max_step, dtype=torch.long, device=self.device) # images_upscaled = images.clone() images_upscaled = F.interpolate(ori_rgb, (256, 256), mode="bilinear", align_corners=False).clamp(0, 1) * 2 - 1 noise = torch.randn_like(images_upscaled) images_upscaled = self.image_noising_scheduler.add_noise(images_upscaled, noise, max_t) if step_ratio is not None: # dreamtime-like # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) else: t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) # w(t), sigma_t^2 w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1).to(self.dtype) ######### debug # imagesx = self.produce_imgs( # images_upscaled=images_upscaled, # max_t=max_t, # images=torch.randn_like(images), # num_inference_steps=50, # guidance_scale=4.0, # ) # [1, 3, 64, 64] # import kiui # kiui.vis.plot_image(images_upscaled * 0.5 + 0.5) # kiui.vis.plot_image(imagesx * 0.5 + 0.5) ######### # add noise noise = torch.randn_like(images) images_noisy = self.scheduler.add_noise(images, noise, t) # pred noise model_input = torch.cat([images_noisy, images_upscaled], dim=1) model_input = torch.cat([model_input] * 3) model_input = self.scheduler.scale_model_input(model_input, t) tt = torch.cat([t] * 3) max_tt = torch.cat([max_t] * 3) if hors is None: embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)]) else: def _get_dir_ind(h): if abs(h) < 60: return 'front' elif abs(h) < 120: return 'side' else: return 'back' embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)]) noise_pred = self.unet( model_input, tt, encoder_hidden_states=embeddings, class_labels=max_tt, ).sample # perform guidance (high scale from paper!) noise_pred_cond, noise_pred_uncond, noise_pred_null = noise_pred.chunk(3) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_cond, _ = noise_pred_cond.split(model_input.shape[1] // 2, dim=1) noise_pred_null, _ = noise_pred_null.split(model_input.shape[1] // 2, dim=1) delta_c = guidance_scale * (noise_pred_cond - noise_pred_null) mask = (t < 200).int().view(batch_size, 1, 1, 1) delta_d = mask * noise_pred_null + (1 - mask) * (noise_pred_null - noise_pred_uncond) grad = w * (delta_c + delta_d) grad = torch.nan_to_num(grad) # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8 # grad = grad_norm.clamp(max=0.1) * grad / grad_norm target = (images - grad).detach() loss = 0.5 * F.mse_loss(images, target, reduction='sum') / images.shape[0] return loss @torch.no_grad() def produce_imgs( self, images_upscaled, max_t, height=256, width=256, num_inference_steps=50, guidance_scale=4.0, images=None, ): if images is None: images = torch.randn( ( 1, self.unet.in_channels, height, width, ), device=self.device, ) batch_size = images.shape[0] self.scheduler.set_timesteps(num_inference_steps) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps): # expand the images if we are doing classifier-free guidance to avoid doing two forward passes. model_input = torch.cat([images, images_upscaled], dim=1) model_input = torch.cat([model_input] * 2) model_input = self.scheduler.scale_model_input(model_input, t) max_tt = torch.cat([max_t] * 2) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=embeddings, class_labels=max_tt, ).sample # perform guidance noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) # compute the previous noisy sample x_t -> x_t-1 images = self.scheduler.step(noise_pred, t, images).prev_sample return images def prompt_to_img( self, images_upscaled, max_t, prompts, negative_prompts="", height=256, width=256, num_inference_steps=50, guidance_scale=4.0, images=None, ): if isinstance(prompts, str): prompts = [prompts] if isinstance(negative_prompts, str): negative_prompts = [negative_prompts] # Prompts -> text embeds self.get_text_embeds(prompts, negative_prompts) # Text embeds -> img images images = self.produce_imgs( images_upscaled=images_upscaled, max_t=max_t, height=height, width=width, images=images, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) # [1, 4, 64, 64] # Img to Numpy images = images.detach().cpu().permute(0, 2, 3, 1).numpy() images = (images * 255).round().astype("uint8") return images if __name__ == "__main__": import kiui import argparse import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument("prompt", type=str) parser.add_argument("--negative", default="", type=str) parser.add_argument("--fp16", action="store_true", help="use float16 for training") parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage") parser.add_argument("-H", type=int, default=512) parser.add_argument("-W", type=int, default=512) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--steps", type=int, default=50) opt = parser.parse_args() kiui.seed_everything(opt.seed) device = torch.device("cuda") sd = IF2(device, opt.fp16, opt.vram_O) imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) # visualize image plt.imshow(imgs[0]) plt.show() ================================================ FILE: threefiner/guidance/if2_utils.py ================================================ from diffusers import ( PNDMScheduler, DDIMScheduler, IFPipeline, IFSuperResolutionPipeline, ) import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class IF2(nn.Module): def __init__( self, device, fp16=True, vram_O=False, # model_key = "DeepFloyd/IF-II-L-v1.0", model_key = "DeepFloyd/IF-II-M-v1.0", t_range=[0.02, 0.50], ): super().__init__() self.device = device self.model_key = model_key self.dtype = torch.float16 if fp16 else torch.float32 # Create model pipe = IFSuperResolutionPipeline.from_pretrained( model_key, variant="fp16", torch_dtype=torch.float16, watermarker=None, safety_checker=None, requires_safety_checker=False, ) if vram_O: pipe.unet.to(memory_format=torch.channels_last) pipe.enable_attention_slicing(1) # pipe.enable_model_cpu_offload() pipe.enable_sequential_cpu_offload() else: pipe.to(device) self.unet = pipe.unet self.tokenizer = pipe.tokenizer self.text_encoder = pipe.text_encoder self.scheduler = pipe.scheduler self.image_noising_scheduler = pipe.image_noising_scheduler self.pipe = pipe self.num_train_timesteps = self.scheduler.config.num_train_timesteps self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience self.embeddings = {} @torch.no_grad() def get_text_embeds(self, prompts, negative_prompts): pos_embeds = self.encode_text(prompts) # [1, 77, 768] neg_embeds = self.encode_text(negative_prompts) self.embeddings['pos'] = pos_embeds self.embeddings['neg'] = neg_embeds # directional embeddings for d in ['front', 'side', 'back']: embeds = self.encode_text([f'{p}, {d} view' for p in prompts]) self.embeddings[d] = embeds def encode_text(self, prompt): # prompt: [str] prompt = self.pipe._text_preprocessing(prompt, clean_caption=False) inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] return embeddings def train_step( self, pred_rgb, ori_rgb, step_ratio=None, guidance_scale=50, vers=None, hors=None, ): batch_size = pred_rgb.shape[0] pred_rgb = pred_rgb.to(self.dtype) ori_rgb = ori_rgb.to(self.dtype) images = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) * 2 - 1 with torch.no_grad(): max_t = torch.full((batch_size,), self.max_step, dtype=torch.long, device=self.device) # images_upscaled = images.clone() images_upscaled = F.interpolate(ori_rgb, (256, 256), mode="bilinear", align_corners=False).clamp(0, 1) * 2 - 1 noise = torch.randn_like(images_upscaled) images_upscaled = self.image_noising_scheduler.add_noise(images_upscaled, noise, max_t) if step_ratio is not None: # dreamtime-like # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) else: t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) # w(t), sigma_t^2 w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1).to(self.dtype) ######### debug # imagesx = self.produce_imgs( # images_upscaled=images_upscaled, # max_t=max_t, # images=torch.randn_like(images), # num_inference_steps=50, # guidance_scale=4.0, # ) # [1, 3, 64, 64] # import kiui # kiui.vis.plot_image(images_upscaled * 0.5 + 0.5) # kiui.vis.plot_image(imagesx * 0.5 + 0.5) ######### # add noise noise = torch.randn_like(images) images_noisy = self.scheduler.add_noise(images, noise, t) # pred noise model_input = torch.cat([images_noisy, images_upscaled], dim=1) model_input = torch.cat([model_input] * 2) model_input = self.scheduler.scale_model_input(model_input, t) tt = torch.cat([t] * 2) max_tt = torch.cat([max_t] * 2) if hors is None: embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) else: def _get_dir_ind(h): if abs(h) < 60: return 'front' elif abs(h) < 120: return 'side' else: return 'back' embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)]) noise_pred = self.unet( model_input, tt, encoder_hidden_states=embeddings, class_labels=max_tt, ).sample # perform guidance (high scale from paper!) noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) grad = w * (noise_pred - noise) grad = torch.nan_to_num(grad) # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8 # grad = grad_norm.clamp(max=0.1) * grad / grad_norm target = (images - grad).detach() loss = 0.5 * F.mse_loss(images, target, reduction='sum') / images.shape[0] return loss @torch.no_grad() def produce_imgs( self, images_upscaled, max_t, height=256, width=256, num_inference_steps=50, guidance_scale=4.0, images=None, ): if images is None: images = torch.randn( ( 1, self.unet.in_channels, height, width, ), device=self.device, ) batch_size = images.shape[0] self.scheduler.set_timesteps(num_inference_steps) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps): # expand the images if we are doing classifier-free guidance to avoid doing two forward passes. model_input = torch.cat([images, images_upscaled], dim=1) model_input = torch.cat([model_input] * 2) model_input = self.scheduler.scale_model_input(model_input, t) max_tt = torch.cat([max_t] * 2) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=embeddings, class_labels=max_tt, ).sample # perform guidance noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) # compute the previous noisy sample x_t -> x_t-1 images = self.scheduler.step(noise_pred, t, images).prev_sample return images def prompt_to_img( self, images_upscaled, max_t, prompts, negative_prompts="", height=256, width=256, num_inference_steps=50, guidance_scale=4.0, images=None, ): if isinstance(prompts, str): prompts = [prompts] if isinstance(negative_prompts, str): negative_prompts = [negative_prompts] # Prompts -> text embeds self.get_text_embeds(prompts, negative_prompts) # Text embeds -> img images images = self.produce_imgs( images_upscaled=images_upscaled, max_t=max_t, height=height, width=width, images=images, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) # [1, 4, 64, 64] # Img to Numpy images = images.detach().cpu().permute(0, 2, 3, 1).numpy() images = (images * 255).round().astype("uint8") return images if __name__ == "__main__": import kiui import argparse import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument("prompt", type=str) parser.add_argument("--negative", default="", type=str) parser.add_argument("--fp16", action="store_true", help="use float16 for training") parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage") parser.add_argument("-H", type=int, default=512) parser.add_argument("-W", type=int, default=512) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--steps", type=int, default=50) opt = parser.parse_args() kiui.seed_everything(opt.seed) device = torch.device("cuda") sd = IF2(device, opt.fp16, opt.vram_O) imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) # visualize image plt.imshow(imgs[0]) plt.show() ================================================ FILE: threefiner/guidance/if_utils.py ================================================ from diffusers import ( PNDMScheduler, DDIMScheduler, IFPipeline, ) import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class IF(nn.Module): def __init__( self, device, fp16=True, vram_O=False, model_key = "DeepFloyd/IF-I-XL-v1.0", # model_key = "DeepFloyd/IF-I-M-v1.0", t_range=[0.02, 0.98], ): super().__init__() self.device = device self.model_key = model_key self.dtype = torch.float16 if fp16 else torch.float32 # Create model pipe = IFPipeline.from_pretrained( model_key, variant="fp16", torch_dtype=torch.float16, watermarker=None, safety_checker=None, requires_safety_checker=False, ) if vram_O: pipe.unet.to(memory_format=torch.channels_last) pipe.enable_attention_slicing(1) # pipe.enable_model_cpu_offload() pipe.enable_sequential_cpu_offload() else: pipe.to(device) self.unet = pipe.unet self.tokenizer = pipe.tokenizer self.text_encoder = pipe.text_encoder self.scheduler = pipe.scheduler self.pipe = pipe self.num_train_timesteps = self.scheduler.config.num_train_timesteps self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience self.embeddings = {} @torch.no_grad() def get_text_embeds(self, prompts, negative_prompts): pos_embeds = self.encode_text(prompts) # [1, 77, 768] neg_embeds = self.encode_text(negative_prompts) self.embeddings['pos'] = pos_embeds self.embeddings['neg'] = neg_embeds # directional embeddings for d in ['front', 'side', 'back']: embeds = self.encode_text([f'{p}, {d} view' for p in prompts]) self.embeddings[d] = embeds def encode_text(self, prompt): # prompt: [str] prompt = self.pipe._text_preprocessing(prompt, clean_caption=False) inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] return embeddings @torch.no_grad() def refine(self, pred_rgb, guidance_scale=100, steps=50, strength=0.8, ): batch_size = pred_rgb.shape[0] images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) self.scheduler.set_timesteps(steps) init_step = int(steps * strength) images = self.scheduler.add_noise(images, torch.randn_like(images), self.scheduler.timesteps[init_step]) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps[init_step:]): model_input = torch.cat([images] * 2) noise_pred = self.unet( model_input, t, encoder_hidden_states=embeddings, ).sample noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) images = self.scheduler.step(noise_pred, t, images).prev_sample return images def train_step( self, pred_rgb, step_ratio=None, guidance_scale=50, vers=None, hors=None, ): batch_size = pred_rgb.shape[0] pred_rgb = pred_rgb.to(self.dtype) images = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1 with torch.no_grad(): if step_ratio is not None: # dreamtime-like # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) else: t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) # w(t), sigma_t^2 w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) ######### debug # # Text embeds -> img latents # imagesx = self.produce_imgs( # height=64, # width=64, # images=torch.randn_like(images), # num_inference_steps=50, # guidance_scale=7.5, # ) # [1, 3, 64, 64] # import kiui # kiui.vis.plot_image(imagesx) ######### # add noise noise = torch.randn_like(images) images_noisy = self.scheduler.add_noise(images, noise, t) # pred noise model_input = torch.cat([images_noisy] * 2) model_input = self.scheduler.scale_model_input(model_input, t) tt = torch.cat([t] * 2) if hors is None: embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) else: def _get_dir_ind(h): if abs(h) < 60: return 'front' elif abs(h) < 120: return 'side' else: return 'back' embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)]) noise_pred = self.unet( model_input, tt, encoder_hidden_states=embeddings ).sample # perform guidance (high scale from paper!) noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) grad = w * (noise_pred - noise) grad = torch.nan_to_num(grad) # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8 # grad = grad_norm.clamp(max=0.1) * grad / grad_norm target = (images - grad).detach() loss = 0.5 * F.mse_loss(images.float(), target, reduction='sum') / images.shape[0] return loss @torch.no_grad() def produce_imgs( self, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, images=None, ): if images is None: images = torch.randn( ( 1, self.unet.in_channels, height, width, ), device=self.device, ) batch_size = images.shape[0] self.scheduler.set_timesteps(num_inference_steps) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps): # expand the images if we are doing classifier-free guidance to avoid doing two forward passes. model_input = torch.cat([images] * 2) model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=embeddings ).sample # perform guidance noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) # compute the previous noisy sample x_t -> x_t-1 images = self.scheduler.step(noise_pred, t, images).prev_sample return images def prompt_to_img( self, prompts, negative_prompts="", height=64, width=64, num_inference_steps=50, guidance_scale=7.5, images=None, ): if isinstance(prompts, str): prompts = [prompts] if isinstance(negative_prompts, str): negative_prompts = [negative_prompts] # Prompts -> text embeds self.get_text_embeds(prompts, negative_prompts) # Text embeds -> img images images = self.produce_imgs( height=height, width=width, images=images, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) # [1, 4, 64, 64] # Img to Numpy images = images.detach().cpu().permute(0, 2, 3, 1).numpy() images = (images * 255).round().astype("uint8") return images if __name__ == "__main__": import kiui import argparse import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument("prompt", type=str) parser.add_argument("--negative", default="", type=str) parser.add_argument("--fp16", action="store_true", help="use float16 for training") parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage") parser.add_argument("-H", type=int, default=512) parser.add_argument("-W", type=int, default=512) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--steps", type=int, default=50) opt = parser.parse_args() kiui.seed_everything(opt.seed) device = torch.device("cuda") sd = IF(device, opt.fp16, opt.vram_O) imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) # visualize image plt.imshow(imgs[0]) plt.show() ================================================ FILE: threefiner/guidance/sd_ism_utils.py ================================================ from diffusers import ( PNDMScheduler, DDIMScheduler, StableDiffusionPipeline, ) import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def invert_noise(scheduler, noisy_samples, noise, timesteps): alphas_cumprod = scheduler.alphas_cumprod.to(device=noisy_samples.device, dtype=noisy_samples.dtype) timesteps = timesteps.to(noisy_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(noisy_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(noisy_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) original_samples = 1 / sqrt_alpha_prod * (noisy_samples - sqrt_one_minus_alpha_prod * noise) return original_samples class StableDiffusion(nn.Module): def __init__( self, device, fp16=True, vram_O=False, model_key="stabilityai/stable-diffusion-2-1-base", # model_key="philz1337/revanimated", t_range=[0.02, 0.98], ): super().__init__() self.device = device self.model_key = model_key self.dtype = torch.float16 if fp16 else torch.float32 # Create model pipe = StableDiffusionPipeline.from_pretrained( model_key, torch_dtype=self.dtype ) if vram_O: pipe.enable_sequential_cpu_offload() pipe.enable_vae_slicing() pipe.unet.to(memory_format=torch.channels_last) pipe.enable_attention_slicing(1) # pipe.enable_model_cpu_offload() else: pipe.to(device) self.vae = pipe.vae self.tokenizer = pipe.tokenizer self.text_encoder = pipe.text_encoder self.unet = pipe.unet self.scheduler = DDIMScheduler.from_pretrained( model_key, subfolder="scheduler", torch_dtype=self.dtype ) del pipe self.num_train_timesteps = self.scheduler.config.num_train_timesteps self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience self.embeddings = {} @torch.no_grad() def get_text_embeds(self, prompts, negative_prompts): pos_embeds = self.encode_text(prompts) # [1, 77, 768] neg_embeds = self.encode_text(negative_prompts) null_embeds = self.encode_text([""]) self.embeddings['pos'] = pos_embeds self.embeddings['neg'] = neg_embeds self.embeddings['null'] = null_embeds # directional embeddings for d in ['front', 'side', 'back']: embeds = self.encode_text([f'{p}, {d} view' for p in prompts]) self.embeddings[d] = embeds def encode_text(self, prompt): # prompt: [str] inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] return embeddings @torch.no_grad() def refine(self, pred_rgb, guidance_scale=100, steps=50, strength=0.8, ): batch_size = pred_rgb.shape[0] pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) latents = self.encode_imgs(pred_rgb_512.to(self.dtype)) # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype) self.scheduler.set_timesteps(steps) init_step = int(steps * strength) latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps[init_step:]): latent_model_input = torch.cat([latents] * 2) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=embeddings, ).sample noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) latents = self.scheduler.step(noise_pred, t, latents).prev_sample imgs = self.decode_latents(latents) # [1, 3, 512, 512] return imgs def train_step( self, pred_rgb, step_ratio=None, guidance_scale=7.5, as_latent=False, vers=None, hors=None, delta_t=50, delta_s=200, ): batch_size = pred_rgb.shape[0] pred_rgb = pred_rgb.to(self.dtype) if as_latent: latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1 else: # interp to 512x512 to be fed into vae. pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) # encode image into latents with vae, requires grad! latents = self.encode_imgs(pred_rgb_512) with torch.no_grad(): if step_ratio is not None: # dreamtime-like # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) else: t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) # w(t), sigma_t^2 w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) ######### debug # # Text embeds -> img latents # latentsx = self.produce_latents( # height=512, # width=512, # latents=torch.randn_like(latents), # num_inference_steps=50, # guidance_scale=7.5, # ) # [1, 4, 64, 64] # # Img latents -> imgs # imgs = self.decode_latents(latentsx) # [1, 3, 512, 512] # import kiui # kiui.vis.plot_image(imgs) ######### null_embeddings = self.embeddings['null'].expand(batch_size, -1, -1) if hors is None: embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) else: def _get_dir_ind(h): if abs(h) < 60: return 'front' elif abs(h) < 120: return 'side' else: return 'back' embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)]) ########### ISM # steps t = t.clamp(min=delta_t) s = t - delta_t n = s // delta_s r = s % delta_s # construct trajectory latents_noisy = latents.clone() cur_t = torch.full((batch_size,), 0, dtype=torch.long, device=self.device) noise = self.unet(latents_noisy, cur_t, encoder_hidden_states=null_embeddings).sample latents_original = invert_noise(self.scheduler, latents_noisy, noise, cur_t) cur_t += r latents_noisy = self.scheduler.add_noise(latents_original, noise, cur_t) for i in range(n): noise = self.unet(latents_noisy, cur_t, encoder_hidden_states=null_embeddings).sample latents_original = invert_noise(self.scheduler, latents_noisy, noise, cur_t) cur_t += delta_s latents_noisy = self.scheduler.add_noise(latents_original, noise, cur_t) # x_s # construct last step noise = self.unet(latents_noisy, cur_t, encoder_hidden_states=null_embeddings).sample # \epsilon_s latents_original = invert_noise(self.scheduler, latents_noisy, noise, cur_t) # \hat x_0^s # perform guidance latents_noisy = self.scheduler.add_noise(latents_original, noise, t) # x_t latent_model_input = torch.cat([latents_noisy] * 2) tt = torch.cat([t] * 2) noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=embeddings).sample noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) grad = w * (noise_pred - noise) grad = torch.nan_to_num(grad) # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8 # grad = grad_norm.clamp(max=0.1) * grad / grad_norm target = (latents - grad).detach() loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] return loss @torch.no_grad() def produce_latents( self, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, ): if latents is None: latents = torch.randn( ( 1, self.unet.in_channels, height // 8, width // 8, ), device=self.device, ) batch_size = latents.shape[0] self.scheduler.set_timesteps(num_inference_steps) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps): # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([latents] * 2) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=embeddings ).sample # perform guidance noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents).prev_sample return latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents imgs = self.vae.decode(latents).sample imgs = (imgs / 2 + 0.5).clamp(0, 1) return imgs def encode_imgs(self, imgs): # imgs: [B, 3, H, W] imgs = 2 * imgs - 1 posterior = self.vae.encode(imgs).latent_dist latents = posterior.sample() * self.vae.config.scaling_factor return latents def prompt_to_img( self, prompts, negative_prompts="", height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, ): if isinstance(prompts, str): prompts = [prompts] if isinstance(negative_prompts, str): negative_prompts = [negative_prompts] # Prompts -> text embeds self.get_text_embeds(prompts, negative_prompts) # Text embeds -> img latents latents = self.produce_latents( height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) # [1, 4, 64, 64] # Img latents -> imgs imgs = self.decode_latents(latents) # [1, 3, 512, 512] # Img to Numpy imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() imgs = (imgs * 255).round().astype("uint8") return imgs if __name__ == "__main__": import kiui import argparse import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument("prompt", type=str) parser.add_argument("--negative", default="", type=str) parser.add_argument("--fp16", action="store_true", help="use float16 for training") parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage") parser.add_argument("-H", type=int, default=512) parser.add_argument("-W", type=int, default=512) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--steps", type=int, default=50) opt = parser.parse_args() kiui.seed_everything(opt.seed) device = torch.device("cuda") sd = StableDiffusion(device, opt.fp16, opt.vram_O) imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) # visualize image plt.imshow(imgs[0]) plt.show() ================================================ FILE: threefiner/guidance/sd_nfsd_utils.py ================================================ from diffusers import ( PNDMScheduler, DDIMScheduler, StableDiffusionPipeline, ) import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class StableDiffusion(nn.Module): def __init__( self, device, fp16=True, vram_O=False, model_key="stabilityai/stable-diffusion-2-1-base", # model_key="philz1337/revanimated", t_range=[0.02, 0.50], ): super().__init__() self.device = device self.model_key = model_key self.dtype = torch.float16 if fp16 else torch.float32 # Create model pipe = StableDiffusionPipeline.from_pretrained( model_key, torch_dtype=self.dtype ) if vram_O: pipe.enable_sequential_cpu_offload() pipe.enable_vae_slicing() pipe.unet.to(memory_format=torch.channels_last) pipe.enable_attention_slicing(1) # pipe.enable_model_cpu_offload() else: pipe.to(device) self.vae = pipe.vae self.tokenizer = pipe.tokenizer self.text_encoder = pipe.text_encoder self.unet = pipe.unet self.scheduler = DDIMScheduler.from_pretrained( model_key, subfolder="scheduler", torch_dtype=self.dtype ) del pipe self.num_train_timesteps = self.scheduler.config.num_train_timesteps self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience self.embeddings = {} @torch.no_grad() def get_text_embeds(self, prompts, negative_prompts): pos_embeds = self.encode_text(prompts) # [1, 77, 768] neg_embeds = self.encode_text(negative_prompts) null_embeds = self.encode_text(['']) self.embeddings['pos'] = pos_embeds self.embeddings['neg'] = neg_embeds self.embeddings['null'] = null_embeds # directional embeddings for d in ['front', 'side', 'back']: embeds = self.encode_text([f'{p}, {d} view' for p in prompts]) self.embeddings[d] = embeds def encode_text(self, prompt): # prompt: [str] inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] return embeddings @torch.no_grad() def refine(self, pred_rgb, guidance_scale=100, steps=50, strength=0.8, ): batch_size = pred_rgb.shape[0] pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) latents = self.encode_imgs(pred_rgb_512.to(self.dtype)) # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype) self.scheduler.set_timesteps(steps) init_step = int(steps * strength) latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps[init_step:]): latent_model_input = torch.cat([latents] * 2) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=embeddings, ).sample noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) latents = self.scheduler.step(noise_pred, t, latents).prev_sample imgs = self.decode_latents(latents) # [1, 3, 512, 512] return imgs def train_step( self, pred_rgb, step_ratio=None, guidance_scale=7.5, as_latent=False, vers=None, hors=None, ): batch_size = pred_rgb.shape[0] pred_rgb = pred_rgb.to(self.dtype) if as_latent: latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1 else: # interp to 512x512 to be fed into vae. pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) # encode image into latents with vae, requires grad! latents = self.encode_imgs(pred_rgb_512) with torch.no_grad(): if step_ratio is not None: # dreamtime-like # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) else: t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) # w(t), sigma_t^2 w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) ######### debug # # Text embeds -> img latents # latentsx = self.produce_latents( # height=512, # width=512, # latents=torch.randn_like(latents), # num_inference_steps=50, # guidance_scale=7.5, # ) # [1, 4, 64, 64] # # Img latents -> imgs # imgs = self.decode_latents(latentsx) # [1, 3, 512, 512] # import kiui # kiui.vis.plot_image(imgs) ######### # add noise noise = torch.randn_like(latents) latents_noisy = self.scheduler.add_noise(latents, noise, t) # pred noise latent_model_input = torch.cat([latents_noisy] * 3) tt = torch.cat([t] * 3) if hors is None: embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)]) else: def _get_dir_ind(h): if abs(h) < 60: return 'front' elif abs(h) < 120: return 'side' else: return 'back' embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)]) noise_pred = self.unet( latent_model_input, tt, encoder_hidden_states=embeddings ).sample # perform guidance (high scale from paper!) noise_pred_cond, noise_pred_uncond, noise_pred_null = noise_pred.chunk(3) delta_c = guidance_scale * (noise_pred_cond - noise_pred_null) mask = (t < 200).int().view(batch_size, 1, 1, 1) delta_d = mask * noise_pred_null + (1 - mask) * (noise_pred_null - noise_pred_uncond) * 3 grad = w * (delta_c + delta_d) grad = torch.nan_to_num(grad) # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8 # grad = grad_norm.clamp(max=0.1) * grad / grad_norm target = (latents - grad).detach() loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] return loss @torch.no_grad() def produce_latents( self, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, ): if latents is None: latents = torch.randn( ( 1, self.unet.in_channels, height // 8, width // 8, ), device=self.device, ) batch_size = latents.shape[0] self.scheduler.set_timesteps(num_inference_steps) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps): # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([latents] * 2) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=embeddings ).sample # perform guidance noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents).prev_sample return latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents imgs = self.vae.decode(latents).sample imgs = (imgs / 2 + 0.5).clamp(0, 1) return imgs def encode_imgs(self, imgs): # imgs: [B, 3, H, W] imgs = 2 * imgs - 1 posterior = self.vae.encode(imgs).latent_dist latents = posterior.sample() * self.vae.config.scaling_factor return latents def prompt_to_img( self, prompts, negative_prompts="", height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, ): if isinstance(prompts, str): prompts = [prompts] if isinstance(negative_prompts, str): negative_prompts = [negative_prompts] # Prompts -> text embeds self.get_text_embeds(prompts, negative_prompts) # Text embeds -> img latents latents = self.produce_latents( height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) # [1, 4, 64, 64] # Img latents -> imgs imgs = self.decode_latents(latents) # [1, 3, 512, 512] # Img to Numpy imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() imgs = (imgs * 255).round().astype("uint8") return imgs if __name__ == "__main__": import kiui import argparse import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument("prompt", type=str) parser.add_argument("--negative", default="", type=str) parser.add_argument("--fp16", action="store_true", help="use float16 for training") parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage") parser.add_argument("-H", type=int, default=512) parser.add_argument("-W", type=int, default=512) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--steps", type=int, default=50) opt = parser.parse_args() kiui.seed_everything(opt.seed) device = torch.device("cuda") sd = StableDiffusion(device, opt.fp16, opt.vram_O) imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) # visualize image plt.imshow(imgs[0]) plt.show() ================================================ FILE: threefiner/guidance/sd_utils.py ================================================ from diffusers import ( PNDMScheduler, DDIMScheduler, StableDiffusionPipeline, ) import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class StableDiffusion(nn.Module): def __init__( self, device, fp16=True, vram_O=False, model_key="stabilityai/stable-diffusion-2-1-base", # model_key="philz1337/revanimated", t_range=[0.02, 0.50], ): super().__init__() self.device = device self.model_key = model_key self.dtype = torch.float16 if fp16 else torch.float32 # Create model pipe = StableDiffusionPipeline.from_pretrained( model_key, torch_dtype=self.dtype ) if vram_O: pipe.enable_sequential_cpu_offload() pipe.enable_vae_slicing() pipe.unet.to(memory_format=torch.channels_last) pipe.enable_attention_slicing(1) # pipe.enable_model_cpu_offload() else: pipe.to(device) self.vae = pipe.vae self.tokenizer = pipe.tokenizer self.text_encoder = pipe.text_encoder self.unet = pipe.unet self.scheduler = DDIMScheduler.from_pretrained( model_key, subfolder="scheduler", torch_dtype=self.dtype ) del pipe self.num_train_timesteps = self.scheduler.config.num_train_timesteps self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience self.embeddings = {} @torch.no_grad() def get_text_embeds(self, prompts, negative_prompts): pos_embeds = self.encode_text(prompts) # [1, 77, 768] neg_embeds = self.encode_text(negative_prompts) self.embeddings['pos'] = pos_embeds self.embeddings['neg'] = neg_embeds # directional embeddings for d in ['front', 'side', 'back']: embeds = self.encode_text([f'{p}, {d} view' for p in prompts]) self.embeddings[d] = embeds def encode_text(self, prompt): # prompt: [str] inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] return embeddings @torch.no_grad() def refine(self, pred_rgb, guidance_scale=100, steps=50, strength=0.8, ): batch_size = pred_rgb.shape[0] pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) latents = self.encode_imgs(pred_rgb_512.to(self.dtype)) # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype) self.scheduler.set_timesteps(steps) init_step = int(steps * strength) latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps[init_step:]): latent_model_input = torch.cat([latents] * 2) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=embeddings, ).sample noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) latents = self.scheduler.step(noise_pred, t, latents).prev_sample imgs = self.decode_latents(latents) # [1, 3, 512, 512] return imgs def train_step( self, pred_rgb, step_ratio=None, guidance_scale=100, as_latent=False, vers=None, hors=None, ): batch_size = pred_rgb.shape[0] pred_rgb = pred_rgb.to(self.dtype) if as_latent: latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1 else: # interp to 512x512 to be fed into vae. pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) # encode image into latents with vae, requires grad! latents = self.encode_imgs(pred_rgb_512) with torch.no_grad(): if step_ratio is not None: # dreamtime-like # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) else: t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) # w(t), sigma_t^2 w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) ######### debug # # Text embeds -> img latents # latentsx = self.produce_latents( # height=512, # width=512, # latents=torch.randn_like(latents), # num_inference_steps=50, # guidance_scale=7.5, # ) # [1, 4, 64, 64] # # Img latents -> imgs # imgs = self.decode_latents(latentsx) # [1, 3, 512, 512] # import kiui # kiui.vis.plot_image(imgs) ######### # add noise noise = torch.randn_like(latents) latents_noisy = self.scheduler.add_noise(latents, noise, t) # pred noise latent_model_input = torch.cat([latents_noisy] * 2) tt = torch.cat([t] * 2) if hors is None: embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) else: def _get_dir_ind(h): if abs(h) < 60: return 'front' elif abs(h) < 120: return 'side' else: return 'back' embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)]) noise_pred = self.unet( latent_model_input, tt, encoder_hidden_states=embeddings ).sample # perform guidance (high scale from paper!) noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) grad = w * (noise_pred - noise) grad = torch.nan_to_num(grad) # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8 # grad = grad_norm.clamp(max=0.1) * grad / grad_norm target = (latents - grad).detach() loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] return loss @torch.no_grad() def produce_latents( self, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, ): if latents is None: latents = torch.randn( ( 1, self.unet.in_channels, height // 8, width // 8, ), device=self.device, ) batch_size = latents.shape[0] self.scheduler.set_timesteps(num_inference_steps) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps): # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([latents] * 2) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=embeddings ).sample # perform guidance noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents).prev_sample return latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents imgs = self.vae.decode(latents).sample imgs = (imgs / 2 + 0.5).clamp(0, 1) return imgs def encode_imgs(self, imgs): # imgs: [B, 3, H, W] imgs = 2 * imgs - 1 posterior = self.vae.encode(imgs).latent_dist latents = posterior.sample() * self.vae.config.scaling_factor return latents def prompt_to_img( self, prompts, negative_prompts="", height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, ): if isinstance(prompts, str): prompts = [prompts] if isinstance(negative_prompts, str): negative_prompts = [negative_prompts] # Prompts -> text embeds self.get_text_embeds(prompts, negative_prompts) # Text embeds -> img latents latents = self.produce_latents( height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) # [1, 4, 64, 64] # Img latents -> imgs imgs = self.decode_latents(latents) # [1, 3, 512, 512] # Img to Numpy imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() imgs = (imgs * 255).round().astype("uint8") return imgs if __name__ == "__main__": import kiui import argparse import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument("prompt", type=str) parser.add_argument("--negative", default="", type=str) parser.add_argument("--fp16", action="store_true", help="use float16 for training") parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage") parser.add_argument("-H", type=int, default=512) parser.add_argument("-W", type=int, default=512) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--steps", type=int, default=50) opt = parser.parse_args() kiui.seed_everything(opt.seed) device = torch.device("cuda") sd = StableDiffusion(device, opt.fp16, opt.vram_O) imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) # visualize image plt.imshow(imgs[0]) plt.show() ================================================ FILE: threefiner/guidance/sdcn_utils.py ================================================ from diffusers import ( PNDMScheduler, DDIMScheduler, StableDiffusionPipeline, ControlNetModel, ) import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class StableDiffusionControlNet(nn.Module): def __init__( self, device, fp16=True, vram_O=False, control_mode=["tile"], model_key="runwayml/stable-diffusion-v1-5", # model_key="philz1337/revanimated", t_range=[0.02, 0.50], ): super().__init__() self.device = device self.control_mode = control_mode self.dtype = torch.float16 if fp16 else torch.float32 # Create model pipe = StableDiffusionPipeline.from_pretrained( model_key, torch_dtype=self.dtype ) if vram_O: pipe.enable_sequential_cpu_offload() pipe.enable_vae_slicing() pipe.unet.to(memory_format=torch.channels_last) pipe.enable_attention_slicing(1) # pipe.enable_model_cpu_offload() else: pipe.to(device) self.vae = pipe.vae self.tokenizer = pipe.tokenizer self.text_encoder = pipe.text_encoder self.unet = pipe.unet # controlnet if self.control_mode is not None: self.controlnet = {} self.controlnet_conditioning_scale = {} if "normal" in self.control_mode: self.controlnet['normal'] = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_normalbae",torch_dtype=self.dtype).to(self.device) self.controlnet_conditioning_scale['normal'] = 1.0 if "depth" in self.control_mode: self.controlnet['depth'] = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth",torch_dtype=self.dtype).to(self.device) self.controlnet_conditioning_scale['depth'] = 1.0 if "ip2p" in self.control_mode: self.controlnet['ip2p'] = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_ip2p",torch_dtype=self.dtype).to(self.device) self.controlnet_conditioning_scale['ip2p'] = 1.0 if "inpaint" in self.control_mode: self.controlnet['inpaint'] = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint",torch_dtype=self.dtype).to(self.device) self.controlnet_conditioning_scale['inpaint'] = 1.0 if "depth_inpaint" in self.control_mode: self.controlnet['depth_inpaint'] = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_depth_aware_inpaint",torch_dtype=self.dtype).to(self.device) self.controlnet_conditioning_scale['depth_inpaint'] = 1.0 if "pose" in self.control_mode: self.controlnet['pose'] = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose",torch_dtype=self.dtype).to(self.device) self.controlnet_conditioning_scale['pose'] = 1.0 if "tile" in self.control_mode: self.controlnet['tile'] = ControlNetModel.from_pretrained("lllyasviel/control_v11f1e_sd15_tile",torch_dtype=self.dtype).to(self.device) self.controlnet_conditioning_scale['tile'] = 1.0 self.scheduler = DDIMScheduler.from_pretrained( model_key, subfolder="scheduler", torch_dtype=self.dtype ) del pipe self.num_train_timesteps = self.scheduler.config.num_train_timesteps self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience self.embeddings = {} @torch.no_grad() def get_text_embeds(self, prompts, negative_prompts): pos_embeds = self.encode_text(prompts) # [1, 77, 768] neg_embeds = self.encode_text(negative_prompts) self.embeddings['pos'] = pos_embeds self.embeddings['neg'] = neg_embeds # directional embeddings for d in ['front', 'side', 'back']: embeds = self.encode_text([f'{p}, {d} view' for p in prompts]) self.embeddings[d] = embeds def encode_text(self, prompt): # prompt: [str] inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] return embeddings @torch.no_grad() def refine(self, pred_rgb, guidance_scale=100, steps=50, strength=0.8, control_images=None ): batch_size = pred_rgb.shape[0] pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) latents = self.encode_imgs(pred_rgb_512.to(self.dtype)) # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype) self.scheduler.set_timesteps(steps) init_step = int(steps * strength) latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps[init_step:]): latent_model_input = torch.cat([latents] * 2) if self.control_mode is not None and control_images is not None: noise_pred = 0 for mode, controlnet in self.controlnet.items(): # may omit control mode if input is not provided if mode not in control_images: continue control_image = control_images[mode].to(self.dtype) weight = 1 / len(self.controlnet) control_image_input = torch.cat([control_image] * 2) down_samples, mid_sample = controlnet( latent_model_input, t, encoder_hidden_states=embeddings, controlnet_cond=control_image_input, conditioning_scale=self.controlnet_conditioning_scale[mode], return_dict=False ) # predict the noise residual noise_pred_cur = self.unet( latent_model_input, t, encoder_hidden_states=embeddings, down_block_additional_residuals=down_samples, mid_block_additional_residual=mid_sample ).sample # merge after unet noise_pred = noise_pred + weight * noise_pred_cur else: noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=self.embeddings, ).sample noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) latents = self.scheduler.step(noise_pred, t, latents).prev_sample imgs = self.decode_latents(latents) # [1, 3, 512, 512] return imgs def train_step( self, pred_rgb, step_ratio=None, guidance_scale=100, as_latent=False, control_images=None, vers=None, hors=None, ): batch_size = pred_rgb.shape[0] pred_rgb = pred_rgb.to(self.dtype) if as_latent: latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1 else: # interp to 512x512 to be fed into vae. pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) # encode image into latents with vae, requires grad! latents = self.encode_imgs(pred_rgb_512) with torch.no_grad(): if step_ratio is not None: # dreamtime-like # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) else: t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) # w(t), sigma_t^2 w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) ######### debug # # Text embeds -> img latents # latentsx = self.produce_latents( # height=512, # width=512, # latents=torch.randn_like(latents), # num_inference_steps=50, # guidance_scale=7.5, # control_images=control_images, # ) # [1, 4, 64, 64] # # Img latents -> imgs # imgs = self.decode_latents(latentsx) # [1, 3, 512, 512] # import kiui # kiui.vis.plot_image(control_images['tile'], imgs) ######### # add noise noise = torch.randn_like(latents) latents_noisy = self.scheduler.add_noise(latents, noise, t) # pred noise latent_model_input = torch.cat([latents_noisy] * 2) tt = torch.cat([t] * 2) if hors is None: embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) else: def _get_dir_ind(h): if abs(h) < 60: return 'front' elif abs(h) < 120: return 'side' else: return 'back' embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)]) if self.control_mode is not None and control_images is not None: noise_pred = 0 for mode, controlnet in self.controlnet.items(): # may omit control mode if input is not provided if mode not in control_images: continue control_image = control_images[mode].to(self.dtype) weight = 1 / len(self.controlnet) control_image_input = torch.cat([control_image] * 2) down_samples, mid_sample = controlnet( latent_model_input, tt, encoder_hidden_states=embeddings, controlnet_cond=control_image_input, conditioning_scale=self.controlnet_conditioning_scale[mode], return_dict=False ) # predict the noise residual noise_pred_cur = self.unet( latent_model_input, tt, encoder_hidden_states=embeddings, down_block_additional_residuals=down_samples, mid_block_additional_residual=mid_sample ).sample # merge after unet noise_pred = noise_pred + weight * noise_pred_cur else: noise_pred = self.unet( latent_model_input, tt, encoder_hidden_states=embeddings, ).sample # perform guidance (high scale from paper!) noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) grad = w * (noise_pred - noise) grad = torch.nan_to_num(grad) # grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8 # grad = grad_norm.clamp(max=0.1) * grad / grad_norm target = (latents - grad).detach() loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] return loss @torch.no_grad() def produce_latents( self, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, control_images=None, ): if latents is None: latents = torch.randn( ( 1, self.unet.in_channels, height // 8, width // 8, ), device=self.device, ) batch_size = latents.shape[0] self.scheduler.set_timesteps(num_inference_steps) embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)]) for i, t in enumerate(self.scheduler.timesteps): # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([latents] * 2) # predict the noise residual if self.control_mode is not None and control_images is not None: noise_pred = 0 for mode, controlnet in self.controlnet.items(): # may omit control mode if input is not provided if mode not in control_images: continue control_image = control_images[mode].to(self.dtype) weight = 1 / len(self.controlnet) control_image_input = torch.cat([control_image] * 2) down_samples, mid_sample = controlnet( latent_model_input, t, encoder_hidden_states=embeddings, controlnet_cond=control_image_input, conditioning_scale=self.controlnet_conditioning_scale[mode], return_dict=False ) # predict the noise residual noise_pred_cur = self.unet( latent_model_input, t, encoder_hidden_states=embeddings, down_block_additional_residuals=down_samples, mid_block_additional_residual=mid_sample ).sample # merge after unet noise_pred = noise_pred + weight * noise_pred_cur else: noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=embeddings, ).sample # perform guidance noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents).prev_sample return latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents imgs = self.vae.decode(latents).sample imgs = (imgs / 2 + 0.5).clamp(0, 1) return imgs def encode_imgs(self, imgs): # imgs: [B, 3, H, W] imgs = 2 * imgs - 1 posterior = self.vae.encode(imgs).latent_dist latents = posterior.sample() * self.vae.config.scaling_factor return latents def prompt_to_img( self, prompts, negative_prompts="", height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, control_images=None, ): if isinstance(prompts, str): prompts = [prompts] if isinstance(negative_prompts, str): negative_prompts = [negative_prompts] # Prompts -> text embeds self.get_text_embeds(prompts, negative_prompts) # Text embeds -> img latents latents = self.produce_latents( height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, control_images=control_images, ) # [1, 4, 64, 64] # Img latents -> imgs imgs = self.decode_latents(latents) # [1, 3, 512, 512] # Img to Numpy imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() imgs = (imgs * 255).round().astype("uint8") return imgs if __name__ == "__main__": import kiui import argparse import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument("image", type=str) parser.add_argument("prompt", default="", type=str) parser.add_argument("--control", default='tile', type=str) parser.add_argument("--negative", default="", type=str) parser.add_argument("--fp16", action="store_true", help="use float16 for training") parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage") parser.add_argument("-H", type=int, default=512) parser.add_argument("-W", type=int, default=512) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--steps", type=int, default=50) opt = parser.parse_args() kiui.seed_everything(opt.seed) device = torch.device("cuda") # load image control_image = kiui.read_image(opt.image, mode='tensor').permute(2,0,1).contiguous().unsqueeze(0).to(device) control_image = F.interpolate(control_image, (opt.H, opt.W), mode='bilinear', align_corners=False) kiui.lo(control_image) control_images = {} control_images[opt.control] = control_image sd = StableDiffusionControlNet(device, opt.fp16, opt.vram_O, control_mode=[opt.control]) while True: imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps, control_images=control_images) # visualize image plt.imshow(imgs[0]) plt.show() ================================================ FILE: threefiner/lights/LICENSE.txt ================================================ The mud_road_puresky.hdr HDR probe is from https://polyhaven.com/a/mud_road_puresky CC0 License. ================================================ FILE: threefiner/nn.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import tinycudann as tcnn class HashGridEncoder(nn.Module): def __init__(self, input_dim=3, num_levels=16, level_dim=2, log2_hashmap_size=18, base_resolution=16, desired_resolution=1024, interpolation='linear' ): super().__init__() self.encoder = tcnn.Encoding( n_input_dims=input_dim, encoding_config={ "otype": "HashGrid", "n_levels": num_levels, "n_features_per_level": level_dim, "log2_hashmap_size": log2_hashmap_size, "base_resolution": base_resolution, "per_level_scale": np.exp2(np.log2(desired_resolution / num_levels) / (num_levels - 1)), "interpolation": "Smoothstep" if interpolation == 'smoothstep' else "Linear", }, dtype=torch.float32, ) self.input_dim = input_dim self.output_dim = self.encoder.n_output_dims # patch def forward(self, x, bound=1): return self.encoder((x + bound) / (2 * bound)) class FrequencyEncoder(nn.Module): def __init__(self, input_dim=3, output_dim=32, n_frequencies=12, ): super().__init__() self.encoder = tcnn.Encoding( n_input_dims=input_dim, encoding_config={ "otype": "Frequency", "n_frequencies": n_frequencies, }, dtype=torch.float32, ) self.implicit_mlp = MLP(self.encoder.n_output_dims, output_dim, 128, 5, bias=True) self.input_dim = input_dim self.output_dim = output_dim def forward(self, x, **kwargs): return self.implicit_mlp(self.encoder(x)) class TriplaneEncoder(nn.Module): def __init__(self, input_dim=3, output_dim=32, resolution=256, ): super().__init__() self.C_mat = nn.Parameter(torch.randn(3, output_dim, resolution, resolution)) torch.nn.init.kaiming_normal_(self.C_mat) self.mat_ids = [[0, 1], [0, 2], [1, 2]] self.input_dim = input_dim self.output_dim = output_dim def forward(self, x, bound=1): N = x.shape[0] x = x / bound # to [-1, 1] mat_coord = torch.stack((x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]])).view(3, -1, 1, 2) # [3, N, 1, 2] feat = F.grid_sample(self.C_mat[[0]], mat_coord[[0]], align_corners=False).view(-1, N) + \ F.grid_sample(self.C_mat[[1]], mat_coord[[1]], align_corners=False).view(-1, N) + \ F.grid_sample(self.C_mat[[2]], mat_coord[[2]], align_corners=False).view(-1, N) # [r, N] # density feat = feat.transpose(0, 1).contiguous() # [N, C] return feat class MLP(nn.Module): def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): super().__init__() self.dim_in = dim_in self.dim_out = dim_out self.dim_hidden = dim_hidden self.num_layers = num_layers net = [] for l in range(num_layers): net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) self.net = nn.ModuleList(net) def forward(self, x): for l in range(self.num_layers): x = self.net[l](x) if l != self.num_layers - 1: x = F.relu(x, inplace=True) return x ================================================ FILE: threefiner/opt.py ================================================ import os from dataclasses import dataclass from typing import Tuple, Literal, Dict, Optional @dataclass class Options: # path to input mesh mesh: Optional[str] = None # input text prompt prompt: Optional[str] = None # additional positive prompt positive_prompt: str = "best quality, extremely detailed, masterpiece, high resolution, high quality" # additional negative prompt negative_prompt: str = "blur, lowres, cropped, low quality, worst quality, ugly, dark, shadow, oversaturated" # whether to append directional text prompt text_dir: bool = False # set mesh front-facing direction (camera front=+z, right=+x, up=+y, clock-wise rotation 90=1, 180=2, 270=3, e.g., +z, -y1) front_dir: str = "+z" # training iterations iters: int = 500 # training resolution render_resolution: int = 512 # training camera radius radius: float = 2.5 # training camera fovy in degree fovy: float = 49.1 # whether to allow geom training fix_geo: bool = False # whether to mix normal with rgb for geometry training mix_normal: bool = True # whether to pretrain texture first fit_tex: bool = True # pretrain texture iterations fit_tex_iters: int = 512 # output folder outdir: str = '.' # output filename, default to {name}_fine.{ext} save: Optional[str] = None # guidance mode mode: Literal['SD', 'IF', 'IF2', 'SDCN', 'SD_NFSD', 'IF2_NFSD', 'SD_ISM', 'IF2_ISM'] = 'IF2' # renderer geometry mode geom_mode: Literal['mesh', 'diffmc', 'pbr_mesh', 'pbr_diffmc'] = 'diffmc' # renderer texture mode tex_mode: Literal['hashgrid', 'mlp', 'triplane'] = 'hashgrid' # training batch size per iter batch_size: int = 1 # environmental texture env_texture: Optional[str] = None # environmental light scale env_scale: float = 2 # DiffMC grid size mc_grid_size: int = 128 # Mesh remeshing interval remesh_interval: int = 200 # mesh decimation target face number decimate_target: int = 5e4 # remesh target edge length (smaller value lead to finer mesh) remesh_size: float = 0.015 # texture resolution texture_resolution: int = 1024 # learning rate for hashgrid hashgrid_lr: float = 0.01 # learning rate for feature MLP mlp_lr: float = 0.001 # learning rate for SDF sdf_lr: float = 0.0001 # learning rate for deformation deform_lr: float = 0.0001 # learning rate for mesh geometry geom_lr: float = 0.0001 # guidance loss weights lambda_sd: float = 1 # mesh laplacian regularization weight lambda_lap: float = 0 # mesh normal consistency weight (should be large enough) lambda_normal: float = 10000 # mesh vertices offset penalty weight lambda_offsets: float = 100 # whether to open a GUI gui: bool = False # GUI height H: int = 800 # GUI width W: int = 800 # whether to use CUDA rasterizer (in case OpenGL fails) force_cuda_rast: bool = False # whether to use GPU memory-optimized mode (slower, but uses less GPU memory) vram_O: bool = False # all the default settings config_defaults: Dict[str, Options] = {} config_doc: Dict[str, str] = {} config_doc['sd'] = 'coarse-level generation with stable-diffusion 2.' config_defaults['sd'] = Options( mode='SD', iters=800, ) config_doc['if'] = 'coarse-level generation with deepfloyd-if I.' config_defaults['if'] = Options( mode='IF', iters=400, ) config_doc['if2'] = 'fine-level refinement with deepfloyd-if II.' config_defaults['if2'] = Options( mode='IF2', iters=400, ) config_doc['sd_fixgeo'] = 'coarse-level generation with stable-diffusion 2, fixed goemetry.' config_defaults['sd_fixgeo'] = Options( mode='SD', iters=800, fix_geo=True, geom_mode='mesh', ) config_doc['if_fixgeo'] = 'coarse-level generation with deepfloyd-if I, fixed goemetry.' config_defaults['if_fixgeo'] = Options( mode='IF', iters=400, fix_geo=True, geom_mode='mesh', ) config_doc['if2_fixgeo'] = 'fine-level refinement with deepfloyd-if II, fixed goemetry.' config_defaults['if2_fixgeo'] = Options( mode='IF2', iters=400, fix_geo=True, geom_mode='mesh', ) def check_options(opt: Options): assert opt.mesh is not None, 'mesh path must be specified!' assert opt.prompt is not None, 'prompt must be specified!' if opt.save is None: input_name, input_ext = os.path.splitext(os.path.basename(opt.mesh)) opt.save = input_name + '_fine' + '.glb' print(f'[INFO] save to default output path: {os.path.join(opt.outdir, opt.save)}.') return opt ================================================ FILE: threefiner/renderer/__init__.py ================================================ ================================================ FILE: threefiner/renderer/diffmc_renderer.py ================================================ import os import tqdm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import nvdiffrast.torch as dr import kiui from kiui.mesh import Mesh from kiui.mesh_utils import clean_mesh, decimate_mesh from kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding from kiui.cam import orbit_camera, get_perspective from threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder from threefiner.renderer.mesh_renderer import render_mesh from diso import DiffMC, DiffDMC class Renderer(nn.Module): def __init__(self, opt, device): super().__init__() self.opt = opt self.device = device if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'): self.glctx = dr.RasterizeGLContext() else: self.glctx = dr.RasterizeCudaContext() # diffmc self.verts = torch.stack( torch.meshgrid( torch.linspace(-1, 1, self.opt.mc_grid_size, device=device), torch.linspace(-1, 1, self.opt.mc_grid_size, device=device), torch.linspace(-1, 1, self.opt.mc_grid_size, device=device), indexing="ij", ), dim=-1, ) # [N, N, N, 3] self.grid_scale = 1 self.diffmc = DiffMC(dtype=torch.float32).to(device) # vert sdf and deform self.sdf = nn.Parameter(torch.zeros_like(self.verts[..., 0])) self.deform = nn.Parameter(torch.zeros_like(self.verts)) # init diffmc from mesh self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir) vertices = self.mesh.v.detach().cpu().numpy() triangles = self.mesh.f.detach().cpu().numpy() vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False) self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device) self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device) self.grid_scale = self.mesh.v.abs().max() + 1e-1 self.verts = self.verts * self.grid_scale try: import cubvh BVH = cubvh.cuBVH(self.mesh.v, self.mesh.f) sdf, _, _ = BVH.signed_distance(self.verts.reshape(-1, 3), return_uvw=False, mode='raystab') # some mesh may not be watertight... except: from pysdf import SDF sdf_func = SDF(self.mesh.v.detach().cpu().numpy(), self.mesh.f.detach().cpu().numpy()) sdf = sdf_func(self.verts.detach().cpu().numpy().reshape(-1, 3)) sdf = torch.from_numpy(sdf).to(self.device) sdf *= -1 # OUTER is POSITIVE self.sdf.data += sdf.reshape(*self.sdf.data.shape).to(self.sdf.data.dtype) # texture if self.opt.tex_mode == 'hashgrid': self.encoder = HashGridEncoder().to(self.device) elif self.opt.tex_mode == 'mlp': self.encoder = FrequencyEncoder().to(self.device) elif self.opt.tex_mode == 'triplane': self.encoder = TriplaneEncoder().to(self.device) else: raise NotImplementedError(f"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}") self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=True).to(self.device) self.v, self.f = None, None # placeholder # init hashgrid texture from mesh if self.opt.fit_tex: self.fit_texture_from_mesh(self.opt.fit_tex_iters) def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1): return render_mesh( self.glctx, self.mesh.v, self.mesh.f, self.mesh.vt, self.mesh.ft, self.mesh.albedo, self.mesh.vc, self.mesh.vn, self.mesh.fn, pose, proj, h, w, ssaa=ssaa, bg_color=bg_color, ) def fit_texture_from_mesh(self, iters=512): # a small training loop... loss_fn = torch.nn.MSELoss() optimizer = torch.optim.Adam([ {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr}, {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr}, ]) resolution = 512 print(f"[INFO] fitting texture...") pbar = tqdm.trange(iters) for i in pbar: ver = np.random.randint(-45, 45) hor = np.random.randint(-180, 180) pose = orbit_camera(ver, hor, self.opt.radius) proj = get_perspective(self.opt.fovy) image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image'] image_pred = self.render(pose, proj, resolution, resolution)['image'] loss = loss_fn(image_pred, image_mesh) optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_description(f"MSE = {loss.item():.6f}") print(f"[INFO] finished fitting texture!") def get_params(self): params = [ {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr}, {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr}, ] if not self.opt.fix_geo: params.append({'params': self.sdf, 'lr': self.opt.sdf_lr}) params.append({'params': self.deform, 'lr': self.opt.deform_lr}) return params @torch.no_grad() def export_mesh(self, save_path, texture_resolution=2048, padding=16): # get v sdf = self.sdf deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5] v, f = self.diffmc(sdf, deform) v = (2 * v - 1) * self.grid_scale f = f.int() self.v, self.f = v, f vertices = v.detach().cpu().numpy() triangles = f.detach().cpu().numpy() # clean vertices = vertices.astype(np.float32) triangles = triangles.astype(np.int32) vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=self.opt.remesh_size) # decimation if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target: vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target) v = torch.from_numpy(vertices).contiguous().float().to(self.device) f = torch.from_numpy(triangles).contiguous().int().to(self.device) mesh = Mesh(v=v, f=f, albedo=None, device=self.device) print(f"[INFO] uv unwrapping...") mesh.auto_normal() mesh.auto_uv() # render uv maps h = w = texture_resolution uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1] uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4] rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4] xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3] mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1] # masked query xyzs = xyzs.view(-1, 3) mask = (mask > 0).view(-1) albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32) if mask.any(): print(f"[INFO] querying texture...") xyzs = xyzs[mask] # [M, 3] # batched inference to avoid OOM batch = [] head = 0 while head < xyzs.shape[0]: tail = min(head + 640000, xyzs.shape[0]) batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float()) head += 640000 albedo[mask] = torch.cat(batch, dim=0) albedo = albedo.view(h, w, -1) mask = mask.view(h, w) print(f"[INFO] uv padding...") albedo = uv_padding(albedo, mask, padding) mesh.albedo = albedo mesh.write(save_path) def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1): # do super-sampling if ssaa != 1: h = make_divisible(h0 * ssaa, 8) w = make_divisible(w0 * ssaa, 8) else: h, w = h0, w0 results = {} # get v sdf = self.sdf deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5] v, f = self.diffmc(sdf, deform) v = (2 * v - 1) * self.grid_scale f = f.int() self.v, self.f = v, f pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) proj = torch.from_numpy(proj.astype(np.float32)).to(v.device) # get v_clip and render rgb v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0) v_clip = v_cam @ proj.T rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [V, H, W, 1] alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients! depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1] depth = depth.squeeze(0) # [H, W, 1] xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3] xyzs = xyzs.view(-1, 3) mask = (alpha > 0).view(-1) color = torch.zeros_like(xyzs, dtype=torch.float32) if mask.any(): masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1))) color[mask] = masked_albedo.float() color = color.view(1, h, w, 3) # antialias color = dr.antialias(color, rast, v_clip, f).squeeze(0) # [H, W, 3] color = alpha * color + (1 - alpha) * bg_color # get vn and render normal i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long() v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0) face_normals = safe_normalize(face_normals) vn = torch.zeros_like(v) vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device)) normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f) normal = safe_normalize(normal[0]) # rotated normal (where [0, 0, 1] always faces camera) rot_normal = normal @ pose[:3, :3] viewcos = rot_normal[..., [2]] # ssaa if ssaa != 1: color = scale_img_hwc(color, (h0, w0)) alpha = scale_img_hwc(alpha, (h0, w0)) depth = scale_img_hwc(depth, (h0, w0)) normal = scale_img_hwc(normal, (h0, w0)) viewcos = scale_img_hwc(viewcos, (h0, w0)) results['image'] = color.clamp(0, 1) results['alpha'] = alpha results['depth'] = depth results['normal'] = (normal + 1) / 2 results['viewcos'] = viewcos return results ================================================ FILE: threefiner/renderer/mesh_renderer.py ================================================ import os import tqdm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import nvdiffrast.torch as dr from kiui.mesh import Mesh from kiui.mesh_utils import clean_mesh, decimate_mesh from kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding from kiui.cam import orbit_camera, get_perspective from threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder def render_mesh( glctx, v, f, vt, ft, albedo, vc, vn, fn, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-mipmap-linear', color_activation=None, ): # do super-sampling if ssaa != 1: h = make_divisible(h0 * ssaa, 8) w = make_divisible(w0 * ssaa, 8) else: h, w = h0, w0 results = {} pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) proj = torch.from_numpy(proj.astype(np.float32)).to(v.device) # get v_clip and render rgb v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0) v_clip = v_cam @ proj.T rast, rast_db = dr.rasterize(glctx, v_clip, f, (h, w)) alpha = (rast[0, ..., 3:] > 0).float() depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1] depth = depth.squeeze(0) # [H, W, 1] if vc is not None: # use vertex color color, _ = dr.interpolate(vc.unsqueeze(0).contiguous(), rast, f) else: # use texture image texc, texc_db = dr.interpolate(vt.unsqueeze(0).contiguous(), rast, ft, rast_db=rast_db, diff_attrs='all') color = dr.texture(albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3] if color_activation is not None: color = color_activation(color) # antialias color = dr.antialias(color, rast, v_clip, f).squeeze(0) # [H, W, 3] color = alpha * color + (1 - alpha) * bg_color # get vn and render normal if vn is None: i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long() v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0) face_normals = safe_normalize(face_normals) vn = torch.zeros_like(v) vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device)) normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, fn) normal = safe_normalize(normal[0]) # rotated normal (where [0, 0, 1] always faces camera) rot_normal = normal @ pose[:3, :3] viewcos = rot_normal[..., [2]] # ssaa if ssaa != 1: color = scale_img_hwc(color, (h0, w0)) alpha = scale_img_hwc(alpha, (h0, w0)) depth = scale_img_hwc(depth, (h0, w0)) normal = scale_img_hwc(normal, (h0, w0)) viewcos = scale_img_hwc(viewcos, (h0, w0)) results['image'] = color.clamp(0, 1) results['alpha'] = alpha results['depth'] = depth results['normal'] = (normal + 1) / 2 results['viewcos'] = viewcos return results class Renderer(nn.Module): def __init__(self, opt, device): super().__init__() self.opt = opt self.device = device self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir) # it's necessary to clean the mesh to facilitate later remeshing! vertices = self.mesh.v.detach().cpu().numpy() triangles = self.mesh.f.detach().cpu().numpy() vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False) self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device) self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device) if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'): self.glctx = dr.RasterizeGLContext() else: self.glctx = dr.RasterizeCudaContext() # extract trainable parameters self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v)) # texture if self.opt.tex_mode == 'hashgrid': self.encoder = HashGridEncoder().to(self.device) elif self.opt.tex_mode == 'mlp': self.encoder = FrequencyEncoder().to(self.device) elif self.opt.tex_mode == 'triplane': self.encoder = TriplaneEncoder().to(self.device) else: raise NotImplementedError(f"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}") self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=True).to(self.device) # init hashgrid texture from mesh if self.opt.fit_tex: self.fit_texture_from_mesh(self.opt.fit_tex_iters) def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1): return render_mesh( self.glctx, self.mesh.v, self.mesh.f, self.mesh.vt, self.mesh.ft, self.mesh.albedo, self.mesh.vc, self.mesh.vn, self.mesh.fn, pose, proj, h, w, ssaa=ssaa, bg_color=bg_color, ) def fit_texture_from_mesh(self, iters=512): # a small training loop... loss_fn = torch.nn.MSELoss() optimizer = torch.optim.Adam([ {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr}, {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr}, ]) resolution = 512 print(f"[INFO] fitting texture...") pbar = tqdm.trange(iters) for i in pbar: ver = np.random.randint(-45, 45) hor = np.random.randint(-180, 180) pose = orbit_camera(ver, hor, self.opt.radius) proj = get_perspective(self.opt.fovy) image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image'] image_pred = self.render(pose, proj, resolution, resolution)['image'] loss = loss_fn(image_pred, image_mesh) optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_description(f"MSE = {loss.item():.6f}") print(f"[INFO] finished fitting texture!") def get_params(self): params = [ {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr}, {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr}, ] if not self.opt.fix_geo: params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr}) return params @torch.no_grad() def export_mesh(self, save_path, texture_resolution=2048, padding=16): mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device) print(f"[INFO] uv unwrapping...") mesh.auto_normal() mesh.auto_uv() # render uv maps h = w = texture_resolution uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1] uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4] rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4] # masked query xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3] mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1] xyzs = xyzs.view(-1, 3) mask = (mask > 0).view(-1) albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32) if mask.any(): print(f"[INFO] querying texture...") xyzs = xyzs[mask] # [M, 3] # batched inference to avoid OOM batch = [] head = 0 while head < xyzs.shape[0]: tail = min(head + 640000, xyzs.shape[0]) batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float()) head += 640000 albedo[mask] = torch.cat(batch, dim=0) albedo = albedo.view(h, w, -1) mask = mask.view(h, w) print(f"[INFO] uv padding...") albedo = uv_padding(albedo, mask, padding) mesh.albedo = albedo mesh.write(save_path) @property def v(self): if self.opt.fix_geo: return self.mesh.v else: return self.mesh.v + self.v_offsets @property def f(self): return self.mesh.f @torch.no_grad() def remesh(self): vertices = self.v.detach().cpu().numpy() triangles = self.f.detach().cpu().numpy() vertices, triangles = clean_mesh(vertices, triangles, repair=False, remesh=True, remesh_size=self.opt.remesh_size) if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target: vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target, optimalplacement=False) self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device) self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device) self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v)).to(self.device) def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1): # do super-sampling if ssaa != 1: h = make_divisible(h0 * ssaa, 8) w = make_divisible(w0 * ssaa, 8) else: h, w = h0, w0 results = {} # get v v = self.v f = self.f pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) proj = torch.from_numpy(proj.astype(np.float32)).to(v.device) # get v_clip and render rgb v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0) v_clip = v_cam @ proj.T rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1] alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients! depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1] depth = depth.squeeze(0) # [H, W, 1] xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3] xyzs = xyzs.view(-1, 3) mask = (alpha > 0).view(-1) color = torch.zeros_like(xyzs, dtype=torch.float32) if mask.any(): masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1))) color[mask] = masked_albedo.float() color = color.view(1, h, w, 3) # antialias color = dr.antialias(color, rast, v_clip, f).clamp(0, 1).squeeze(0) # [H, W, 3] color = alpha * color + (1 - alpha) * bg_color # get vn and render normal if self.opt.fix_geo: vn = self.mesh.vn else: i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long() v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0) face_normals = safe_normalize(face_normals) vn = torch.zeros_like(v) vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device)) normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f) normal = safe_normalize(normal[0]) # rotated normal (where [0, 0, 1] always faces camera) rot_normal = normal @ pose[:3, :3] viewcos = rot_normal[..., [2]] # ssaa if ssaa != 1: color = scale_img_hwc(color, (h0, w0)) alpha = scale_img_hwc(alpha, (h0, w0)) depth = scale_img_hwc(depth, (h0, w0)) normal = scale_img_hwc(normal, (h0, w0)) viewcos = scale_img_hwc(viewcos, (h0, w0)) results['image'] = color results['alpha'] = alpha results['depth'] = depth results['normal'] = (normal + 1) / 2 results['viewcos'] = viewcos return results ================================================ FILE: threefiner/renderer/pbr_diffmc_renderer.py ================================================ import os import tqdm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import envlight import nvdiffrast.torch as dr import kiui from kiui.mesh import Mesh from kiui.mesh_utils import clean_mesh, decimate_mesh from kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding from kiui.cam import orbit_camera, get_perspective from threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder from threefiner.renderer.mesh_renderer import render_mesh from diso import DiffMC, DiffDMC class Renderer(nn.Module): def __init__(self, opt, device): super().__init__() self.opt = opt self.device = device if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'): self.glctx = dr.RasterizeGLContext() else: self.glctx = dr.RasterizeCudaContext() # diffmc self.verts = torch.stack( torch.meshgrid( torch.linspace(-1, 1, self.opt.mc_grid_size, device=device), torch.linspace(-1, 1, self.opt.mc_grid_size, device=device), torch.linspace(-1, 1, self.opt.mc_grid_size, device=device), indexing="ij", ), dim=-1, ) # [N, N, N, 3] self.grid_scale = 1 self.diffmc = DiffMC(dtype=torch.float32).to(device) # vert sdf and deform self.sdf = nn.Parameter(torch.zeros_like(self.verts[..., 0])) self.deform = nn.Parameter(torch.zeros_like(self.verts)) # init diffmc from mesh self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir) vertices = self.mesh.v.detach().cpu().numpy() triangles = self.mesh.f.detach().cpu().numpy() vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False) self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device) self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device) self.grid_scale = self.mesh.v.abs().max() + 1e-1 self.verts = self.verts * self.grid_scale try: import cubvh BVH = cubvh.cuBVH(self.mesh.v, self.mesh.f) sdf, _, _ = BVH.signed_distance(self.verts.reshape(-1, 3), return_uvw=False, mode='raystab') # some mesh may not be watertight... except: from pysdf import SDF sdf_func = SDF(self.mesh.v.detach().cpu().numpy(), self.mesh.f.detach().cpu().numpy()) sdf = sdf_func(self.verts.detach().cpu().numpy().reshape(-1, 3)) sdf = torch.from_numpy(sdf).to(self.device) sdf *= -1 # OUTER is POSITIVE self.sdf.data += sdf.reshape(*self.sdf.data.shape).to(self.sdf.data.dtype) # texture if self.opt.tex_mode == 'hashgrid': self.encoder = HashGridEncoder().to(self.device) elif self.opt.tex_mode == 'mlp': self.encoder = FrequencyEncoder().to(self.device) elif self.opt.tex_mode == 'triplane': self.encoder = TriplaneEncoder().to(self.device) else: raise NotImplementedError(f"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}") self.mlp = MLP(self.encoder.output_dim, 3+2, 32, 2, bias=True).to(self.device) self.v, self.f = None, None # placeholder # env light if self.opt.env_texture is None: hdr_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../lights/mud_road_puresky_1k.hdr') else: hdr_path = self.opt.env_texture self.light = envlight.EnvLight(hdr_path, scale=self.opt.env_scale, device=self.device) FG_LUT = torch.from_numpy(np.fromfile(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../lights/bsdf_256_256.bin"), dtype=np.float32).reshape(1, 256, 256, 2)).to(self.device) self.register_buffer("FG_LUT", FG_LUT) # init hashgrid texture from mesh if self.opt.fit_tex: self.fit_texture_from_mesh(self.opt.fit_tex_iters) def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1): return render_mesh( self.glctx, self.mesh.v, self.mesh.f, self.mesh.vt, self.mesh.ft, self.mesh.albedo, self.mesh.vc, self.mesh.vn, self.mesh.fn, pose, proj, h, w, ssaa=ssaa, bg_color=bg_color, ) def fit_texture_from_mesh(self, iters=512): # a small training loop... loss_fn = torch.nn.MSELoss() optimizer = torch.optim.Adam([ {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr}, {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr}, ]) resolution = 512 print(f"[INFO] fitting texture...") pbar = tqdm.trange(iters) for i in pbar: ver = np.random.randint(-45, 45) hor = np.random.randint(-180, 180) pose = orbit_camera(ver, hor, self.opt.radius) proj = get_perspective(self.opt.fovy) image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image'] image_pred = self.render(pose, proj, resolution, resolution)['image'] loss = loss_fn(image_pred, image_mesh) optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_description(f"MSE = {loss.item():.6f}") print(f"[INFO] finished fitting texture!") def get_params(self): params = [ {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr}, {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr}, ] if not self.opt.fix_geo: params.append({'params': self.sdf, 'lr': self.opt.sdf_lr}) params.append({'params': self.deform, 'lr': self.opt.deform_lr}) return params @torch.no_grad() def export_mesh(self, save_path, texture_resolution=2048, padding=16): # get v sdf = self.sdf deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5] v, f = self.diffmc(sdf, deform) v = (2 * v - 1) * self.grid_scale f = f.int() self.v, self.f = v, f vertices = v.detach().cpu().numpy() triangles = f.detach().cpu().numpy() # clean vertices = vertices.astype(np.float32) triangles = triangles.astype(np.int32) vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=self.opt.remesh_size) # decimation if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target: vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target) v = torch.from_numpy(vertices).contiguous().float().to(self.device) f = torch.from_numpy(triangles).contiguous().int().to(self.device) mesh = Mesh(v=v, f=f, albedo=None, device=self.device) print(f"[INFO] uv unwrapping...") mesh.auto_normal() mesh.auto_uv() # render uv maps h = w = texture_resolution uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1] uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4] rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4] xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3] mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1] # masked query xyzs = xyzs.view(-1, 3) mask = (mask > 0).view(-1) material = torch.zeros(h * w, 5, device=self.device, dtype=torch.float32) if mask.any(): print(f"[INFO] querying texture...") xyzs = xyzs[mask] # [M, 3] # batched inference to avoid OOM batch = [] head = 0 while head < xyzs.shape[0]: tail = min(head + 640000, xyzs.shape[0]) batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float()) head += 640000 material[mask] = torch.cat(batch, dim=0) material = material.view(h, w, -1) mask = mask.view(h, w) print(f"[INFO] uv padding...") material = uv_padding(material, mask, padding) mesh.albedo = material[..., :3] mesh.metallicRoughness = torch.cat([torch.zeros_like(material[..., 3:4]), material[..., 4:5], material[..., 3:4]], dim=-1) mesh.write(save_path) def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1): # do super-sampling if ssaa != 1: h = make_divisible(h0 * ssaa, 8) w = make_divisible(w0 * ssaa, 8) else: h, w = h0, w0 results = {} # get v sdf = self.sdf deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5] v, f = self.diffmc(sdf, deform) v = (2 * v - 1) * self.grid_scale f = f.int() self.v, self.f = v, f pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) proj = torch.from_numpy(proj.astype(np.float32)).to(v.device) # get v_clip and render rgb v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0) v_clip = v_cam @ proj.T rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [V, H, W, 1] alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients! depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1] depth = depth.squeeze(0) # [H, W, 1] xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3] viewdir = safe_normalize(xyzs - pose[:3, 3]).squeeze(0) xyzs = xyzs.view(-1, 3) mask = (alpha > 0).view(-1) material = torch.zeros(xyzs.shape[0], 5, dtype=torch.float32, device=xyzs.device) if mask.any(): masked_material = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1))) material[mask] = masked_material.float() material = material.view(h, w, -1) # get vn and render normal i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long() v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0) face_normals = safe_normalize(face_normals) vn = torch.zeros_like(v) vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device)) normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f) normal = safe_normalize(normal[0]) # rotated normal (where [0, 0, 1] always faces camera) rot_normal = normal @ pose[:3, :3] viewcos = rot_normal[..., [2]] # shading albedo = material[..., :3] metallic = material[..., 3:4] roughness = material[..., 4:5] n_dot_v = (normal * viewdir).sum(-1, keepdim=True) # [H, W, 1] reflective = n_dot_v * normal * 2 - viewdir diffuse_albedo = (1 - metallic) * albedo fg_uv = torch.cat([n_dot_v, roughness], -1).clamp(0, 1) # [H, W, 2] fg = dr.texture( self.FG_LUT, fg_uv.reshape(1, -1, 1, 2).contiguous(), filter_mode="linear", boundary_mode="clamp", ).reshape(h, w, 2) F0 = (1 - metallic) * 0.04 + metallic * albedo specular_albedo = F0 * fg[..., 0:1] + fg[..., 1:2] diffuse_light = self.light(normal) specular_light = self.light(reflective, roughness) color = diffuse_albedo * diffuse_light + specular_albedo * specular_light # [H, W, 3] # antialias color = dr.antialias(color.unsqueeze(0), rast, v_clip, f).clamp(0, 1).squeeze(0) # [H, W, 3] color = alpha * color + (1 - alpha) * bg_color # ssaa if ssaa != 1: color = scale_img_hwc(color, (h0, w0)) alpha = scale_img_hwc(alpha, (h0, w0)) depth = scale_img_hwc(depth, (h0, w0)) normal = scale_img_hwc(normal, (h0, w0)) viewcos = scale_img_hwc(viewcos, (h0, w0)) results['image'] = color results['alpha'] = alpha results['depth'] = depth results['normal'] = (normal + 1) / 2 results['viewcos'] = viewcos return results ================================================ FILE: threefiner/renderer/pbr_mesh_renderer.py ================================================ import os import tqdm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import envlight import nvdiffrast.torch as dr from kiui.mesh import Mesh from kiui.mesh_utils import clean_mesh, decimate_mesh from kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding from kiui.cam import orbit_camera, get_perspective from threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder from threefiner.renderer.mesh_renderer import render_mesh class Renderer(nn.Module): def __init__(self, opt, device): super().__init__() self.opt = opt self.device = device self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir) # it's necessary to clean the mesh to facilitate later remeshing! vertices = self.mesh.v.detach().cpu().numpy() triangles = self.mesh.f.detach().cpu().numpy() vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False) self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device) self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device) if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'): self.glctx = dr.RasterizeGLContext() else: self.glctx = dr.RasterizeCudaContext() # extract trainable parameters self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v)) # texture if self.opt.tex_mode == 'hashgrid': self.encoder = HashGridEncoder().to(self.device) elif self.opt.tex_mode == 'mlp': self.encoder = FrequencyEncoder().to(self.device) elif self.opt.tex_mode == 'triplane': self.encoder = TriplaneEncoder().to(self.device) else: raise NotImplementedError(f"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}") self.mlp = MLP(self.encoder.output_dim, 3+2, 32, 2, bias=True).to(self.device) # env light if self.opt.env_texture is None: hdr_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../lights/mud_road_puresky_1k.hdr') else: hdr_path = self.opt.env_texture self.light = envlight.EnvLight(hdr_path, scale=self.opt.env_scale, device=self.device) FG_LUT = torch.from_numpy(np.fromfile(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../lights/bsdf_256_256.bin"), dtype=np.float32).reshape(1, 256, 256, 2)).to(self.device) self.register_buffer("FG_LUT", FG_LUT) # init hashgrid texture from mesh if self.opt.fit_tex: self.fit_texture_from_mesh(self.opt.fit_tex_iters) def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1): return render_mesh( self.glctx, self.mesh.v, self.mesh.f, self.mesh.vt, self.mesh.ft, self.mesh.albedo, self.mesh.vc, self.mesh.vn, self.mesh.fn, pose, proj, h, w, ssaa=ssaa, bg_color=bg_color, ) def fit_texture_from_mesh(self, iters=512): # a small training loop... loss_fn = torch.nn.MSELoss() optimizer = torch.optim.Adam([ {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr}, {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr}, ]) resolution = 512 print(f"[INFO] fitting texture...") pbar = tqdm.trange(iters) for i in pbar: ver = np.random.randint(-45, 45) hor = np.random.randint(-180, 180) pose = orbit_camera(ver, hor, self.opt.radius) proj = get_perspective(self.opt.fovy) image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image'] image_pred = self.render(pose, proj, resolution, resolution)['image'] loss = loss_fn(image_pred, image_mesh) optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_description(f"MSE = {loss.item():.6f}") print(f"[INFO] finished fitting texture!") def get_params(self): params = [ {'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr}, {'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr}, ] if not self.opt.fix_geo: params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr}) return params @torch.no_grad() def export_mesh(self, save_path, texture_resolution=2048, padding=16): mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device) print(f"[INFO] uv unwrapping...") mesh.auto_normal() mesh.auto_uv() # render uv maps h = w = texture_resolution uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1] uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4] rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4] # masked query xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3] mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1] xyzs = xyzs.view(-1, 3) mask = (mask > 0).view(-1) material = torch.zeros(h * w, 5, device=self.device, dtype=torch.float32) if mask.any(): print(f"[INFO] querying texture...") xyzs = xyzs[mask] # [M, 3] # batched inference to avoid OOM batch = [] head = 0 while head < xyzs.shape[0]: tail = min(head + 640000, xyzs.shape[0]) batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float()) head += 640000 material[mask] = torch.cat(batch, dim=0) material = material.view(h, w, -1) mask = mask.view(h, w) print(f"[INFO] uv padding...") material = uv_padding(material, mask, padding) mesh.albedo = material[..., :3] mesh.metallicRoughness = torch.cat([torch.zeros_like(material[..., 3:4]), material[..., 4:5], material[..., 3:4]], dim=-1) mesh.write(save_path) @property def v(self): if self.opt.fix_geo: return self.mesh.v else: return self.mesh.v + self.v_offsets @property def f(self): return self.mesh.f @torch.no_grad() def remesh(self): vertices = self.v.detach().cpu().numpy() triangles = self.f.detach().cpu().numpy() vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=self.opt.remesh_size) if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target: vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target, optimalplacement=False) self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device) self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device) self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v)).to(self.device) def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1): # do super-sampling if ssaa != 1: h = make_divisible(h0 * ssaa, 8) w = make_divisible(w0 * ssaa, 8) else: h, w = h0, w0 results = {} # get v v = self.v f = self.f pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) proj = torch.from_numpy(proj.astype(np.float32)).to(v.device) # get v_clip and render rgb v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0) v_clip = v_cam @ proj.T rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1] alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients! depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1] depth = depth.squeeze(0) # [H, W, 1] xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3] viewdir = safe_normalize(xyzs - pose[:3, 3]).squeeze(0) xyzs = xyzs.view(-1, 3) mask = (alpha > 0).view(-1) material = torch.zeros(xyzs.shape[0], 5, dtype=torch.float32, device=xyzs.device) if mask.any(): masked_material = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1))) material[mask] = masked_material.float() material = material.view(h, w, -1) # get vn and render normal if self.opt.fix_geo: vn = self.mesh.vn else: i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long() v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0) face_normals = safe_normalize(face_normals) vn = torch.zeros_like(v) vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device)) normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f) normal = safe_normalize(normal[0]) # rotated normal (where [0, 0, 1] always faces camera) rot_normal = normal @ pose[:3, :3] viewcos = rot_normal[..., [2]] # shading albedo = material[..., :3] metallic = material[..., 3:4] roughness = material[..., 4:5] n_dot_v = (normal * viewdir).sum(-1, keepdim=True) # [H, W, 1] reflective = n_dot_v * normal * 2 - viewdir diffuse_albedo = (1 - metallic) * albedo fg_uv = torch.cat([n_dot_v, roughness], -1).clamp(0, 1) # [H, W, 2] fg = dr.texture( self.FG_LUT, fg_uv.reshape(1, -1, 1, 2).contiguous(), filter_mode="linear", boundary_mode="clamp", ).reshape(h, w, 2) F0 = (1 - metallic) * 0.04 + metallic * albedo specular_albedo = F0 * fg[..., 0:1] + fg[..., 1:2] diffuse_light = self.light(normal) specular_light = self.light(reflective, roughness) color = diffuse_albedo * diffuse_light + specular_albedo * specular_light # [H, W, 3] # antialias color = dr.antialias(color.unsqueeze(0), rast, v_clip, f).clamp(0, 1).squeeze(0) # [H, W, 3] color = alpha * color + (1 - alpha) * bg_color # ssaa if ssaa != 1: color = scale_img_hwc(color, (h0, w0)) alpha = scale_img_hwc(alpha, (h0, w0)) depth = scale_img_hwc(depth, (h0, w0)) normal = scale_img_hwc(normal, (h0, w0)) viewcos = scale_img_hwc(viewcos, (h0, w0)) results['image'] = color results['alpha'] = alpha results['depth'] = depth results['normal'] = (normal + 1) / 2 results['viewcos'] = viewcos return results