Repository: 3DTopia/threefiner
Branch: main
Commit: 6c34f089e61a
Files: 32
Total size: 196.8 KB
Directory structure:
gitextract_wp6txn7z/
├── .github/
│ └── workflows/
│ └── pypi-publish.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── data/
│ ├── car.glb
│ └── chair.ply
├── gradio_app.py
├── readme.md
├── scripts/
│ ├── run.sh
│ └── test_all.sh
├── setup.py
└── threefiner/
├── __init__.py
├── cli.py
├── gui.py
├── guidance/
│ ├── __init__.py
│ ├── if2_ism_utils.py
│ ├── if2_nfsd_utils.py
│ ├── if2_utils.py
│ ├── if_utils.py
│ ├── sd_ism_utils.py
│ ├── sd_nfsd_utils.py
│ ├── sd_utils.py
│ └── sdcn_utils.py
├── lights/
│ ├── LICENSE.txt
│ └── mud_road_puresky_1k.hdr
├── nn.py
├── opt.py
└── renderer/
├── __init__.py
├── diffmc_renderer.py
├── mesh_renderer.py
├── pbr_diffmc_renderer.py
└── pbr_mesh_renderer.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/pypi-publish.yml
================================================
name: Upload Python Package
on:
release:
types: [created]
workflow_dispatch:
jobs:
deploy:
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/project/threefiner/
permissions:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'
# prepare distributions in dist/
- name: Install dependencies and Build
run: |
python -m pip install --upgrade pip
pip install setuptools wheel
python setup.py sdist bdist_wheel
# publish by trusted publishers (need to first setup in pypi.org projects-manage-publishing!)
# ref: https://github.com/marketplace/actions/pypi-publish
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
================================================
FILE: .gitignore
================================================
__pycache__
tmp*
data_*
logs
logs*
videos*
*.egg-info
build/
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: MANIFEST.in
================================================
recursive-include threefiner/lights *
================================================
FILE: gradio_app.py
================================================
import os
import tyro
import tqdm
import torch
import gradio as gr
import kiui
from threefiner.opt import config_defaults, config_doc, check_options
from threefiner.gui import GUI
GRADIO_SAVE_PATH_MESH = 'gradio_output.glb'
GRADIO_SAVE_PATH_VIDEO = 'gradio_output.mp4'
opt = tyro.cli(tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc))
# hacks for not loading mesh at initialization
opt.save = GRADIO_SAVE_PATH_MESH
opt.prompt = ''
opt.text_dir = True
opt.front_dir = '+z'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gui = GUI(opt)
# process function
def process(input_model, input_text, input_dir, iters):
# set front facing direction (map from gradio model3D's mysterious coordinate system to OpenGL...)
opt.text_dir = True
if input_dir == 'front':
opt.front_dir = '-z'
elif input_dir == 'back':
opt.front_dir = '+z'
elif input_dir == 'left':
opt.front_dir = '+x'
elif input_dir == 'right':
opt.front_dir = '-x'
elif input_dir == 'up':
opt.front_dir = '+y'
elif input_dir == 'down':
opt.front_dir = '-y'
else:
# turn off text_dir
opt.text_dir = False
opt.front_dir = '+z'
# set mesh path
opt.mesh = input_model
# load mesh!
gui.renderer = gui.renderer_class(opt, device).to(device)
# set prompt
gui.prompt = opt.positive_prompt + ', ' + input_text
# train
gui.prepare_train() # update optimizer and prompt embeddings
for i in tqdm.trange(iters):
gui.train_step()
# save mesh & video
gui.save_model(GRADIO_SAVE_PATH_MESH)
gui.save_model(GRADIO_SAVE_PATH_VIDEO)
# return 3d model & video
return GRADIO_SAVE_PATH_MESH, GRADIO_SAVE_PATH_VIDEO
# gradio UI
block = gr.Blocks().queue()
with block:
gr.Markdown("""
## Threefiner: Text-guided mesh refinement.
""")
with gr.Row(variant='panel'):
with gr.Column(scale=1):
input_model = gr.Model3D(label="input mesh")
input_text = gr.Text(label="prompt")
input_dir = gr.Radio(['front', 'back', 'left', 'right', 'up', 'down'], label="front-facing direction")
iters = gr.Slider(minimum=100, maximum=1000, step=100, value=400, label="training iterations")
button_gen = gr.Button("Refine!")
with gr.Column(scale=1):
output_model = gr.Model3D(label="output mesh")
output_video = gr.Video(label="output video")
button_gen.click(process, inputs=[input_model, input_text, input_dir, iters], outputs=[output_model, output_video])
block.launch(server_name="0.0.0.0", share=True)
================================================
FILE: readme.md
================================================
Threefiner
An interface for text-guided mesh refinement.
https://github.com/3DTopia/threefiner/assets/25863658/a4abe725-b542-4a4a-a6d4-e4c4821f7d96
### Features
* **Mesh in, mesh out**: we support `ply` with vertex colors, `obj`, and single object `glb/gltf` with textures!
* **Easy to use**: both a CLI and a GUI is available.
* **Performant**: Refine your texture in 1 minute with Deepfloyd-IF-II.
### Install
We rely on `torch` and several CUDA extensions, please make sure you install them correctly first!
```bash
# tiny-cuda-nn
pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
# nvdiffrast
pip install git+https://github.com/NVlabs/nvdiffrast
# [optional, will use pysdf if unavailable] cubvh:
pip install git+https://github.com/ashawkey/cubvh
```
To use [Deepfloyd-IF](https://github.com/deep-floyd/IF) models, please log in to your huggingface and accept the [license](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0).
To install this package:
```bash
# install from pypi
pip install threefiner
# install from github
pip install git+https://github.com/3DTopia/threefiner
# local install
git clone https://github.com/3DTopia/threefiner
cd threefiner
pip install .
```
### Usage
```bash
### command line interface
threefiner --help
# this is short for
python -m threefiner.cli --help
### refine a coarse mesh ('input.obj') using Stable-diffusion and save to 'logs/hamburger.glb'
threefiner sd --mesh input.obj --prompt 'a hamburger' --outdir logs --save hamburger.glb
### if the initial texture is good, we recommend using IF2 for refinement.
# by default, it will save to './name_fine.glb'
threefiner if2 --mesh name.glb --prompt 'description'
### if the initial texture is not good, we recommend using SD or IF first.
threefiner sd --mesh name.glb --prompt 'description'
threefiner if --mesh name.glb --prompt 'description'
### if the initial geometry is good, you can fix the geometry.
threefiner sd_fixgeo --mesh name.glb --prompt 'description'
threefiner if_fixgeo --mesh name.glb --prompt 'description'
threefiner if2_fixgeo --mesh name.glb --prompt 'description'
### advanced
# directional text prompt (append front/side/back view in text prompt)
# you need to know the mesh's front facing direction and specify it by '--front_dir'
# we use the OpenGL coordinate system, i.e., +x is right, +y is up, +z is front (more details: https://kit.kiui.moe/camera/)
# clock-wise rotation can be specified per 90 degree, e.g., +z1, -y2
threefiner if2 --mesh input.glb --prompt 'description' --text_dir --front_dir='+z'
# adjust training iterations
threefiner if2 --mesh input.glb --prompt 'description' --iters 1000
# explicitly fix the geometry and only refine texture
threefiner if2 --fix-geo --geom_mode mesh --mesh input.glb --prompt 'description' # equals if2_fixgeo
# open a GUI to visualize the training progress (needs a desktop)
threefiner if2 --mesh input.glb --prompt 'description' --gui
```
Gradio demo:
```bash
# requires gradio 4
python gradio_app.py if2
```
For more examples, please see [scripts](./scripts/).
### Q&A
* **How to make sure `--front_dir` for your model?**
You may first visualize it in a 3D viewer that follows OpenGL coordinate system:
The chair is facing down the Y axis (Green), so we can use `--front_dir="-y"` to rectify it to face +Z axis (Blue).
* **fatal error: EGL/egl.h: No such file or directory**
By default, we use the OpenGL rasterizer. This error means there is no OpenGL installation, which is often the case for headless servers.
It's recommended to install OpenGL (along with NVIDIA driver) as it brings better performance.
Otherwise, you can append `--force_cuda_rast` to use the CUDA rasterizer instead.
## Acknowledgement
This work is built on many amazing research works and open-source projects, thanks a lot to all the authors for sharing!
- SDS `guidance` classes are based on [diffusers](https://github.com/huggingface/diffusers).
- `diffmc` geometry is based on [diso](https://github.com/SarahWeiii/diso).
- `mesh` geometry is based on [nerf2mesh](https://github.com/ashawkey/nerf2mesh).
- Texture encoding is based on [tinycudann](https://github.com/NVlabs/tiny-cuda-nn).
- Mesh renderer is based on [nvdiffrast](https://github.com/NVlabs/nvdiffrast).
- GUI is based on [dearpygui](https://github.com/hoffstadt/DearPyGui).
- The coarse models used in demo are generated by [Genie](https://lumalabs.ai/genie?view=create) and [3DTopia](https://github.com/3DTopia/3DTopia).
================================================
FILE: scripts/run.sh
================================================
export CUDA_VISIBLE_DEVICES=0
# the mesh is already with good initial texture, just refine it using IF2
threefiner if2 --mesh data/car.glb --prompt 'a red car' --outdir logs --save car_fine.glb --text_dir --front_dir='+x'
# the mesh is coarse, using SD for diverse texture generation and IF2 for refinement
threefiner sd --mesh data/chair.ply --prompt 'a swivel chair' --outdir logs --save chair_coarse.glb --text_dir --front_dir='-y'
threefiner if2 --mesh logs/chair_coarse.glb --prompt 'a swivel chair' --outdir logs --save chair_fine.glb --text_dir --front_dir='+z'
================================================
FILE: scripts/test_all.sh
================================================
export CUDA_VISIBLE_DEVICES=1
# geom_mode
threefiner if2 --geom_mode diffmc --save car_diffmc.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
threefiner if2 --geom_mode mesh --save car_mesh.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
threefiner if2 --geom_mode pbr_diffmc --save car_pbr_diffmc.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
threefiner if2 --geom_mode pbr_mesh --save car_pbr_mesh.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
# tex_mode
threefiner if2 --tex_mode mlp --save car_mlp.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
threefiner if2 --tex_mode triplane --save car_triplane.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
# guidance mode
threefiner sd --save car_SD.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
threefiner if --save car_IF.glb --mesh data/car.glb --prompt 'a red car' --outdir logs_test --text_dir --front_dir='+x'
================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
setup(
name = 'threefiner',
packages = find_packages(exclude=[]),
include_package_data = True,
entry_points={
# CLI tools
'console_scripts': [
'threefiner = threefiner.cli:main'
],
},
version = '0.1.2',
license='MIT',
description = 'Threefiner: a text-guided mesh refiner',
author = 'kiui',
author_email = 'ashawkey1999@gmail.com',
long_description=open("readme.md", encoding="utf-8").read(),
long_description_content_type = 'text/markdown',
url = 'https://github.com/3DTopia/threefiner',
keywords = [
'generative mesh refinement',
],
install_requires=[
'tyro',
'tqdm',
'rich',
'ninja',
'numpy',
'pandas',
'matplotlib',
'opencv-python',
'imageio',
'imageio-ffmpeg',
'scipy',
'scikit-learn',
'torch',
'einops',
'huggingface_hub',
'diffusers',
'accelerate',
'transformers',
"sentencepiece", # required by deepfloyd-if T5 encoder
'plyfile',
'pygltflib',
'xatlas',
'trimesh',
'PyMCubes',
'pymeshlab',
"pysdf",
"diso",
"envlight",
'dearpygui',
'kiui >= 0.2.1',
],
classifiers=[
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3',
],
)
================================================
FILE: threefiner/__init__.py
================================================
================================================
FILE: threefiner/cli.py
================================================
import os
import tyro
from threefiner.opt import config_defaults, config_doc, check_options
from threefiner.gui import GUI
def main():
opt = tyro.cli(tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc))
opt = check_options(opt)
gui = GUI(opt)
if gui.gui:
gui.render()
else:
gui.train(opt.iters)
if __name__ == "__main__":
main()
================================================
FILE: threefiner/gui.py
================================================
import os
import tqdm
import random
import imageio
import numpy as np
import torch
import torch.nn.functional as F
GUI_AVAILABLE = True
try:
import dearpygui.dearpygui as dpg
except Exception as e:
GUI_AVAILABLE = False
import kiui
from kiui.cam import orbit_camera, OrbitCamera
from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency
from threefiner.opt import Options
class GUI:
def __init__(self, opt: Options):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
if not GUI_AVAILABLE and opt.gui:
print(f'[WARN] cannot import dearpygui, assume without --gui')
self.gui = opt.gui and GUI_AVAILABLE # enable gui
self.W = opt.W
self.H = opt.H
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
self.mode = "image"
self.seed = "random"
self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # update buffer_image
self.save_path = os.path.join(self.opt.outdir, self.opt.save)
os.makedirs(self.opt.outdir, exist_ok=True)
# models
self.device = torch.device("cuda")
self.guidance = None
# renderer
if self.opt.geom_mode == 'mesh':
from threefiner.renderer.mesh_renderer import Renderer
elif self.opt.geom_mode == 'diffmc':
from threefiner.renderer.diffmc_renderer import Renderer
elif self.opt.geom_mode == 'pbr_mesh':
from threefiner.renderer.pbr_mesh_renderer import Renderer
elif self.opt.geom_mode == 'pbr_diffmc':
from threefiner.renderer.pbr_diffmc_renderer import Renderer
else:
raise NotImplementedError(f"unknown geometry mode: {self.opt.geom_mode}")
self.renderer_class = Renderer
if self.opt.mesh is None:
self.renderer = None
else:
self.renderer = Renderer(opt, self.device).to(self.device)
# input prompt
self.prompt = self.opt.prompt
self.negative_prompt = ""
if self.opt.positive_prompt is not None:
self.prompt = self.opt.positive_prompt + ', ' + self.prompt
if self.opt.negative_prompt is not None:
self.negative_prompt = self.opt.negative_prompt
# training stuff
self.training = False
self.optimizer = None
self.step = 0
self.train_steps = 1 # steps per rendering loop
if self.gui:
dpg.create_context()
self.register_dpg()
self.test_step()
def __del__(self):
if self.gui:
dpg.destroy_context()
def seed_everything(self):
try:
seed = int(self.seed)
except:
seed = np.random.randint(0, 1000000)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
self.last_seed = seed
def prepare_train(self):
assert self.renderer is not None, 'no mesh loaded!'
self.step = 0
# setup training
self.optimizer = torch.optim.Adam(self.renderer.get_params())
# lazy load guidance model
if self.guidance is None:
print(f"[INFO] loading guidance...")
if self.opt.mode == 'SD':
from threefiner.guidance.sd_utils import StableDiffusion
self.guidance = StableDiffusion(self.device, vram_O=self.opt.vram_O)
elif self.opt.mode == 'SD_NFSD':
from threefiner.guidance.sd_nfsd_utils import StableDiffusion
self.guidance = StableDiffusion(self.device, vram_O=self.opt.vram_O)
elif self.opt.mode == 'SDCN':
from threefiner.guidance.sdcn_utils import StableDiffusionControlNet
self.guidance = StableDiffusionControlNet(self.device, vram_O=self.opt.vram_O)
elif self.opt.mode == 'IF':
from threefiner.guidance.if_utils import IF
self.guidance = IF(self.device, vram_O=self.opt.vram_O)
elif self.opt.mode == 'IF2':
from threefiner.guidance.if2_utils import IF2
self.guidance = IF2(self.device, vram_O=self.opt.vram_O)
elif self.opt.mode == 'IF2_NFSD':
from threefiner.guidance.if2_nfsd_utils import IF2
self.guidance = IF2(self.device, vram_O=self.opt.vram_O)
elif self.opt.mode == 'SD_ISM':
from threefiner.guidance.sd_ism_utils import StableDiffusion
self.guidance = StableDiffusion(self.device, vram_O=self.opt.vram_O)
elif self.opt.mode == 'IF2_ISM':
from threefiner.guidance.if2_ism_utils import IF2
self.guidance = IF2(self.device, vram_O=self.opt.vram_O)
else:
raise NotImplementedError(f"unknown guidance mode {self.opt.mode}!")
print(f"[INFO] loaded guidance!")
# prepare embeddings
with torch.no_grad():
self.guidance.get_text_embeds([self.prompt], [self.negative_prompt])
def train_step(self):
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
starter.record()
self.renderer.train()
for _ in range(self.train_steps):
self.step += 1
step_ratio = min(1, self.step / self.opt.iters)
loss = 0
### novel view (manual batch)
images = []
poses = []
normals = []
ori_images = []
vers, hors, radii = [], [], []
for _ in range(self.opt.batch_size):
# render random view
ver = np.random.randint(-60, 30)
hor = np.random.randint(-180, 180)
radius = np.random.uniform() - 0.5 # [-0.5, 0.5]
pose = orbit_camera(ver, hor, self.opt.radius + radius)
vers.append(ver)
hors.append(hor)
radii.append(radius)
poses.append(pose)
# random render resolution
ssaa = min(2.0, max(0.125, 2 * np.random.random()))
out = self.renderer.render(pose, self.cam.perspective, self.opt.render_resolution, self.opt.render_resolution, ssaa=ssaa)
image = out["image"] # [H, W, 3] in [0, 1]
image = image.permute(2,0,1).contiguous().unsqueeze(0) # [1, 3, H, W] in [0, 1]
images.append(image)
# mix_normal
if not self.opt.fix_geo and self.opt.mix_normal:
normal = out['normal']
normal = normal.permute(2,0,1).contiguous().unsqueeze(0)
normals.append(normal)
# IF SR model requires the original rendering
if self.opt.mode in ['IF2', 'IF2_NFSD', 'IF2_ISM', 'SDCN']:
out_mesh = self.renderer.render_mesh(pose, self.cam.perspective, self.opt.render_resolution, self.opt.render_resolution, ssaa=1)
ori_image = out_mesh["image"] # [H, W, 3] in [0, 1]
ori_image = ori_image.permute(2,0,1).contiguous().unsqueeze(0)
ori_images.append(ori_image)
# ori_images.append(image.clone())
# guidance loss
guidance_input = {'pred_rgb': torch.cat(images, dim=0)}
if not self.opt.fix_geo and self.opt.mix_normal:
if random.random() > 0.5:
ratio = random.random()
guidance_input['pred_rgb'] = guidance_input['pred_rgb'] * ratio + torch.cat(normals, dim=0) * (1 - ratio)
# guidance_input['step_ratio'] = step_ratio
if self.opt.mode in ['IF2', 'IF2_NFSD', 'IF2_ISM']:
guidance_input['ori_rgb'] = torch.cat(ori_images, dim=0)
if self.opt.mode == 'SDCN':
guidance_input['control_images'] = {'tile': torch.cat(ori_images, dim=0)}
if self.opt.text_dir:
guidance_input['vers'] = vers
guidance_input['hors'] = hors
loss = loss + self.opt.lambda_sd * self.guidance.train_step(**guidance_input)
# geom regularizations
if self.opt.geom_mode in ['diffmc', 'pbr_diffmc', 'mesh', 'pbr_mesh'] and not self.opt.fix_geo:
if self.opt.lambda_lap > 0:
lap_loss = laplacian_smooth_loss(self.renderer.v, self.renderer.f)
loss = loss + self.opt.lambda_lap * lap_loss
if self.opt.lambda_normal > 0:
normal_loss = normal_consistency(self.renderer.v, self.renderer.f)
loss = loss + self.opt.lambda_normal * normal_loss
if self.opt.geom_mode in ['mesh', 'pbr_mesh'] and self.opt.lambda_offsets > 0:
offset_loss = (self.renderer.v_offsets ** 2).sum(-1).mean()
loss = loss + self.opt.lambda_offsets * offset_loss
# optimize step
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
# for mesh geom_mode: peoriodically remesh
if self.opt.geom_mode in ['mesh', 'pbr_mesh'] and not self.opt.fix_geo:
if self.step > 0 and self.step % self.opt.remesh_interval == 0:
self.renderer.remesh()
# reset optimizer
self.optimizer = torch.optim.Adam(self.renderer.get_params())
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
self.need_update = True
if self.gui:
dpg.set_value("_log_train_time", f"{t:.4f}ms")
dpg.set_value(
"_log_train_log",
f"step = {self.step: 5d} (+{self.train_steps: 2d}) loss = {loss.item():.4f}",
)
@torch.no_grad()
def test_step(self):
# ignore if no need to update
if not self.need_update:
return
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
starter.record()
# should update image
if self.need_update:
# render image
self.renderer.eval()
out = self.renderer.render(self.cam.pose, self.cam.perspective, self.H, self.W)
buffer_image = out[self.mode] # [H, W, 3]
if self.mode in ['depth', 'alpha']:
buffer_image = buffer_image.repeat(1, 1, 3)
if self.mode == 'depth':
buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20)
self.buffer_image = buffer_image.contiguous().clamp(0, 1).detach().cpu().numpy()
self.need_update = False
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
if self.gui:
dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)")
dpg.set_value(
"_texture", self.buffer_image
) # buffer must be contiguous, else seg fault!
def save_model(self, save_path=None):
if save_path is None:
save_path = self.save_path
# export video
if save_path.endswith(".mp4"):
images = []
elevation = 0
azimuth = np.arange(0, 360, 3, dtype=np.int32) # front-->back-->front
for azi in tqdm.tqdm(azimuth):
pose = orbit_camera(elevation, azi, self.opt.radius)
out = self.renderer.render(pose, self.cam.perspective, self.opt.render_resolution, self.opt.render_resolution, ssaa=1)
image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8)
images.append(image)
images = np.stack(images, axis=0)
# ~4 seconds, 120 frames at 30 fps
imageio.mimwrite(save_path, images, fps=30, quality=8, macro_block_size=1)
# export mesh
else:
self.renderer.export_mesh(save_path, texture_resolution=self.opt.texture_resolution)
print(f"[INFO] save model to {save_path}.")
def register_dpg(self):
### register texture
with dpg.texture_registry(show=False):
dpg.add_raw_texture(
self.W,
self.H,
self.buffer_image,
format=dpg.mvFormat_Float_rgb,
tag="_texture",
)
### register window
# the rendered image, as the primary window
with dpg.window(
tag="_primary_window",
width=self.W,
height=self.H,
pos=[0, 0],
no_move=True,
no_title_bar=True,
no_scrollbar=True,
):
# add the texture
dpg.add_image("_texture")
# dpg.set_primary_window("_primary_window", True)
# control window
with dpg.window(
label="Control",
tag="_control_window",
width=600,
height=self.H,
pos=[self.W, 0],
no_move=True,
no_title_bar=True,
):
# button theme
with dpg.theme() as theme_button:
with dpg.theme_component(dpg.mvButton):
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
# timer stuff
with dpg.group(horizontal=True):
dpg.add_text("Infer time: ")
dpg.add_text("no data", tag="_log_infer_time")
def callback_setattr(sender, app_data, user_data):
setattr(self, user_data, app_data)
# init stuff
with dpg.collapsing_header(label="Initialize", default_open=True):
# seed stuff
def callback_set_seed(sender, app_data):
self.seed = app_data
self.seed_everything()
dpg.add_input_text(
label="seed",
default_value=self.seed,
on_enter=True,
callback=callback_set_seed,
)
# input stuff
def callback_select_input(sender, app_data):
# only one item
for k, v in app_data["selections"].items():
dpg.set_value("_log_input", k)
self.load_input(v)
self.need_update = True
with dpg.file_dialog(
directory_selector=False,
show=False,
callback=callback_select_input,
file_count=1,
tag="file_dialog_tag",
width=700,
height=400,
):
dpg.add_file_extension("Images{.jpg,.jpeg,.png}")
with dpg.group(horizontal=True):
dpg.add_button(
label="input",
callback=lambda: dpg.show_item("file_dialog_tag"),
)
dpg.add_text("", tag="_log_input")
# prompt stuff
dpg.add_input_text(
label="prompt",
default_value=self.prompt,
callback=callback_setattr,
user_data="prompt",
)
dpg.add_input_text(
label="negative",
default_value=self.negative_prompt,
callback=callback_setattr,
user_data="negative_prompt",
)
# save current model
with dpg.group(horizontal=True):
dpg.add_text("Save: ")
dpg.add_button(
label="model",
tag="_button_save_model",
callback=self.save_model,
)
dpg.bind_item_theme("_button_save_model", theme_button)
dpg.add_input_text(
label="",
default_value=self.save_path,
callback=callback_setattr,
user_data="save_path",
)
# training stuff
with dpg.collapsing_header(label="Train", default_open=True):
# lr and train button
with dpg.group(horizontal=True):
dpg.add_text("Train: ")
def callback_train(sender, app_data):
if self.training:
self.training = False
dpg.configure_item("_button_train", label="start")
else:
self.prepare_train()
self.training = True
dpg.configure_item("_button_train", label="stop")
# dpg.add_button(
# label="init", tag="_button_init", callback=self.prepare_train
# )
# dpg.bind_item_theme("_button_init", theme_button)
dpg.add_button(
label="start", tag="_button_train", callback=callback_train
)
dpg.bind_item_theme("_button_train", theme_button)
with dpg.group(horizontal=True):
dpg.add_text("", tag="_log_train_time")
dpg.add_text("", tag="_log_train_log")
# rendering options
with dpg.collapsing_header(label="Rendering", default_open=True):
# mode combo
def callback_change_mode(sender, app_data):
self.mode = app_data
self.need_update = True
dpg.add_combo(
("image", "depth", "alpha", "normal"),
label="mode",
default_value=self.mode,
callback=callback_change_mode,
)
# fov slider
def callback_set_fovy(sender, app_data):
self.cam.fovy = np.deg2rad(app_data)
self.need_update = True
dpg.add_slider_int(
label="FoV (vertical)",
min_value=1,
max_value=120,
format="%d deg",
default_value=np.rad2deg(self.cam.fovy),
callback=callback_set_fovy,
)
### register camera handler
def callback_camera_drag_rotate_or_draw_mask(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.orbit(dx, dy)
self.need_update = True
def callback_camera_wheel_scale(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
delta = app_data
self.cam.scale(delta)
self.need_update = True
def callback_camera_drag_pan(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.pan(dx, dy)
self.need_update = True
with dpg.handler_registry():
# for camera moving
dpg.add_mouse_drag_handler(
button=dpg.mvMouseButton_Left,
callback=callback_camera_drag_rotate_or_draw_mask,
)
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
dpg.add_mouse_drag_handler(
button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan
)
dpg.create_viewport(
title="Threefiner",
width=self.W + 600,
height=self.H + (45 if os.name == "nt" else 0),
resizable=False,
)
### global theme
with dpg.theme() as theme_no_padding:
with dpg.theme_component(dpg.mvAll):
# set all padding to 0 to avoid scroll bar
dpg.add_theme_style(
dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.add_theme_style(
dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.add_theme_style(
dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.bind_item_theme("_primary_window", theme_no_padding)
dpg.setup_dearpygui()
### register a larger font
# get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf
if os.path.exists("LXGWWenKai-Regular.ttf"):
with dpg.font_registry():
with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font:
dpg.bind_font(default_font)
# dpg.show_metrics()
dpg.show_viewport()
def render(self):
assert self.gui
while dpg.is_dearpygui_running():
# update texture every frame
if self.training:
self.train_step()
self.test_step()
dpg.render_dearpygui_frame()
# no gui mode
def train(self, iters=500):
if iters > 0:
self.prepare_train()
for i in tqdm.trange(iters):
self.train_step()
# save
self.save_model()
================================================
FILE: threefiner/guidance/__init__.py
================================================
================================================
FILE: threefiner/guidance/if2_ism_utils.py
================================================
from diffusers import (
PNDMScheduler,
DDIMScheduler,
IFPipeline,
IFSuperResolutionPipeline,
)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def invert_noise(scheduler, noisy_samples, noise, timesteps):
alphas_cumprod = scheduler.alphas_cumprod.to(device=noisy_samples.device, dtype=noisy_samples.dtype)
timesteps = timesteps.to(noisy_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(noisy_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(noisy_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
original_samples = 1 / sqrt_alpha_prod * (noisy_samples - sqrt_one_minus_alpha_prod * noise)
return original_samples
class IF2(nn.Module):
def __init__(
self,
device,
fp16=True,
vram_O=False,
model_key = "DeepFloyd/IF-II-M-v1.0",
t_range=[0.02, 0.50],
):
super().__init__()
self.device = device
self.model_key = model_key
self.dtype = torch.float16 if fp16 else torch.float32
# Create model
pipe = IFSuperResolutionPipeline.from_pretrained(
model_key, variant="fp16", torch_dtype=torch.float16,
watermarker=None, safety_checker=None, requires_safety_checker=False,
)
if vram_O:
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
# pipe.enable_model_cpu_offload()
else:
pipe.to(device)
self.unet = pipe.unet
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.scheduler = pipe.scheduler
self.image_noising_scheduler = pipe.image_noising_scheduler
self.pipe = pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
self.embeddings = {}
@torch.no_grad()
def get_text_embeds(self, prompts, negative_prompts):
pos_embeds = self.encode_text(prompts) # [1, 77, 768]
neg_embeds = self.encode_text(negative_prompts)
null_embeds = self.encode_text([""])
self.embeddings['pos'] = pos_embeds
self.embeddings['neg'] = neg_embeds
self.embeddings['null'] = null_embeds
# directional embeddings
for d in ['front', 'side', 'back']:
embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
self.embeddings[d] = embeds
def encode_text(self, prompt):
# prompt: [str]
prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
def train_step(
self,
pred_rgb,
ori_rgb,
step_ratio=None,
guidance_scale=5,
vers=None, hors=None,
delta_t=50, delta_s=200,
):
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
ori_rgb = ori_rgb.to(self.dtype)
images = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) * 2 - 1
with torch.no_grad():
max_t = torch.full((batch_size,), self.max_step, dtype=torch.long, device=self.device)
# images_upscaled = images.clone()
images_upscaled = F.interpolate(ori_rgb, (256, 256), mode="bilinear", align_corners=False).clamp(0, 1) * 2 - 1
noise = torch.randn_like(images_upscaled)
images_upscaled = self.image_noising_scheduler.add_noise(images_upscaled, noise, max_t)
if step_ratio is not None:
# dreamtime-like
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
else:
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
# w(t), sigma_t^2
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1).to(self.dtype)
######### debug
# imagesx = self.produce_imgs(
# images_upscaled=images_upscaled,
# max_t=max_t,
# images=torch.randn_like(images),
# num_inference_steps=50,
# guidance_scale=4.0,
# ) # [1, 3, 64, 64]
# import kiui
# kiui.vis.plot_image(images_upscaled * 0.5 + 0.5)
# kiui.vis.plot_image(imagesx * 0.5 + 0.5)
#########
null_embeddings = self.embeddings['null'].expand(batch_size, -1, -1)
if hors is None:
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
else:
def _get_dir_ind(h):
if abs(h) < 60: return 'front'
elif abs(h) < 120: return 'side'
else: return 'back'
embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])
########### ISM
# steps
t = t.clamp(min=delta_t)
s = t - delta_t
n = s // delta_s
r = s % delta_s
# construct trajectory
images_noisy = images.clone()
cur_t = torch.full((batch_size,), 0, dtype=torch.long, device=self.device)
noise = self.unet(torch.cat([images_noisy, images_upscaled], dim=1), cur_t, encoder_hidden_states=null_embeddings, class_labels=max_t).sample.split(images_noisy.shape[1], dim=1)[0]
images_original = invert_noise(self.scheduler, images_noisy, noise, cur_t)
cur_t += r
images_noisy = self.scheduler.add_noise(images_original, noise, cur_t)
for i in range(n):
noise = self.unet(torch.cat([images_noisy, images_upscaled], dim=1), cur_t, encoder_hidden_states=null_embeddings, class_labels=max_t).sample.split(images_noisy.shape[1], dim=1)[0]
images_original = invert_noise(self.scheduler, images_noisy, noise, cur_t)
cur_t += delta_s
images_noisy = self.scheduler.add_noise(images_original, noise, cur_t) # x_s
# construct last step
noise = self.unet(torch.cat([images_noisy, images_upscaled], dim=1), cur_t, encoder_hidden_states=null_embeddings, class_labels=max_t).sample.split(images_noisy.shape[1], dim=1)[0]
images_original = invert_noise(self.scheduler, images_noisy, noise, cur_t) # \hat x_0^s
# perform guidance
images_noisy = self.scheduler.add_noise(images_original, noise, t)
model_input = torch.cat([images_noisy, images_upscaled], dim=1)
model_input = torch.cat([model_input] * 2)
model_input = self.scheduler.scale_model_input(model_input, t)
tt = torch.cat([t] * 2)
max_tt = torch.cat([max_t] * 2)
noise_pred = self.unet(
model_input, tt, encoder_hidden_states=embeddings, class_labels=max_tt,
).sample
# perform guidance (high scale from paper!)
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
# grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
# grad = grad_norm.clamp(max=0.1) * grad / grad_norm
target = (images - grad).detach()
loss = 0.5 * F.mse_loss(images, target, reduction='sum') / images.shape[0]
return loss
@torch.no_grad()
def produce_imgs(
self,
images_upscaled,
max_t,
height=256,
width=256,
num_inference_steps=50,
guidance_scale=4.0,
images=None,
):
if images is None:
images = torch.randn(
(
1,
self.unet.in_channels,
height,
width,
),
device=self.device,
)
batch_size = images.shape[0]
self.scheduler.set_timesteps(num_inference_steps)
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps):
# expand the images if we are doing classifier-free guidance to avoid doing two forward passes.
model_input = torch.cat([images, images_upscaled], dim=1)
model_input = torch.cat([model_input] * 2)
model_input = self.scheduler.scale_model_input(model_input, t)
max_tt = torch.cat([max_t] * 2)
# predict the noise residual
noise_pred = self.unet(
model_input, t, encoder_hidden_states=embeddings, class_labels=max_tt,
).sample
# perform guidance
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
# compute the previous noisy sample x_t -> x_t-1
images = self.scheduler.step(noise_pred, t, images).prev_sample
return images
def prompt_to_img(
self,
images_upscaled,
max_t,
prompts,
negative_prompts="",
height=256,
width=256,
num_inference_steps=50,
guidance_scale=4.0,
images=None,
):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
self.get_text_embeds(prompts, negative_prompts)
# Text embeds -> img images
images = self.produce_imgs(
images_upscaled=images_upscaled,
max_t=max_t,
height=height,
width=width,
images=images,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
) # [1, 4, 64, 64]
# Img to Numpy
images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (images * 255).round().astype("uint8")
return images
if __name__ == "__main__":
import kiui
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("prompt", type=str)
parser.add_argument("--negative", default="", type=str)
parser.add_argument("--fp16", action="store_true", help="use float16 for training")
parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
parser.add_argument("-H", type=int, default=512)
parser.add_argument("-W", type=int, default=512)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--steps", type=int, default=50)
opt = parser.parse_args()
kiui.seed_everything(opt.seed)
device = torch.device("cuda")
sd = IF2(device, opt.fp16, opt.vram_O)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()
================================================
FILE: threefiner/guidance/if2_nfsd_utils.py
================================================
from diffusers import (
PNDMScheduler,
DDIMScheduler,
IFPipeline,
IFSuperResolutionPipeline,
)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class IF2(nn.Module):
def __init__(
self,
device,
fp16=True,
vram_O=False,
model_key = "DeepFloyd/IF-II-L-v1.0",
t_range=[0.02, 0.50],
):
super().__init__()
self.device = device
self.model_key = model_key
self.dtype = torch.float16 if fp16 else torch.float32
# Create model
pipe = IFSuperResolutionPipeline.from_pretrained(
model_key, variant="fp16", torch_dtype=torch.float16,
watermarker=None, safety_checker=None, requires_safety_checker=False,
)
if vram_O:
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
# pipe.enable_model_cpu_offload()
else:
pipe.to(device)
self.unet = pipe.unet
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.scheduler = pipe.scheduler
self.image_noising_scheduler = pipe.image_noising_scheduler
self.pipe = pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
self.embeddings = {}
@torch.no_grad()
def get_text_embeds(self, prompts, negative_prompts):
pos_embeds = self.encode_text(prompts) # [1, 77, 768]
neg_embeds = self.encode_text(negative_prompts)
null_embeds = self.encode_text([''])
self.embeddings['pos'] = pos_embeds
self.embeddings['neg'] = neg_embeds
self.embeddings['null'] = null_embeds
# directional embeddings
for d in ['front', 'side', 'back']:
embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
self.embeddings[d] = embeds
def encode_text(self, prompt):
# prompt: [str]
prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
def train_step(
self,
pred_rgb,
ori_rgb,
step_ratio=None,
guidance_scale=5,
vers=None, hors=None,
):
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
ori_rgb = ori_rgb.to(self.dtype)
images = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) * 2 - 1
with torch.no_grad():
max_t = torch.full((batch_size,), self.max_step, dtype=torch.long, device=self.device)
# images_upscaled = images.clone()
images_upscaled = F.interpolate(ori_rgb, (256, 256), mode="bilinear", align_corners=False).clamp(0, 1) * 2 - 1
noise = torch.randn_like(images_upscaled)
images_upscaled = self.image_noising_scheduler.add_noise(images_upscaled, noise, max_t)
if step_ratio is not None:
# dreamtime-like
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
else:
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
# w(t), sigma_t^2
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1).to(self.dtype)
######### debug
# imagesx = self.produce_imgs(
# images_upscaled=images_upscaled,
# max_t=max_t,
# images=torch.randn_like(images),
# num_inference_steps=50,
# guidance_scale=4.0,
# ) # [1, 3, 64, 64]
# import kiui
# kiui.vis.plot_image(images_upscaled * 0.5 + 0.5)
# kiui.vis.plot_image(imagesx * 0.5 + 0.5)
#########
# add noise
noise = torch.randn_like(images)
images_noisy = self.scheduler.add_noise(images, noise, t)
# pred noise
model_input = torch.cat([images_noisy, images_upscaled], dim=1)
model_input = torch.cat([model_input] * 3)
model_input = self.scheduler.scale_model_input(model_input, t)
tt = torch.cat([t] * 3)
max_tt = torch.cat([max_t] * 3)
if hors is None:
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)])
else:
def _get_dir_ind(h):
if abs(h) < 60: return 'front'
elif abs(h) < 120: return 'side'
else: return 'back'
embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)])
noise_pred = self.unet(
model_input, tt, encoder_hidden_states=embeddings, class_labels=max_tt,
).sample
# perform guidance (high scale from paper!)
noise_pred_cond, noise_pred_uncond, noise_pred_null = noise_pred.chunk(3)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
noise_pred_cond, _ = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)
noise_pred_null, _ = noise_pred_null.split(model_input.shape[1] // 2, dim=1)
delta_c = guidance_scale * (noise_pred_cond - noise_pred_null)
mask = (t < 200).int().view(batch_size, 1, 1, 1)
delta_d = mask * noise_pred_null + (1 - mask) * (noise_pred_null - noise_pred_uncond)
grad = w * (delta_c + delta_d)
grad = torch.nan_to_num(grad)
# grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
# grad = grad_norm.clamp(max=0.1) * grad / grad_norm
target = (images - grad).detach()
loss = 0.5 * F.mse_loss(images, target, reduction='sum') / images.shape[0]
return loss
@torch.no_grad()
def produce_imgs(
self,
images_upscaled,
max_t,
height=256,
width=256,
num_inference_steps=50,
guidance_scale=4.0,
images=None,
):
if images is None:
images = torch.randn(
(
1,
self.unet.in_channels,
height,
width,
),
device=self.device,
)
batch_size = images.shape[0]
self.scheduler.set_timesteps(num_inference_steps)
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps):
# expand the images if we are doing classifier-free guidance to avoid doing two forward passes.
model_input = torch.cat([images, images_upscaled], dim=1)
model_input = torch.cat([model_input] * 2)
model_input = self.scheduler.scale_model_input(model_input, t)
max_tt = torch.cat([max_t] * 2)
# predict the noise residual
noise_pred = self.unet(
model_input, t, encoder_hidden_states=embeddings, class_labels=max_tt,
).sample
# perform guidance
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
# compute the previous noisy sample x_t -> x_t-1
images = self.scheduler.step(noise_pred, t, images).prev_sample
return images
def prompt_to_img(
self,
images_upscaled,
max_t,
prompts,
negative_prompts="",
height=256,
width=256,
num_inference_steps=50,
guidance_scale=4.0,
images=None,
):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
self.get_text_embeds(prompts, negative_prompts)
# Text embeds -> img images
images = self.produce_imgs(
images_upscaled=images_upscaled,
max_t=max_t,
height=height,
width=width,
images=images,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
) # [1, 4, 64, 64]
# Img to Numpy
images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (images * 255).round().astype("uint8")
return images
if __name__ == "__main__":
import kiui
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("prompt", type=str)
parser.add_argument("--negative", default="", type=str)
parser.add_argument("--fp16", action="store_true", help="use float16 for training")
parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
parser.add_argument("-H", type=int, default=512)
parser.add_argument("-W", type=int, default=512)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--steps", type=int, default=50)
opt = parser.parse_args()
kiui.seed_everything(opt.seed)
device = torch.device("cuda")
sd = IF2(device, opt.fp16, opt.vram_O)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()
================================================
FILE: threefiner/guidance/if2_utils.py
================================================
from diffusers import (
PNDMScheduler,
DDIMScheduler,
IFPipeline,
IFSuperResolutionPipeline,
)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class IF2(nn.Module):
def __init__(
self,
device,
fp16=True,
vram_O=False,
# model_key = "DeepFloyd/IF-II-L-v1.0",
model_key = "DeepFloyd/IF-II-M-v1.0",
t_range=[0.02, 0.50],
):
super().__init__()
self.device = device
self.model_key = model_key
self.dtype = torch.float16 if fp16 else torch.float32
# Create model
pipe = IFSuperResolutionPipeline.from_pretrained(
model_key, variant="fp16", torch_dtype=torch.float16,
watermarker=None, safety_checker=None, requires_safety_checker=False,
)
if vram_O:
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
# pipe.enable_model_cpu_offload()
pipe.enable_sequential_cpu_offload()
else:
pipe.to(device)
self.unet = pipe.unet
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.scheduler = pipe.scheduler
self.image_noising_scheduler = pipe.image_noising_scheduler
self.pipe = pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
self.embeddings = {}
@torch.no_grad()
def get_text_embeds(self, prompts, negative_prompts):
pos_embeds = self.encode_text(prompts) # [1, 77, 768]
neg_embeds = self.encode_text(negative_prompts)
self.embeddings['pos'] = pos_embeds
self.embeddings['neg'] = neg_embeds
# directional embeddings
for d in ['front', 'side', 'back']:
embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
self.embeddings[d] = embeds
def encode_text(self, prompt):
# prompt: [str]
prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
def train_step(
self,
pred_rgb,
ori_rgb,
step_ratio=None,
guidance_scale=50,
vers=None, hors=None,
):
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
ori_rgb = ori_rgb.to(self.dtype)
images = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) * 2 - 1
with torch.no_grad():
max_t = torch.full((batch_size,), self.max_step, dtype=torch.long, device=self.device)
# images_upscaled = images.clone()
images_upscaled = F.interpolate(ori_rgb, (256, 256), mode="bilinear", align_corners=False).clamp(0, 1) * 2 - 1
noise = torch.randn_like(images_upscaled)
images_upscaled = self.image_noising_scheduler.add_noise(images_upscaled, noise, max_t)
if step_ratio is not None:
# dreamtime-like
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
else:
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
# w(t), sigma_t^2
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1).to(self.dtype)
######### debug
# imagesx = self.produce_imgs(
# images_upscaled=images_upscaled,
# max_t=max_t,
# images=torch.randn_like(images),
# num_inference_steps=50,
# guidance_scale=4.0,
# ) # [1, 3, 64, 64]
# import kiui
# kiui.vis.plot_image(images_upscaled * 0.5 + 0.5)
# kiui.vis.plot_image(imagesx * 0.5 + 0.5)
#########
# add noise
noise = torch.randn_like(images)
images_noisy = self.scheduler.add_noise(images, noise, t)
# pred noise
model_input = torch.cat([images_noisy, images_upscaled], dim=1)
model_input = torch.cat([model_input] * 2)
model_input = self.scheduler.scale_model_input(model_input, t)
tt = torch.cat([t] * 2)
max_tt = torch.cat([max_t] * 2)
if hors is None:
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
else:
def _get_dir_ind(h):
if abs(h) < 60: return 'front'
elif abs(h) < 120: return 'side'
else: return 'back'
embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])
noise_pred = self.unet(
model_input, tt, encoder_hidden_states=embeddings, class_labels=max_tt,
).sample
# perform guidance (high scale from paper!)
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
# grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
# grad = grad_norm.clamp(max=0.1) * grad / grad_norm
target = (images - grad).detach()
loss = 0.5 * F.mse_loss(images, target, reduction='sum') / images.shape[0]
return loss
@torch.no_grad()
def produce_imgs(
self,
images_upscaled,
max_t,
height=256,
width=256,
num_inference_steps=50,
guidance_scale=4.0,
images=None,
):
if images is None:
images = torch.randn(
(
1,
self.unet.in_channels,
height,
width,
),
device=self.device,
)
batch_size = images.shape[0]
self.scheduler.set_timesteps(num_inference_steps)
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps):
# expand the images if we are doing classifier-free guidance to avoid doing two forward passes.
model_input = torch.cat([images, images_upscaled], dim=1)
model_input = torch.cat([model_input] * 2)
model_input = self.scheduler.scale_model_input(model_input, t)
max_tt = torch.cat([max_t] * 2)
# predict the noise residual
noise_pred = self.unet(
model_input, t, encoder_hidden_states=embeddings, class_labels=max_tt,
).sample
# perform guidance
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1] // 2, dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
# compute the previous noisy sample x_t -> x_t-1
images = self.scheduler.step(noise_pred, t, images).prev_sample
return images
def prompt_to_img(
self,
images_upscaled,
max_t,
prompts,
negative_prompts="",
height=256,
width=256,
num_inference_steps=50,
guidance_scale=4.0,
images=None,
):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
self.get_text_embeds(prompts, negative_prompts)
# Text embeds -> img images
images = self.produce_imgs(
images_upscaled=images_upscaled,
max_t=max_t,
height=height,
width=width,
images=images,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
) # [1, 4, 64, 64]
# Img to Numpy
images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (images * 255).round().astype("uint8")
return images
if __name__ == "__main__":
import kiui
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("prompt", type=str)
parser.add_argument("--negative", default="", type=str)
parser.add_argument("--fp16", action="store_true", help="use float16 for training")
parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
parser.add_argument("-H", type=int, default=512)
parser.add_argument("-W", type=int, default=512)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--steps", type=int, default=50)
opt = parser.parse_args()
kiui.seed_everything(opt.seed)
device = torch.device("cuda")
sd = IF2(device, opt.fp16, opt.vram_O)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()
================================================
FILE: threefiner/guidance/if_utils.py
================================================
from diffusers import (
PNDMScheduler,
DDIMScheduler,
IFPipeline,
)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class IF(nn.Module):
def __init__(
self,
device,
fp16=True,
vram_O=False,
model_key = "DeepFloyd/IF-I-XL-v1.0",
# model_key = "DeepFloyd/IF-I-M-v1.0",
t_range=[0.02, 0.98],
):
super().__init__()
self.device = device
self.model_key = model_key
self.dtype = torch.float16 if fp16 else torch.float32
# Create model
pipe = IFPipeline.from_pretrained(
model_key, variant="fp16", torch_dtype=torch.float16,
watermarker=None, safety_checker=None, requires_safety_checker=False,
)
if vram_O:
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
# pipe.enable_model_cpu_offload()
pipe.enable_sequential_cpu_offload()
else:
pipe.to(device)
self.unet = pipe.unet
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.scheduler = pipe.scheduler
self.pipe = pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
self.embeddings = {}
@torch.no_grad()
def get_text_embeds(self, prompts, negative_prompts):
pos_embeds = self.encode_text(prompts) # [1, 77, 768]
neg_embeds = self.encode_text(negative_prompts)
self.embeddings['pos'] = pos_embeds
self.embeddings['neg'] = neg_embeds
# directional embeddings
for d in ['front', 'side', 'back']:
embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
self.embeddings[d] = embeds
def encode_text(self, prompt):
# prompt: [str]
prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
@torch.no_grad()
def refine(self, pred_rgb,
guidance_scale=100, steps=50, strength=0.8,
):
batch_size = pred_rgb.shape[0]
images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False)
self.scheduler.set_timesteps(steps)
init_step = int(steps * strength)
images = self.scheduler.add_noise(images, torch.randn_like(images), self.scheduler.timesteps[init_step])
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps[init_step:]):
model_input = torch.cat([images] * 2)
noise_pred = self.unet(
model_input, t, encoder_hidden_states=embeddings,
).sample
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1], dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
images = self.scheduler.step(noise_pred, t, images).prev_sample
return images
def train_step(
self,
pred_rgb,
step_ratio=None,
guidance_scale=50,
vers=None, hors=None,
):
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
images = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1
with torch.no_grad():
if step_ratio is not None:
# dreamtime-like
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
else:
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
# w(t), sigma_t^2
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)
######### debug
# # Text embeds -> img latents
# imagesx = self.produce_imgs(
# height=64,
# width=64,
# images=torch.randn_like(images),
# num_inference_steps=50,
# guidance_scale=7.5,
# ) # [1, 3, 64, 64]
# import kiui
# kiui.vis.plot_image(imagesx)
#########
# add noise
noise = torch.randn_like(images)
images_noisy = self.scheduler.add_noise(images, noise, t)
# pred noise
model_input = torch.cat([images_noisy] * 2)
model_input = self.scheduler.scale_model_input(model_input, t)
tt = torch.cat([t] * 2)
if hors is None:
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
else:
def _get_dir_ind(h):
if abs(h) < 60: return 'front'
elif abs(h) < 120: return 'side'
else: return 'back'
embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])
noise_pred = self.unet(
model_input, tt, encoder_hidden_states=embeddings
).sample
# perform guidance (high scale from paper!)
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1], dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
# grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
# grad = grad_norm.clamp(max=0.1) * grad / grad_norm
target = (images - grad).detach()
loss = 0.5 * F.mse_loss(images.float(), target, reduction='sum') / images.shape[0]
return loss
@torch.no_grad()
def produce_imgs(
self,
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
images=None,
):
if images is None:
images = torch.randn(
(
1,
self.unet.in_channels,
height,
width,
),
device=self.device,
)
batch_size = images.shape[0]
self.scheduler.set_timesteps(num_inference_steps)
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps):
# expand the images if we are doing classifier-free guidance to avoid doing two forward passes.
model_input = torch.cat([images] * 2)
model_input = self.scheduler.scale_model_input(model_input, t)
# predict the noise residual
noise_pred = self.unet(
model_input, t, encoder_hidden_states=embeddings
).sample
# perform guidance
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_cond, predicted_variance = noise_pred_cond.split(model_input.shape[1], dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
# compute the previous noisy sample x_t -> x_t-1
images = self.scheduler.step(noise_pred, t, images).prev_sample
return images
def prompt_to_img(
self,
prompts,
negative_prompts="",
height=64,
width=64,
num_inference_steps=50,
guidance_scale=7.5,
images=None,
):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
self.get_text_embeds(prompts, negative_prompts)
# Text embeds -> img images
images = self.produce_imgs(
height=height,
width=width,
images=images,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
) # [1, 4, 64, 64]
# Img to Numpy
images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (images * 255).round().astype("uint8")
return images
if __name__ == "__main__":
import kiui
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("prompt", type=str)
parser.add_argument("--negative", default="", type=str)
parser.add_argument("--fp16", action="store_true", help="use float16 for training")
parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
parser.add_argument("-H", type=int, default=512)
parser.add_argument("-W", type=int, default=512)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--steps", type=int, default=50)
opt = parser.parse_args()
kiui.seed_everything(opt.seed)
device = torch.device("cuda")
sd = IF(device, opt.fp16, opt.vram_O)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()
================================================
FILE: threefiner/guidance/sd_ism_utils.py
================================================
from diffusers import (
PNDMScheduler,
DDIMScheduler,
StableDiffusionPipeline,
)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def invert_noise(scheduler, noisy_samples, noise, timesteps):
alphas_cumprod = scheduler.alphas_cumprod.to(device=noisy_samples.device, dtype=noisy_samples.dtype)
timesteps = timesteps.to(noisy_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(noisy_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(noisy_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
original_samples = 1 / sqrt_alpha_prod * (noisy_samples - sqrt_one_minus_alpha_prod * noise)
return original_samples
class StableDiffusion(nn.Module):
def __init__(
self,
device,
fp16=True,
vram_O=False,
model_key="stabilityai/stable-diffusion-2-1-base",
# model_key="philz1337/revanimated",
t_range=[0.02, 0.98],
):
super().__init__()
self.device = device
self.model_key = model_key
self.dtype = torch.float16 if fp16 else torch.float32
# Create model
pipe = StableDiffusionPipeline.from_pretrained(
model_key, torch_dtype=self.dtype
)
if vram_O:
pipe.enable_sequential_cpu_offload()
pipe.enable_vae_slicing()
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
# pipe.enable_model_cpu_offload()
else:
pipe.to(device)
self.vae = pipe.vae
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.unet = pipe.unet
self.scheduler = DDIMScheduler.from_pretrained(
model_key, subfolder="scheduler", torch_dtype=self.dtype
)
del pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
self.embeddings = {}
@torch.no_grad()
def get_text_embeds(self, prompts, negative_prompts):
pos_embeds = self.encode_text(prompts) # [1, 77, 768]
neg_embeds = self.encode_text(negative_prompts)
null_embeds = self.encode_text([""])
self.embeddings['pos'] = pos_embeds
self.embeddings['neg'] = neg_embeds
self.embeddings['null'] = null_embeds
# directional embeddings
for d in ['front', 'side', 'back']:
embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
self.embeddings[d] = embeds
def encode_text(self, prompt):
# prompt: [str]
inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
@torch.no_grad()
def refine(self, pred_rgb,
guidance_scale=100, steps=50, strength=0.8,
):
batch_size = pred_rgb.shape[0]
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
latents = self.encode_imgs(pred_rgb_512.to(self.dtype))
# latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)
self.scheduler.set_timesteps(steps)
init_step = int(steps * strength)
latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps[init_step:]):
latent_model_input = torch.cat([latents] * 2)
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=embeddings,
).sample
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
return imgs
def train_step(
self,
pred_rgb,
step_ratio=None,
guidance_scale=7.5,
as_latent=False,
vers=None, hors=None,
delta_t=50, delta_s=200,
):
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
if as_latent:
latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1
else:
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
# encode image into latents with vae, requires grad!
latents = self.encode_imgs(pred_rgb_512)
with torch.no_grad():
if step_ratio is not None:
# dreamtime-like
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
else:
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
# w(t), sigma_t^2
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)
######### debug
# # Text embeds -> img latents
# latentsx = self.produce_latents(
# height=512,
# width=512,
# latents=torch.randn_like(latents),
# num_inference_steps=50,
# guidance_scale=7.5,
# ) # [1, 4, 64, 64]
# # Img latents -> imgs
# imgs = self.decode_latents(latentsx) # [1, 3, 512, 512]
# import kiui
# kiui.vis.plot_image(imgs)
#########
null_embeddings = self.embeddings['null'].expand(batch_size, -1, -1)
if hors is None:
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
else:
def _get_dir_ind(h):
if abs(h) < 60: return 'front'
elif abs(h) < 120: return 'side'
else: return 'back'
embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])
########### ISM
# steps
t = t.clamp(min=delta_t)
s = t - delta_t
n = s // delta_s
r = s % delta_s
# construct trajectory
latents_noisy = latents.clone()
cur_t = torch.full((batch_size,), 0, dtype=torch.long, device=self.device)
noise = self.unet(latents_noisy, cur_t, encoder_hidden_states=null_embeddings).sample
latents_original = invert_noise(self.scheduler, latents_noisy, noise, cur_t)
cur_t += r
latents_noisy = self.scheduler.add_noise(latents_original, noise, cur_t)
for i in range(n):
noise = self.unet(latents_noisy, cur_t, encoder_hidden_states=null_embeddings).sample
latents_original = invert_noise(self.scheduler, latents_noisy, noise, cur_t)
cur_t += delta_s
latents_noisy = self.scheduler.add_noise(latents_original, noise, cur_t) # x_s
# construct last step
noise = self.unet(latents_noisy, cur_t, encoder_hidden_states=null_embeddings).sample # \epsilon_s
latents_original = invert_noise(self.scheduler, latents_noisy, noise, cur_t) # \hat x_0^s
# perform guidance
latents_noisy = self.scheduler.add_noise(latents_original, noise, t) # x_t
latent_model_input = torch.cat([latents_noisy] * 2)
tt = torch.cat([t] * 2)
noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=embeddings).sample
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
# grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
# grad = grad_norm.clamp(max=0.1) * grad / grad_norm
target = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]
return loss
@torch.no_grad()
def produce_latents(
self,
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
latents=None,
):
if latents is None:
latents = torch.randn(
(
1,
self.unet.in_channels,
height // 8,
width // 8,
),
device=self.device,
)
batch_size = latents.shape[0]
self.scheduler.set_timesteps(num_inference_steps)
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=embeddings
).sample
# perform guidance
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
return latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def encode_imgs(self, imgs):
# imgs: [B, 3, H, W]
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents
def prompt_to_img(
self,
prompts,
negative_prompts="",
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
latents=None,
):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
self.get_text_embeds(prompts, negative_prompts)
# Text embeds -> img latents
latents = self.produce_latents(
height=height,
width=width,
latents=latents,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
) # [1, 4, 64, 64]
# Img latents -> imgs
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
# Img to Numpy
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
imgs = (imgs * 255).round().astype("uint8")
return imgs
if __name__ == "__main__":
import kiui
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("prompt", type=str)
parser.add_argument("--negative", default="", type=str)
parser.add_argument("--fp16", action="store_true", help="use float16 for training")
parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
parser.add_argument("-H", type=int, default=512)
parser.add_argument("-W", type=int, default=512)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--steps", type=int, default=50)
opt = parser.parse_args()
kiui.seed_everything(opt.seed)
device = torch.device("cuda")
sd = StableDiffusion(device, opt.fp16, opt.vram_O)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()
================================================
FILE: threefiner/guidance/sd_nfsd_utils.py
================================================
from diffusers import (
PNDMScheduler,
DDIMScheduler,
StableDiffusionPipeline,
)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class StableDiffusion(nn.Module):
def __init__(
self,
device,
fp16=True,
vram_O=False,
model_key="stabilityai/stable-diffusion-2-1-base",
# model_key="philz1337/revanimated",
t_range=[0.02, 0.50],
):
super().__init__()
self.device = device
self.model_key = model_key
self.dtype = torch.float16 if fp16 else torch.float32
# Create model
pipe = StableDiffusionPipeline.from_pretrained(
model_key, torch_dtype=self.dtype
)
if vram_O:
pipe.enable_sequential_cpu_offload()
pipe.enable_vae_slicing()
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
# pipe.enable_model_cpu_offload()
else:
pipe.to(device)
self.vae = pipe.vae
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.unet = pipe.unet
self.scheduler = DDIMScheduler.from_pretrained(
model_key, subfolder="scheduler", torch_dtype=self.dtype
)
del pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
self.embeddings = {}
@torch.no_grad()
def get_text_embeds(self, prompts, negative_prompts):
pos_embeds = self.encode_text(prompts) # [1, 77, 768]
neg_embeds = self.encode_text(negative_prompts)
null_embeds = self.encode_text([''])
self.embeddings['pos'] = pos_embeds
self.embeddings['neg'] = neg_embeds
self.embeddings['null'] = null_embeds
# directional embeddings
for d in ['front', 'side', 'back']:
embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
self.embeddings[d] = embeds
def encode_text(self, prompt):
# prompt: [str]
inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
@torch.no_grad()
def refine(self, pred_rgb,
guidance_scale=100, steps=50, strength=0.8,
):
batch_size = pred_rgb.shape[0]
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
latents = self.encode_imgs(pred_rgb_512.to(self.dtype))
# latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)
self.scheduler.set_timesteps(steps)
init_step = int(steps * strength)
latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps[init_step:]):
latent_model_input = torch.cat([latents] * 2)
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=embeddings,
).sample
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
return imgs
def train_step(
self,
pred_rgb,
step_ratio=None,
guidance_scale=7.5,
as_latent=False,
vers=None, hors=None,
):
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
if as_latent:
latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1
else:
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
# encode image into latents with vae, requires grad!
latents = self.encode_imgs(pred_rgb_512)
with torch.no_grad():
if step_ratio is not None:
# dreamtime-like
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
else:
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
# w(t), sigma_t^2
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)
######### debug
# # Text embeds -> img latents
# latentsx = self.produce_latents(
# height=512,
# width=512,
# latents=torch.randn_like(latents),
# num_inference_steps=50,
# guidance_scale=7.5,
# ) # [1, 4, 64, 64]
# # Img latents -> imgs
# imgs = self.decode_latents(latentsx) # [1, 3, 512, 512]
# import kiui
# kiui.vis.plot_image(imgs)
#########
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * 3)
tt = torch.cat([t] * 3)
if hors is None:
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)])
else:
def _get_dir_ind(h):
if abs(h) < 60: return 'front'
elif abs(h) < 120: return 'side'
else: return 'back'
embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1), self.embeddings['null'].expand(batch_size, -1, -1)])
noise_pred = self.unet(
latent_model_input, tt, encoder_hidden_states=embeddings
).sample
# perform guidance (high scale from paper!)
noise_pred_cond, noise_pred_uncond, noise_pred_null = noise_pred.chunk(3)
delta_c = guidance_scale * (noise_pred_cond - noise_pred_null)
mask = (t < 200).int().view(batch_size, 1, 1, 1)
delta_d = mask * noise_pred_null + (1 - mask) * (noise_pred_null - noise_pred_uncond) * 3
grad = w * (delta_c + delta_d)
grad = torch.nan_to_num(grad)
# grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
# grad = grad_norm.clamp(max=0.1) * grad / grad_norm
target = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]
return loss
@torch.no_grad()
def produce_latents(
self,
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
latents=None,
):
if latents is None:
latents = torch.randn(
(
1,
self.unet.in_channels,
height // 8,
width // 8,
),
device=self.device,
)
batch_size = latents.shape[0]
self.scheduler.set_timesteps(num_inference_steps)
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=embeddings
).sample
# perform guidance
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
return latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def encode_imgs(self, imgs):
# imgs: [B, 3, H, W]
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents
def prompt_to_img(
self,
prompts,
negative_prompts="",
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
latents=None,
):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
self.get_text_embeds(prompts, negative_prompts)
# Text embeds -> img latents
latents = self.produce_latents(
height=height,
width=width,
latents=latents,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
) # [1, 4, 64, 64]
# Img latents -> imgs
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
# Img to Numpy
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
imgs = (imgs * 255).round().astype("uint8")
return imgs
if __name__ == "__main__":
import kiui
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("prompt", type=str)
parser.add_argument("--negative", default="", type=str)
parser.add_argument("--fp16", action="store_true", help="use float16 for training")
parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
parser.add_argument("-H", type=int, default=512)
parser.add_argument("-W", type=int, default=512)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--steps", type=int, default=50)
opt = parser.parse_args()
kiui.seed_everything(opt.seed)
device = torch.device("cuda")
sd = StableDiffusion(device, opt.fp16, opt.vram_O)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()
================================================
FILE: threefiner/guidance/sd_utils.py
================================================
from diffusers import (
PNDMScheduler,
DDIMScheduler,
StableDiffusionPipeline,
)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class StableDiffusion(nn.Module):
def __init__(
self,
device,
fp16=True,
vram_O=False,
model_key="stabilityai/stable-diffusion-2-1-base",
# model_key="philz1337/revanimated",
t_range=[0.02, 0.50],
):
super().__init__()
self.device = device
self.model_key = model_key
self.dtype = torch.float16 if fp16 else torch.float32
# Create model
pipe = StableDiffusionPipeline.from_pretrained(
model_key, torch_dtype=self.dtype
)
if vram_O:
pipe.enable_sequential_cpu_offload()
pipe.enable_vae_slicing()
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
# pipe.enable_model_cpu_offload()
else:
pipe.to(device)
self.vae = pipe.vae
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.unet = pipe.unet
self.scheduler = DDIMScheduler.from_pretrained(
model_key, subfolder="scheduler", torch_dtype=self.dtype
)
del pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
self.embeddings = {}
@torch.no_grad()
def get_text_embeds(self, prompts, negative_prompts):
pos_embeds = self.encode_text(prompts) # [1, 77, 768]
neg_embeds = self.encode_text(negative_prompts)
self.embeddings['pos'] = pos_embeds
self.embeddings['neg'] = neg_embeds
# directional embeddings
for d in ['front', 'side', 'back']:
embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
self.embeddings[d] = embeds
def encode_text(self, prompt):
# prompt: [str]
inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
@torch.no_grad()
def refine(self, pred_rgb,
guidance_scale=100, steps=50, strength=0.8,
):
batch_size = pred_rgb.shape[0]
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
latents = self.encode_imgs(pred_rgb_512.to(self.dtype))
# latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)
self.scheduler.set_timesteps(steps)
init_step = int(steps * strength)
latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps[init_step:]):
latent_model_input = torch.cat([latents] * 2)
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=embeddings,
).sample
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
return imgs
def train_step(
self,
pred_rgb,
step_ratio=None,
guidance_scale=100,
as_latent=False,
vers=None, hors=None,
):
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
if as_latent:
latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1
else:
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
# encode image into latents with vae, requires grad!
latents = self.encode_imgs(pred_rgb_512)
with torch.no_grad():
if step_ratio is not None:
# dreamtime-like
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
else:
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
# w(t), sigma_t^2
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)
######### debug
# # Text embeds -> img latents
# latentsx = self.produce_latents(
# height=512,
# width=512,
# latents=torch.randn_like(latents),
# num_inference_steps=50,
# guidance_scale=7.5,
# ) # [1, 4, 64, 64]
# # Img latents -> imgs
# imgs = self.decode_latents(latentsx) # [1, 3, 512, 512]
# import kiui
# kiui.vis.plot_image(imgs)
#########
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * 2)
tt = torch.cat([t] * 2)
if hors is None:
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
else:
def _get_dir_ind(h):
if abs(h) < 60: return 'front'
elif abs(h) < 120: return 'side'
else: return 'back'
embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])
noise_pred = self.unet(
latent_model_input, tt, encoder_hidden_states=embeddings
).sample
# perform guidance (high scale from paper!)
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
# grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
# grad = grad_norm.clamp(max=0.1) * grad / grad_norm
target = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]
return loss
@torch.no_grad()
def produce_latents(
self,
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
latents=None,
):
if latents is None:
latents = torch.randn(
(
1,
self.unet.in_channels,
height // 8,
width // 8,
),
device=self.device,
)
batch_size = latents.shape[0]
self.scheduler.set_timesteps(num_inference_steps)
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=embeddings
).sample
# perform guidance
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
return latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def encode_imgs(self, imgs):
# imgs: [B, 3, H, W]
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents
def prompt_to_img(
self,
prompts,
negative_prompts="",
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
latents=None,
):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
self.get_text_embeds(prompts, negative_prompts)
# Text embeds -> img latents
latents = self.produce_latents(
height=height,
width=width,
latents=latents,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
) # [1, 4, 64, 64]
# Img latents -> imgs
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
# Img to Numpy
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
imgs = (imgs * 255).round().astype("uint8")
return imgs
if __name__ == "__main__":
import kiui
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("prompt", type=str)
parser.add_argument("--negative", default="", type=str)
parser.add_argument("--fp16", action="store_true", help="use float16 for training")
parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
parser.add_argument("-H", type=int, default=512)
parser.add_argument("-W", type=int, default=512)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--steps", type=int, default=50)
opt = parser.parse_args()
kiui.seed_everything(opt.seed)
device = torch.device("cuda")
sd = StableDiffusion(device, opt.fp16, opt.vram_O)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()
================================================
FILE: threefiner/guidance/sdcn_utils.py
================================================
from diffusers import (
PNDMScheduler,
DDIMScheduler,
StableDiffusionPipeline,
ControlNetModel,
)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class StableDiffusionControlNet(nn.Module):
def __init__(
self,
device,
fp16=True,
vram_O=False,
control_mode=["tile"],
model_key="runwayml/stable-diffusion-v1-5",
# model_key="philz1337/revanimated",
t_range=[0.02, 0.50],
):
super().__init__()
self.device = device
self.control_mode = control_mode
self.dtype = torch.float16 if fp16 else torch.float32
# Create model
pipe = StableDiffusionPipeline.from_pretrained(
model_key, torch_dtype=self.dtype
)
if vram_O:
pipe.enable_sequential_cpu_offload()
pipe.enable_vae_slicing()
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
# pipe.enable_model_cpu_offload()
else:
pipe.to(device)
self.vae = pipe.vae
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.unet = pipe.unet
# controlnet
if self.control_mode is not None:
self.controlnet = {}
self.controlnet_conditioning_scale = {}
if "normal" in self.control_mode:
self.controlnet['normal'] = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_normalbae",torch_dtype=self.dtype).to(self.device)
self.controlnet_conditioning_scale['normal'] = 1.0
if "depth" in self.control_mode:
self.controlnet['depth'] = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth",torch_dtype=self.dtype).to(self.device)
self.controlnet_conditioning_scale['depth'] = 1.0
if "ip2p" in self.control_mode:
self.controlnet['ip2p'] = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_ip2p",torch_dtype=self.dtype).to(self.device)
self.controlnet_conditioning_scale['ip2p'] = 1.0
if "inpaint" in self.control_mode:
self.controlnet['inpaint'] = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint",torch_dtype=self.dtype).to(self.device)
self.controlnet_conditioning_scale['inpaint'] = 1.0
if "depth_inpaint" in self.control_mode:
self.controlnet['depth_inpaint'] = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_depth_aware_inpaint",torch_dtype=self.dtype).to(self.device)
self.controlnet_conditioning_scale['depth_inpaint'] = 1.0
if "pose" in self.control_mode:
self.controlnet['pose'] = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose",torch_dtype=self.dtype).to(self.device)
self.controlnet_conditioning_scale['pose'] = 1.0
if "tile" in self.control_mode:
self.controlnet['tile'] = ControlNetModel.from_pretrained("lllyasviel/control_v11f1e_sd15_tile",torch_dtype=self.dtype).to(self.device)
self.controlnet_conditioning_scale['tile'] = 1.0
self.scheduler = DDIMScheduler.from_pretrained(
model_key, subfolder="scheduler", torch_dtype=self.dtype
)
del pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
self.embeddings = {}
@torch.no_grad()
def get_text_embeds(self, prompts, negative_prompts):
pos_embeds = self.encode_text(prompts) # [1, 77, 768]
neg_embeds = self.encode_text(negative_prompts)
self.embeddings['pos'] = pos_embeds
self.embeddings['neg'] = neg_embeds
# directional embeddings
for d in ['front', 'side', 'back']:
embeds = self.encode_text([f'{p}, {d} view' for p in prompts])
self.embeddings[d] = embeds
def encode_text(self, prompt):
# prompt: [str]
inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
@torch.no_grad()
def refine(self, pred_rgb,
guidance_scale=100, steps=50, strength=0.8,
control_images=None
):
batch_size = pred_rgb.shape[0]
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
latents = self.encode_imgs(pred_rgb_512.to(self.dtype))
# latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype)
self.scheduler.set_timesteps(steps)
init_step = int(steps * strength)
latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps[init_step:]):
latent_model_input = torch.cat([latents] * 2)
if self.control_mode is not None and control_images is not None:
noise_pred = 0
for mode, controlnet in self.controlnet.items():
# may omit control mode if input is not provided
if mode not in control_images: continue
control_image = control_images[mode].to(self.dtype)
weight = 1 / len(self.controlnet)
control_image_input = torch.cat([control_image] * 2)
down_samples, mid_sample = controlnet(
latent_model_input, t, encoder_hidden_states=embeddings,
controlnet_cond=control_image_input,
conditioning_scale=self.controlnet_conditioning_scale[mode],
return_dict=False
)
# predict the noise residual
noise_pred_cur = self.unet(
latent_model_input, t, encoder_hidden_states=embeddings,
down_block_additional_residuals=down_samples,
mid_block_additional_residual=mid_sample
).sample
# merge after unet
noise_pred = noise_pred + weight * noise_pred_cur
else:
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=self.embeddings,
).sample
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
return imgs
def train_step(
self,
pred_rgb,
step_ratio=None,
guidance_scale=100,
as_latent=False,
control_images=None,
vers=None, hors=None,
):
batch_size = pred_rgb.shape[0]
pred_rgb = pred_rgb.to(self.dtype)
if as_latent:
latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1
else:
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False)
# encode image into latents with vae, requires grad!
latents = self.encode_imgs(pred_rgb_512)
with torch.no_grad():
if step_ratio is not None:
# dreamtime-like
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
else:
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
# w(t), sigma_t^2
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)
######### debug
# # Text embeds -> img latents
# latentsx = self.produce_latents(
# height=512,
# width=512,
# latents=torch.randn_like(latents),
# num_inference_steps=50,
# guidance_scale=7.5,
# control_images=control_images,
# ) # [1, 4, 64, 64]
# # Img latents -> imgs
# imgs = self.decode_latents(latentsx) # [1, 3, 512, 512]
# import kiui
# kiui.vis.plot_image(control_images['tile'], imgs)
#########
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * 2)
tt = torch.cat([t] * 2)
if hors is None:
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
else:
def _get_dir_ind(h):
if abs(h) < 60: return 'front'
elif abs(h) < 120: return 'side'
else: return 'back'
embeddings = torch.cat([self.embeddings[_get_dir_ind(h)] for h in hors] + [self.embeddings['neg'].expand(batch_size, -1, -1)])
if self.control_mode is not None and control_images is not None:
noise_pred = 0
for mode, controlnet in self.controlnet.items():
# may omit control mode if input is not provided
if mode not in control_images: continue
control_image = control_images[mode].to(self.dtype)
weight = 1 / len(self.controlnet)
control_image_input = torch.cat([control_image] * 2)
down_samples, mid_sample = controlnet(
latent_model_input, tt, encoder_hidden_states=embeddings,
controlnet_cond=control_image_input,
conditioning_scale=self.controlnet_conditioning_scale[mode],
return_dict=False
)
# predict the noise residual
noise_pred_cur = self.unet(
latent_model_input, tt, encoder_hidden_states=embeddings,
down_block_additional_residuals=down_samples,
mid_block_additional_residual=mid_sample
).sample
# merge after unet
noise_pred = noise_pred + weight * noise_pred_cur
else:
noise_pred = self.unet(
latent_model_input, tt, encoder_hidden_states=embeddings,
).sample
# perform guidance (high scale from paper!)
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
# grad_norm = torch.norm(grad, dim=-1, keepdim=True) + 1e-8
# grad = grad_norm.clamp(max=0.1) * grad / grad_norm
target = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0]
return loss
@torch.no_grad()
def produce_latents(
self,
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
latents=None,
control_images=None,
):
if latents is None:
latents = torch.randn(
(
1,
self.unet.in_channels,
height // 8,
width // 8,
),
device=self.device,
)
batch_size = latents.shape[0]
self.scheduler.set_timesteps(num_inference_steps)
embeddings = torch.cat([self.embeddings['pos'].expand(batch_size, -1, -1), self.embeddings['neg'].expand(batch_size, -1, -1)])
for i, t in enumerate(self.scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
# predict the noise residual
if self.control_mode is not None and control_images is not None:
noise_pred = 0
for mode, controlnet in self.controlnet.items():
# may omit control mode if input is not provided
if mode not in control_images: continue
control_image = control_images[mode].to(self.dtype)
weight = 1 / len(self.controlnet)
control_image_input = torch.cat([control_image] * 2)
down_samples, mid_sample = controlnet(
latent_model_input, t, encoder_hidden_states=embeddings,
controlnet_cond=control_image_input,
conditioning_scale=self.controlnet_conditioning_scale[mode],
return_dict=False
)
# predict the noise residual
noise_pred_cur = self.unet(
latent_model_input, t, encoder_hidden_states=embeddings,
down_block_additional_residuals=down_samples,
mid_block_additional_residual=mid_sample
).sample
# merge after unet
noise_pred = noise_pred + weight * noise_pred_cur
else:
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=embeddings,
).sample
# perform guidance
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
return latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def encode_imgs(self, imgs):
# imgs: [B, 3, H, W]
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents
def prompt_to_img(
self,
prompts,
negative_prompts="",
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
latents=None,
control_images=None,
):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
self.get_text_embeds(prompts, negative_prompts)
# Text embeds -> img latents
latents = self.produce_latents(
height=height,
width=width,
latents=latents,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
control_images=control_images,
) # [1, 4, 64, 64]
# Img latents -> imgs
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
# Img to Numpy
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
imgs = (imgs * 255).round().astype("uint8")
return imgs
if __name__ == "__main__":
import kiui
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("image", type=str)
parser.add_argument("prompt", default="", type=str)
parser.add_argument("--control", default='tile', type=str)
parser.add_argument("--negative", default="", type=str)
parser.add_argument("--fp16", action="store_true", help="use float16 for training")
parser.add_argument("--vram_O", action="store_true", help="optimization for low VRAM usage")
parser.add_argument("-H", type=int, default=512)
parser.add_argument("-W", type=int, default=512)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--steps", type=int, default=50)
opt = parser.parse_args()
kiui.seed_everything(opt.seed)
device = torch.device("cuda")
# load image
control_image = kiui.read_image(opt.image, mode='tensor').permute(2,0,1).contiguous().unsqueeze(0).to(device)
control_image = F.interpolate(control_image, (opt.H, opt.W), mode='bilinear', align_corners=False)
kiui.lo(control_image)
control_images = {}
control_images[opt.control] = control_image
sd = StableDiffusionControlNet(device, opt.fp16, opt.vram_O, control_mode=[opt.control])
while True:
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps, control_images=control_images)
# visualize image
plt.imshow(imgs[0])
plt.show()
================================================
FILE: threefiner/lights/LICENSE.txt
================================================
The mud_road_puresky.hdr HDR probe is from https://polyhaven.com/a/mud_road_puresky
CC0 License.
================================================
FILE: threefiner/nn.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tinycudann as tcnn
class HashGridEncoder(nn.Module):
def __init__(self,
input_dim=3,
num_levels=16,
level_dim=2,
log2_hashmap_size=18,
base_resolution=16,
desired_resolution=1024,
interpolation='linear'
):
super().__init__()
self.encoder = tcnn.Encoding(
n_input_dims=input_dim,
encoding_config={
"otype": "HashGrid",
"n_levels": num_levels,
"n_features_per_level": level_dim,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": base_resolution,
"per_level_scale": np.exp2(np.log2(desired_resolution / num_levels) / (num_levels - 1)),
"interpolation": "Smoothstep" if interpolation == 'smoothstep' else "Linear",
},
dtype=torch.float32,
)
self.input_dim = input_dim
self.output_dim = self.encoder.n_output_dims # patch
def forward(self, x, bound=1):
return self.encoder((x + bound) / (2 * bound))
class FrequencyEncoder(nn.Module):
def __init__(self,
input_dim=3,
output_dim=32,
n_frequencies=12,
):
super().__init__()
self.encoder = tcnn.Encoding(
n_input_dims=input_dim,
encoding_config={
"otype": "Frequency",
"n_frequencies": n_frequencies,
},
dtype=torch.float32,
)
self.implicit_mlp = MLP(self.encoder.n_output_dims, output_dim, 128, 5, bias=True)
self.input_dim = input_dim
self.output_dim = output_dim
def forward(self, x, **kwargs):
return self.implicit_mlp(self.encoder(x))
class TriplaneEncoder(nn.Module):
def __init__(self,
input_dim=3,
output_dim=32,
resolution=256,
):
super().__init__()
self.C_mat = nn.Parameter(torch.randn(3, output_dim, resolution, resolution))
torch.nn.init.kaiming_normal_(self.C_mat)
self.mat_ids = [[0, 1], [0, 2], [1, 2]]
self.input_dim = input_dim
self.output_dim = output_dim
def forward(self, x, bound=1):
N = x.shape[0]
x = x / bound # to [-1, 1]
mat_coord = torch.stack((x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]])).view(3, -1, 1, 2) # [3, N, 1, 2]
feat = F.grid_sample(self.C_mat[[0]], mat_coord[[0]], align_corners=False).view(-1, N) + \
F.grid_sample(self.C_mat[[1]], mat_coord[[1]], align_corners=False).view(-1, N) + \
F.grid_sample(self.C_mat[[2]], mat_coord[[2]], align_corners=False).view(-1, N) # [r, N]
# density
feat = feat.transpose(0, 1).contiguous() # [N, C]
return feat
class MLP(nn.Module):
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dim_hidden = dim_hidden
self.num_layers = num_layers
net = []
for l in range(num_layers):
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
self.net = nn.ModuleList(net)
def forward(self, x):
for l in range(self.num_layers):
x = self.net[l](x)
if l != self.num_layers - 1:
x = F.relu(x, inplace=True)
return x
================================================
FILE: threefiner/opt.py
================================================
import os
from dataclasses import dataclass
from typing import Tuple, Literal, Dict, Optional
@dataclass
class Options:
# path to input mesh
mesh: Optional[str] = None
# input text prompt
prompt: Optional[str] = None
# additional positive prompt
positive_prompt: str = "best quality, extremely detailed, masterpiece, high resolution, high quality"
# additional negative prompt
negative_prompt: str = "blur, lowres, cropped, low quality, worst quality, ugly, dark, shadow, oversaturated"
# whether to append directional text prompt
text_dir: bool = False
# set mesh front-facing direction (camera front=+z, right=+x, up=+y, clock-wise rotation 90=1, 180=2, 270=3, e.g., +z, -y1)
front_dir: str = "+z"
# training iterations
iters: int = 500
# training resolution
render_resolution: int = 512
# training camera radius
radius: float = 2.5
# training camera fovy in degree
fovy: float = 49.1
# whether to allow geom training
fix_geo: bool = False
# whether to mix normal with rgb for geometry training
mix_normal: bool = True
# whether to pretrain texture first
fit_tex: bool = True
# pretrain texture iterations
fit_tex_iters: int = 512
# output folder
outdir: str = '.'
# output filename, default to {name}_fine.{ext}
save: Optional[str] = None
# guidance mode
mode: Literal['SD', 'IF', 'IF2', 'SDCN', 'SD_NFSD', 'IF2_NFSD', 'SD_ISM', 'IF2_ISM'] = 'IF2'
# renderer geometry mode
geom_mode: Literal['mesh', 'diffmc', 'pbr_mesh', 'pbr_diffmc'] = 'diffmc'
# renderer texture mode
tex_mode: Literal['hashgrid', 'mlp', 'triplane'] = 'hashgrid'
# training batch size per iter
batch_size: int = 1
# environmental texture
env_texture: Optional[str] = None
# environmental light scale
env_scale: float = 2
# DiffMC grid size
mc_grid_size: int = 128
# Mesh remeshing interval
remesh_interval: int = 200
# mesh decimation target face number
decimate_target: int = 5e4
# remesh target edge length (smaller value lead to finer mesh)
remesh_size: float = 0.015
# texture resolution
texture_resolution: int = 1024
# learning rate for hashgrid
hashgrid_lr: float = 0.01
# learning rate for feature MLP
mlp_lr: float = 0.001
# learning rate for SDF
sdf_lr: float = 0.0001
# learning rate for deformation
deform_lr: float = 0.0001
# learning rate for mesh geometry
geom_lr: float = 0.0001
# guidance loss weights
lambda_sd: float = 1
# mesh laplacian regularization weight
lambda_lap: float = 0
# mesh normal consistency weight (should be large enough)
lambda_normal: float = 10000
# mesh vertices offset penalty weight
lambda_offsets: float = 100
# whether to open a GUI
gui: bool = False
# GUI height
H: int = 800
# GUI width
W: int = 800
# whether to use CUDA rasterizer (in case OpenGL fails)
force_cuda_rast: bool = False
# whether to use GPU memory-optimized mode (slower, but uses less GPU memory)
vram_O: bool = False
# all the default settings
config_defaults: Dict[str, Options] = {}
config_doc: Dict[str, str] = {}
config_doc['sd'] = 'coarse-level generation with stable-diffusion 2.'
config_defaults['sd'] = Options(
mode='SD',
iters=800,
)
config_doc['if'] = 'coarse-level generation with deepfloyd-if I.'
config_defaults['if'] = Options(
mode='IF',
iters=400,
)
config_doc['if2'] = 'fine-level refinement with deepfloyd-if II.'
config_defaults['if2'] = Options(
mode='IF2',
iters=400,
)
config_doc['sd_fixgeo'] = 'coarse-level generation with stable-diffusion 2, fixed goemetry.'
config_defaults['sd_fixgeo'] = Options(
mode='SD',
iters=800,
fix_geo=True,
geom_mode='mesh',
)
config_doc['if_fixgeo'] = 'coarse-level generation with deepfloyd-if I, fixed goemetry.'
config_defaults['if_fixgeo'] = Options(
mode='IF',
iters=400,
fix_geo=True,
geom_mode='mesh',
)
config_doc['if2_fixgeo'] = 'fine-level refinement with deepfloyd-if II, fixed goemetry.'
config_defaults['if2_fixgeo'] = Options(
mode='IF2',
iters=400,
fix_geo=True,
geom_mode='mesh',
)
def check_options(opt: Options):
assert opt.mesh is not None, 'mesh path must be specified!'
assert opt.prompt is not None, 'prompt must be specified!'
if opt.save is None:
input_name, input_ext = os.path.splitext(os.path.basename(opt.mesh))
opt.save = input_name + '_fine' + '.glb'
print(f'[INFO] save to default output path: {os.path.join(opt.outdir, opt.save)}.')
return opt
================================================
FILE: threefiner/renderer/__init__.py
================================================
================================================
FILE: threefiner/renderer/diffmc_renderer.py
================================================
import os
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import nvdiffrast.torch as dr
import kiui
from kiui.mesh import Mesh
from kiui.mesh_utils import clean_mesh, decimate_mesh
from kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding
from kiui.cam import orbit_camera, get_perspective
from threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder
from threefiner.renderer.mesh_renderer import render_mesh
from diso import DiffMC, DiffDMC
class Renderer(nn.Module):
def __init__(self, opt, device):
super().__init__()
self.opt = opt
self.device = device
if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
self.glctx = dr.RasterizeGLContext()
else:
self.glctx = dr.RasterizeCudaContext()
# diffmc
self.verts = torch.stack(
torch.meshgrid(
torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),
torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),
torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),
indexing="ij",
), dim=-1,
) # [N, N, N, 3]
self.grid_scale = 1
self.diffmc = DiffMC(dtype=torch.float32).to(device)
# vert sdf and deform
self.sdf = nn.Parameter(torch.zeros_like(self.verts[..., 0]))
self.deform = nn.Parameter(torch.zeros_like(self.verts))
# init diffmc from mesh
self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir)
vertices = self.mesh.v.detach().cpu().numpy()
triangles = self.mesh.f.detach().cpu().numpy()
vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False)
self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
self.grid_scale = self.mesh.v.abs().max() + 1e-1
self.verts = self.verts * self.grid_scale
try:
import cubvh
BVH = cubvh.cuBVH(self.mesh.v, self.mesh.f)
sdf, _, _ = BVH.signed_distance(self.verts.reshape(-1, 3), return_uvw=False, mode='raystab') # some mesh may not be watertight...
except:
from pysdf import SDF
sdf_func = SDF(self.mesh.v.detach().cpu().numpy(), self.mesh.f.detach().cpu().numpy())
sdf = sdf_func(self.verts.detach().cpu().numpy().reshape(-1, 3))
sdf = torch.from_numpy(sdf).to(self.device)
sdf *= -1
# OUTER is POSITIVE
self.sdf.data += sdf.reshape(*self.sdf.data.shape).to(self.sdf.data.dtype)
# texture
if self.opt.tex_mode == 'hashgrid':
self.encoder = HashGridEncoder().to(self.device)
elif self.opt.tex_mode == 'mlp':
self.encoder = FrequencyEncoder().to(self.device)
elif self.opt.tex_mode == 'triplane':
self.encoder = TriplaneEncoder().to(self.device)
else:
raise NotImplementedError(f"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}")
self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=True).to(self.device)
self.v, self.f = None, None # placeholder
# init hashgrid texture from mesh
if self.opt.fit_tex:
self.fit_texture_from_mesh(self.opt.fit_tex_iters)
def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):
return render_mesh(
self.glctx,
self.mesh.v, self.mesh.f, self.mesh.vt,
self.mesh.ft, self.mesh.albedo,
self.mesh.vc, self.mesh.vn, self.mesh.fn,
pose, proj, h, w,
ssaa=ssaa, bg_color=bg_color,
)
def fit_texture_from_mesh(self, iters=512):
# a small training loop...
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam([
{'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
{'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
])
resolution = 512
print(f"[INFO] fitting texture...")
pbar = tqdm.trange(iters)
for i in pbar:
ver = np.random.randint(-45, 45)
hor = np.random.randint(-180, 180)
pose = orbit_camera(ver, hor, self.opt.radius)
proj = get_perspective(self.opt.fovy)
image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image']
image_pred = self.render(pose, proj, resolution, resolution)['image']
loss = loss_fn(image_pred, image_mesh)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"MSE = {loss.item():.6f}")
print(f"[INFO] finished fitting texture!")
def get_params(self):
params = [
{'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
{'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
]
if not self.opt.fix_geo:
params.append({'params': self.sdf, 'lr': self.opt.sdf_lr})
params.append({'params': self.deform, 'lr': self.opt.deform_lr})
return params
@torch.no_grad()
def export_mesh(self, save_path, texture_resolution=2048, padding=16):
# get v
sdf = self.sdf
deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5]
v, f = self.diffmc(sdf, deform)
v = (2 * v - 1) * self.grid_scale
f = f.int()
self.v, self.f = v, f
vertices = v.detach().cpu().numpy()
triangles = f.detach().cpu().numpy()
# clean
vertices = vertices.astype(np.float32)
triangles = triangles.astype(np.int32)
vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=self.opt.remesh_size)
# decimation
if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target:
vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target)
v = torch.from_numpy(vertices).contiguous().float().to(self.device)
f = torch.from_numpy(triangles).contiguous().int().to(self.device)
mesh = Mesh(v=v, f=f, albedo=None, device=self.device)
print(f"[INFO] uv unwrapping...")
mesh.auto_normal()
mesh.auto_uv()
# render uv maps
h = w = texture_resolution
uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
# masked query
xyzs = xyzs.view(-1, 3)
mask = (mask > 0).view(-1)
albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
if mask.any():
print(f"[INFO] querying texture...")
xyzs = xyzs[mask] # [M, 3]
# batched inference to avoid OOM
batch = []
head = 0
while head < xyzs.shape[0]:
tail = min(head + 640000, xyzs.shape[0])
batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
head += 640000
albedo[mask] = torch.cat(batch, dim=0)
albedo = albedo.view(h, w, -1)
mask = mask.view(h, w)
print(f"[INFO] uv padding...")
albedo = uv_padding(albedo, mask, padding)
mesh.albedo = albedo
mesh.write(save_path)
def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):
# do super-sampling
if ssaa != 1:
h = make_divisible(h0 * ssaa, 8)
w = make_divisible(w0 * ssaa, 8)
else:
h, w = h0, w0
results = {}
# get v
sdf = self.sdf
deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5]
v, f = self.diffmc(sdf, deform)
v = (2 * v - 1) * self.grid_scale
f = f.int()
self.v, self.f = v, f
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)
# get v_clip and render rgb
v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
v_clip = v_cam @ proj.T
rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [V, H, W, 1]
alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients!
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]
depth = depth.squeeze(0) # [H, W, 1]
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
xyzs = xyzs.view(-1, 3)
mask = (alpha > 0).view(-1)
color = torch.zeros_like(xyzs, dtype=torch.float32)
if mask.any():
masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1)))
color[mask] = masked_albedo.float()
color = color.view(1, h, w, 3)
# antialias
color = dr.antialias(color, rast, v_clip, f).squeeze(0) # [H, W, 3]
color = alpha * color + (1 - alpha) * bg_color
# get vn and render normal
i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()
v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0)
face_normals = safe_normalize(face_normals)
vn = torch.zeros_like(v)
vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f)
normal = safe_normalize(normal[0])
# rotated normal (where [0, 0, 1] always faces camera)
rot_normal = normal @ pose[:3, :3]
viewcos = rot_normal[..., [2]]
# ssaa
if ssaa != 1:
color = scale_img_hwc(color, (h0, w0))
alpha = scale_img_hwc(alpha, (h0, w0))
depth = scale_img_hwc(depth, (h0, w0))
normal = scale_img_hwc(normal, (h0, w0))
viewcos = scale_img_hwc(viewcos, (h0, w0))
results['image'] = color.clamp(0, 1)
results['alpha'] = alpha
results['depth'] = depth
results['normal'] = (normal + 1) / 2
results['viewcos'] = viewcos
return results
================================================
FILE: threefiner/renderer/mesh_renderer.py
================================================
import os
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import nvdiffrast.torch as dr
from kiui.mesh import Mesh
from kiui.mesh_utils import clean_mesh, decimate_mesh
from kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding
from kiui.cam import orbit_camera, get_perspective
from threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder
def render_mesh(
glctx,
v, f,
vt, ft, albedo, vc,
vn, fn,
pose, proj,
h0, w0,
ssaa=1, bg_color=1,
texture_filter='linear-mipmap-linear',
color_activation=None,
):
# do super-sampling
if ssaa != 1:
h = make_divisible(h0 * ssaa, 8)
w = make_divisible(w0 * ssaa, 8)
else:
h, w = h0, w0
results = {}
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)
# get v_clip and render rgb
v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
v_clip = v_cam @ proj.T
rast, rast_db = dr.rasterize(glctx, v_clip, f, (h, w))
alpha = (rast[0, ..., 3:] > 0).float()
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]
depth = depth.squeeze(0) # [H, W, 1]
if vc is not None:
# use vertex color
color, _ = dr.interpolate(vc.unsqueeze(0).contiguous(), rast, f)
else:
# use texture image
texc, texc_db = dr.interpolate(vt.unsqueeze(0).contiguous(), rast, ft, rast_db=rast_db, diff_attrs='all')
color = dr.texture(albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3]
if color_activation is not None:
color = color_activation(color)
# antialias
color = dr.antialias(color, rast, v_clip, f).squeeze(0) # [H, W, 3]
color = alpha * color + (1 - alpha) * bg_color
# get vn and render normal
if vn is None:
i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()
v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0)
face_normals = safe_normalize(face_normals)
vn = torch.zeros_like(v)
vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, fn)
normal = safe_normalize(normal[0])
# rotated normal (where [0, 0, 1] always faces camera)
rot_normal = normal @ pose[:3, :3]
viewcos = rot_normal[..., [2]]
# ssaa
if ssaa != 1:
color = scale_img_hwc(color, (h0, w0))
alpha = scale_img_hwc(alpha, (h0, w0))
depth = scale_img_hwc(depth, (h0, w0))
normal = scale_img_hwc(normal, (h0, w0))
viewcos = scale_img_hwc(viewcos, (h0, w0))
results['image'] = color.clamp(0, 1)
results['alpha'] = alpha
results['depth'] = depth
results['normal'] = (normal + 1) / 2
results['viewcos'] = viewcos
return results
class Renderer(nn.Module):
def __init__(self, opt, device):
super().__init__()
self.opt = opt
self.device = device
self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir)
# it's necessary to clean the mesh to facilitate later remeshing!
vertices = self.mesh.v.detach().cpu().numpy()
triangles = self.mesh.f.detach().cpu().numpy()
vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False)
self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
self.glctx = dr.RasterizeGLContext()
else:
self.glctx = dr.RasterizeCudaContext()
# extract trainable parameters
self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v))
# texture
if self.opt.tex_mode == 'hashgrid':
self.encoder = HashGridEncoder().to(self.device)
elif self.opt.tex_mode == 'mlp':
self.encoder = FrequencyEncoder().to(self.device)
elif self.opt.tex_mode == 'triplane':
self.encoder = TriplaneEncoder().to(self.device)
else:
raise NotImplementedError(f"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}")
self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=True).to(self.device)
# init hashgrid texture from mesh
if self.opt.fit_tex:
self.fit_texture_from_mesh(self.opt.fit_tex_iters)
def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):
return render_mesh(
self.glctx,
self.mesh.v, self.mesh.f, self.mesh.vt,
self.mesh.ft, self.mesh.albedo,
self.mesh.vc, self.mesh.vn, self.mesh.fn,
pose, proj, h, w,
ssaa=ssaa, bg_color=bg_color,
)
def fit_texture_from_mesh(self, iters=512):
# a small training loop...
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam([
{'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
{'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
])
resolution = 512
print(f"[INFO] fitting texture...")
pbar = tqdm.trange(iters)
for i in pbar:
ver = np.random.randint(-45, 45)
hor = np.random.randint(-180, 180)
pose = orbit_camera(ver, hor, self.opt.radius)
proj = get_perspective(self.opt.fovy)
image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image']
image_pred = self.render(pose, proj, resolution, resolution)['image']
loss = loss_fn(image_pred, image_mesh)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"MSE = {loss.item():.6f}")
print(f"[INFO] finished fitting texture!")
def get_params(self):
params = [
{'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
{'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
]
if not self.opt.fix_geo:
params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr})
return params
@torch.no_grad()
def export_mesh(self, save_path, texture_resolution=2048, padding=16):
mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)
print(f"[INFO] uv unwrapping...")
mesh.auto_normal()
mesh.auto_uv()
# render uv maps
h = w = texture_resolution
uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]
# masked query
xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1]
xyzs = xyzs.view(-1, 3)
mask = (mask > 0).view(-1)
albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
if mask.any():
print(f"[INFO] querying texture...")
xyzs = xyzs[mask] # [M, 3]
# batched inference to avoid OOM
batch = []
head = 0
while head < xyzs.shape[0]:
tail = min(head + 640000, xyzs.shape[0])
batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
head += 640000
albedo[mask] = torch.cat(batch, dim=0)
albedo = albedo.view(h, w, -1)
mask = mask.view(h, w)
print(f"[INFO] uv padding...")
albedo = uv_padding(albedo, mask, padding)
mesh.albedo = albedo
mesh.write(save_path)
@property
def v(self):
if self.opt.fix_geo:
return self.mesh.v
else:
return self.mesh.v + self.v_offsets
@property
def f(self):
return self.mesh.f
@torch.no_grad()
def remesh(self):
vertices = self.v.detach().cpu().numpy()
triangles = self.f.detach().cpu().numpy()
vertices, triangles = clean_mesh(vertices, triangles, repair=False, remesh=True, remesh_size=self.opt.remesh_size)
if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target:
vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target, optimalplacement=False)
self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v)).to(self.device)
def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):
# do super-sampling
if ssaa != 1:
h = make_divisible(h0 * ssaa, 8)
w = make_divisible(w0 * ssaa, 8)
else:
h, w = h0, w0
results = {}
# get v
v = self.v
f = self.f
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)
# get v_clip and render rgb
v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
v_clip = v_cam @ proj.T
rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients!
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]
depth = depth.squeeze(0) # [H, W, 1]
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
xyzs = xyzs.view(-1, 3)
mask = (alpha > 0).view(-1)
color = torch.zeros_like(xyzs, dtype=torch.float32)
if mask.any():
masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1)))
color[mask] = masked_albedo.float()
color = color.view(1, h, w, 3)
# antialias
color = dr.antialias(color, rast, v_clip, f).clamp(0, 1).squeeze(0) # [H, W, 3]
color = alpha * color + (1 - alpha) * bg_color
# get vn and render normal
if self.opt.fix_geo:
vn = self.mesh.vn
else:
i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()
v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0)
face_normals = safe_normalize(face_normals)
vn = torch.zeros_like(v)
vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f)
normal = safe_normalize(normal[0])
# rotated normal (where [0, 0, 1] always faces camera)
rot_normal = normal @ pose[:3, :3]
viewcos = rot_normal[..., [2]]
# ssaa
if ssaa != 1:
color = scale_img_hwc(color, (h0, w0))
alpha = scale_img_hwc(alpha, (h0, w0))
depth = scale_img_hwc(depth, (h0, w0))
normal = scale_img_hwc(normal, (h0, w0))
viewcos = scale_img_hwc(viewcos, (h0, w0))
results['image'] = color
results['alpha'] = alpha
results['depth'] = depth
results['normal'] = (normal + 1) / 2
results['viewcos'] = viewcos
return results
================================================
FILE: threefiner/renderer/pbr_diffmc_renderer.py
================================================
import os
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import envlight
import nvdiffrast.torch as dr
import kiui
from kiui.mesh import Mesh
from kiui.mesh_utils import clean_mesh, decimate_mesh
from kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding
from kiui.cam import orbit_camera, get_perspective
from threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder
from threefiner.renderer.mesh_renderer import render_mesh
from diso import DiffMC, DiffDMC
class Renderer(nn.Module):
def __init__(self, opt, device):
super().__init__()
self.opt = opt
self.device = device
if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
self.glctx = dr.RasterizeGLContext()
else:
self.glctx = dr.RasterizeCudaContext()
# diffmc
self.verts = torch.stack(
torch.meshgrid(
torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),
torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),
torch.linspace(-1, 1, self.opt.mc_grid_size, device=device),
indexing="ij",
), dim=-1,
) # [N, N, N, 3]
self.grid_scale = 1
self.diffmc = DiffMC(dtype=torch.float32).to(device)
# vert sdf and deform
self.sdf = nn.Parameter(torch.zeros_like(self.verts[..., 0]))
self.deform = nn.Parameter(torch.zeros_like(self.verts))
# init diffmc from mesh
self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir)
vertices = self.mesh.v.detach().cpu().numpy()
triangles = self.mesh.f.detach().cpu().numpy()
vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False)
self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
self.grid_scale = self.mesh.v.abs().max() + 1e-1
self.verts = self.verts * self.grid_scale
try:
import cubvh
BVH = cubvh.cuBVH(self.mesh.v, self.mesh.f)
sdf, _, _ = BVH.signed_distance(self.verts.reshape(-1, 3), return_uvw=False, mode='raystab') # some mesh may not be watertight...
except:
from pysdf import SDF
sdf_func = SDF(self.mesh.v.detach().cpu().numpy(), self.mesh.f.detach().cpu().numpy())
sdf = sdf_func(self.verts.detach().cpu().numpy().reshape(-1, 3))
sdf = torch.from_numpy(sdf).to(self.device)
sdf *= -1
# OUTER is POSITIVE
self.sdf.data += sdf.reshape(*self.sdf.data.shape).to(self.sdf.data.dtype)
# texture
if self.opt.tex_mode == 'hashgrid':
self.encoder = HashGridEncoder().to(self.device)
elif self.opt.tex_mode == 'mlp':
self.encoder = FrequencyEncoder().to(self.device)
elif self.opt.tex_mode == 'triplane':
self.encoder = TriplaneEncoder().to(self.device)
else:
raise NotImplementedError(f"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}")
self.mlp = MLP(self.encoder.output_dim, 3+2, 32, 2, bias=True).to(self.device)
self.v, self.f = None, None # placeholder
# env light
if self.opt.env_texture is None:
hdr_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../lights/mud_road_puresky_1k.hdr')
else:
hdr_path = self.opt.env_texture
self.light = envlight.EnvLight(hdr_path, scale=self.opt.env_scale, device=self.device)
FG_LUT = torch.from_numpy(np.fromfile(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../lights/bsdf_256_256.bin"), dtype=np.float32).reshape(1, 256, 256, 2)).to(self.device)
self.register_buffer("FG_LUT", FG_LUT)
# init hashgrid texture from mesh
if self.opt.fit_tex:
self.fit_texture_from_mesh(self.opt.fit_tex_iters)
def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):
return render_mesh(
self.glctx,
self.mesh.v, self.mesh.f, self.mesh.vt,
self.mesh.ft, self.mesh.albedo,
self.mesh.vc, self.mesh.vn, self.mesh.fn,
pose, proj, h, w,
ssaa=ssaa, bg_color=bg_color,
)
def fit_texture_from_mesh(self, iters=512):
# a small training loop...
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam([
{'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
{'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
])
resolution = 512
print(f"[INFO] fitting texture...")
pbar = tqdm.trange(iters)
for i in pbar:
ver = np.random.randint(-45, 45)
hor = np.random.randint(-180, 180)
pose = orbit_camera(ver, hor, self.opt.radius)
proj = get_perspective(self.opt.fovy)
image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image']
image_pred = self.render(pose, proj, resolution, resolution)['image']
loss = loss_fn(image_pred, image_mesh)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"MSE = {loss.item():.6f}")
print(f"[INFO] finished fitting texture!")
def get_params(self):
params = [
{'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
{'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
]
if not self.opt.fix_geo:
params.append({'params': self.sdf, 'lr': self.opt.sdf_lr})
params.append({'params': self.deform, 'lr': self.opt.deform_lr})
return params
@torch.no_grad()
def export_mesh(self, save_path, texture_resolution=2048, padding=16):
# get v
sdf = self.sdf
deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5]
v, f = self.diffmc(sdf, deform)
v = (2 * v - 1) * self.grid_scale
f = f.int()
self.v, self.f = v, f
vertices = v.detach().cpu().numpy()
triangles = f.detach().cpu().numpy()
# clean
vertices = vertices.astype(np.float32)
triangles = triangles.astype(np.int32)
vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=self.opt.remesh_size)
# decimation
if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target:
vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target)
v = torch.from_numpy(vertices).contiguous().float().to(self.device)
f = torch.from_numpy(triangles).contiguous().int().to(self.device)
mesh = Mesh(v=v, f=f, albedo=None, device=self.device)
print(f"[INFO] uv unwrapping...")
mesh.auto_normal()
mesh.auto_uv()
# render uv maps
h = w = texture_resolution
uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
# masked query
xyzs = xyzs.view(-1, 3)
mask = (mask > 0).view(-1)
material = torch.zeros(h * w, 5, device=self.device, dtype=torch.float32)
if mask.any():
print(f"[INFO] querying texture...")
xyzs = xyzs[mask] # [M, 3]
# batched inference to avoid OOM
batch = []
head = 0
while head < xyzs.shape[0]:
tail = min(head + 640000, xyzs.shape[0])
batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
head += 640000
material[mask] = torch.cat(batch, dim=0)
material = material.view(h, w, -1)
mask = mask.view(h, w)
print(f"[INFO] uv padding...")
material = uv_padding(material, mask, padding)
mesh.albedo = material[..., :3]
mesh.metallicRoughness = torch.cat([torch.zeros_like(material[..., 3:4]), material[..., 4:5], material[..., 3:4]], dim=-1)
mesh.write(save_path)
def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):
# do super-sampling
if ssaa != 1:
h = make_divisible(h0 * ssaa, 8)
w = make_divisible(w0 * ssaa, 8)
else:
h, w = h0, w0
results = {}
# get v
sdf = self.sdf
deform = torch.tanh(self.deform) / 2 # [-0.5, 0.5]
v, f = self.diffmc(sdf, deform)
v = (2 * v - 1) * self.grid_scale
f = f.int()
self.v, self.f = v, f
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)
# get v_clip and render rgb
v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
v_clip = v_cam @ proj.T
rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [V, H, W, 1]
alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients!
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]
depth = depth.squeeze(0) # [H, W, 1]
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
viewdir = safe_normalize(xyzs - pose[:3, 3]).squeeze(0)
xyzs = xyzs.view(-1, 3)
mask = (alpha > 0).view(-1)
material = torch.zeros(xyzs.shape[0], 5, dtype=torch.float32, device=xyzs.device)
if mask.any():
masked_material = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1)))
material[mask] = masked_material.float()
material = material.view(h, w, -1)
# get vn and render normal
i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()
v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0)
face_normals = safe_normalize(face_normals)
vn = torch.zeros_like(v)
vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f)
normal = safe_normalize(normal[0])
# rotated normal (where [0, 0, 1] always faces camera)
rot_normal = normal @ pose[:3, :3]
viewcos = rot_normal[..., [2]]
# shading
albedo = material[..., :3]
metallic = material[..., 3:4]
roughness = material[..., 4:5]
n_dot_v = (normal * viewdir).sum(-1, keepdim=True) # [H, W, 1]
reflective = n_dot_v * normal * 2 - viewdir
diffuse_albedo = (1 - metallic) * albedo
fg_uv = torch.cat([n_dot_v, roughness], -1).clamp(0, 1) # [H, W, 2]
fg = dr.texture(
self.FG_LUT,
fg_uv.reshape(1, -1, 1, 2).contiguous(),
filter_mode="linear",
boundary_mode="clamp",
).reshape(h, w, 2)
F0 = (1 - metallic) * 0.04 + metallic * albedo
specular_albedo = F0 * fg[..., 0:1] + fg[..., 1:2]
diffuse_light = self.light(normal)
specular_light = self.light(reflective, roughness)
color = diffuse_albedo * diffuse_light + specular_albedo * specular_light # [H, W, 3]
# antialias
color = dr.antialias(color.unsqueeze(0), rast, v_clip, f).clamp(0, 1).squeeze(0) # [H, W, 3]
color = alpha * color + (1 - alpha) * bg_color
# ssaa
if ssaa != 1:
color = scale_img_hwc(color, (h0, w0))
alpha = scale_img_hwc(alpha, (h0, w0))
depth = scale_img_hwc(depth, (h0, w0))
normal = scale_img_hwc(normal, (h0, w0))
viewcos = scale_img_hwc(viewcos, (h0, w0))
results['image'] = color
results['alpha'] = alpha
results['depth'] = depth
results['normal'] = (normal + 1) / 2
results['viewcos'] = viewcos
return results
================================================
FILE: threefiner/renderer/pbr_mesh_renderer.py
================================================
import os
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import envlight
import nvdiffrast.torch as dr
from kiui.mesh import Mesh
from kiui.mesh_utils import clean_mesh, decimate_mesh
from kiui.op import safe_normalize, scale_img_hwc, make_divisible, uv_padding
from kiui.cam import orbit_camera, get_perspective
from threefiner.nn import MLP, HashGridEncoder, FrequencyEncoder, TriplaneEncoder
from threefiner.renderer.mesh_renderer import render_mesh
class Renderer(nn.Module):
def __init__(self, opt, device):
super().__init__()
self.opt = opt
self.device = device
self.mesh = Mesh.load(self.opt.mesh, bound=0.9, front_dir=self.opt.front_dir)
# it's necessary to clean the mesh to facilitate later remeshing!
vertices = self.mesh.v.detach().cpu().numpy()
triangles = self.mesh.f.detach().cpu().numpy()
vertices, triangles = clean_mesh(vertices, triangles, min_f=32, min_d=10, remesh=False)
self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
self.glctx = dr.RasterizeGLContext()
else:
self.glctx = dr.RasterizeCudaContext()
# extract trainable parameters
self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v))
# texture
if self.opt.tex_mode == 'hashgrid':
self.encoder = HashGridEncoder().to(self.device)
elif self.opt.tex_mode == 'mlp':
self.encoder = FrequencyEncoder().to(self.device)
elif self.opt.tex_mode == 'triplane':
self.encoder = TriplaneEncoder().to(self.device)
else:
raise NotImplementedError(f"unsupported texture mode: {self.opt.tex_mode} for {self.opt.geom_mode}")
self.mlp = MLP(self.encoder.output_dim, 3+2, 32, 2, bias=True).to(self.device)
# env light
if self.opt.env_texture is None:
hdr_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../lights/mud_road_puresky_1k.hdr')
else:
hdr_path = self.opt.env_texture
self.light = envlight.EnvLight(hdr_path, scale=self.opt.env_scale, device=self.device)
FG_LUT = torch.from_numpy(np.fromfile(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../lights/bsdf_256_256.bin"), dtype=np.float32).reshape(1, 256, 256, 2)).to(self.device)
self.register_buffer("FG_LUT", FG_LUT)
# init hashgrid texture from mesh
if self.opt.fit_tex:
self.fit_texture_from_mesh(self.opt.fit_tex_iters)
def render_mesh(self, pose, proj, h, w, ssaa=1, bg_color=1):
return render_mesh(
self.glctx,
self.mesh.v, self.mesh.f, self.mesh.vt,
self.mesh.ft, self.mesh.albedo,
self.mesh.vc, self.mesh.vn, self.mesh.fn,
pose, proj, h, w,
ssaa=ssaa, bg_color=bg_color,
)
def fit_texture_from_mesh(self, iters=512):
# a small training loop...
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam([
{'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
{'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
])
resolution = 512
print(f"[INFO] fitting texture...")
pbar = tqdm.trange(iters)
for i in pbar:
ver = np.random.randint(-45, 45)
hor = np.random.randint(-180, 180)
pose = orbit_camera(ver, hor, self.opt.radius)
proj = get_perspective(self.opt.fovy)
image_mesh = self.render_mesh(pose, proj, resolution, resolution)['image']
image_pred = self.render(pose, proj, resolution, resolution)['image']
loss = loss_fn(image_pred, image_mesh)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"MSE = {loss.item():.6f}")
print(f"[INFO] finished fitting texture!")
def get_params(self):
params = [
{'params': self.encoder.parameters(), 'lr': self.opt.hashgrid_lr},
{'params': self.mlp.parameters(), 'lr': self.opt.mlp_lr},
]
if not self.opt.fix_geo:
params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr})
return params
@torch.no_grad()
def export_mesh(self, save_path, texture_resolution=2048, padding=16):
mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)
print(f"[INFO] uv unwrapping...")
mesh.auto_normal()
mesh.auto_uv()
# render uv maps
h = w = texture_resolution
uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]
# masked query
xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1]
xyzs = xyzs.view(-1, 3)
mask = (mask > 0).view(-1)
material = torch.zeros(h * w, 5, device=self.device, dtype=torch.float32)
if mask.any():
print(f"[INFO] querying texture...")
xyzs = xyzs[mask] # [M, 3]
# batched inference to avoid OOM
batch = []
head = 0
while head < xyzs.shape[0]:
tail = min(head + 640000, xyzs.shape[0])
batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
head += 640000
material[mask] = torch.cat(batch, dim=0)
material = material.view(h, w, -1)
mask = mask.view(h, w)
print(f"[INFO] uv padding...")
material = uv_padding(material, mask, padding)
mesh.albedo = material[..., :3]
mesh.metallicRoughness = torch.cat([torch.zeros_like(material[..., 3:4]), material[..., 4:5], material[..., 3:4]], dim=-1)
mesh.write(save_path)
@property
def v(self):
if self.opt.fix_geo:
return self.mesh.v
else:
return self.mesh.v + self.v_offsets
@property
def f(self):
return self.mesh.f
@torch.no_grad()
def remesh(self):
vertices = self.v.detach().cpu().numpy()
triangles = self.f.detach().cpu().numpy()
vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=self.opt.remesh_size)
if self.opt.decimate_target > 0 and triangles.shape[0] > self.opt.decimate_target:
vertices, triangles = decimate_mesh(vertices, triangles, self.opt.decimate_target, optimalplacement=False)
self.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
self.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v)).to(self.device)
def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1):
# do super-sampling
if ssaa != 1:
h = make_divisible(h0 * ssaa, 8)
w = make_divisible(w0 * ssaa, 8)
else:
h, w = h0, w0
results = {}
# get v
v = self.v
f = self.f
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)
# get v_clip and render rgb
v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
v_clip = v_cam @ proj.T
rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(0) # important to enable gradients!
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, f) # [1, H, W, 1]
depth = depth.squeeze(0) # [H, W, 1]
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
viewdir = safe_normalize(xyzs - pose[:3, 3]).squeeze(0)
xyzs = xyzs.view(-1, 3)
mask = (alpha > 0).view(-1)
material = torch.zeros(xyzs.shape[0], 5, dtype=torch.float32, device=xyzs.device)
if mask.any():
masked_material = torch.sigmoid(self.mlp(self.encoder(xyzs[mask], bound=1)))
material[mask] = masked_material.float()
material = material.view(h, w, -1)
# get vn and render normal
if self.opt.fix_geo:
vn = self.mesh.vn
else:
i0, i1, i2 = f[:, 0].long(), f[:, 1].long(), f[:, 2].long()
v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0)
face_normals = safe_normalize(face_normals)
vn = torch.zeros_like(v)
vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, f)
normal = safe_normalize(normal[0])
# rotated normal (where [0, 0, 1] always faces camera)
rot_normal = normal @ pose[:3, :3]
viewcos = rot_normal[..., [2]]
# shading
albedo = material[..., :3]
metallic = material[..., 3:4]
roughness = material[..., 4:5]
n_dot_v = (normal * viewdir).sum(-1, keepdim=True) # [H, W, 1]
reflective = n_dot_v * normal * 2 - viewdir
diffuse_albedo = (1 - metallic) * albedo
fg_uv = torch.cat([n_dot_v, roughness], -1).clamp(0, 1) # [H, W, 2]
fg = dr.texture(
self.FG_LUT,
fg_uv.reshape(1, -1, 1, 2).contiguous(),
filter_mode="linear",
boundary_mode="clamp",
).reshape(h, w, 2)
F0 = (1 - metallic) * 0.04 + metallic * albedo
specular_albedo = F0 * fg[..., 0:1] + fg[..., 1:2]
diffuse_light = self.light(normal)
specular_light = self.light(reflective, roughness)
color = diffuse_albedo * diffuse_light + specular_albedo * specular_light # [H, W, 3]
# antialias
color = dr.antialias(color.unsqueeze(0), rast, v_clip, f).clamp(0, 1).squeeze(0) # [H, W, 3]
color = alpha * color + (1 - alpha) * bg_color
# ssaa
if ssaa != 1:
color = scale_img_hwc(color, (h0, w0))
alpha = scale_img_hwc(alpha, (h0, w0))
depth = scale_img_hwc(depth, (h0, w0))
normal = scale_img_hwc(normal, (h0, w0))
viewcos = scale_img_hwc(viewcos, (h0, w0))
results['image'] = color
results['alpha'] = alpha
results['depth'] = depth
results['normal'] = (normal + 1) / 2
results['viewcos'] = viewcos
return results