Showing preview only (1,062K chars total). Download the full file or copy to clipboard to get everything.
Repository: ashawkey/stable-dreamfusion
Branch: main
Commit: 5550b91862a3
Files: 132
Total size: 1016.8 KB
Directory structure:
gitextract_l054hyr6/
├── .github/
│ └── ISSUE_TEMPLATE/
│ ├── bug_report.yaml
│ └── feature_request.md
├── .gitignore
├── LICENSE
├── activation.py
├── assets/
│ ├── advanced.md
│ └── update_logs.md
├── config/
│ ├── anya.csv
│ ├── car.csv
│ └── corgi.csv
├── docker/
│ ├── Dockerfile
│ └── README.md
├── dpt.py
├── encoding.py
├── evaluation/
│ ├── Prompt.py
│ ├── mesh_to_video.py
│ ├── r_precision.py
│ └── readme.md
├── freqencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── freq.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── freqencoder.cu
│ └── freqencoder.h
├── gridencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── grid.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── gridencoder.cu
│ └── gridencoder.h
├── guidance/
│ ├── clip_utils.py
│ ├── if_utils.py
│ ├── perpneg_utils.py
│ ├── sd_utils.py
│ └── zero123_utils.py
├── ldm/
│ ├── extras.py
│ ├── guidance.py
│ ├── lr_scheduler.py
│ ├── models/
│ │ ├── autoencoder.py
│ │ └── diffusion/
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── plms.py
│ │ └── sampling_util.py
│ ├── modules/
│ │ ├── attention.py
│ │ ├── diffusionmodules/
│ │ │ ├── __init__.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ └── util.py
│ │ ├── distributions/
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── encoders/
│ │ │ ├── __init__.py
│ │ │ └── modules.py
│ │ ├── evaluate/
│ │ │ ├── adm_evaluator.py
│ │ │ ├── evaluate_perceptualsim.py
│ │ │ ├── frechet_video_distance.py
│ │ │ ├── ssim.py
│ │ │ └── torch_frechet_video_distance.py
│ │ ├── image_degradation/
│ │ │ ├── __init__.py
│ │ │ ├── bsrgan.py
│ │ │ ├── bsrgan_light.py
│ │ │ └── utils_image.py
│ │ ├── losses/
│ │ │ ├── __init__.py
│ │ │ ├── contperceptual.py
│ │ │ └── vqperceptual.py
│ │ └── x_transformer.py
│ ├── thirdp/
│ │ └── psp/
│ │ ├── helpers.py
│ │ ├── id_loss.py
│ │ └── model_irse.py
│ └── util.py
├── main.py
├── meshutils.py
├── nerf/
│ ├── gui.py
│ ├── network.py
│ ├── network_grid.py
│ ├── network_grid_taichi.py
│ ├── network_grid_tcnn.py
│ ├── provider.py
│ ├── renderer.py
│ └── utils.py
├── optimizer.py
├── preprocess_image.py
├── pretrained/
│ └── zero123/
│ └── sd-objaverse-finetune-c_concat-256.yaml
├── raymarching/
│ ├── __init__.py
│ ├── backend.py
│ ├── raymarching.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── raymarching.cu
│ └── raymarching.h
├── readme.md
├── requirements.txt
├── scripts/
│ ├── install_ext.sh
│ ├── res64.args
│ ├── run.sh
│ ├── run2.sh
│ ├── run3.sh
│ ├── run4.sh
│ ├── run5.sh
│ ├── run6.sh
│ ├── run_if.sh
│ ├── run_if2.sh
│ ├── run_if2_perpneg.sh
│ ├── run_image.sh
│ ├── run_image_anya.sh
│ ├── run_image_hard_examples.sh
│ ├── run_image_procedure.sh
│ ├── run_image_text.sh
│ └── run_images.sh
├── shencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── setup.py
│ ├── sphere_harmonics.py
│ └── src/
│ ├── bindings.cpp
│ ├── shencoder.cu
│ └── shencoder.h
├── taichi_modules/
│ ├── __init__.py
│ ├── hash_encoder.py
│ ├── intersection.py
│ ├── ray_march.py
│ ├── utils.py
│ ├── volume_render_test.py
│ └── volume_train.py
└── tets/
├── 128_tets.npz
├── 32_tets.npz
├── 64_tets.npz
├── README.md
└── generate_tets.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.yaml
================================================
name: Bug Report
description: File a bug report
title: "<title>"
labels: ["bug"]
body:
- type: markdown
attributes:
value: |
Before filing a bug report, [search for an existing issue](https://github.com/ashawkey/stable-dreamfusion/issues).
Also, ensure you are running the latest version.
- type: textarea
id: description
attributes:
label: Description
description: Provide a clear and concise description of what the bug is.
placeholder: Description
validations:
required: true
- type: textarea
id: steps
attributes:
label: Steps to Reproduce
description: List the steps needed to reproduce the issue.
placeholder: |
1. Go to '...'
2. Click on '...'
validations:
required: true
- type: textarea
id: expected-behavior
attributes:
label: Expected Behavior
description: Describe what you expected to happen.
placeholder: |
The 'action' would do 'some amazing thing'.
validations:
required: true
- type: textarea
id: environment
attributes:
label: Environment
description: Describe your environment.
placeholder: |
Ubuntu 22.04, PyTorch 1.13, CUDA 11.6
validations:
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: enhancement
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
================================================
FILE: .gitignore
================================================
__pycache__/
build/
*.egg-info/
*.so
venv_*/
tmp*
# data/
ldm/data/
data2
scripts2
trial*/
.vs/
TOKEN
*.ckpt
densegridencoder
tets/256_tets.npz
.vscode/launch.json
data2
data/car*
data/chair*
data/warrior*
data/wd*
data/space*
data/corgi*
data/turtle*
# Only keep the original image, not the automatically-generated depth, normals, rgba
data/baby_phoenix_on_ice_*
data/bollywood_actress_*
data/beach_house_1_*
data/beach_house_2_*
data/mona_lisa_*
data/futuristic_car_*
data/church_ruins_*
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: activation.py
================================================
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
class _trunc_exp(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float)
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.exp(x)
@staticmethod
@custom_bwd
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(max=15))
trunc_exp = _trunc_exp.apply
def biased_softplus(x, bias=0):
return torch.nn.functional.softplus(x - bias)
================================================
FILE: assets/advanced.md
================================================
# Code organization & Advanced tips
This is a simple description of the most important implementation details.
If you are interested in improving this repo, this might be a starting point.
Any contribution would be greatly appreciated!
* The SDS loss is located at `./guidance/sd_utils.py > StableDiffusion > train_step`:
```python
## 1. we need to interpolate the NeRF rendering to 512x512, to feed it to SD's VAE.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
## 2. image (512x512) --- VAE --> latents (64x64), this is SD's difference from Imagen.
latents = self.encode_imgs(pred_rgb_512)
... # timestep sampling, noise adding and UNet noise predicting
## 3. the SDS loss
w = (1 - self.alphas[t])
grad = w * (noise_pred - noise)
# since UNet part is ignored and cannot simply audodiff, we have two ways to set the grad:
# 3.1. call backward and set the grad now (need to retain graph since we will call a second backward for the other losses later)
latents.backward(gradient=grad, retain_graph=True)
return 0 # dummy loss
# 3.2. use a custom function to set a hook in backward, so we only call backward once (credits to @elliottzheng)
class SpecifyGradient(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, input_tensor, gt_grad):
ctx.save_for_backward(gt_grad)
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
@staticmethod
@custom_bwd
def backward(ctx, grad_scale):
gt_grad, = ctx.saved_tensors
gt_grad = gt_grad * grad_scale
return gt_grad, None
loss = SpecifyGradient.apply(latents, grad)
return loss # functional loss
# 3.3. reparameterization (credits to @Xallt)
# d(loss)/d(latents) = grad, since grad is already detached, it's this simple.
loss = (grad * latents).sum()
return loss
# 3.4. reparameterization (credits to threestudio)
# this is the same as 3.3, but the loss value only reflects the magnitude of grad, which is more informative.
targets = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents, targets, reduction='sum')
return loss
```
* Other regularizations are in `./nerf/utils.py > Trainer > train_step`.
* The generation seems quite sensitive to regularizations on weights_sum (alphas for each ray). The original opacity loss tends to make NeRF disappear (zero density everywhere), so we use an entropy loss to replace it for now (encourages alpha to be either 0 or 1).
* NeRF Rendering core function: `./nerf/renderer.py > NeRFRenderer > run & run_cuda`.
* Shading & normal evaluation: `./nerf/network*.py > NeRFNetwork > forward`.
* light direction: current implementation use a plane light source, instead of a point light source.
* View-dependent prompting: `./nerf/provider.py > get_view_direction`.
* use `--angle_overhead, --angle_front` to set the border.
* Network backbone (`./nerf/network*.py`) can be chosen by the `--backbone` option.
* Spatial density bias (density blob): `./nerf/network*.py > NeRFNetwork > density_blob`.
# Debugging
`debugpy-run` is a convenient way to remotely debug this project. Simply replace a command like this one:
```bash
python main.py --text "a hamburger" --workspace trial -O --vram_O
```
... with:
```bash
debugpy-run main.py -- --text "a hamburger" --workspace trial -O --vram_O
```
For more details: https://github.com/bulletmark/debugpy-run
# Axes and directions of polar, azimuth, etc. in NeRF and Zero123
<img width="1119" alt="NeRF_Zero123" src="https://github.com/ashawkey/stable-dreamfusion/assets/22424247/a0f432ff-2d08-45a4-a390-bda64f5cbc94">
This code refers to theta for polar, phi for azimuth.
================================================
FILE: assets/update_logs.md
================================================
### 2023.4.19
* Fix depth supervision, migrate depth estimation model to omnidata.
* Add normal supervision (also by omnidata).
https://user-images.githubusercontent.com/25863658/232403294-b77409bf-ddc7-4bb8-af32-ee0cc123825a.mp4
### 2023.4.7
Improvement on mesh quality & DMTet finetuning support.
https://user-images.githubusercontent.com/25863658/230535363-298c960e-bf9c-4906-8b96-cd60edcb24dd.mp4
### 2023.3.30
* adopt ideas from [Fantasia3D](https://fantasia3d.github.io/) to concatenate normal and mask as the latent code in a warm up stage, which shows faster convergence of shape.
https://user-images.githubusercontent.com/25863658/230535373-6ee28f16-bb21-4ec4-bc86-d46597361a04.mp4
### 2023.1.30
* Use an MLP to predict the surface normals as in Magic3D to avoid finite difference / second order gradient, generation quality is greatly improved.
* More efficient two-pass raymarching in training inspired by nerfacc.
https://user-images.githubusercontent.com/25863658/215996308-9fd959f5-b5c7-4a8e-a241-0fe63ec86a4a.mp4
### 2022.12.3
* Support Stable-diffusion 2.0 base.
### 2022.11.15
* Add the vanilla backbone that is pure-pytorch.
### 2022.10.9
* The shading (partially) starts to work, at least it won't make scene empty. For some prompts, it shows better results (less severe Janus problem). The textureless rendering mode is still disabled.
* Enable shading by default (--latent_iter_ratio 1000).
### 2022.10.5
* Basic reproduction finished.
* Non --cuda_ray, --tcnn are not working, need to fix.
* Shading is not working, disabled in utils.py for now. Surface normals are bad.
* Use an entropy loss to regularize weights_sum (alpha), the original L2 reg always leads to degenerated geometry...
https://user-images.githubusercontent.com/25863658/194241493-f3e68f78-aefe-479e-a4a8-001424a61b37.mp4
================================================
FILE: config/anya.csv
================================================
zero123_weight, radius, polar, azimuth, image
1, 3, 90, 0, data/anya_front_rgba.png
1, 3, 90, 180, data/anya_back_rgba.png
================================================
FILE: config/car.csv
================================================
zero123_weight, radius, polar, azimuth, image
4, 3.2, 90, 0, data/car_left_rgba.png
1, 3, 90, 90, data/car_front_rgba.png
4, 3.2, 90, 180, data/car_right_rgba.png
1, 3, 90, -90, data/car_back_rgba.png
================================================
FILE: config/corgi.csv
================================================
zero123_weight, radius, polar, azimuth, image
1, 3.2, 90, 0, data/corgi_puppy_sitting_looking_up_rgba.png
================================================
FILE: docker/Dockerfile
================================================
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
# Remove any third-party apt sources to avoid issues with expiring keys.
RUN rm -f /etc/apt/sources.list.d/*.list
RUN apt-get update
RUN DEBIAN_FRONTEND=noninteractive TZ=Europe/MADRID apt-get install -y tzdata
# Install some basic utilities
RUN apt-get install -y \
curl \
ca-certificates \
sudo \
git \
bzip2 \
libx11-6 \
python3 \
python3-pip \
libglfw3-dev \
libgles2-mesa-dev \
libglib2.0-0 \
&& rm -rf /var/lib/apt/lists/*
# Create a working directory
RUN mkdir /app
WORKDIR /app
RUN cd /app
RUN git clone https://github.com/ashawkey/stable-dreamfusion.git
RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
WORKDIR /app/stable-dreamfusion
RUN pip3 install -r requirements.txt
RUN pip3 install git+https://github.com/NVlabs/nvdiffrast/
# Needs nvidia runtime, if you have "No CUDA runtime is found" error: https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime, first answer
RUN pip3 install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
RUN pip3 install git+https://github.com/openai/CLIP.git
RUN bash scripts/install_ext.sh
# Set the default command to python3
#CMD ["python3"]
================================================
FILE: docker/README.md
================================================
### Docker installation
## Build image
To build the docker image on your own machine, which may take 15-30 mins:
```
docker build -t stable-dreamfusion:latest .
```
If you have the error **No CUDA runtime is found** when building the wheels for tiny-cuda-nn you need to setup the nvidia-runtime for docker.
```
sudo apt-get install nvidia-container-runtime
```
Then edit `/etc/docker/daemon.json` and add the default-runtime:
```
{
"runtimes": {
"nvidia": {
"path": "nvidia-container-runtime",
"runtimeArgs": []
}
},
"default-runtime": "nvidia"
}
```
And restart docker:
```
sudo systemctl restart docker
```
Now you can build tiny-cuda-nn inside docker.
## Download image
To download the image (~6GB) instead:
```
docker pull supercabb/stable-dreamfusion:3080_0.0.1
docker tag supercabb/stable-dreamfusion:3080_0.0.1 stable-dreamfusion
```
## Use image
You can launch an interactive shell inside the container:
```
docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash
```
From this shell, all the code in the repo should work.
To run any single command `<command...>` inside the docker container:
```
docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "<command...>"
```
To train:
```
export TOKEN="#HUGGING FACE ACCESS TOKEN#"
docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "echo ${TOKEN} > TOKEN \
&& python3 main.py --text \"a hamburger\" --workspace trial -O"
```
Run test without gui:
```
export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
main.py --workspace trial -O --test"
```
Run test with gui:
```
export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
xhost +
docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
main.py --workspace trial -O --test --gui"
xhost -
```
================================================
FILE: dpt.py
================================================
import math
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'))
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)
def unflatten_with_named_tensor(input, dim, sizes):
"""Workaround for unflattening with named tensor."""
# tracer acts up with unflatten. See https://github.com/pytorch/pytorch/issues/49538
new_shape = list(input.shape)[:dim] + list(sizes) + list(input.shape)[dim+1:]
return input.view(*new_shape)
class Slice(nn.Module):
def __init__(self, start_index=1):
super(Slice, self).__init__()
self.start_index = start_index
def forward(self, x):
return x[:, self.start_index :]
class AddReadout(nn.Module):
def __init__(self, start_index=1):
super(AddReadout, self).__init__()
self.start_index = start_index
def forward(self, x):
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index :] + readout.unsqueeze(1)
class ProjectReadout(nn.Module):
def __init__(self, in_features, start_index=1):
super(ProjectReadout, self).__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
features = torch.cat((x[:, self.start_index :], readout), -1)
return self.project(features)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x
def forward_vit(pretrained, x):
b, c, h, w = x.shape
glob = pretrained.model.forward_flex(x)
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_4 = pretrained.activations["4"]
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
unflattened_dim = 2
unflattened_size = (
int(torch.div(h, pretrained.model.patch_size[1], rounding_mode='floor')),
int(torch.div(w, pretrained.model.patch_size[0], rounding_mode='floor')),
)
unflatten = nn.Sequential(nn.Unflatten(unflattened_dim, unflattened_size))
if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten_with_named_tensor(layer_3, unflattened_dim, unflattened_size)
if layer_4.ndim == 3:
layer_4 = unflatten_with_named_tensor(layer_4, unflattened_dim, unflattened_size)
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
return layer_1, layer_2, layer_3, layer_4
def _resize_pos_embed(self, posemb, gs_h, gs_w):
posemb_tok, posemb_grid = (
posemb[:, : self.start_index],
posemb[0, self.start_index :],
)
gs_old = int(math.sqrt(posemb_grid.shape[0]))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def forward_flex(self, x):
b, c, h, w = x.shape
pos_embed = self._resize_pos_embed(
self.pos_embed, torch.div(h, self.patch_size[1], rounding_mode='floor'), torch.div(w, self.patch_size[0], rounding_mode='floor')
)
B = x.shape[0]
if hasattr(self.patch_embed, "backbone"):
x = self.patch_embed.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
if getattr(self, "dist_token", None) is not None:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
activations = {}
def get_activation(name):
def hook(model, input, output):
activations[name] = output
return hook
def get_readout_oper(vit_features, features, use_readout, start_index=1):
if use_readout == "ignore":
readout_oper = [Slice(start_index)] * len(features)
elif use_readout == "add":
readout_oper = [AddReadout(start_index)] * len(features)
elif use_readout == "project":
readout_oper = [
ProjectReadout(vit_features, start_index) for out_feat in features
]
else:
assert (
False
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
return readout_oper
def _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
size=[384, 384],
hooks=[2, 5, 8, 11],
vit_features=768,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
# 32, 48, 136, 384
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
hooks = [5, 11, 17, 23] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[256, 512, 1024, 1024],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
)
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model(
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout,
start_index=2,
)
def _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=[0, 1, 8, 11],
vit_features=768,
use_vit_only=False,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
if use_vit_only == True:
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
else:
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
get_activation("1")
)
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
get_activation("2")
)
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
if use_vit_only == True:
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
else:
pretrained.act_postprocess1 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess2 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitb_rn50_384(
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
):
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
hooks = [0, 1, 8, 11] if hooks == None else hooks
return _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
if backbone == "vitl16_384":
pretrained = _make_pretrained_vitl16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[256, 512, 1024, 1024], features, groups=groups, expand=expand
) # ViT-L/16 - 85.0% Top1 (backbone)
elif backbone == "vitb_rn50_384":
pretrained = _make_pretrained_vitb_rn50_384(
use_pretrained,
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
scratch = _make_scratch(
[256, 512, 768, 768], features, groups=groups, expand=expand
) # ViT-H/16 - 85.0% Top1 (backbone)
elif backbone == "vitb16_384":
pretrained = _make_pretrained_vitb16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[96, 192, 384, 768], features, groups=groups, expand=expand
) # ViT-B/16 - 84.6% Top1 (backbone)
elif backbone == "resnext101_wsl":
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
elif backbone == "efficientnet_lite3":
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
else:
print(f"Backbone '{backbone}' not implemented")
assert False
return pretrained, scratch
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand==True:
out_shape1 = out_shape
out_shape2 = out_shape*2
out_shape3 = out_shape*4
out_shape4 = out_shape*8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
efficientnet = torch.hub.load(
"rwightman/gen-efficientnet-pytorch",
"tf_efficientnet_lite3",
pretrained=use_pretrained,
exportable=exportable
)
return _make_efficientnet_backbone(efficientnet)
def _make_efficientnet_backbone(effnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
)
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
return pretrained
def _make_resnet_backbone(resnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
)
pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4
return pretrained
def _make_pretrained_resnext101_wsl(use_pretrained):
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
return _make_resnet_backbone(resnet)
class Interpolate(nn.Module):
"""Interpolation module.
"""
def __init__(self, scale_factor, mode, align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
)
return x
class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=True
)
return output
class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups=1
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
if self.bn==True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn==True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn==True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
# return out + x
class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand==True:
out_features = features//2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
# output += res
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output
def _make_fusion_block(features, use_bn):
return FeatureFusionBlock_custom(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
)
class DPT(BaseModel):
def __init__(
self,
head,
features=256,
backbone="vitb_rn50_384",
readout="project",
channels_last=False,
use_bn=False,
):
super(DPT, self).__init__()
self.channels_last = channels_last
hooks = {
"vitb_rn50_384": [0, 1, 8, 11],
"vitb16_384": [2, 5, 8, 11],
"vitl16_384": [5, 11, 17, 23],
}
# Instantiate backbone and reassemble blocks
self.pretrained, self.scratch = _make_encoder(
backbone,
features,
True, # Set to true of you want to train from scratch, uses ImageNet weights
groups=1,
expand=False,
exportable=False,
hooks=hooks[backbone],
use_readout=readout,
)
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
self.scratch.output_conv = head
def forward(self, x):
if self.channels_last == True:
x.contiguous(memory_format=torch.channels_last)
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return out
class DPTDepthModel(DPT):
def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs):
features = kwargs["features"] if "features" in kwargs else 256
head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
super().__init__(head, **kwargs)
if path is not None:
self.load(path)
def forward(self, x):
return super().forward(x).squeeze(dim=1)
================================================
FILE: encoding.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class FreqEncoder_torch(nn.Module):
def __init__(self, input_dim, max_freq_log2, N_freqs,
log_sampling=True, include_input=True,
periodic_fns=(torch.sin, torch.cos)):
super().__init__()
self.input_dim = input_dim
self.include_input = include_input
self.periodic_fns = periodic_fns
self.N_freqs = N_freqs
self.output_dim = 0
if self.include_input:
self.output_dim += self.input_dim
self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
if log_sampling:
self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs)
else:
self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs)
self.freq_bands = self.freq_bands.numpy().tolist()
def forward(self, input, max_level=None, **kwargs):
if max_level is None:
max_level = self.N_freqs
else:
max_level = int(max_level * self.N_freqs)
out = []
if self.include_input:
out.append(input)
for i in range(max_level):
freq = self.freq_bands[i]
for p_fn in self.periodic_fns:
out.append(p_fn(input * freq))
# append 0
if self.N_freqs - max_level > 0:
out.append(torch.zeros(*input.shape[:-1], (self.N_freqs - max_level) * 2 * input.shape[-1], device=input.device, dtype=input.dtype))
out = torch.cat(out, dim=-1)
return out
def get_encoder(encoding, input_dim=3,
multires=6,
degree=4,
num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, interpolation='linear',
**kwargs):
if encoding == 'None':
return lambda x, **kwargs: x, input_dim
elif encoding == 'frequency_torch':
encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)
elif encoding == 'frequency': # CUDA implementation, faster than torch.
from freqencoder import FreqEncoder
encoder = FreqEncoder(input_dim=input_dim, degree=multires)
elif encoding == 'sphere_harmonics':
from shencoder import SHEncoder
encoder = SHEncoder(input_dim=input_dim, degree=degree)
elif encoding == 'hashgrid':
from gridencoder import GridEncoder
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation)
elif encoding == 'tiledgrid':
from gridencoder import GridEncoder
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation)
elif encoding == 'hashgrid_taichi':
from taichi_modules.hash_encoder import HashEncoderTaichi
encoder = HashEncoderTaichi(batch_size=4096) #TODO: hard encoded batch size
else:
raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
return encoder, encoder.output_dim
================================================
FILE: evaluation/Prompt.py
================================================
import textwrap
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification
from transformers import pipeline
import argparse
import sys
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
#python Prompt.py --text "a dog is in front of a rabbit" --model vlt5
if __name__ == '__main__':
# Mimic the calling part of the main, using
parser = argparse.ArgumentParser()
parser.add_argument('--text', default="", type=str, help="text prompt")
#parser.add_argument('--workspace', default="trial", type=str, help="workspace")
parser.add_argument('--model', default='vlt5', type=str, help="model choices - vlt5, bert, XLNet")
opt = parser.parse_args()
if opt.model == "vlt5":
tokenizer = AutoTokenizer.from_pretrained("Voicelab/vlt5-base-keywords")
model = AutoModelForSeq2SeqLM.from_pretrained("Voicelab/vlt5-base-keywords")
task_prefix = "Keywords: "
inputs = [
opt.text
]
for sample in inputs:
input_sequences = [task_prefix + sample]
input_ids = tokenizer(
input_sequences, return_tensors="pt", truncation=True
).input_ids
output = model.generate(input_ids, no_repeat_ngram_size=3, num_beams=4)
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
#print(sample, "\n --->", output_text)
elif opt.model == "bert":
tokenizer = AutoTokenizer.from_pretrained("yanekyuk/bert-uncased-keyword-extractor")
model = AutoModelForTokenClassification.from_pretrained("yanekyuk/bert-uncased-keyword-extractor")
text = opt.text
input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
# Classify tokens
outputs = model(input_ids)
predictions = outputs.logits.detach().numpy()[0]
labels = predictions.argmax(axis=1)
labels = labels[1:-1]
print(labels)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
tokens = tokens[1:-1]
output_tokens = [tokens[i] for i in range(len(tokens)) if labels[i] != 0]
output_text = tokenizer.convert_tokens_to_string(output_tokens)
#print(output_text)
elif opt.model == "XLNet":
tokenizer = AutoTokenizer.from_pretrained("jasminejwebb/KeywordIdentifier")
model = AutoModelForTokenClassification.from_pretrained("jasminejwebb/KeywordIdentifier")
text = opt.text
input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
# Classify tokens
outputs = model(input_ids)
predictions = outputs.logits.detach().numpy()[0]
labels = predictions.argmax(axis=1)
labels = labels[1:-1]
print(labels)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
tokens = tokens[1:-1]
output_tokens = [tokens[i] for i in range(len(tokens)) if labels[i] != 0]
output_text = tokenizer.convert_tokens_to_string(output_tokens)
#print(output_text)
wrapped_text = textwrap.fill(output_text, width=50)
print('+' + '-'*52 + '+')
for line in wrapped_text.split('\n'):
print('| {} |'.format(line.ljust(50)))
print('+' + '-'*52 + '+')
#print(result)
================================================
FILE: evaluation/mesh_to_video.py
================================================
import os
import numpy as np
import trimesh
import argparse
from pathlib import Path
from tqdm import tqdm
import pyvista as pv
def render_video(anim_mesh):
center = anim_mesh.center_mass
plotter = pv.Plotter(off_screen=True)
plotter.add_mesh(anim_mesh)
radius = 10
n_frames = 360
angle_step = 2 * np.pi / n_frames
for i in tqdm(range(n_frames)):
camera_pos = [center[0] + radius * np.cos(i*angle_step),center[1] + radius *np.sin(i*angle_step),center[2]]
plotter.camera_position = (camera_pos, center, (0, 0, 1))
plotter.show(screenshot=f'frame_{i}.png', auto_close=False)
plotter.close()
os.system('ffmpeg -r 30 -f image2 -s 1920x1080 -i "result/frame_%d.png" -vcodec libx264 -crf 25 -pix_fmt yuv420p result/output.mp4')
def generate_mesh(obj1,obj2,transform_vector):
# Read 2 objects
filename1 = obj1 # Central Object
filename2 = obj2 # Surrounding Object
mesh1 = trimesh.load_mesh(filename1)
mesh2 = trimesh.load_mesh(filename2)
extents1 = mesh1.extents
extents2 = mesh1.extents
radius1 = sum(extents1) / 3.0
radius2 = sum(extents2) / 3.0
center1 = mesh1.center_mass
center2 = mesh2.center_mass
# Move
T1 = -center1
new =[]
for i in transform_vector:
try:
new.append(float(i))*radius1
except:
pass
transform_vector = new
print(T1, transform_vector, radius1)
T2 = -center2 + transform_vector
# Transform
mesh1.apply_translation(T1)
mesh2.apply_translation(T2)
# merge mesh
merged_mesh = trimesh.util.concatenate((mesh1, mesh2))
# save mesh
merged_mesh.export('merged_mesh.obj')
print("----> merge mesh done")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate rotating mesh animation.')
parser.add_argument('--center_obj', type=str, help='Input OBJ1 file.')
parser.add_argument('--surround_obj', type=str, help='Input OBJ2 file.')
parser.add_argument('--transform_vector', help='Transform_vector.')
parser.add_argument('--output_file', type=str, default="result/Demo.mp4", help='Output MP4 file.')
parser.add_argument('--num_frames', type=int, default=100, help='Number of frames to render.')
args = parser.parse_args()
#mesh = obj.Obj("wr.obj")
generate_mesh(args.center_obj,args.surround_obj,args.transform_vector)
input_file = Path("merged_mesh.obj")
output_file = Path(args.output_file)
out_dir = output_file.parent.joinpath('frames')
out_dir.mkdir(parents=True, exist_ok=True)
anim_mesh = trimesh.load_mesh(str(input_file))
render_video(anim_mesh)
================================================
FILE: evaluation/r_precision.py
================================================
from sentence_transformers import SentenceTransformer, util
from PIL import Image
import argparse
import sys
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--text', default="", type=str, help="text prompt")
parser.add_argument('--workspace', default="trial", type=str, help="text prompt")
parser.add_argument('--latest', default='ep0001', type=str, help="which epoch result you want to use for image path")
parser.add_argument('--mode', default='rgb', type=str, help="mode of result, color(rgb) or textureless()")
parser.add_argument('--clip', default="clip-ViT-B-32", type=str, help="CLIP model to encode the img and prompt")
opt = parser.parse_args()
#Load CLIP model
model = SentenceTransformer(f'{opt.clip}')
#Encode an image:
img_emb = model.encode(Image.open(f'../results/{opt.workspace}/validation/df_{opt.latest}_0005_{opt.mode}.png'))
#Encode text descriptions
text_emb = model.encode([f'{opt.text}'])
#Compute cosine similarities
cos_scores = util.cos_sim(img_emb, text_emb)
print("The final CLIP R-Precision is:", cos_scores[0][0].cpu().numpy())
================================================
FILE: evaluation/readme.md
================================================
### Improvement:
- Usage
- r_precision.py <br>
For prompt seperation <br>
--text is for the prompt following the author of stable dream fusion <br>
--workspace is the workspace folder which will be created for every prompt fed into stable dreamfusion <br>
--latest is which ckpt is used. Stable dream fusion record every epoch data. Normally is ep0100 unless the training is not finished or we further extend the training <br>
--mode has choices of rgb and depth which is correspondent to color and texture result as original paper Figure 5: Qualitative comparison with baselines. <br>
--clip has choices of clip-ViT-B-32, CLIP B/16, CLIP L/14, same as original paper <br>
```bash
python Prompt.py --text "matte painting of a castle made of cheesecake surrounded by a moat made of ice cream" --workspace ../castle --latest ep0100 --mode rgb --clip clip-ViT-B-32
```
- Prompt.py (model name case sensitive) <br>
For prompt seperation <br> <br>
--text is for the prompt following the author of stable dream fusion <br>
--model is for choose the pretrain models <br>
```bash
python Prompt.py --text "a dog is in front of a rabbit" --model vlt5
python Prompt.py --text "a dog is in front of a rabbit" --model bert
python Prompt.py --text "a dog is in front of a rabbit" --model XLNet
```
- mesh_to_video.py <br>
--center_obj IS THE CENTER OBJECT <br>
--surround_obj IS THE SURROUNDING OBJECT SUBJECT TO CHANGE <br>
--transform_vector THE X Y Z 3d vector for transform <br>
```bash
python mesh_to_video.py --center_obj 'mesh_whiterabbit/mesh.obj' --surround_obj 'mesh_snake/mesh.obj' --transform_vector [1,0,0]
```
================================================
FILE: freqencoder/__init__.py
================================================
from .freq import FreqEncoder
================================================
FILE: freqencoder/backend.py
================================================
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
'-use_fast_math'
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
_backend = load(name='_freqencoder',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'freqencoder.cu',
'bindings.cpp',
]],
)
__all__ = ['_backend']
================================================
FILE: freqencoder/freq.py
================================================
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _freqencoder as _backend
except ImportError:
from .backend import _backend
class _freq_encoder(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
def forward(ctx, inputs, degree, output_dim):
# inputs: [B, input_dim], float
# RETURN: [B, F], float
if not inputs.is_cuda: inputs = inputs.cuda()
inputs = inputs.contiguous()
B, input_dim = inputs.shape # batch size, coord dim
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
_backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
ctx.save_for_backward(inputs, outputs)
ctx.dims = [B, input_dim, degree, output_dim]
return outputs
@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
# grad: [B, C * C]
grad = grad.contiguous()
inputs, outputs = ctx.saved_tensors
B, input_dim, degree, output_dim = ctx.dims
grad_inputs = torch.zeros_like(inputs)
_backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
return grad_inputs, None, None
freq_encode = _freq_encoder.apply
class FreqEncoder(nn.Module):
def __init__(self, input_dim=3, degree=4):
super().__init__()
self.input_dim = input_dim
self.degree = degree
self.output_dim = input_dim + input_dim * 2 * degree
def __repr__(self):
return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
def forward(self, inputs, **kwargs):
# inputs: [..., input_dim]
# return: [..., ]
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.reshape(-1, self.input_dim)
outputs = freq_encode(inputs, self.degree, self.output_dim)
outputs = outputs.reshape(prefix_shape + [self.output_dim])
return outputs
================================================
FILE: freqencoder/setup.py
================================================
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
'-use_fast_math'
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
setup(
name='freqencoder', # package name, import this to use python API
ext_modules=[
CUDAExtension(
name='_freqencoder', # extension name, import this to use CUDA API
sources=[os.path.join(_src_path, 'src', f) for f in [
'freqencoder.cu',
'bindings.cpp',
]],
extra_compile_args={
'cxx': c_flags,
'nvcc': nvcc_flags,
}
),
],
cmdclass={
'build_ext': BuildExtension,
}
)
================================================
FILE: freqencoder/src/bindings.cpp
================================================
#include <torch/extension.h>
#include "freqencoder.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
}
================================================
FILE: freqencoder/src/freqencoder.cu
================================================
#include <stdint.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/torch.h>
#include <algorithm>
#include <stdexcept>
#include <cstdio>
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
inline constexpr __device__ float PI() { return 3.141592653589793f; }
template <typename T>
__host__ __device__ T div_round_up(T val, T divisor) {
return (val + divisor - 1) / divisor;
}
// inputs: [B, D]
// outputs: [B, C], C = D + D * deg * 2
__global__ void kernel_freq(
const float * __restrict__ inputs,
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
float * outputs
) {
// parallel on per-element
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * C) return;
// get index
const uint32_t b = t / C;
const uint32_t c = t - b * C; // t % C;
// locate
inputs += b * D;
outputs += t;
// write self
if (c < D) {
outputs[0] = inputs[c];
// write freq
} else {
const uint32_t col = c / D - 1;
const uint32_t d = c % D;
const uint32_t freq = col / 2;
const float phase_shift = (col % 2) * (PI() / 2);
outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
}
}
// grad: [B, C], C = D + D * deg * 2
// outputs: [B, C]
// grad_inputs: [B, D]
__global__ void kernel_freq_backward(
const float * __restrict__ grad,
const float * __restrict__ outputs,
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
float * grad_inputs
) {
// parallel on per-element
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * D) return;
const uint32_t b = t / D;
const uint32_t d = t - b * D; // t % D;
// locate
grad += b * C;
outputs += b * C;
grad_inputs += t;
// register
float result = grad[d];
grad += D;
outputs += D;
for (uint32_t f = 0; f < deg; f++) {
result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
grad += 2 * D;
outputs += 2 * D;
}
// write
grad_inputs[0] = result;
}
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
CHECK_CUDA(inputs);
CHECK_CUDA(outputs);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(outputs);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(outputs);
static constexpr uint32_t N_THREADS = 128;
kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
}
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
CHECK_CUDA(grad);
CHECK_CUDA(outputs);
CHECK_CUDA(grad_inputs);
CHECK_CONTIGUOUS(grad);
CHECK_CONTIGUOUS(outputs);
CHECK_CONTIGUOUS(grad_inputs);
CHECK_IS_FLOATING(grad);
CHECK_IS_FLOATING(outputs);
CHECK_IS_FLOATING(grad_inputs);
static constexpr uint32_t N_THREADS = 128;
kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
}
================================================
FILE: freqencoder/src/freqencoder.h
================================================
# pragma once
#include <stdint.h>
#include <torch/torch.h>
// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
================================================
FILE: gridencoder/__init__.py
================================================
from .grid import GridEncoder
================================================
FILE: gridencoder/backend.py
================================================
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
_backend = load(name='_grid_encoder',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'gridencoder.cu',
'bindings.cpp',
]],
)
__all__ = ['_backend']
================================================
FILE: gridencoder/grid.py
================================================
import math
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _gridencoder as _backend
except ImportError:
from .backend import _backend
_gridtype_to_id = {
'hash': 0,
'tiled': 1,
}
_interp_to_id = {
'linear': 0,
'smoothstep': 1,
}
class _grid_encode(Function):
@staticmethod
@custom_fwd
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0, max_level=None):
# inputs: [B, D], float in [0, 1]
# embeddings: [sO, C], float
# offsets: [L + 1], int
# RETURN: [B, F], float
inputs = inputs.contiguous()
B, D = inputs.shape # batch size, coord dim
L = offsets.shape[0] - 1 # level
C = embeddings.shape[1] # embedding dim for each level
S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
H = base_resolution # base resolution
max_level = L if max_level is None else max(min(int(math.ceil(max_level * L)), L), 1)
# manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
# if C % 2 != 0, force float, since half for atomicAdd is very slow.
if torch.is_autocast_enabled() and C % 2 == 0:
embeddings = embeddings.to(torch.half)
# L first, optimize cache for cuda kernel, but needs an extra permute later
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
# zero init if we only calculate partial levels
if max_level < L: outputs.zero_()
if calc_grad_inputs:
dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
if max_level < L: dy_dx.zero_()
else:
dy_dx = None
_backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interpolation)
# permute back to [B, L * C]
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
ctx.dims = [B, D, C, L, S, H, gridtype, interpolation, max_level]
ctx.align_corners = align_corners
return outputs
@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
B, D, C, L, S, H, gridtype, interpolation, max_level = ctx.dims
align_corners = ctx.align_corners
# grad: [B, L * C] --> [L, B, C]
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
grad_embeddings = torch.zeros_like(embeddings)
if dy_dx is not None:
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
else:
grad_inputs = None
_backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation)
if dy_dx is not None:
grad_inputs = grad_inputs.to(inputs.dtype)
return grad_inputs, grad_embeddings, None, None, None, None, None, None, None, None
grid_encode = _grid_encode.apply
class GridEncoder(nn.Module):
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'):
super().__init__()
# the finest resolution desired at the last level, if provided, overridee per_level_scale
if desired_resolution is not None:
per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
self.input_dim = input_dim # coord dims, 2 or 3
self.num_levels = num_levels # num levels, each level multiply resolution by 2
self.level_dim = level_dim # encode channels per level
self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
self.log2_hashmap_size = log2_hashmap_size
self.base_resolution = base_resolution
self.output_dim = num_levels * level_dim
self.gridtype = gridtype
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
self.interpolation = interpolation
self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep"
self.align_corners = align_corners
# allocate parameters
offsets = []
offset = 0
self.max_params = 2 ** log2_hashmap_size
for i in range(num_levels):
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
params_in_level = min(self.max_params, (resolution) ** input_dim) # limit max number
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
offsets.append(offset)
offset += params_in_level
offsets.append(offset)
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
self.register_buffer('offsets', offsets)
self.n_params = offsets[-1] * level_dim
# parameters
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
self.reset_parameters()
def reset_parameters(self):
std = 1e-4
self.embeddings.data.uniform_(-std, std)
def __repr__(self):
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}"
def forward(self, inputs, bound=1, max_level=None):
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
# max_level: only calculate first max_level levels (None will use all levels)
# return: [..., num_levels * level_dim]
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.view(-1, self.input_dim)
outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id, max_level)
outputs = outputs.view(prefix_shape + [self.output_dim])
#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
return outputs
# always run in float precision!
@torch.cuda.amp.autocast(enabled=False)
def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000):
# inputs: [..., input_dim], float in [-b, b], location to calculate TV loss.
D = self.input_dim
C = self.embeddings.shape[1] # embedding dim for each level
L = self.offsets.shape[0] - 1 # level
S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
H = self.base_resolution # base resolution
if inputs is None:
# randomized in [0, 1]
inputs = torch.rand(B, self.input_dim, device=self.embeddings.device)
else:
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
inputs = inputs.view(-1, self.input_dim)
B = inputs.shape[0]
if self.embeddings.grad is None:
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
_backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners)
@torch.cuda.amp.autocast(enabled=False)
def grad_weight_decay(self, weight=0.1):
# level-wise meaned weight decay (ref: zip-nerf)
B = self.embeddings.shape[0] # size of embedding
C = self.embeddings.shape[1] # embedding dim for each level
L = self.offsets.shape[0] - 1 # level
if self.embeddings.grad is None:
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
_backend.grad_weight_decay(self.embeddings, self.embeddings.grad, self.offsets, weight, B, C, L)
================================================
FILE: gridencoder/setup.py
================================================
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
setup(
name='gridencoder', # package name, import this to use python API
ext_modules=[
CUDAExtension(
name='_gridencoder', # extension name, import this to use CUDA API
sources=[os.path.join(_src_path, 'src', f) for f in [
'gridencoder.cu',
'bindings.cpp',
]],
extra_compile_args={
'cxx': c_flags,
'nvcc': nvcc_flags,
}
),
],
cmdclass={
'build_ext': BuildExtension,
}
)
================================================
FILE: gridencoder/src/bindings.cpp
================================================
#include <torch/extension.h>
#include "gridencoder.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)");
m.def("grad_weight_decay", &grad_weight_decay, "grad_weight_decay (CUDA)");
}
================================================
FILE: gridencoder/src/gridencoder.cu
================================================
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/torch.h>
#include <algorithm>
#include <stdexcept>
#include <stdint.h>
#include <cstdio>
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here!
__device__ inline at::Half atomicAdd(at::Half *address, at::Half val) {
// requires CUDA >= 10 and ARCH >= 70
// this is very slow compared to float or __half2, never use it.
//return atomicAdd(reinterpret_cast<__half*>(address), val);
}
template <typename T>
__host__ __device__ inline T div_round_up(T val, T divisor) {
return (val + divisor - 1) / divisor;
}
template <typename T>
__device__ inline T smoothstep(T val) {
return val*val*(3.0f - 2.0f * val);
}
template <typename T>
__device__ inline T smoothstep_derivative(T val) {
return 6*val*(1.0f - val);
}
template <uint32_t D>
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
// coherent type of hashing
constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u };
uint32_t result = 0;
#pragma unroll
for (uint32_t i = 0; i < D; ++i) {
result ^= pos_grid[i] * primes[i];
}
return result;
}
template <uint32_t D, uint32_t C>
__device__ uint32_t get_grid_index(const uint32_t gridtype, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
uint32_t stride = 1;
uint32_t index = 0;
#pragma unroll
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
index += pos_grid[d] * stride;
stride *= resolution;
}
// NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
// gridtype: 0 == hash, 1 == tiled
if (gridtype == 0 && stride > hashmap_size) {
index = fast_hash<D>(pos_grid);
}
return (index % hashmap_size) * C + ch;
}
template <typename scalar_t, uint32_t D, uint32_t C>
__global__ void kernel_grid(
const float * __restrict__ inputs,
const scalar_t * __restrict__ grid,
const int * __restrict__ offsets,
scalar_t * __restrict__ outputs,
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
scalar_t * __restrict__ dy_dx,
const uint32_t gridtype,
const bool align_corners,
const uint32_t interp
) {
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= B) return;
const uint32_t level = blockIdx.y;
// locate
grid += (uint32_t)offsets[level] * C;
inputs += b * D;
outputs += level * B * C + b * C;
// check input range (should be in [0, 1])
bool flag_oob = false;
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if (inputs[d] < 0 || inputs[d] > 1) {
flag_oob = true;
}
}
// if input out of bound, just set output to 0
if (flag_oob) {
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
outputs[ch] = 0;
}
if (dy_dx) {
dy_dx += b * D * L * C + level * D * C; // B L D C
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
dy_dx[d * C + ch] = 0;
}
}
}
return;
}
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
// calculate coordinate (always use float for precision!)
float pos[D];
float pos_deriv[D];
uint32_t pos_grid[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
// align_corners
if (align_corners) {
pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
} else {
pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
}
pos[d] -= (float)pos_grid[d];
// smoothstep instead of linear
if (interp == 1) {
pos_deriv[d] = smoothstep_derivative(pos[d]);
pos[d] = smoothstep(pos[d]);
} else {
pos_deriv[d] = 1.0f;
}
}
// verification of alignment
// if (level == L - 1 && b < 4) {
// printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
// }
// interpolate
scalar_t results[C] = {0}; // temp results in register
#pragma unroll
for (uint32_t idx = 0; idx < (1 << D); idx++) {
float w = 1;
uint32_t pos_grid_local[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if ((idx & (1 << d)) == 0) {
w *= 1 - pos[d];
pos_grid_local[d] = pos_grid[d];
} else {
w *= pos[d];
pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
}
}
uint32_t index = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid_local);
// writing to register (fast)
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
results[ch] += w * grid[index + ch];
}
//printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
}
// writing to global memory (slow)
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
outputs[ch] = results[ch];
}
// prepare dy_dx
// differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
if (dy_dx) {
dy_dx += b * D * L * C + level * D * C; // B L D C
#pragma unroll
for (uint32_t gd = 0; gd < D; gd++) {
scalar_t results_grad[C] = {0};
#pragma unroll
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
float w = (float)(align_corners ? resolution - 1 : resolution);
uint32_t pos_grid_local[D];
#pragma unroll
for (uint32_t nd = 0; nd < D - 1; nd++) {
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
if ((idx & (1 << nd)) == 0) {
w *= 1 - pos[d];
pos_grid_local[d] = pos_grid[d];
} else {
w *= pos[d];
pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
}
}
pos_grid_local[gd] = pos_grid[gd];
uint32_t index_left = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid_local);
pos_grid_local[gd] = min(pos_grid[gd] + 1, resolution - 1);
uint32_t index_right = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid_local);
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd];
}
}
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
dy_dx[gd * C + ch] = results_grad[ch];
}
}
}
}
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
__global__ void kernel_grid_backward(
const scalar_t * __restrict__ grad,
const float * __restrict__ inputs,
const scalar_t * __restrict__ grid,
const int * __restrict__ offsets,
scalar_t * __restrict__ grad_grid,
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
const uint32_t gridtype,
const bool align_corners,
const uint32_t interp
) {
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
if (b >= B) return;
const uint32_t level = blockIdx.y;
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
// locate
grad_grid += offsets[level] * C;
inputs += b * D;
grad += level * B * C + b * C + ch; // L, B, C
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
// check input range (should be in [0, 1])
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if (inputs[d] < 0 || inputs[d] > 1) {
return; // grad is init as 0, so we simply return.
}
}
// calculate coordinate
float pos[D];
uint32_t pos_grid[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
// align_corners
if (align_corners) {
pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
} else {
pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
}
pos[d] -= (float)pos_grid[d];
// smoothstep instead of linear
if (interp == 1) {
pos[d] = smoothstep(pos[d]);
}
}
scalar_t grad_cur[N_C] = {0}; // fetch to register
#pragma unroll
for (uint32_t c = 0; c < N_C; c++) {
grad_cur[c] = grad[c];
}
// interpolate
#pragma unroll
for (uint32_t idx = 0; idx < (1 << D); idx++) {
float w = 1;
uint32_t pos_grid_local[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if ((idx & (1 << d)) == 0) {
w *= 1 - pos[d];
pos_grid_local[d] = pos_grid[d];
} else {
w *= pos[d];
pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
}
}
uint32_t index = get_grid_index<D, C>(gridtype, ch, hashmap_size, resolution, pos_grid_local);
// atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
// TODO: use float which is better than __half, if N_C % 2 != 0
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
#pragma unroll
for (uint32_t c = 0; c < N_C; c += 2) {
// process two __half at once (by interpreting as a __half2)
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
atomicAdd((__half2*)&grad_grid[index + c], v);
}
// float, or __half when N_C % 2 != 0 (which means C == 1)
} else {
#pragma unroll
for (uint32_t c = 0; c < N_C; c++) {
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
}
}
}
}
template <typename scalar_t, uint32_t D, uint32_t C>
__global__ void kernel_input_backward(
const scalar_t * __restrict__ grad,
const scalar_t * __restrict__ dy_dx,
scalar_t * __restrict__ grad_inputs,
uint32_t B, uint32_t L
) {
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * D) return;
const uint32_t b = t / D;
const uint32_t d = t - b * D;
dy_dx += b * L * D * C;
scalar_t result = 0;
# pragma unroll
for (int l = 0; l < L; l++) {
# pragma unroll
for (int ch = 0; ch < C; ch++) {
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
}
}
grad_inputs[t] = result;
}
template <typename scalar_t, uint32_t D>
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
static constexpr uint32_t N_THREAD = 512;
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), max_level, 1 };
switch (C) {
case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
case 16: kernel_grid<scalar_t, D, 16><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
case 32: kernel_grid<scalar_t, D, 32><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
}
}
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
// H: base resolution
// dy_dx: [B, L * D * C]
template <typename scalar_t>
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
switch (D) {
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4 or 5."};
}
}
template <typename scalar_t, uint32_t D>
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
static constexpr uint32_t N_THREAD = 256;
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), max_level, 1 };
switch (C) {
case 1:
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
case 2:
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
case 4:
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
case 8:
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
case 16:
kernel_grid_backward<scalar_t, D, 16, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
if (dy_dx) kernel_input_backward<scalar_t, D, 16><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
case 32:
kernel_grid_backward<scalar_t, D, 32, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
if (dy_dx) kernel_input_backward<scalar_t, D, 32><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
}
}
// grad: [L, B, C], float
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// grad_embeddings: [sO, C]
// H: base resolution
template <typename scalar_t>
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
switch (D) {
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4 or 5."};
}
}
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
CHECK_CUDA(inputs);
CHECK_CUDA(embeddings);
CHECK_CUDA(offsets);
CHECK_CUDA(outputs);
// CHECK_CUDA(dy_dx);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(embeddings);
CHECK_CONTIGUOUS(offsets);
CHECK_CONTIGUOUS(outputs);
// CHECK_CONTIGUOUS(dy_dx);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(embeddings);
CHECK_IS_INT(offsets);
CHECK_IS_FLOATING(outputs);
// CHECK_IS_FLOATING(dy_dx);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
embeddings.scalar_type(), "grid_encode_forward", ([&] {
grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp);
}));
}
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
CHECK_CUDA(grad);
CHECK_CUDA(inputs);
CHECK_CUDA(embeddings);
CHECK_CUDA(offsets);
CHECK_CUDA(grad_embeddings);
// CHECK_CUDA(dy_dx);
// CHECK_CUDA(grad_inputs);
CHECK_CONTIGUOUS(grad);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(embeddings);
CHECK_CONTIGUOUS(offsets);
CHECK_CONTIGUOUS(grad_embeddings);
// CHECK_CONTIGUOUS(dy_dx);
// CHECK_CONTIGUOUS(grad_inputs);
CHECK_IS_FLOATING(grad);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(embeddings);
CHECK_IS_INT(offsets);
CHECK_IS_FLOATING(grad_embeddings);
// CHECK_IS_FLOATING(dy_dx);
// CHECK_IS_FLOATING(grad_inputs);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "grid_encode_backward", ([&] {
grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp);
}));
}
template <typename scalar_t, uint32_t D, uint32_t C>
__global__ void kernel_grad_tv(
const scalar_t * __restrict__ inputs,
const scalar_t * __restrict__ grid,
scalar_t * __restrict__ grad,
const int * __restrict__ offsets,
const float weight,
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
const uint32_t gridtype,
const bool align_corners
) {
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= B) return;
const uint32_t level = blockIdx.y;
// locate
inputs += b * D;
grid += (uint32_t)offsets[level] * C;
grad += (uint32_t)offsets[level] * C;
// check input range (should be in [0, 1])
bool flag_oob = false;
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if (inputs[d] < 0 || inputs[d] > 1) {
flag_oob = true;
}
}
// if input out of bound, do nothing
if (flag_oob) return;
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
// calculate coordinate
float pos[D];
uint32_t pos_grid[D]; // [0, resolution]
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
// align_corners
if (align_corners) {
pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
} else {
pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
}
}
//printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
// total variation on pos_grid
scalar_t results[C] = {0}; // temp results in register
scalar_t idelta[C] = {0};
uint32_t index = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid);
scalar_t w = weight / (2 * D);
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
uint32_t cur_d = pos_grid[d];
scalar_t grad_val;
// right side
if (cur_d < resolution) {
pos_grid[d] = cur_d + 1;
uint32_t index_right = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid);
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
grad_val = (grid[index + ch] - grid[index_right + ch]);
results[ch] += grad_val;
idelta[ch] += grad_val * grad_val;
}
}
// left side
if (cur_d > 0) {
pos_grid[d] = cur_d - 1;
uint32_t index_left = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid);
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
grad_val = (grid[index + ch] - grid[index_left + ch]);
results[ch] += grad_val;
idelta[ch] += grad_val * grad_val;
}
}
// reset
pos_grid[d] = cur_d;
}
// writing to global memory (slow)
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
// index may collide, so use atomic!
atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f));
}
}
template <typename scalar_t, uint32_t D>
void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
static constexpr uint32_t N_THREAD = 512;
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
switch (C) {
case 1: kernel_grad_tv<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
case 2: kernel_grad_tv<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
case 4: kernel_grad_tv<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
case 8: kernel_grad_tv<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
case 16: kernel_grad_tv<scalar_t, D, 16><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
case 32: kernel_grad_tv<scalar_t, D, 32><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
}
}
template <typename scalar_t>
void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
switch (D) {
case 2: kernel_grad_tv_wrapper<scalar_t, 2>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
case 3: kernel_grad_tv_wrapper<scalar_t, 3>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
case 4: kernel_grad_tv_wrapper<scalar_t, 4>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
case 5: kernel_grad_tv_wrapper<scalar_t, 5>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4, or 5."};
}
}
void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
embeddings.scalar_type(), "grad_total_variation", ([&] {
grad_total_variation_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), embeddings.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), offsets.data_ptr<int>(), weight, B, D, C, L, S, H, gridtype, align_corners);
}));
}
template <typename scalar_t>
__global__ void kernel_grad_wd(
const scalar_t * __restrict__ grid,
scalar_t * __restrict__ grad,
const int * __restrict__ offsets,
const float weight,
const uint32_t B, const uint32_t L, const uint32_t C
) {
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= B * C) return;
// locate
grid += b;
grad += b;
// decide in which level is this thread...
uint32_t level = 0;
const uint32_t n = b / C;
// binary search b in offsets
uint32_t l = 0, r = L;
while (l < r) {
uint32_t m = (l + r) / 2;
if (offsets[m] <= n) {
level = m;
l = m + 1;
} else {
r = m;
}
}
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
grad[0] += 2 * weight * grid[0] / hashmap_size;
}
void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
embeddings.scalar_type(), "grad_weight_decay", ([&] {
static constexpr uint32_t N_THREAD = 1024;
const dim3 blocks_hashgrid = { div_round_up(B * C, N_THREAD), 1, 1 };
kernel_grad_wd<scalar_t><<<blocks_hashgrid, N_THREAD>>>(embeddings.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), offsets.data_ptr<int>(), weight, B, L, C);
}));
}
================================================
FILE: gridencoder/src/gridencoder.h
================================================
#ifndef _HASH_ENCODE_H
#define _HASH_ENCODE_H
#include <stdint.h>
#include <torch/torch.h>
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// outputs: [B, L * C], float
// H: base resolution
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners);
void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L);
#endif
================================================
FILE: guidance/clip_utils.py
================================================
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import clip
class CLIP(nn.Module):
def __init__(self, device, **kwargs):
super().__init__()
self.device = device
self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False)
self.aug = T.Compose([
T.Resize((224, 224)),
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def get_text_embeds(self, prompt, **kwargs):
text = clip.tokenize(prompt).to(self.device)
text_z = self.clip_model.encode_text(text)
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
return text_z
def get_img_embeds(self, image, **kwargs):
image_z = self.clip_model.encode_image(self.aug(image))
image_z = image_z / image_z.norm(dim=-1, keepdim=True)
return image_z
def train_step(self, clip_z, pred_rgb, grad_scale=10, **kwargs):
"""
Args:
grad_scale: scalar or 1-tensor of size [B], i.e. 1 grad_scale per batch item.
"""
# TODO: resize the image from NeRF-rendered resolution (e.g. 128x128) to what CLIP expects (512x512), to prevent Pytorch warning about `antialias=None`.
image_z = self.clip_model.encode_image(self.aug(pred_rgb))
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
loss = 0
if 'image' in clip_z:
loss -= ((image_z * clip_z['image']).sum(-1) * grad_scale).mean()
if 'text' in clip_z:
loss -= ((image_z * clip_z['text']).sum(-1) * grad_scale).mean()
return loss
================================================
FILE: guidance/if_utils.py
================================================
from transformers import logging
from diffusers import IFPipeline, DDPMScheduler
# suppress partial model loading warning
logging.set_verbosity_error()
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from .perpneg_utils import weighted_perpendicular_aggregator
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = True
class IF(nn.Module):
def __init__(self, device, vram_O, t_range=[0.02, 0.98]):
super().__init__()
self.device = device
print(f'[INFO] loading DeepFloyd IF-I-XL...')
model_key = "DeepFloyd/IF-I-XL-v1.0"
is_torch2 = torch.__version__[0] == '2'
# Create model
pipe = IFPipeline.from_pretrained(model_key, variant="fp16", torch_dtype=torch.float16)
if not is_torch2:
pipe.enable_xformers_memory_efficient_attention()
if vram_O:
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
pipe.enable_model_cpu_offload()
else:
pipe.to(device)
self.unet = pipe.unet
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.unet = pipe.unet
self.scheduler = pipe.scheduler
self.pipe = pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
print(f'[INFO] loaded DeepFloyd IF-I-XL!')
@torch.no_grad()
def get_text_embeds(self, prompt):
# prompt: [str]
# TODO: should I add the preprocessing at https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py#LL486C10-L486C28
prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
inputs = self.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt')
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1):
# [0, 1] to [-1, 1] and make sure shape is [64, 64]
images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(images)
images_noisy = self.scheduler.add_noise(images, noise, t)
# pred noise
model_input = torch.cat([images_noisy] * 2)
model_input = self.scheduler.scale_model_input(model_input, t)
tt = torch.cat([t] * 2)
noise_pred = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# TODO: how to use the variance here?
# noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
# w(t), sigma_t^2
w = (1 - self.alphas[t])
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)
targets = (images - grad).detach()
loss = 0.5 * F.mse_loss(images.float(), targets, reduction='sum') / images.shape[0]
return loss
def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, grad_scale=1):
B = pred_rgb.shape[0]
K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts
# [0, 1] to [-1, 1] and make sure shape is [64, 64]
images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(images)
images_noisy = self.scheduler.add_noise(images, noise, t)
# pred noise
model_input = torch.cat([images_noisy] * (1 + K))
model_input = self.scheduler.scale_model_input(model_input, t)
tt = torch.cat([t] * (1 + K))
unet_output = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample
noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:]
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
# noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1)
noise_pred = noise_pred_uncond + guidance_scale * weighted_perpendicular_aggregator(delta_noise_preds, weights, B)
# w(t), sigma_t^2
w = (1 - self.alphas[t])
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)
targets = (images - grad).detach()
loss = 0.5 * F.mse_loss(images.float(), targets, reduction='sum') / images.shape[0]
return loss
@torch.no_grad()
def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps=50, guidance_scale=7.5):
images = torch.randn((1, 3, height, width), device=text_embeddings.device, dtype=text_embeddings.dtype)
images = images * self.scheduler.init_noise_sigma
self.scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(self.scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
model_input = torch.cat([images] * 2)
model_input = self.scheduler.scale_model_input(model_input, t)
# predict the noise residual
noise_pred = self.unet(model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
# compute the previous noisy sample x_t -> x_t-1
images = self.scheduler.step(noise_pred, t, images).prev_sample
images = (images + 1) / 2
return images
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
neg_embeds = self.get_text_embeds(negative_prompts)
text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
# Text embeds -> img
imgs = self.produce_imgs(text_embeds, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
# Img to Numpy
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
imgs = (imgs * 255).round().astype('uint8')
return imgs
if __name__ == '__main__':
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('prompt', type=str)
parser.add_argument('--negative', default='', type=str)
parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage")
parser.add_argument('-H', type=int, default=64)
parser.add_argument('-W', type=int, default=64)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--steps', type=int, default=50)
opt = parser.parse_args()
seed_everything(opt.seed)
device = torch.device('cuda')
sd = IF(device, opt.vram_O)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()
================================================
FILE: guidance/perpneg_utils.py
================================================
import torch
# Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm
def get_perpendicular_component(x, y):
assert x.shape == y.shape
return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y
def batch_get_perpendicular_component(x, y):
assert x.shape == y.shape
result = []
for i in range(x.shape[0]):
result.append(get_perpendicular_component(x[i], y[i]))
return torch.stack(result)
def weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_size):
"""
Notes:
- weights: an array with the weights for combining the noise predictions
- delta_noise_preds: [B x K, 4, 64, 64], K = max_prompts_per_dir
"""
delta_noise_preds = delta_noise_preds.split(batch_size, dim=0) # K x [B, 4, 64, 64]
weights = weights.split(batch_size, dim=0) # K x [B]
# print(f"{weights[0].shape = } {weights = }")
assert torch.all(weights[0] == 1.0)
main_positive = delta_noise_preds[0] # [B, 4, 64, 64]
accumulated_output = torch.zeros_like(main_positive)
for i, complementary_noise_pred in enumerate(delta_noise_preds[1:], start=1):
# print(f"\n{i = }, {weights[i] = }, {weights[i].shape = }\n")
idx_non_zero = torch.abs(weights[i]) > 1e-4
# print(f"{idx_non_zero.shape = }, {idx_non_zero = }")
# print(f"{weights[i][idx_non_zero].shape = }, {weights[i][idx_non_zero] = }")
# print(f"{complementary_noise_pred.shape = }, {complementary_noise_pred[idx_non_zero].shape = }")
# print(f"{main_positive.shape = }, {main_positive[idx_non_zero].shape = }")
if sum(idx_non_zero) == 0:
continue
accumulated_output[idx_non_zero] += weights[i][idx_non_zero].reshape(-1, 1, 1, 1) * batch_get_perpendicular_component(complementary_noise_pred[idx_non_zero], main_positive[idx_non_zero])
assert accumulated_output.shape == main_positive.shape, f"{accumulated_output.shape = }, {main_positive.shape = }"
return accumulated_output + main_positive
================================================
FILE: guidance/sd_utils.py
================================================
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline
from diffusers.utils.import_utils import is_xformers_available
from os.path import isfile
from pathlib import Path
# suppress partial model loading warning
logging.set_verbosity_error()
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image
from torch.cuda.amp import custom_bwd, custom_fwd
from .perpneg_utils import weighted_perpendicular_aggregator
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = True
class StableDiffusion(nn.Module):
def __init__(self, device, fp16, vram_O, sd_version='2.1', hf_key=None, t_range=[0.02, 0.98]):
super().__init__()
self.device = device
self.sd_version = sd_version
print(f'[INFO] loading stable diffusion...')
if hf_key is not None:
print(f'[INFO] using hugging face custom model key: {hf_key}')
model_key = hf_key
elif self.sd_version == '2.1':
model_key = "stabilityai/stable-diffusion-2-1-base"
elif self.sd_version == '2.0':
model_key = "stabilityai/stable-diffusion-2-base"
elif self.sd_version == '1.5':
model_key = "runwayml/stable-diffusion-v1-5"
else:
raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
self.precision_t = torch.float16 if fp16 else torch.float32
# Create model
pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.precision_t)
if vram_O:
pipe.enable_sequential_cpu_offload()
pipe.enable_vae_slicing()
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
# pipe.enable_model_cpu_offload()
else:
pipe.to(device)
self.vae = pipe.vae
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.unet = pipe.unet
self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler", torch_dtype=self.precision_t)
del pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
print(f'[INFO] loaded stable diffusion!')
@torch.no_grad()
def get_text_embeds(self, prompt):
# prompt: [str]
inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1,
save_guidance_path:Path=None):
if as_latent:
latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
else:
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
# encode image into latents with vae, requires grad!
latents = self.encode_imgs(pred_rgb_512)
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * 2)
tt = torch.cat([t] * 2)
noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample
# perform guidance (high scale from paper!)
noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
# import kiui
# latents_tmp = torch.randn((1, 4, 64, 64), device=self.device)
# latents_tmp = latents_tmp.detach()
# kiui.lo(latents_tmp)
# self.scheduler.set_timesteps(30)
# for i, t in enumerate(self.scheduler.timesteps):
# latent_model_input = torch.cat([latents_tmp] * 3)
# noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
# noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
# noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond)
# latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample']
# imgs = self.decode_latents(latents_tmp)
# kiui.vis.plot_image(imgs)
# w(t), sigma_t^2
w = (1 - self.alphas[t])
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)
if save_guidance_path:
with torch.no_grad():
if as_latent:
pred_rgb_512 = self.decode_latents(latents)
# visualize predicted denoised image
# The following block of code is equivalent to `predict_start_from_noise`...
# see zero123_utils.py's version for a simpler implementation.
alphas = self.scheduler.alphas.to(latents)
total_timesteps = self.max_step - self.min_step + 1
index = total_timesteps - t.to(latents.device) - 1
b = len(noise_pred)
a_t = alphas[index].reshape(b,1,1,1).to(self.device)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)
pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))
# visualize noisier image
result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))
# TODO: also denoise all-the-way
# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)
save_image(viz_images, save_guidance_path)
targets = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
return loss
def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1,
save_guidance_path:Path=None):
B = pred_rgb.shape[0]
K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts
if as_latent:
latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
else:
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
# encode image into latents with vae, requires grad!
latents = self.encode_imgs(pred_rgb_512)
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * (1 + K))
tt = torch.cat([t] * (1 + K))
unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample
# perform guidance (high scale from paper!)
noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:]
delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1)
noise_pred = noise_pred_uncond + guidance_scale * weighted_perpendicular_aggregator(delta_noise_preds, weights, B)
# import kiui
# latents_tmp = torch.randn((1, 4, 64, 64), device=self.device)
# latents_tmp = latents_tmp.detach()
# kiui.lo(latents_tmp)
# self.scheduler.set_timesteps(30)
# for i, t in enumerate(self.scheduler.timesteps):
# latent_model_input = torch.cat([latents_tmp] * 3)
# noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
# noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
# noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond)
# latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample']
# imgs = self.decode_latents(latents_tmp)
# kiui.vis.plot_image(imgs)
# w(t), sigma_t^2
w = (1 - self.alphas[t])
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)
if save_guidance_path:
with torch.no_grad():
if as_latent:
pred_rgb_512 = self.decode_latents(latents)
# visualize predicted denoised image
# The following block of code is equivalent to `predict_start_from_noise`...
# see zero123_utils.py's version for a simpler implementation.
alphas = self.scheduler.alphas.to(latents)
total_timesteps = self.max_step - self.min_step + 1
index = total_timesteps - t.to(latents.device) - 1
b = len(noise_pred)
a_t = alphas[index].reshape(b,1,1,1).to(self.device)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)
pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))
# visualize noisier image
result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))
# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)
save_image(viz_images, save_guidance_path)
targets = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
return loss
@torch.no_grad()
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
if latents is None:
latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
self.scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(self.scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
# perform guidance
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
return latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def encode_imgs(self, imgs):
# imgs: [B, 3, H, W]
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
neg_embeds = self.get_text_embeds(negative_prompts)
text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
# Text embeds -> img latents
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
# Img latents -> imgs
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
# Img to Numpy
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
imgs = (imgs * 255).round().astype('uint8')
return imgs
if __name__ == '__main__':
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('prompt', type=str)
parser.add_argument('--negative', default='', type=str)
parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
parser.add_argument('--fp16', action='store_true', help="use float16 for training")
parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage")
parser.add_argument('-H', type=int, default=512)
parser.add_argument('-W', type=int, default=512)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--steps', type=int, default=50)
opt = parser.parse_args()
seed_everything(opt.seed)
device = torch.device('cuda')
sd = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()
================================================
FILE: guidance/zero123_utils.py
================================================
import math
import numpy as np
from omegaconf import OmegaConf
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from torchvision.utils import save_image
from diffusers import DDIMScheduler
import sys
from os import path
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
from ldm.util import instantiate_from_config
# load model
def load_model_from_config(config, ckpt, device, vram_O=False, verbose=False):
pl_sd = torch.load(ckpt, map_location='cpu')
if 'global_step' in pl_sd and verbose:
print(f'[INFO] Global Step: {pl_sd["global_step"]}')
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print('[INFO] missing keys: \n', m)
if len(u) > 0 and verbose:
print('[INFO] unexpected keys: \n', u)
# manually load ema and delete it to save GPU memory
if model.use_ema:
if verbose:
print('[INFO] loading EMA...')
model.model_ema.copy_to(model.model)
del model.model_ema
if vram_O:
# we don't need decoder
del model.first_stage_model.decoder
torch.cuda.empty_cache()
model.eval().to(device)
return model
class Zero123(nn.Module):
def __init__(self, device, fp16,
config='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml',
ckpt='./pretrained/zero123/zero123-xl.ckpt', vram_O=False, t_range=[0.02, 0.98], opt=None):
super().__init__()
self.device = device
self.fp16 = fp16
self.vram_O = vram_O
self.t_range = t_range
self.opt = opt
self.config = OmegaConf.load(config)
# TODO: seems it cannot load into fp16...
self.model = load_model_from_config(self.config, ckpt, device=self.device, vram_O=vram_O)
# timesteps: use diffuser for convenience... hope it's alright.
self.num_train_timesteps = self.config.model.params.timesteps
self.scheduler = DDIMScheduler(
self.num_train_timesteps,
self.config.model.params.linear_start,
self.config.model.params.linear_end,
beta_schedule='scaled_linear',
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
@torch.no_grad()
def get_img_embeds(self, x):
# x: image tensor [B, 3, 256, 256] in [0, 1]
x = x * 2 - 1
c = [self.model.get_learned_conditioning(xx.unsqueeze(0)) for xx in x] #.tile(n_samples, 1, 1)
v = [self.model.encode_first_stage(xx.unsqueeze(0)).mode() for xx in x]
return c, v
def angle_between(self, sph_v1, sph_v2):
def sph2cart(sv):
r, theta, phi = sv[0], sv[1], sv[2]
return torch.tensor([r * torch.sin(theta) * torch.cos(phi), r * torch.sin(theta) * torch.sin(phi), r * torch.cos(theta)])
def unit_vector(v):
return v / torch.linalg.norm(v)
def angle_between_2_sph(sv1, sv2):
v1, v2 = sph2cart(sv1), sph2cart(sv2)
v1_u, v2_u = unit_vector(v1), unit_vector(v2)
return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0))
angles = torch.empty(len(sph_v1), len(sph_v2))
for i, sv1 in enumerate(sph_v1):
for j, sv2 in enumerate(sph_v2):
angles[i][j] = angle_between_2_sph(sv1, sv2)
return angles
def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scale=3, as_latent=False, grad_scale=1, save_guidance_path:Path=None):
# pred_rgb: tensor [1, 3, H, W] in [0, 1]
# adjust SDS scale based on how far the novel view is from the known view
ref_radii = embeddings['ref_radii']
ref_polars = embeddings['ref_polars']
ref_azimuths = embeddings['ref_azimuths']
v1 = torch.stack([radius + ref_radii[0], torch.deg2rad(polar + ref_polars[0]), torch.deg2rad(azimuth + ref_azimuths[0])], dim=-1) # polar,azimuth,radius are all actually delta wrt default
v2 = torch.stack([torch.tensor(ref_radii), torch.deg2rad(torch.tensor(ref_polars)), torch.deg2rad(torch.tensor(ref_azimuths))], dim=-1)
angles = torch.rad2deg(self.angle_between(v1, v2)).to(self.device)
if self.opt.zero123_grad_scale == 'angle':
grad_scale = (angles.min(dim=1)[0] / (180/len(ref_azimuths))) * grad_scale # rethink 180/len(ref_azimuths) # claforte: try inverting grad_scale or just fixing it to 1.0
elif self.opt.zero123_grad_scale == 'None':
grad_scale = 1.0 # claforte: I think this might converge faster...?
else:
assert False, f'Unrecognized `zero123_grad_scale`: {self.opt.zero123_grad_scale}'
if as_latent:
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
else:
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
latents = self.encode_imgs(pred_rgb_256)
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
# Set weights acc to closeness in angle
if len(ref_azimuths) > 1:
inv_angles = 1/angles
inv_angles[inv_angles > 100] = 100
inv_angles /= inv_angles.max(dim=-1, keepdim=True)[0]
inv_angles[inv_angles < 0.1] = 0
else:
inv_angles = torch.tensor([1.]).to(self.device)
# Multiply closeness-weight by user-given weights
zero123_ws = torch.tensor(embeddings['zero123_ws'])[None, :].to(self.device) * inv_angles
zero123_ws /= zero123_ws.max(dim=-1, keepdim=True)[0]
zero123_ws[zero123_ws < 0.1] = 0
with torch.no_grad():
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
x_in = torch.cat([latents_noisy] * 2)
t_in = torch.cat([t] * 2)
noise_preds = []
# Loop through each ref image
for (zero123_w, c_crossattn, c_concat, ref_polar, ref_azimuth, ref_radius) in zip(zero123_ws.T,
embeddings['c_crossattn'], embeddings['c_concat'],
ref_polars, ref_azimuths, ref_radii):
# polar,azimuth,radius are all actually delta wrt default
p = polar + ref_polars[0] - ref_polar
a = azimuth + ref_azimuths[0] - ref_azimuth
a[a > 180] -= 360 # range in [-180, 180]
r = radius + ref_radii[0] - ref_radius
# T = torch.tensor([math.radians(p), math.sin(math.radians(-a)), math.cos(math.radians(a)), r])
# T = T[None, None, :].to(self.device)
T = torch.stack([torch.deg2rad(p), torch.sin(torch.deg2rad(-a)), torch.cos(torch.deg2rad(a)), r], dim=-1)[:, None, :]
cond = {}
clip_emb = self.model.cc_projection(torch.cat([c_crossattn.repeat(len(T), 1, 1), T], dim=-1))
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).repeat(len(T), 1, 1, 1).to(self.device), c_concat.repeat(len(T), 1, 1, 1)], dim=0)]
noise_pred = self.model.apply_model(x_in, t_in, cond)
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
noise_preds.append(zero123_w[:, None, None, None] * noise_pred)
noise_pred = torch.stack(noise_preds).sum(dim=0) / zero123_ws.sum(dim=-1)[:, None, None, None]
w = (1 - self.alphas[t])
grad = (grad_scale * w)[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)
# import kiui
# if not as_latent:
# kiui.vis.plot_image(pred_rgb_256)
# kiui.vis.plot_matrix(latents)
# kiui.vis.plot_matrix(grad)
# import kiui
# latents = torch.randn((1, 4, 32, 32), device=self.device)
# kiui.lo(latents)
# self.scheduler.set_timesteps(30)
# with torch.no_grad():
# for i, t in enumerate(self.scheduler.timesteps):
# x_in = torch.cat([latents] * 2)
# t_in = torch.cat([t.view(1)] * 2).to(self.device)
# noise_pred = self.model.apply_model(x_in, t_in, cond)
# noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
# noise_pred = noise_pred_uncond + 3 * (noise_pred_cond - noise_pred_uncond)
# latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
# imgs = self.decode_latents(latents)
# print(polar, azimuth, radius)
# kiui.vis.plot_image(pred_rgb_256, imgs)
if save_guidance_path:
with torch.no_grad():
if as_latent:
pred_rgb_256 = self.decode_latents(latents) # claforte: test!
# visualize predicted denoised image
result_hopefully_less_noisy_image = self.decode_latents(self.model.predict_start_from_noise(latents_noisy, t, noise_pred))
# visualize noisier image
result_noisier_image = self.decode_latents(latents_noisy)
# TODO: also denoise all-the-way
# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
viz_images = torch.cat([pred_rgb_256, result_noisier_image, result_hopefully_less_noisy_image],dim=-1)
save_image(viz_images, save_guidance_path)
targets = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
return loss
# verification
@torch.no_grad()
def __call__(self,
image, # image tensor [1, 3, H, W] in [0, 1]
polar=0, azimuth=0, radius=0, # new view params
scale=3, ddim_steps=50, ddim_eta=1, h=256, w=256, # diffusion params
c_crossattn=None, c_concat=None, post_process=True,
):
if c_crossattn is None:
embeddings = self.get_img_embeds(image)
T = torch.tensor([math.radians(polar), math.sin(math.radians(azimuth)), math.cos(math.radians(azimuth)), radius])
T = T[None, None, :].to(self.device)
cond = {}
clip_emb = self.model.cc_projection(torch.cat([embeddings['c_crossattn'] if c_crossattn is None else c_crossattn, T], dim=-1))
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
cond['c_concat'] = [torch.cat([torch.zeros_like(embeddings['c_concat']).to(self.device), embeddings['c_concat']], dim=0)] if c_concat is None else [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)]
# produce latents loop
latents = torch.randn((1, 4, h // 8, w // 8), device=self.device)
self.scheduler.set_timesteps(ddim_steps)
for i, t in enumerate(self.scheduler.timesteps):
x_in = torch.cat([latents] * 2)
t_in = torch.cat([t.view(1)] * 2).to(self.device)
noise_pred = self.model.apply_model(x_in, t_in, cond)
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + scale * (noise_pred_cond - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)['prev_sample']
imgs = self.decode_latents(latents)
imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs
return imgs
def decode_latents(self, latents):
# zs: [B, 4, 32, 32] Latent space image
# with self.model.ema_scope():
imgs = self.model.decode_first_stage(latents)
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs # [B, 3, 256, 256] RGB space image
def encode_imgs(self, imgs):
# imgs: [B, 3, 256, 256] RGB space image
# with self.model.ema_scope():
imgs = imgs * 2 - 1
latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0)
return latents # [B, 4, 32, 32] Latent space image
if __name__ == '__main__':
import cv2
import argparse
import numpy as np
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('input', type=str)
parser.add_argument('--fp16', action='store_true', help="use float16 for training") # no use now, can only run in fp32
parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]')
parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]')
parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]')
opt = parser.parse_args()
device = torch.device('cuda')
print(f'[INFO] loading image from {opt.input} ...')
image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
image = image.astype(np.float32) / 255.0
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
print(f'[INFO] loading model ...')
zero123 = Zero123(device, opt.fp16, opt=opt)
print(f'[INFO] running model ...')
outputs = zero123(image, polar=opt.polar, azimuth=opt.azimuth, radius=opt.radius)
plt.imshow(outputs[0])
plt.show()
================================================
FILE: ldm/extras.py
================================================
from pathlib import Path
from omegaconf import OmegaConf
import torch
from ldm.util import instantiate_from_config
import logging
from contextlib import contextmanager
from contextlib import contextmanager
import logging
@contextmanager
def all_logging_disabled(highest_level=logging.CRITICAL):
"""
A context manager that will prevent any logging messages
triggered during the body from being processed.
:param highest_level: the maximum logging level in use.
This would only need to be changed if a custom level greater than CRITICAL
is defined.
https://gist.github.com/simon-weber/7853144
"""
# two kind-of hacks here:
# * can't get the highest logging level in effect => delegate to the user
# * can't get the current module-level override => use an undocumented
# (but non-private!) interface
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def load_training_dir(train_dir, device, epoch="last"):
"""Load a checkpoint and config from training directory"""
train_dir = Path(train_dir)
ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
config = list(train_dir.rglob(f"*-project.yaml"))
assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
if len(config) > 1:
print(f"found {len(config)} matching config files")
config = sorted(config)[-1]
print(f"selecting {config}")
else:
config = config[0]
config = OmegaConf.load(config)
return load_model_from_config(config, ckpt[0], device)
def load_model_from_config(config, ckpt, device="cpu", verbose=False):
"""Loads a model from config and a ckpt
if config is a path will use omegaconf to load
"""
if isinstance(config, (str, Path)):
config = OmegaConf.load(config)
with all_logging_disabled():
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
model.to(device)
model.eval()
model.cond_stage_model.device = device
return model
================================================
FILE: ldm/guidance.py
================================================
from typing import List, Tuple
from scipy import interpolate
import numpy as np
import torch
import matplotlib.pyplot as plt
from IPython.display import clear_output
import abc
class GuideModel(torch.nn.Module, abc.ABC):
def __init__(self) -> None:
super().__init__()
@abc.abstractmethod
def preprocess(self, x_img):
pass
@abc.abstractmethod
def compute_loss(self, inp):
pass
class Guider(torch.nn.Module):
def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
"""Apply classifier guidance
Specify a guidance scale as either a scalar
Or a schedule as a list of tuples t = 0->1 and scale, e.g.
[(0, 10), (0.5, 20), (1, 50)]
"""
super().__init__()
self.sampler = sampler
self.index = 0
self.show = verbose
self.guide_model = guide_model
self.history = []
if isinstance(scale, (Tuple, List)):
times = np.array([x[0] for x in scale])
values = np.array([x[1] for x in scale])
self.scale_schedule = {"times": times, "values": values}
else:
self.scale_schedule = float(scale)
self.ddim_timesteps = sampler.ddim_timesteps
self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
def get_scales(self):
if isinstance(self.scale_schedule, float):
return len(self.ddim_timesteps)*[self.scale_schedule]
interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
return interpolater(fractional_steps)
def modify_score(self, model, e_t, x, t, c):
# TODO look up index by t
scale = self.get_scales()[self.index]
if (scale == 0):
return e_t
sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
inp = self.guide_model.preprocess(x_img)
loss = self.guide_model.compute_loss(inp)
grads = torch.autograd.grad(loss.sum(), x_in)[0]
correction = grads * scale
if self.show:
clear_output(wait=True)
print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
plt.axis('off')
plt.show()
plt.imshow(correction[0][0].detach().cpu())
plt.axis('off')
plt.show()
e_t_mod = e_t - sqrt_1ma*correction
if self.show:
fig, axs = plt.subplots(1, 3)
axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
plt.show()
self.index += 1
return e_t_mod
================================================
FILE: ldm/lr_scheduler.py
================================================
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi))
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n,**kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi))
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
self.last_f = f
return f
================================================
FILE: ldm/models/autoencoder.py
================================================
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
]
return [opt_ae, opt_disc], scheduler
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
================================================
FILE: ldm/models/diffusion/__init__.py
================================================
================================================
FILE: ldm/models/diffusion/classifier.py
================================================
import os
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from copy import deepcopy
from einops import rearrange
from glob import glob
from natsort import natsorted
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
__models__ = {
'class_label': EncoderUNetModel,
'segmentation': UNetModel
}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(self,
diffusion_path,
num_classes,
ckpt_path=None,
pool='attention',
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.e-2,
log_steps=10,
monitor='val/loss',
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
# get latest config of diffusion model
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion()
self.monitor = monitor
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
self.log_steps = log_steps
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
else self.diffusion_model.cond_stage_key
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
if self.label_key not in __models__:
raise NotImplementedError()
self.load_classifier(ckpt_path, pool)
self.scheduler_config = scheduler_config
self.use_scheduler = self.scheduler_config is not None
self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config)
self.diffusion_model = model.eval()
self.diffusion_model.train = disabled_train
for param in self.diffusion_model.parameters():
param.requires_grad = False
def load_classifier(self, ckpt_path, pool):
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
model_config.out_channels = self.num_classes
if self.label_key == 'class_label':
model_config.pool = pool
self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None:
print('#####################################################################')
print(f'load from ckpt "{ckpt_path}"')
print('#####################################################################')
self.init_from_ckpt(ckpt_path)
@torch.no_grad()
def get_x_noisy(self, x, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
# todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t)
@torch.no_grad()
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
@torch.no_grad()
def get_conditioning(self, batch, k=None):
if k is None:
k = self.label_key
assert k is not None, 'Needs to provide label key'
targets = batch[k].to(self.device)
if self.label_key == 'segmentation':
targets = rearrange(targets, 'b h w c -> b c h w')
for down in range(self.numd):
h, w = targets.shape[-2:]
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
# targets = rearrange(targets,'b c h w -> b h w c')
return targets
def compute_top_k(self, logits, labels, k, reduction="mean"):
_, top_ks = torch.topk(logits, k, dim=1)
if reduction == "mean":
return
gitextract_l054hyr6/
├── .github/
│ └── ISSUE_TEMPLATE/
│ ├── bug_report.yaml
│ └── feature_request.md
├── .gitignore
├── LICENSE
├── activation.py
├── assets/
│ ├── advanced.md
│ └── update_logs.md
├── config/
│ ├── anya.csv
│ ├── car.csv
│ └── corgi.csv
├── docker/
│ ├── Dockerfile
│ └── README.md
├── dpt.py
├── encoding.py
├── evaluation/
│ ├── Prompt.py
│ ├── mesh_to_video.py
│ ├── r_precision.py
│ └── readme.md
├── freqencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── freq.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── freqencoder.cu
│ └── freqencoder.h
├── gridencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── grid.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── gridencoder.cu
│ └── gridencoder.h
├── guidance/
│ ├── clip_utils.py
│ ├── if_utils.py
│ ├── perpneg_utils.py
│ ├── sd_utils.py
│ └── zero123_utils.py
├── ldm/
│ ├── extras.py
│ ├── guidance.py
│ ├── lr_scheduler.py
│ ├── models/
│ │ ├── autoencoder.py
│ │ └── diffusion/
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── plms.py
│ │ └── sampling_util.py
│ ├── modules/
│ │ ├── attention.py
│ │ ├── diffusionmodules/
│ │ │ ├── __init__.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ └── util.py
│ │ ├── distributions/
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── encoders/
│ │ │ ├── __init__.py
│ │ │ └── modules.py
│ │ ├── evaluate/
│ │ │ ├── adm_evaluator.py
│ │ │ ├── evaluate_perceptualsim.py
│ │ │ ├── frechet_video_distance.py
│ │ │ ├── ssim.py
│ │ │ └── torch_frechet_video_distance.py
│ │ ├── image_degradation/
│ │ │ ├── __init__.py
│ │ │ ├── bsrgan.py
│ │ │ ├── bsrgan_light.py
│ │ │ └── utils_image.py
│ │ ├── losses/
│ │ │ ├── __init__.py
│ │ │ ├── contperceptual.py
│ │ │ └── vqperceptual.py
│ │ └── x_transformer.py
│ ├── thirdp/
│ │ └── psp/
│ │ ├── helpers.py
│ │ ├── id_loss.py
│ │ └── model_irse.py
│ └── util.py
├── main.py
├── meshutils.py
├── nerf/
│ ├── gui.py
│ ├── network.py
│ ├── network_grid.py
│ ├── network_grid_taichi.py
│ ├── network_grid_tcnn.py
│ ├── provider.py
│ ├── renderer.py
│ └── utils.py
├── optimizer.py
├── preprocess_image.py
├── pretrained/
│ └── zero123/
│ └── sd-objaverse-finetune-c_concat-256.yaml
├── raymarching/
│ ├── __init__.py
│ ├── backend.py
│ ├── raymarching.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── raymarching.cu
│ └── raymarching.h
├── readme.md
├── requirements.txt
├── scripts/
│ ├── install_ext.sh
│ ├── res64.args
│ ├── run.sh
│ ├── run2.sh
│ ├── run3.sh
│ ├── run4.sh
│ ├── run5.sh
│ ├── run6.sh
│ ├── run_if.sh
│ ├── run_if2.sh
│ ├── run_if2_perpneg.sh
│ ├── run_image.sh
│ ├── run_image_anya.sh
│ ├── run_image_hard_examples.sh
│ ├── run_image_procedure.sh
│ ├── run_image_text.sh
│ └── run_images.sh
├── shencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── setup.py
│ ├── sphere_harmonics.py
│ └── src/
│ ├── bindings.cpp
│ ├── shencoder.cu
│ └── shencoder.h
├── taichi_modules/
│ ├── __init__.py
│ ├── hash_encoder.py
│ ├── intersection.py
│ ├── ray_march.py
│ ├── utils.py
│ ├── volume_render_test.py
│ └── volume_train.py
└── tets/
├── 128_tets.npz
├── 32_tets.npz
├── 64_tets.npz
├── README.md
└── generate_tets.py
SYMBOL INDEX (1102 symbols across 75 files)
FILE: activation.py
class _trunc_exp (line 5) | class _trunc_exp(Function):
method forward (line 8) | def forward(ctx, x):
method backward (line 14) | def backward(ctx, g):
function biased_softplus (line 20) | def biased_softplus(x, bias=0):
FILE: dpt.py
class BaseModel (line 10) | class BaseModel(torch.nn.Module):
method load (line 11) | def load(self, path):
function unflatten_with_named_tensor (line 24) | def unflatten_with_named_tensor(input, dim, sizes):
class Slice (line 30) | class Slice(nn.Module):
method __init__ (line 31) | def __init__(self, start_index=1):
method forward (line 35) | def forward(self, x):
class AddReadout (line 39) | class AddReadout(nn.Module):
method __init__ (line 40) | def __init__(self, start_index=1):
method forward (line 44) | def forward(self, x):
class ProjectReadout (line 52) | class ProjectReadout(nn.Module):
method __init__ (line 53) | def __init__(self, in_features, start_index=1):
method forward (line 59) | def forward(self, x):
class Transpose (line 66) | class Transpose(nn.Module):
method __init__ (line 67) | def __init__(self, dim0, dim1):
method forward (line 72) | def forward(self, x):
function forward_vit (line 77) | def forward_vit(pretrained, x):
function _resize_pos_embed (line 118) | def _resize_pos_embed(self, posemb, gs_h, gs_w):
function forward_flex (line 135) | def forward_flex(self, x):
function get_activation (line 177) | def get_activation(name):
function get_readout_oper (line 184) | def get_readout_oper(vit_features, features, use_readout, start_index=1):
function _make_vit_b16_backbone (line 201) | def _make_vit_b16_backbone(
function _make_pretrained_vitl16_384 (line 315) | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=...
function _make_pretrained_vitb16_384 (line 328) | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=...
function _make_pretrained_deitb16_384 (line 337) | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks...
function _make_pretrained_deitb16_distil_384 (line 346) | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore"...
function _make_vit_b_rn50_backbone (line 361) | def _make_vit_b_rn50_backbone(
function _make_pretrained_vitb_rn50_384 (line 496) | def _make_pretrained_vitb_rn50_384(
function _make_encoder (line 511) | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=F...
function _make_scratch (line 549) | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
function _make_pretrained_efficientnet_lite3 (line 578) | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
function _make_efficientnet_backbone (line 588) | def _make_efficientnet_backbone(effnet):
function _make_resnet_backbone (line 601) | def _make_resnet_backbone(resnet):
function _make_pretrained_resnext101_wsl (line 614) | def _make_pretrained_resnext101_wsl(use_pretrained):
class Interpolate (line 620) | class Interpolate(nn.Module):
method __init__ (line 624) | def __init__(self, scale_factor, mode, align_corners=False):
method forward (line 637) | def forward(self, x):
class ResidualConvUnit (line 652) | class ResidualConvUnit(nn.Module):
method __init__ (line 656) | def __init__(self, features):
method forward (line 673) | def forward(self, x):
class FeatureFusionBlock (line 688) | class FeatureFusionBlock(nn.Module):
method __init__ (line 692) | def __init__(self, features):
method forward (line 702) | def forward(self, *xs):
class ResidualConvUnit_custom (line 723) | class ResidualConvUnit_custom(nn.Module):
method __init__ (line 727) | def __init__(self, features, activation, bn):
method forward (line 754) | def forward(self, x):
class FeatureFusionBlock_custom (line 780) | class FeatureFusionBlock_custom(nn.Module):
method __init__ (line 784) | def __init__(self, features, activation, deconv=False, bn=False, expan...
method forward (line 808) | def forward(self, *xs):
function _make_fusion_block (line 832) | def _make_fusion_block(features, use_bn):
class DPT (line 843) | class DPT(BaseModel):
method __init__ (line 844) | def __init__(
method forward (line 884) | def forward(self, x):
class DPTDepthModel (line 904) | class DPTDepthModel(DPT):
method __init__ (line 905) | def __init__(self, path=None, non_negative=True, num_channels=1, **kwa...
method forward (line 923) | def forward(self, x):
FILE: encoding.py
class FreqEncoder_torch (line 5) | class FreqEncoder_torch(nn.Module):
method __init__ (line 6) | def __init__(self, input_dim, max_freq_log2, N_freqs,
method forward (line 30) | def forward(self, input, max_level=None, **kwargs):
function get_encoder (line 54) | def get_encoder(encoding, input_dim=3,
FILE: evaluation/mesh_to_video.py
function render_video (line 9) | def render_video(anim_mesh):
function generate_mesh (line 26) | def generate_mesh(obj1,obj2,transform_vector):
FILE: freqencoder/backend.py
function find_cl_path (line 18) | def find_cl_path():
FILE: freqencoder/freq.py
class _freq_encoder (line 15) | class _freq_encoder(Function):
method forward (line 18) | def forward(ctx, inputs, degree, output_dim):
method backward (line 39) | def backward(ctx, grad):
class FreqEncoder (line 55) | class FreqEncoder(nn.Module):
method __init__ (line 56) | def __init__(self, input_dim=3, degree=4):
method __repr__ (line 63) | def __repr__(self):
method forward (line 66) | def forward(self, inputs, **kwargs):
FILE: freqencoder/setup.py
function find_cl_path (line 19) | def find_cl_path():
FILE: freqencoder/src/bindings.cpp
function PYBIND11_MODULE (line 5) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: gridencoder/backend.py
function find_cl_path (line 17) | def find_cl_path():
FILE: gridencoder/grid.py
class _grid_encode (line 25) | class _grid_encode(Function):
method forward (line 28) | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_re...
method backward (line 75) | def backward(ctx, grad):
class GridEncoder (line 103) | class GridEncoder(nn.Module):
method __init__ (line 104) | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_...
method reset_parameters (line 145) | def reset_parameters(self):
method __repr__ (line 149) | def __repr__(self):
method forward (line 152) | def forward(self, inputs, bound=1, max_level=None):
method grad_total_variation (line 173) | def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=10...
method grad_weight_decay (line 196) | def grad_weight_decay(self, weight=0.1):
FILE: gridencoder/setup.py
function find_cl_path (line 18) | def find_cl_path():
FILE: gridencoder/src/bindings.cpp
function PYBIND11_MODULE (line 5) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: guidance/clip_utils.py
class CLIP (line 9) | class CLIP(nn.Module):
method __init__ (line 10) | def __init__(self, device, **kwargs):
method get_text_embeds (line 21) | def get_text_embeds(self, prompt, **kwargs):
method get_img_embeds (line 29) | def get_img_embeds(self, image, **kwargs):
method train_step (line 37) | def train_step(self, clip_z, pred_rgb, grad_scale=10, **kwargs):
FILE: guidance/if_utils.py
function seed_everything (line 15) | def seed_everything(seed):
class IF (line 22) | class IF(nn.Module):
method __init__ (line 23) | def __init__(self, device, vram_O, t_range=[0.02, 0.98]):
method get_text_embeds (line 62) | def get_text_embeds(self, prompt):
method train_step (line 73) | def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, gr...
method train_step_perpneg (line 110) | def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guida...
method produce_imgs (line 152) | def produce_imgs(self, text_embeddings, height=64, width=64, num_infer...
method prompt_to_img (line 182) | def prompt_to_img(self, prompts, negative_prompts='', height=512, widt...
FILE: guidance/perpneg_utils.py
function get_perpendicular_component (line 4) | def get_perpendicular_component(x, y):
function batch_get_perpendicular_component (line 9) | def batch_get_perpendicular_component(x, y):
function weighted_perpendicular_aggregator (line 17) | def weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_...
FILE: guidance/sd_utils.py
function seed_everything (line 19) | def seed_everything(seed):
class StableDiffusion (line 25) | class StableDiffusion(nn.Module):
method __init__ (line 26) | def __init__(self, device, fp16, vram_O, sd_version='2.1', hf_key=None...
method get_text_embeds (line 77) | def get_text_embeds(self, prompt):
method train_step (line 86) | def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as...
method train_step_perpneg (line 166) | def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guida...
method produce_latents (line 251) | def produce_latents(self, text_embeddings, height=512, width=512, num_...
method decode_latents (line 273) | def decode_latents(self, latents):
method encode_imgs (line 282) | def encode_imgs(self, imgs):
method prompt_to_img (line 292) | def prompt_to_img(self, prompts, negative_prompts='', height=512, widt...
FILE: guidance/zero123_utils.py
function load_model_from_config (line 22) | def load_model_from_config(config, ckpt, device, vram_O=False, verbose=F...
class Zero123 (line 56) | class Zero123(nn.Module):
method __init__ (line 57) | def __init__(self, device, fp16,
method get_img_embeds (line 90) | def get_img_embeds(self, x):
method angle_between (line 97) | def angle_between(self, sph_v1, sph_v2):
method train_step (line 113) | def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, gui...
method __call__ (line 235) | def __call__(self,
method decode_latents (line 272) | def decode_latents(self, latents):
method encode_imgs (line 280) | def encode_imgs(self, imgs):
FILE: ldm/extras.py
function all_logging_disabled (line 12) | def all_logging_disabled(highest_level=logging.CRITICAL):
function load_training_dir (line 37) | def load_training_dir(train_dir, device, epoch="last"):
function load_model_from_config (line 55) | def load_model_from_config(config, ckpt, device="cpu", verbose=False):
FILE: ldm/guidance.py
class GuideModel (line 10) | class GuideModel(torch.nn.Module, abc.ABC):
method __init__ (line 11) | def __init__(self) -> None:
method preprocess (line 15) | def preprocess(self, x_img):
method compute_loss (line 19) | def compute_loss(self, inp):
class Guider (line 23) | class Guider(torch.nn.Module):
method __init__ (line 24) | def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
method get_scales (line 49) | def get_scales(self):
method modify_score (line 57) | def modify_score(self, model, e_t, x, t, c):
FILE: ldm/lr_scheduler.py
class LambdaWarmUpCosineScheduler (line 4) | class LambdaWarmUpCosineScheduler:
method __init__ (line 8) | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_...
method schedule (line 17) | def schedule(self, n, **kwargs):
method __call__ (line 32) | def __call__(self, n, **kwargs):
class LambdaWarmUpCosineScheduler2 (line 36) | class LambdaWarmUpCosineScheduler2:
method __init__ (line 41) | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths...
method find_in_interval (line 52) | def find_in_interval(self, n):
method schedule (line 59) | def schedule(self, n, **kwargs):
method __call__ (line 77) | def __call__(self, n, **kwargs):
class LambdaLinearScheduler (line 81) | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
method schedule (line 83) | def schedule(self, n, **kwargs):
FILE: ldm/models/autoencoder.py
class VQModel (line 14) | class VQModel(pl.LightningModule):
method __init__ (line 15) | def __init__(self,
method ema_scope (line 64) | def ema_scope(self, context=None):
method init_from_ckpt (line 78) | def init_from_ckpt(self, path, ignore_keys=list()):
method on_train_batch_end (line 92) | def on_train_batch_end(self, *args, **kwargs):
method encode (line 96) | def encode(self, x):
method encode_to_prequant (line 102) | def encode_to_prequant(self, x):
method decode (line 107) | def decode(self, quant):
method decode_code (line 112) | def decode_code(self, code_b):
method forward (line 117) | def forward(self, input, return_pred_indices=False):
method get_input (line 124) | def get_input(self, batch, k):
method training_step (line 142) | def training_step(self, batch, batch_idx, optimizer_idx):
method validation_step (line 164) | def validation_step(self, batch, batch_idx):
method _validation_step (line 170) | def _validation_step(self, batch, batch_idx, suffix=""):
method configure_optimizers (line 197) | def configure_optimizers(self):
method get_last_layer (line 230) | def get_last_layer(self):
method log_images (line 233) | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
method to_rgb (line 255) | def to_rgb(self, x):
class VQModelInterface (line 264) | class VQModelInterface(VQModel):
method __init__ (line 265) | def __init__(self, embed_dim, *args, **kwargs):
method encode (line 269) | def encode(self, x):
method decode (line 274) | def decode(self, h, force_not_quantize=False):
class AutoencoderKL (line 285) | class AutoencoderKL(pl.LightningModule):
method __init__ (line 286) | def __init__(self,
method init_from_ckpt (line 313) | def init_from_ckpt(self, path, ignore_keys=list()):
method encode (line 324) | def encode(self, x):
method decode (line 330) | def decode(self, z):
method forward (line 335) | def forward(self, input, sample_posterior=True):
method get_input (line 344) | def get_input(self, batch, k):
method training_step (line 351) | def training_step(self, batch, batch_idx, optimizer_idx):
method validation_step (line 372) | def validation_step(self, batch, batch_idx):
method configure_optimizers (line 386) | def configure_optimizers(self):
method get_last_layer (line 397) | def get_last_layer(self):
method log_images (line 401) | def log_images(self, batch, only_inputs=False, **kwargs):
method to_rgb (line 417) | def to_rgb(self, x):
class IdentityFirstStage (line 426) | class IdentityFirstStage(torch.nn.Module):
method __init__ (line 427) | def __init__(self, *args, vq_interface=False, **kwargs):
method encode (line 431) | def encode(self, x, *args, **kwargs):
method decode (line 434) | def decode(self, x, *args, **kwargs):
method quantize (line 437) | def quantize(self, x, *args, **kwargs):
method forward (line 442) | def forward(self, x, *args, **kwargs):
FILE: ldm/models/diffusion/classifier.py
function disabled_train (line 22) | def disabled_train(self, mode=True):
class NoisyLatentImageClassifier (line 28) | class NoisyLatentImageClassifier(pl.LightningModule):
method __init__ (line 30) | def __init__(self,
method init_from_ckpt (line 70) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method load_diffusion (line 88) | def load_diffusion(self):
method load_classifier (line 95) | def load_classifier(self, ckpt_path, pool):
method get_x_noisy (line 110) | def get_x_noisy(self, x, t, noise=None):
method forward (line 120) | def forward(self, x_noisy, t, *args, **kwargs):
method get_input (line 124) | def get_input(self, batch, k):
method get_conditioning (line 133) | def get_conditioning(self, batch, k=None):
method compute_top_k (line 150) | def compute_top_k(self, logits, labels, k, reduction="mean"):
method on_train_epoch_start (line 157) | def on_train_epoch_start(self):
method write_logs (line 162) | def write_logs(self, loss, logits, targets):
method shared_step (line 179) | def shared_step(self, batch, t=None):
method training_step (line 198) | def training_step(self, batch, batch_idx):
method reset_noise_accs (line 202) | def reset_noise_accs(self):
method on_validation_start (line 206) | def on_validation_start(self):
method validation_step (line 210) | def validation_step(self, batch, batch_idx):
method configure_optimizers (line 220) | def configure_optimizers(self):
method log_images (line 238) | def log_images(self, batch, N=8, *args, **kwargs):
FILE: ldm/models/diffusion/ddim.py
class DDIMSampler (line 13) | class DDIMSampler(object):
method __init__ (line 14) | def __init__(self, model, schedule="linear", **kwargs):
method to (line 20) | def to(self, device):
method register_buffer (line 29) | def register_buffer(self, name, attr):
method make_schedule (line 35) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 67) | def sample(self,
method ddim_sampling (line 128) | def ddim_sampling(self, cond, shape,
method p_sample_ddim (line 186) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
method encode (line 248) | def encode(self, x0, c, t_enc, use_original_steps=False, return_interm...
method stochastic_encode (line 294) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
method decode (line 310) | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale...
FILE: ldm/models/diffusion/ddpm.py
function disabled_train (line 37) | def disabled_train(self, mode=True):
function uniform_on_device (line 43) | def uniform_on_device(r1, r2, shape, device):
class DDPM (line 47) | class DDPM(pl.LightningModule):
method __init__ (line 49) | def __init__(self,
method register_schedule (line 126) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
method ema_scope (line 181) | def ema_scope(self, context=None):
method init_from_ckpt (line 196) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method q_mean_variance (line 254) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 266) | def predict_start_from_noise(self, x_t, t, noise):
method q_posterior (line 272) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 281) | def p_mean_variance(self, x, t, clip_denoised: bool):
method p_sample (line 294) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
method p_sample_loop (line 303) | def p_sample_loop(self, shape, return_intermediates=False):
method sample (line 318) | def sample(self, batch_size=16, return_intermediates=False):
method q_sample (line 324) | def q_sample(self, x_start, t, noise=None):
method get_loss (line 329) | def get_loss(self, pred, target, mean=True):
method p_losses (line 344) | def p_losses(self, x_start, t, noise=None):
method forward (line 373) | def forward(self, x, *args, **kwargs):
method get_input (line 379) | def get_input(self, batch, k):
method shared_step (line 387) | def shared_step(self, batch):
method training_step (line 392) | def training_step(self, batch, batch_idx):
method validation_step (line 417) | def validation_step(self, batch, batch_idx):
method on_train_batch_end (line 425) | def on_train_batch_end(self, *args, **kwargs):
method _get_rows_from_list (line 429) | def _get_rows_from_list(self, samples):
method log_images (line 437) | def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=Non...
method configure_optimizers (line 474) | def configure_optimizers(self):
class LatentDiffusion (line 483) | class LatentDiffusion(DDPM):
method __init__ (line 485) | def __init__(self,
method make_cond_schedule (line 539) | def make_cond_schedule(self, ):
method on_train_batch_start (line 546) | def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
method register_schedule (line 561) | def register_schedule(self,
method instantiate_first_stage (line 570) | def instantiate_first_stage(self, config):
method instantiate_cond_stage (line 577) | def instantiate_cond_stage(self, config):
method _get_denoise_row_from_list (line 598) | def _get_denoise_row_from_list(self, samples, desc='', force_no_decode...
method get_first_stage_encoding (line 610) | def get_first_stage_encoding(self, encoder_posterior):
method get_learned_conditioning (line 619) | def get_learned_conditioning(self, c):
method meshgrid (line 632) | def meshgrid(self, h, w):
method delta_border (line 639) | def delta_border(self, h, w):
method get_weighting (line 653) | def get_weighting(self, h, w, Ly, Lx, device):
method get_fold_unfold (line 669) | def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo...
method get_input (line 723) | def get_input(self, batch, k, return_first_stage_outputs=False, force_...
method decode_first_stage (line 763) | def decode_first_stage(self, z, predict_cids=False, force_not_quantize...
method encode_first_stage (line 823) | def encode_first_stage(self, x):
method shared_step (line 862) | def shared_step(self, batch, **kwargs):
method forward (line 867) | def forward(self, x, c, *args, **kwargs):
method _rescale_annotations (line 878) | def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: mov...
method apply_model (line 888) | def apply_model(self, x_noisy, t, cond, return_ids=False):
method _predict_eps_from_xstart (line 986) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
method _prior_bpd (line 990) | def _prior_bpd(self, x_start):
method p_losses (line 1004) | def p_losses(self, x_start, cond, t, noise=None):
method p_mean_variance (line 1039) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codeboo...
method p_sample (line 1071) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
method progressive_denoising (line 1102) | def progressive_denoising(self, cond, shape, verbose=True, callback=No...
method p_sample_loop (line 1158) | def p_sample_loop(self, cond, shape, return_intermediates=False,
method sample (line 1209) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
method sample_log (line 1227) | def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
method get_unconditional_conditioning (line 1241) | def get_unconditional_conditioning(self, batch_size, null_label=None, ...
method log_images (line 1262) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
method configure_optimizers (line 1387) | def configure_optimizers(self):
method to_rgb (line 1432) | def to_rgb(self, x):
class DiffusionWrapper (line 1441) | class DiffusionWrapper(pl.LightningModule):
method __init__ (line 1442) | def __init__(self, diff_model_config, conditioning_key):
method forward (line 1448) | def forward(self, x, t, c_concat: list = None, c_crossattn: list = Non...
class LatentUpscaleDiffusion (line 1477) | class LatentUpscaleDiffusion(LatentDiffusion):
method __init__ (line 1478) | def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwar...
method instantiate_low_stage (line 1485) | def instantiate_low_stage(self, config):
method get_input (line 1493) | def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
method log_images (line 1515) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
class LatentInpaintDiffusion (line 1613) | class LatentInpaintDiffusion(LatentDiffusion):
method __init__ (line 1619) | def __init__(self,
method init_from_ckpt (line 1644) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method get_input (line 1674) | def get_input(self, batch, k, cond_key=None, bs=None, return_first_sta...
method log_images (line 1700) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
class Layout2ImgDiffusion (line 1783) | class Layout2ImgDiffusion(LatentDiffusion):
method __init__ (line 1785) | def __init__(self, cond_stage_key, *args, **kwargs):
method log_images (line 1789) | def log_images(self, batch, N=8, *args, **kwargs):
class SimpleUpscaleDiffusion (line 1807) | class SimpleUpscaleDiffusion(LatentDiffusion):
method __init__ (line 1808) | def __init__(self, *args, low_scale_key="LR", **kwargs):
method get_input (line 1815) | def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
method log_images (line 1838) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
class MultiCatFrameDiffusion (line 1899) | class MultiCatFrameDiffusion(LatentDiffusion):
method __init__ (line 1900) | def __init__(self, *args, low_scale_key="LR", **kwargs):
method get_input (line 1907) | def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
method log_images (line 1935) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
FILE: ldm/models/diffusion/plms.py
class PLMSSampler (line 12) | class PLMSSampler(object):
method __init__ (line 13) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 19) | def register_buffer(self, name, attr):
method make_schedule (line 25) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 59) | def sample(self,
method plms_sampling (line 120) | def plms_sampling(self, cond, shape,
method p_sample_plms (line 180) | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_origin...
FILE: ldm/models/diffusion/sampling_util.py
function append_dims (line 5) | def append_dims(x, target_dims):
function renorm_thresholding (line 14) | def renorm_thresholding(x0, value):
function norm_thresholding (line 42) | def norm_thresholding(x0, value):
function spatial_norm_thresholding (line 47) | def spatial_norm_thresholding(x0, value):
FILE: ldm/modules/attention.py
function exists (line 11) | def exists(val):
function uniq (line 15) | def uniq(arr):
function default (line 19) | def default(val, d):
function max_neg_value (line 25) | def max_neg_value(t):
function init_ (line 29) | def init_(tensor):
class GEGLU (line 37) | class GEGLU(nn.Module):
method __init__ (line 38) | def __init__(self, dim_in, dim_out):
method forward (line 42) | def forward(self, x):
class FeedForward (line 47) | class FeedForward(nn.Module):
method __init__ (line 48) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 63) | def forward(self, x):
function zero_module (line 67) | def zero_module(module):
function Normalize (line 76) | def Normalize(in_channels):
class LinearAttention (line 80) | class LinearAttention(nn.Module):
method __init__ (line 81) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 88) | def forward(self, x):
class SpatialSelfAttention (line 99) | class SpatialSelfAttention(nn.Module):
method __init__ (line 100) | def __init__(self, in_channels):
method forward (line 126) | def forward(self, x):
class CrossAttention (line 152) | class CrossAttention(nn.Module):
method __init__ (line 153) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 170) | def forward(self, x, context=None, mask=None):
class BasicTransformerBlock (line 196) | class BasicTransformerBlock(nn.Module):
method __init__ (line 197) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
method forward (line 211) | def forward(self, x, context=None):
method _forward (line 214) | def _forward(self, x, context=None):
class SpatialTransformer (line 221) | class SpatialTransformer(nn.Module):
method __init__ (line 229) | def __init__(self, in_channels, n_heads, d_head,
method forward (line 255) | def forward(self, x, context=None):
FILE: ldm/modules/diffusionmodules/model.py
function get_timestep_embedding (line 12) | def get_timestep_embedding(timesteps, embedding_dim):
function nonlinearity (line 33) | def nonlinearity(x):
function Normalize (line 38) | def Normalize(in_channels, num_groups=32):
class Upsample (line 42) | class Upsample(nn.Module):
method __init__ (line 43) | def __init__(self, in_channels, with_conv):
method forward (line 53) | def forward(self, x):
class Downsample (line 60) | class Downsample(nn.Module):
method __init__ (line 61) | def __init__(self, in_channels, with_conv):
method forward (line 72) | def forward(self, x):
class ResnetBlock (line 82) | class ResnetBlock(nn.Module):
method __init__ (line 83) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
method forward (line 121) | def forward(self, x, temb):
class LinAttnBlock (line 144) | class LinAttnBlock(LinearAttention):
method __init__ (line 146) | def __init__(self, in_channels):
class AttnBlock (line 150) | class AttnBlock(nn.Module):
method __init__ (line 151) | def __init__(self, in_channels):
method forward (line 178) | def forward(self, x):
function make_attn (line 205) | def make_attn(in_channels, attn_type="vanilla"):
class Model (line 216) | class Model(nn.Module):
method __init__ (line 217) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 316) | def forward(self, x, t=None, context=None):
method get_last_layer (line 364) | def get_last_layer(self):
class Encoder (line 368) | class Encoder(nn.Module):
method __init__ (line 369) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 434) | def forward(self, x):
class Decoder (line 462) | class Decoder(nn.Module):
method __init__ (line 463) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 535) | def forward(self, z):
class SimpleDecoder (line 571) | class SimpleDecoder(nn.Module):
method __init__ (line 572) | def __init__(self, in_channels, out_channels, *args, **kwargs):
method forward (line 594) | def forward(self, x):
class UpsampleDecoder (line 607) | class UpsampleDecoder(nn.Module):
method __init__ (line 608) | def __init__(self, in_channels, out_channels, ch, num_res_blocks, reso...
method forward (line 641) | def forward(self, x):
class LatentRescaler (line 655) | class LatentRescaler(nn.Module):
method __init__ (line 656) | def __init__(self, factor, in_channels, mid_channels, out_channels, de...
method forward (line 680) | def forward(self, x):
class MergedRescaleEncoder (line 692) | class MergedRescaleEncoder(nn.Module):
method __init__ (line 693) | def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
method forward (line 705) | def forward(self, x):
class MergedRescaleDecoder (line 711) | class MergedRescaleDecoder(nn.Module):
method __init__ (line 712) | def __init__(self, z_channels, out_ch, resolution, num_res_blocks, att...
method forward (line 722) | def forward(self, x):
class Upsampler (line 728) | class Upsampler(nn.Module):
method __init__ (line 729) | def __init__(self, in_size, out_size, in_channels, out_channels, ch_mu...
method forward (line 741) | def forward(self, x):
class Resize (line 747) | class Resize(nn.Module):
method __init__ (line 748) | def __init__(self, in_channels=None, learned=False, mode="bilinear"):
method forward (line 763) | def forward(self, x, scale_factor=1.0):
class FirstStagePostProcessor (line 770) | class FirstStagePostProcessor(nn.Module):
method __init__ (line 772) | def __init__(self, ch_mult:list, in_channels,
method instantiate_pretrained (line 807) | def instantiate_pretrained(self, config):
method encode_with_pretrained (line 816) | def encode_with_pretrained(self,x):
method forward (line 822) | def forward(self,x):
FILE: ldm/modules/diffusionmodules/openaimodel.py
function convert_module_to_f16 (line 25) | def convert_module_to_f16(x):
function convert_module_to_f32 (line 28) | def convert_module_to_f32(x):
class AttentionPool2d (line 33) | class AttentionPool2d(nn.Module):
method __init__ (line 38) | def __init__(
method forward (line 52) | def forward(self, x):
class TimestepBlock (line 63) | class TimestepBlock(nn.Module):
method forward (line 69) | def forward(self, x, emb):
class TimestepEmbedSequential (line 75) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
method forward (line 81) | def forward(self, x, emb, context=None):
class Upsample (line 92) | class Upsample(nn.Module):
method __init__ (line 101) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 110) | def forward(self, x):
class TransposedUpsample (line 122) | class TransposedUpsample(nn.Module):
method __init__ (line 124) | def __init__(self, channels, out_channels=None, ks=5):
method forward (line 131) | def forward(self,x):
class Downsample (line 135) | class Downsample(nn.Module):
method __init__ (line 144) | def __init__(self, channels, use_conv, dims=2, out_channels=None,paddi...
method forward (line 159) | def forward(self, x):
class ResBlock (line 164) | class ResBlock(TimestepBlock):
method __init__ (line 180) | def __init__(
method forward (line 244) | def forward(self, x, emb):
method _forward (line 256) | def _forward(self, x, emb):
class AttentionBlock (line 279) | class AttentionBlock(nn.Module):
method __init__ (line 286) | def __init__(
method forward (line 315) | def forward(self, x):
method _forward (line 319) | def _forward(self, x):
function count_flops_attn (line 328) | def count_flops_attn(model, _x, y):
class QKVAttentionLegacy (line 348) | class QKVAttentionLegacy(nn.Module):
method __init__ (line 353) | def __init__(self, n_heads):
method forward (line 357) | def forward(self, qkv):
method count_flops (line 376) | def count_flops(model, _x, y):
class QKVAttention (line 380) | class QKVAttention(nn.Module):
method __init__ (line 385) | def __init__(self, n_heads):
method forward (line 389) | def forward(self, qkv):
method count_flops (line 410) | def count_flops(model, _x, y):
class UNetModel (line 414) | class UNetModel(nn.Module):
method __init__ (line 444) | def __init__(
method convert_to_fp16 (line 729) | def convert_to_fp16(self):
method convert_to_fp32 (line 737) | def convert_to_fp32(self):
method forward (line 745) | def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
class EncoderUNetModel (line 780) | class EncoderUNetModel(nn.Module):
method __init__ (line 786) | def __init__(
method convert_to_fp16 (line 959) | def convert_to_fp16(self):
method convert_to_fp32 (line 966) | def convert_to_fp32(self):
method forward (line 973) | def forward(self, x, timesteps):
FILE: ldm/modules/diffusionmodules/util.py
function make_beta_schedule (line 21) | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_e...
function make_ddim_timesteps (line 46) | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_...
function make_ddim_sampling_parameters (line 63) | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbos...
function betas_for_alpha_bar (line 77) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
function extract_into_tensor (line 96) | def extract_into_tensor(a, t, x_shape):
function checkpoint (line 102) | def checkpoint(func, inputs, params, flag):
class CheckpointFunction (line 119) | class CheckpointFunction(torch.autograd.Function):
method forward (line 121) | def forward(ctx, run_function, length, *args):
method backward (line 131) | def backward(ctx, *output_grads):
function timestep_embedding (line 151) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
function zero_module (line 174) | def zero_module(module):
function scale_module (line 183) | def scale_module(module, scale):
function mean_flat (line 192) | def mean_flat(tensor):
function normalization (line 199) | def normalization(channels):
class SiLU (line 209) | class SiLU(nn.Module):
method forward (line 210) | def forward(self, x):
class GroupNorm32 (line 214) | class GroupNorm32(nn.GroupNorm):
method forward (line 215) | def forward(self, x):
function conv_nd (line 218) | def conv_nd(dims, *args, **kwargs):
function linear (line 231) | def linear(*args, **kwargs):
function avg_pool_nd (line 238) | def avg_pool_nd(dims, *args, **kwargs):
class HybridConditioner (line 251) | class HybridConditioner(nn.Module):
method __init__ (line 253) | def __init__(self, c_concat_config, c_crossattn_config):
method forward (line 258) | def forward(self, c_concat, c_crossattn):
function noise_like (line 264) | def noise_like(shape, device, repeat=False):
FILE: ldm/modules/distributions/distributions.py
class AbstractDistribution (line 5) | class AbstractDistribution:
method sample (line 6) | def sample(self):
method mode (line 9) | def mode(self):
class DiracDistribution (line 13) | class DiracDistribution(AbstractDistribution):
method __init__ (line 14) | def __init__(self, value):
method sample (line 17) | def sample(self):
method mode (line 20) | def mode(self):
class DiagonalGaussianDistribution (line 24) | class DiagonalGaussianDistribution(object):
method __init__ (line 25) | def __init__(self, parameters, deterministic=False):
method sample (line 35) | def sample(self):
method kl (line 39) | def kl(self, other=None):
method nll (line 53) | def nll(self, sample, dims=[1,2,3]):
method mode (line 61) | def mode(self):
function normal_kl (line 65) | def normal_kl(mean1, logvar1, mean2, logvar2):
FILE: ldm/modules/ema.py
class LitEma (line 5) | class LitEma(nn.Module):
method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
method forward (line 25) | def forward(self,model):
method copy_to (line 46) | def copy_to(self, model):
method store (line 55) | def store(self, parameters):
method restore (line 64) | def restore(self, parameters):
FILE: ldm/modules/encoders/modules.py
class AbstractEncoder (line 12) | class AbstractEncoder(nn.Module):
method __init__ (line 13) | def __init__(self):
method encode (line 16) | def encode(self, *args, **kwargs):
class IdentityEncoder (line 19) | class IdentityEncoder(AbstractEncoder):
method encode (line 21) | def encode(self, x):
class FaceClipEncoder (line 24) | class FaceClipEncoder(AbstractEncoder):
method __init__ (line 25) | def __init__(self, augment=True, retreival_key=None):
method forward (line 31) | def forward(self, img):
method encode (line 54) | def encode(self, img):
class FaceIdClipEncoder (line 61) | class FaceIdClipEncoder(AbstractEncoder):
method __init__ (line 62) | def __init__(self):
method forward (line 69) | def forward(self, img):
method encode (line 84) | def encode(self, img):
class ClassEmbedder (line 91) | class ClassEmbedder(nn.Module):
method __init__ (line 92) | def __init__(self, embed_dim, n_classes=1000, key='class'):
method forward (line 97) | def forward(self, batch, key=None):
class TransformerEmbedder (line 106) | class TransformerEmbedder(AbstractEncoder):
method __init__ (line 108) | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, devic...
method forward (line 114) | def forward(self, tokens):
method encode (line 119) | def encode(self, x):
class BERTTokenizer (line 123) | class BERTTokenizer(AbstractEncoder):
method __init__ (line 125) | def __init__(self, device="cuda", vq_interface=True, max_length=77):
method forward (line 133) | def forward(self, text):
method encode (line 140) | def encode(self, text):
method decode (line 146) | def decode(self, text):
class BERTEmbedder (line 150) | class BERTEmbedder(AbstractEncoder):
method __init__ (line 152) | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
method forward (line 163) | def forward(self, text):
method encode (line 171) | def encode(self, text):
function disabled_train (line 178) | def disabled_train(self, mode=True):
class FrozenT5Embedder (line 184) | class FrozenT5Embedder(AbstractEncoder):
method __init__ (line 186) | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_...
method freeze (line 194) | def freeze(self):
method forward (line 200) | def forward(self, text):
method encode (line 209) | def encode(self, text):
class FrozenFaceEncoder (line 215) | class FrozenFaceEncoder(AbstractEncoder):
method __init__ (line 216) | def __init__(self, model_path, augment=False):
method forward (line 237) | def forward(self, img):
method encode (line 251) | def encode(self, img):
class FrozenCLIPEmbedder (line 254) | class FrozenCLIPEmbedder(AbstractEncoder):
method __init__ (line 256) | def __init__(self, version="openai/clip-vit-large-patch14", device="cu...
method freeze (line 264) | def freeze(self):
method forward (line 270) | def forward(self, text):
method encode (line 279) | def encode(self, text):
class ClipImageProjector (line 284) | class ClipImageProjector(AbstractEncoder):
method __init__ (line 288) | def __init__(self, version="openai/clip-vit-large-patch14", max_length...
method get_null_cond (line 301) | def get_null_cond(self, version, max_length):
method preprocess (line 307) | def preprocess(self, x):
method forward (line 317) | def forward(self, x):
method encode (line 327) | def encode(self, im):
class ProjectedFrozenCLIPEmbedder (line 330) | class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
method __init__ (line 331) | def __init__(self, version="openai/clip-vit-large-patch14", device="cu...
method forward (line 336) | def forward(self, text):
method encode (line 340) | def encode(self, text):
class FrozenCLIPImageEmbedder (line 343) | class FrozenCLIPImageEmbedder(AbstractEncoder):
method __init__ (line 348) | def __init__(
method preprocess (line 363) | def preprocess(self, x):
method forward (line 373) | def forward(self, x):
method encode (line 381) | def encode(self, im):
class FrozenCLIPImageMutliEmbedder (line 387) | class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
method __init__ (line 392) | def __init__(
method preprocess (line 409) | def preprocess(self, x):
method forward (line 423) | def forward(self, x):
method encode (line 440) | def encode(self, im):
class SpatialRescaler (line 443) | class SpatialRescaler(nn.Module):
method __init__ (line 444) | def __init__(self,
method forward (line 462) | def forward(self,x):
method encode (line 471) | def encode(self, x):
class LowScaleEncoder (line 479) | class LowScaleEncoder(nn.Module):
method __init__ (line 480) | def __init__(self, model_config, linear_start, linear_end, timesteps=1...
method register_schedule (line 490) | def register_schedule(self, beta_schedule="linear", timesteps=1000,
method q_sample (line 517) | def q_sample(self, x_start, t, noise=None):
method forward (line 522) | def forward(self, x):
method decode (line 532) | def decode(self, z):
FILE: ldm/modules/evaluate/adm_evaluator.py
function main (line 31) | def main():
class InvalidFIDException (line 84) | class InvalidFIDException(Exception):
class FIDStatistics (line 88) | class FIDStatistics:
method __init__ (line 89) | def __init__(self, mu: np.ndarray, sigma: np.ndarray):
method frechet_distance (line 93) | def frechet_distance(self, other, eps=1e-6):
class Evaluator (line 139) | class Evaluator:
method __init__ (line 140) | def __init__(
method warmup (line 156) | def warmup(self):
method read_activations (line 159) | def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndar...
method compute_activations (line 163) | def compute_activations(self, batches: Iterable[np.ndarray],silent=Fal...
method read_statistics (line 186) | def read_statistics(
method compute_statistics (line 196) | def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
method compute_inception_score (line 201) | def compute_inception_score(self, activations: np.ndarray, split_size:...
method compute_prec_recall (line 216) | def compute_prec_recall(
class ManifoldEstimator (line 227) | class ManifoldEstimator:
method __init__ (line 234) | def __init__(
method warmup (line 263) | def warmup(self):
method manifold_radii (line 270) | def manifold_radii(self, features: np.ndarray) -> np.ndarray:
method evaluate (line 305) | def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_featu...
method evaluate_pr (line 347) | def evaluate_pr(
class DistanceBlock (line 384) | class DistanceBlock:
method __init__ (line 391) | def __init__(self, session):
method pairwise_distances (line 415) | def pairwise_distances(self, U, V):
method less_thans (line 424) | def less_thans(self, batch_1, radii_1, batch_2, radii_2):
function _batch_pairwise_distances (line 436) | def _batch_pairwise_distances(U, V):
class NpzArrayReader (line 455) | class NpzArrayReader(ABC):
method read_batch (line 457) | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
method remaining (line 461) | def remaining(self) -> int:
method read_batches (line 464) | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
class BatchIterator (line 477) | class BatchIterator:
method __init__ (line 478) | def __init__(self, gen_fn, length):
method __len__ (line 482) | def __len__(self):
method __iter__ (line 485) | def __iter__(self):
class StreamingNpzArrayReader (line 489) | class StreamingNpzArrayReader(NpzArrayReader):
method __init__ (line 490) | def __init__(self, arr_f, shape, dtype):
method read_batch (line 496) | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
method remaining (line 511) | def remaining(self) -> int:
class MemoryNpzArrayReader (line 515) | class MemoryNpzArrayReader(NpzArrayReader):
method __init__ (line 516) | def __init__(self, arr):
method load (line 521) | def load(cls, path: str, arr_name: str):
method read_batch (line 526) | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
method remaining (line 534) | def remaining(self) -> int:
function open_npz_array (line 539) | def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
function _read_bytes (line 556) | def _read_bytes(fp, size, error_template="ran out of data"):
function _open_npy_file (line 586) | def _open_npy_file(path: str, arr_name: str):
function _download_inception_model (line 595) | def _download_inception_model():
function _create_feature_graph (line 608) | def _create_feature_graph(input_batch):
function _create_softmax_graph (line 625) | def _create_softmax_graph(input_batch):
function _update_shapes (line 639) | def _update_shapes(pool3):
function _numpy_partition (line 658) | def _numpy_partition(arr, kth, **kwargs):
FILE: ldm/modules/evaluate/evaluate_perceptualsim.py
function normalize_tensor (line 18) | def normalize_tensor(in_feat, eps=1e-10):
function cos_sim (line 25) | def cos_sim(in0, in1):
class squeezenet (line 40) | class squeezenet(torch.nn.Module):
method __init__ (line 41) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 72) | def forward(self, X):
class alexnet (line 98) | class alexnet(torch.nn.Module):
method __init__ (line 99) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 124) | def forward(self, X):
class vgg16 (line 143) | class vgg16(torch.nn.Module):
method __init__ (line 144) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 167) | def forward(self, X):
class resnet (line 187) | class resnet(torch.nn.Module):
method __init__ (line 188) | def __init__(self, requires_grad=False, pretrained=True, num=18):
method forward (line 211) | def forward(self, X):
class PNet (line 234) | class PNet(torch.nn.Module):
method __init__ (line 237) | def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True):
method forward (line 272) | def forward(self, in0, in1, retPerLayer=False):
function ssim_metric (line 299) | def ssim_metric(img1, img2, mask=None):
function psnr (line 304) | def psnr(img1, img2, mask=None,reshape=False):
function perceptual_sim (line 328) | def perceptual_sim(img1, img2, vgg16):
function load_img (line 334) | def load_img(img_name, size=None):
function compute_perceptual_similarity (line 353) | def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_...
function compute_perceptual_similarity_from_list (line 416) | def compute_perceptual_similarity_from_list(pred_imgs_list, tgt_imgs_list,
function compute_perceptual_similarity_from_list_topk (line 502) | def compute_perceptual_similarity_from_list_topk(pred_imgs_list, tgt_img...
FILE: ldm/modules/evaluate/frechet_video_distance.py
function preprocess (line 35) | def preprocess(videos, target_resolution):
function _is_in_graph (line 57) | def _is_in_graph(tensor_name):
function create_id3_embedding (line 66) | def create_id3_embedding(videos,warmup=False,batch_size=16):
function calculate_fvd (line 135) | def calculate_fvd(real_activations,
FILE: ldm/modules/evaluate/ssim.py
function gaussian (line 12) | def gaussian(window_size, sigma):
function create_window (line 22) | def create_window(window_size, channel):
function _ssim (line 31) | def _ssim(
class SSIM (line 79) | class SSIM(torch.nn.Module):
method __init__ (line 80) | def __init__(self, window_size=11, size_average=True):
method forward (line 87) | def forward(self, img1, img2, mask=None):
function ssim (line 116) | def ssim(img1, img2, window_size=11, mask=None, size_average=True):
FILE: ldm/modules/evaluate/torch_frechet_video_distance.py
function compute_frechet_distance (line 25) | def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) ->...
function compute_stats (line 34) | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
function open_url (line 41) | def open_url(url: str, num_attempts: int = 10, verbose: bool = True, ret...
function load_video (line 114) | def load_video(ip):
function get_data_from_str (line 119) | def get_data_from_str(input_str,nprc = None):
function get_stats (line 142) | def get_stats(stats):
function compute_fvd (line 155) | def compute_fvd(ref_input, sample_input, bs=32,
function compute_statistics (line 199) | def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=...
FILE: ldm/modules/image_degradation/bsrgan.py
function modcrop_np (line 29) | def modcrop_np(img, sf):
function analytic_kernel (line 49) | def analytic_kernel(k):
function anisotropic_Gaussian (line 65) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
function gm_blur_kernel (line 86) | def gm_blur_kernel(mean, cov, size=15):
function shift_pixel (line 99) | def shift_pixel(x, sf, upper_left=True):
function blur (line 128) | def blur(x, k):
function gen_kernel (line 145) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
function fspecial_gaussian (line 187) | def fspecial_gaussian(hsize, sigma):
function fspecial_laplacian (line 201) | def fspecial_laplacian(alpha):
function fspecial (line 210) | def fspecial(filter_type, *args, **kwargs):
function bicubic_degradation (line 228) | def bicubic_degradation(x, sf=3):
function srmd_degradation (line 240) | def srmd_degradation(x, k, sf=3):
function dpsr_degradation (line 262) | def dpsr_degradation(x, k, sf=3):
function classical_degradation (line 284) | def classical_degradation(x, k, sf=3):
function add_sharpening (line 299) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
function add_blur (line 325) | def add_blur(img, sf=4):
function add_resize (line 339) | def add_resize(img, sf=4):
function add_Gaussian_noise (line 369) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
function add_speckle_noise (line 386) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
function add_Poisson_noise (line 404) | def add_Poisson_noise(img):
function add_JPEG_noise (line 418) | def add_JPEG_noise(img):
function random_crop (line 427) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
function degradation_bsrgan (line 438) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
function degradation_bsrgan_variant (line 530) | def degradation_bsrgan_variant(image, sf=4, isp_model=None):
function degradation_bsrgan_plus (line 617) | def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True,...
FILE: ldm/modules/image_degradation/bsrgan_light.py
function modcrop_np (line 29) | def modcrop_np(img, sf):
function analytic_kernel (line 49) | def analytic_kernel(k):
function anisotropic_Gaussian (line 65) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
function gm_blur_kernel (line 86) | def gm_blur_kernel(mean, cov, size=15):
function shift_pixel (line 99) | def shift_pixel(x, sf, upper_left=True):
function blur (line 128) | def blur(x, k):
function gen_kernel (line 145) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
function fspecial_gaussian (line 187) | def fspecial_gaussian(hsize, sigma):
function fspecial_laplacian (line 201) | def fspecial_laplacian(alpha):
function fspecial (line 210) | def fspecial(filter_type, *args, **kwargs):
function bicubic_degradation (line 228) | def bicubic_degradation(x, sf=3):
function srmd_degradation (line 240) | def srmd_degradation(x, k, sf=3):
function dpsr_degradation (line 262) | def dpsr_degradation(x, k, sf=3):
function classical_degradation (line 284) | def classical_degradation(x, k, sf=3):
function add_sharpening (line 299) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
function add_blur (line 325) | def add_blur(img, sf=4):
function add_resize (line 343) | def add_resize(img, sf=4):
function add_Gaussian_noise (line 373) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
function add_speckle_noise (line 390) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
function add_Poisson_noise (line 408) | def add_Poisson_noise(img):
function add_JPEG_noise (line 422) | def add_JPEG_noise(img):
function random_crop (line 431) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
function degradation_bsrgan (line 442) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
function degradation_bsrgan_variant (line 534) | def degradation_bsrgan_variant(image, sf=4, isp_model=None):
FILE: ldm/modules/image_degradation/utils_image.py
function is_image_file (line 29) | def is_image_file(filename):
function get_timestamp (line 33) | def get_timestamp():
function imshow (line 37) | def imshow(x, title=None, cbar=False, figsize=None):
function surf (line 47) | def surf(Z, cmap='rainbow', figsize=None):
function get_image_paths (line 67) | def get_image_paths(dataroot):
function _get_paths_from_images (line 74) | def _get_paths_from_images(path):
function patches_from_image (line 93) | def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
function imssave (line 112) | def imssave(imgs, img_path):
function split_imageset (line 125) | def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_si...
function mkdir (line 153) | def mkdir(path):
function mkdirs (line 158) | def mkdirs(paths):
function mkdir_and_rename (line 166) | def mkdir_and_rename(path):
function imread_uint (line 185) | def imread_uint(path, n_channels=3):
function imsave (line 203) | def imsave(img, img_path):
function imwrite (line 209) | def imwrite(img, img_path):
function read_img (line 220) | def read_img(path):
function uint2single (line 249) | def uint2single(img):
function single2uint (line 254) | def single2uint(img):
function uint162single (line 259) | def uint162single(img):
function single2uint16 (line 264) | def single2uint16(img):
function uint2tensor4 (line 275) | def uint2tensor4(img):
function uint2tensor3 (line 282) | def uint2tensor3(img):
function tensor2uint (line 289) | def tensor2uint(img):
function single2tensor3 (line 302) | def single2tensor3(img):
function single2tensor4 (line 307) | def single2tensor4(img):
function tensor2single (line 312) | def tensor2single(img):
function tensor2single3 (line 320) | def tensor2single3(img):
function single2tensor5 (line 329) | def single2tensor5(img):
function single32tensor5 (line 333) | def single32tensor5(img):
function single42tensor4 (line 337) | def single42tensor4(img):
function tensor2img (line 342) | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
function augment_img (line 380) | def augment_img(img, mode=0):
function augment_img_tensor4 (line 401) | def augment_img_tensor4(img, mode=0):
function augment_img_tensor (line 422) | def augment_img_tensor(img, mode=0):
function augment_img_np3 (line 441) | def augment_img_np3(img, mode=0):
function augment_imgs (line 469) | def augment_imgs(img_list, hflip=True, rot=True):
function modcrop (line 494) | def modcrop(img_in, scale):
function shave (line 510) | def shave(img_in, border=0):
function rgb2ycbcr (line 529) | def rgb2ycbcr(img, only_y=True):
function ycbcr2rgb (line 553) | def ycbcr2rgb(img):
function bgr2ycbcr (line 573) | def bgr2ycbcr(img, only_y=True):
function channel_convert (line 597) | def channel_convert(in_c, tar_type, img_list):
function calculate_psnr (line 621) | def calculate_psnr(img1, img2, border=0):
function calculate_ssim (line 642) | def calculate_ssim(img1, img2, border=0):
function ssim (line 669) | def ssim(img1, img2):
function cubic (line 700) | def cubic(x):
function calculate_weights_indices (line 708) | def calculate_weights_indices(in_length, out_length, scale, kernel, kern...
function imresize (line 766) | def imresize(img, scale, antialiasing=True):
function imresize_np (line 839) | def imresize_np(img, scale, antialiasing=True):
FILE: ldm/modules/losses/contperceptual.py
class LPIPSWithDiscriminator (line 7) | class LPIPSWithDiscriminator(nn.Module):
method __init__ (line 8) | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixello...
method calculate_adaptive_weight (line 32) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
method forward (line 45) | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
FILE: ldm/modules/losses/vqperceptual.py
function hinge_d_loss_with_exemplar_weights (line 11) | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
function adopt_weight (line 20) | def adopt_weight(weight, global_step, threshold=0, value=0.):
function measure_perplexity (line 26) | def measure_perplexity(predicted_indices, n_embed):
function l1 (line 35) | def l1(x, y):
function l2 (line 39) | def l2(x, y):
class VQLPIPSWithDiscriminator (line 43) | class VQLPIPSWithDiscriminator(nn.Module):
method __init__ (line 44) | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
method calculate_adaptive_weight (line 85) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
method forward (line 98) | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
FILE: ldm/modules/x_transformer.py
class AbsolutePositionalEmbedding (line 25) | class AbsolutePositionalEmbedding(nn.Module):
method __init__ (line 26) | def __init__(self, dim, max_seq_len):
method init_ (line 31) | def init_(self):
method forward (line 34) | def forward(self, x):
class FixedPositionalEmbedding (line 39) | class FixedPositionalEmbedding(nn.Module):
method __init__ (line 40) | def __init__(self, dim):
method forward (line 45) | def forward(self, x, seq_dim=1, offset=0):
function exists (line 54) | def exists(val):
function default (line 58) | def default(val, d):
function always (line 64) | def always(val):
function not_equals (line 70) | def not_equals(val):
function equals (line 76) | def equals(val):
function max_neg_value (line 82) | def max_neg_value(tensor):
function pick_and_pop (line 88) | def pick_and_pop(keys, d):
function group_dict_by_key (line 93) | def group_dict_by_key(cond, d):
function string_begins_with (line 102) | def string_begins_with(prefix, str):
function group_by_key_prefix (line 106) | def group_by_key_prefix(prefix, d):
function groupby_prefix_and_trim (line 110) | def groupby_prefix_and_trim(prefix, d):
class Scale (line 117) | class Scale(nn.Module):
method __init__ (line 118) | def __init__(self, value, fn):
method forward (line 123) | def forward(self, x, **kwargs):
class Rezero (line 128) | class Rezero(nn.Module):
method __init__ (line 129) | def __init__(self, fn):
method forward (line 134) | def forward(self, x, **kwargs):
class ScaleNorm (line 139) | class ScaleNorm(nn.Module):
method __init__ (line 140) | def __init__(self, dim, eps=1e-5):
method forward (line 146) | def forward(self, x):
class RMSNorm (line 151) | class RMSNorm(nn.Module):
method __init__ (line 152) | def __init__(self, dim, eps=1e-8):
method forward (line 158) | def forward(self, x):
class Residual (line 163) | class Residual(nn.Module):
method forward (line 164) | def forward(self, x, residual):
class GRUGating (line 168) | class GRUGating(nn.Module):
method __init__ (line 169) | def __init__(self, dim):
method forward (line 173) | def forward(self, x, residual):
class GEGLU (line 184) | class GEGLU(nn.Module):
method __init__ (line 185) | def __init__(self, dim_in, dim_out):
method forward (line 189) | def forward(self, x):
class FeedForward (line 194) | class FeedForward(nn.Module):
method __init__ (line 195) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 210) | def forward(self, x):
class Attention (line 215) | class Attention(nn.Module):
method __init__ (line 216) | def __init__(
method forward (line 268) | def forward(
class AttentionLayers (line 370) | class AttentionLayers(nn.Module):
method __init__ (line 371) | def __init__(
method forward (line 481) | def forward(
class Encoder (line 541) | class Encoder(AttentionLayers):
method __init__ (line 542) | def __init__(self, **kwargs):
class TransformerWrapper (line 548) | class TransformerWrapper(nn.Module):
method __init__ (line 549) | def __init__(
method init_ (line 595) | def init_(self):
method forward (line 598) | def forward(
FILE: ldm/thirdp/psp/helpers.py
class Flatten (line 12) | class Flatten(Module):
method forward (line 13) | def forward(self, input):
function l2_norm (line 17) | def l2_norm(input, axis=1):
class Bottleneck (line 23) | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
function get_block (line 27) | def get_block(in_channel, depth, num_units, stride=2):
function get_blocks (line 31) | def get_blocks(num_layers):
class SEModule (line 58) | class SEModule(Module):
method __init__ (line 59) | def __init__(self, channels, reduction):
method forward (line 67) | def forward(self, x):
class bottleneck_IR (line 77) | class bottleneck_IR(Module):
method __init__ (line 78) | def __init__(self, in_channel, depth, stride):
method forward (line 93) | def forward(self, x):
class bottleneck_IR_SE (line 99) | class bottleneck_IR_SE(Module):
method __init__ (line 100) | def __init__(self, in_channel, depth, stride):
method forward (line 118) | def forward(self, x):
FILE: ldm/thirdp/psp/id_loss.py
class IDFeatures (line 7) | class IDFeatures(nn.Module):
method __init__ (line 8) | def __init__(self, model_path):
method forward (line 16) | def forward(self, x, crop=False):
FILE: ldm/thirdp/psp/model_irse.py
class Backbone (line 11) | class Backbone(Module):
method __init__ (line 12) | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, ...
method forward (line 46) | def forward(self, x):
function IR_50 (line 53) | def IR_50(input_size):
function IR_101 (line 59) | def IR_101(input_size):
function IR_152 (line 65) | def IR_152(input_size):
function IR_SE_50 (line 71) | def IR_SE_50(input_size):
function IR_SE_101 (line 77) | def IR_SE_101(input_size):
function IR_SE_152 (line 83) | def IR_SE_152(input_size):
FILE: ldm/util.py
function pil_rectangle_crop (line 21) | def pil_rectangle_crop(im):
function log_txt_as_img (line 41) | def log_txt_as_img(wh, xc, size=10):
function ismap (line 65) | def ismap(x):
function isimage (line 71) | def isimage(x):
function exists (line 77) | def exists(x):
function default (line 81) | def default(val, d):
function mean_flat (line 87) | def mean_flat(tensor):
function count_params (line 95) | def count_params(model, verbose=False):
function instantiate_from_config (line 102) | def instantiate_from_config(config):
function get_obj_from_str (line 112) | def get_obj_from_str(string, reload=False):
class AdamWwithEMAandWings (line 120) | class AdamWwithEMAandWings(optim.Optimizer):
method __init__ (line 122) | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, #...
method __setstate__ (line 143) | def __setstate__(self, state):
method step (line 149) | def step(self, closure=None):
FILE: main.py
class LoadFromFile (line 13) | class LoadFromFile (argparse.Action):
method __call__ (line 14) | def __call__ (self, parser, namespace, values, option_string = None):
FILE: meshutils.py
function poisson_mesh_reconstruction (line 4) | def poisson_mesh_reconstruction(points, normals=None):
function decimate_mesh (line 39) | def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=Fals...
function clean_mesh (line 75) | def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, rem...
FILE: nerf/gui.py
class OrbitCamera (line 10) | class OrbitCamera:
method __init__ (line 11) | def __init__(self, W, H, r=2, fovy=60):
method pose (line 24) | def pose(self):
method intrinsics (line 38) | def intrinsics(self):
method mvp (line 43) | def mvp(self):
method orbit (line 54) | def orbit(self, dx, dy):
method scale (line 61) | def scale(self, delta):
method pan (line 64) | def pan(self, dx, dy, dz=0):
class NeRFGUI (line 69) | class NeRFGUI:
method __init__ (line 70) | def __init__(self, opt, trainer, loader=None, debug=True):
method __del__ (line 99) | def __del__(self):
method train_step (line 103) | def train_step(self):
method prepare_buffer (line 128) | def prepare_buffer(self, outputs):
method test_step (line 137) | def test_step(self):
method register_dpg (line 172) | def register_dpg(self):
method render (line 478) | def render(self):
FILE: nerf/network.py
class ResBlock (line 14) | class ResBlock(nn.Module):
method __init__ (line 15) | def __init__(self, dim_in, dim_out, bias=True):
method forward (line 29) | def forward(self, x):
class BasicBlock (line 44) | class BasicBlock(nn.Module):
method __init__ (line 45) | def __init__(self, dim_in, dim_out, bias=True):
method forward (line 53) | def forward(self, x):
class MLP (line 61) | class MLP(nn.Module):
method __init__ (line 62) | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True,...
method forward (line 81) | def forward(self, x):
class NeRFNetwork (line 89) | class NeRFNetwork(NeRFRenderer):
method __init__ (line 90) | def __init__(self,
method common_forward (line 118) | def common_forward(self, x):
method finite_difference_normal (line 132) | def finite_difference_normal(self, x, epsilon=1e-2):
method normal (line 149) | def normal(self, x):
method forward (line 164) | def forward(self, x, d, l=None, ratio=1, shading='albedo'):
method density (line 203) | def density(self, x):
method background (line 214) | def background(self, d):
method get_params (line 226) | def get_params(self, lr):
FILE: nerf/network_grid.py
class MLP (line 13) | class MLP(nn.Module):
method __init__ (line 14) | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
method forward (line 27) | def forward(self, x):
class NeRFNetwork (line 35) | class NeRFNetwork(NeRFRenderer):
method __init__ (line 36) | def __init__(self,
method common_forward (line 68) | def common_forward(self, x):
method finite_difference_normal (line 81) | def finite_difference_normal(self, x, epsilon=1e-2):
method normal (line 98) | def normal(self, x):
method forward (line 104) | def forward(self, x, d, l=None, ratio=1, shading='albedo'):
method density (line 133) | def density(self, x):
method background (line 144) | def background(self, d):
method get_params (line 156) | def get_params(self, lr):
FILE: nerf/network_grid_taichi.py
class MLP (line 13) | class MLP(nn.Module):
method __init__ (line 14) | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
method forward (line 27) | def forward(self, x):
class NeRFNetwork (line 35) | class NeRFNetwork(NeRFRenderer):
method __init__ (line 36) | def __init__(self,
method common_forward (line 67) | def common_forward(self, x):
method finite_difference_normal (line 80) | def finite_difference_normal(self, x, epsilon=1e-2):
method normal (line 97) | def normal(self, x):
method forward (line 103) | def forward(self, x, d, l=None, ratio=1, shading='albedo'):
method density (line 131) | def density(self, x):
method background (line 142) | def background(self, d):
method get_params (line 154) | def get_params(self, lr):
FILE: nerf/network_grid_tcnn.py
class MLP (line 15) | class MLP(nn.Module):
method __init__ (line 16) | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
method forward (line 29) | def forward(self, x):
class NeRFNetwork (line 37) | class NeRFNetwork(NeRFRenderer):
method __init__ (line 38) | def __init__(self,
method common_forward (line 82) | def common_forward(self, x):
method normal (line 93) | def normal(self, x):
method forward (line 108) | def forward(self, x, d, l=None, ratio=1, shading='albedo'):
method density (line 141) | def density(self, x):
method background (line 152) | def background(self, d):
method get_params (line 164) | def get_params(self, lr):
FILE: nerf/provider.py
function visualize_poses (line 27) | def visualize_poses(poses, dirs, size=0.1):
function get_view_direction (line 52) | def get_view_direction(thetas, phis, overhead, front):
function rand_poses (line 73) | def rand_poses(size, device, opt, radius_range=[1, 1.5], theta_range=[0,...
function circle_poses (line 152) | def circle_poses(device, radius=torch.tensor([3.2]), theta=torch.tensor(...
class NeRFDataset (line 183) | class NeRFDataset:
method __init__ (line 184) | def __init__(self, opt, device, type='train', H=256, W=256, size=100):
method get_default_view_data (line 207) | def get_default_view_data(self):
method collate (line 248) | def collate(self, index):
method dataloader (line 316) | def dataloader(self, batch_size=None):
FILE: nerf/renderer.py
function sample_pdf (line 19) | def sample_pdf(bins, weights, n_samples, det=False):
function near_far_from_bound (line 56) | def near_far_from_bound(rays_o, rays_d, bound, type='cube', min_near=0.05):
function plot_pointcloud (line 82) | def plot_pointcloud(pc, color=None):
class DMTet (line 94) | class DMTet():
method __init__ (line 95) | def __init__(self, device):
method sort_edges (line 118) | def sort_edges(self, edges_ex2):
method __call__ (line 128) | def __call__(self, pos_nx3, sdf_n, tet_fx4):
function compute_edge_to_face_mapping (line 176) | def compute_edge_to_face_mapping(attr_idx):
function normal_consistency (line 209) | def normal_consistency(face_normals, t_pos_idx):
function laplacian_uniform (line 224) | def laplacian_uniform(verts, faces):
function laplacian_smooth_loss (line 248) | def laplacian_smooth_loss(verts, faces):
class NeRFRenderer (line 257) | class NeRFRenderer(nn.Module):
method __init__ (line 258) | def __init__(self, opt):
method density_blob (line 339) | def density_blob(self, x):
method forward (line 351) | def forward(self, x, d):
method density (line 354) | def density(self, x):
method reset_extra_state (line 357) | def reset_extra_state(self):
method export_mesh (line 366) | def export_mesh(self, path, resolution=None, decimate_target=-1, S=128):
method run (line 560) | def run(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading...
method run_cuda (line 710) | def run_cuda(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, sh...
method init_tet (line 818) | def init_tet(self, mesh=None):
method run_dmtet (line 862) | def run_dmtet(self, rays_o, rays_d, mvp, h, w, light_d=None, ambient_r...
method run_taichi (line 966) | def run_taichi(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, ...
method update_extra_state (line 1103) | def update_extra_state(self, decay=0.95, S=128):
method render (line 1154) | def render(self, rays_o, rays_d, mvp, h, w, staged=False, max_ray_batc...
FILE: nerf/utils.py
function adjust_text_embeddings (line 34) | def adjust_text_embeddings(embeddings, azimuth, opt):
function get_pos_neg_text_embeddings (line 60) | def get_pos_neg_text_embeddings(embeddings, azimuth_val, opt):
function custom_meshgrid (line 102) | def custom_meshgrid(*args):
function safe_normalize (line 109) | def safe_normalize(x, eps=1e-20):
function get_rays (line 113) | def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
function seed_everything (line 179) | def seed_everything(seed):
function linear_to_srgb (line 190) | def linear_to_srgb(x):
function srgb_to_linear (line 195) | def srgb_to_linear(x):
class Trainer (line 199) | class Trainer(object):
method __init__ (line 200) | def __init__(self,
method prepare_embeddings (line 353) | def prepare_embeddings(self):
method __del__ (line 423) | def __del__(self):
method log (line 428) | def log(self, *args, **kwargs):
method train_step (line 439) | def train_step(self, data, save_guidance_path:Path=None):
method post_train_step (line 725) | def post_train_step(self):
method eval_step (line 743) | def eval_step(self, data):
method test_step (line 765) | def test_step(self, data, bg_color=None, perturb=False):
method save_mesh (line 787) | def save_mesh(self, loader=None, save_path=None):
method train (line 802) | def train(self, train_loader, valid_loader, test_loader, max_epochs):
method evaluate (line 833) | def evaluate(self, loader, name=None):
method test (line 838) | def test(self, loader, save_path=None, name=None, write_video=True):
method train_gui (line 890) | def train_gui(self, train_loader, step=16):
method test_gui (line 949) | def test_gui(self, pose, intrinsics, mvp, W, H, bg_color=None, spp=1, ...
method train_one_epoch (line 1008) | def train_one_epoch(self, loader, max_epochs):
method evaluate_one_epoch (line 1115) | def evaluate_one_epoch(self, loader, name=None):
method save_checkpoint (line 1206) | def save_checkpoint(self, name=None, full=False, best=False):
method load_checkpoint (line 1266) | def load_checkpoint(self, checkpoint=None, model_only=False):
function get_CPU_mem (line 1337) | def get_CPU_mem():
function get_GPU_mem (line 1341) | def get_GPU_mem():
FILE: optimizer.py
class Adan (line 23) | class Adan(Optimizer):
method __init__ (line 47) | def __init__(self,
method __setstate__ (line 80) | def __setstate__(self, state):
method restart_opt (line 86) | def restart_opt(self):
method step (line 102) | def step(self, closure=None):
function _single_tensor_adan (line 201) | def _single_tensor_adan(
function _multi_tensor_adan (line 259) | def _multi_tensor_adan(
FILE: preprocess_image.py
class BackgroundRemoval (line 14) | class BackgroundRemoval():
method __init__ (line 15) | def __init__(self, device='cuda'):
method __call__ (line 32) | def __call__(self, image):
class BLIP2 (line 41) | class BLIP2():
method __init__ (line 42) | def __init__(self, device='cuda'):
method __call__ (line 49) | def __call__(self, image):
class DPT (line 59) | class DPT():
method __init__ (line 60) | def __init__(self, task='depth', device='cuda'):
method __call__ (line 97) | def __call__(self, image):
FILE: raymarching/backend.py
function find_cl_path (line 17) | def find_cl_path():
FILE: raymarching/raymarching.py
function get_backend (line 14) | def get_backend():
class _near_far_from_aabb (line 31) | class _near_far_from_aabb(Function):
method forward (line 34) | def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
class _sph_from_ray (line 64) | class _sph_from_ray(Function):
method forward (line 67) | def forward(ctx, rays_o, rays_d, radius):
class _morton3D (line 95) | class _morton3D(Function):
method forward (line 97) | def forward(ctx, coords):
class _morton3D_invert (line 118) | class _morton3D_invert(Function):
method forward (line 120) | def forward(ctx, indices):
class _packbits (line 141) | class _packbits(Function):
method forward (line 144) | def forward(ctx, grid, thresh, bitfield=None):
class _flatten_rays (line 170) | class _flatten_rays(Function):
method forward (line 172) | def forward(ctx, rays, M):
class _march_rays_train (line 197) | class _march_rays_train(Function):
method forward (line 200) | def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears,...
class _composite_rays_train (line 261) | class _composite_rays_train(Function):
method forward (line 264) | def forward(ctx, sigmas, rgbs, ts, rays, T_thresh=1e-4, binarize=False):
method backward (line 299) | def backward(ctx, grad_weights, grad_weights_sum, grad_depth, grad_ima...
class _march_rays (line 323) | class _march_rays(Function):
method forward (line 326) | def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, ...
class _composite_rays (line 374) | class _composite_rays(Function):
method forward (line 377) | def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts...
FILE: raymarching/setup.py
function find_cl_path (line 18) | def find_cl_path():
FILE: raymarching/src/bindings.cpp
function PYBIND11_MODULE (line 5) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: shencoder/backend.py
function find_cl_path (line 17) | def find_cl_path():
FILE: shencoder/setup.py
function find_cl_path (line 18) | def find_cl_path():
FILE: shencoder/sphere_harmonics.py
class _sh_encoder (line 14) | class _sh_encoder(Function):
method forward (line 17) | def forward(ctx, inputs, degree, calc_grad_inputs=False):
method backward (line 42) | def backward(ctx, grad):
class SHEncoder (line 61) | class SHEncoder(nn.Module):
method __init__ (line 62) | def __init__(self, input_dim=3, degree=4):
method __repr__ (line 72) | def __repr__(self):
method forward (line 75) | def forward(self, inputs, size=1):
FILE: shencoder/src/bindings.cpp
function PYBIND11_MODULE (line 5) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: taichi_modules/hash_encoder.py
function random_initialize (line 15) | def random_initialize(data: ti.types.ndarray()):
function ti_copy (line 21) | def ti_copy(data1: ti.template(), data2: ti.template()):
function ti_copy_array (line 27) | def ti_copy_array(data1: ti.types.ndarray(), data2: ti.types.ndarray()):
function ti_copy_field_array (line 33) | def ti_copy_field_array(data1: ti.template(), data2: ti.types.ndarray()):
function fast_hash (line 39) | def fast_hash(pos_grid_local):
function under_hash (line 49) | def under_hash(pos_grid_local, resolution):
function grid_pos2hash_index (line 59) | def grid_pos2hash_index(indicator, pos_grid_local, resolution, map_size):
function hash_encode_kernel (line 70) | def hash_encode_kernel(
function hash_encode_kernel_half2 (line 120) | def hash_encode_kernel_half2(
class HashEncoderTaichi (line 166) | class HashEncoderTaichi(torch.nn.Module):
method __init__ (line 168) | def __init__(self,
method zero_grad (line 300) | def zero_grad(self):
method forward (line 303) | def forward(self, positions, bound=1):
FILE: taichi_modules/intersection.py
function simple_ray_aabb_intersec_taichi_forward (line 10) | def simple_ray_aabb_intersec_taichi_forward(
class RayAABBIntersector (line 39) | class RayAABBIntersector(torch.autograd.Function):
method forward (line 60) | def forward(ctx, rays_o, rays_d, center, half_size, max_hits):
FILE: taichi_modules/ray_march.py
function raymarching_train (line 10) | def raymarching_train(rays_o: ti.types.ndarray(ndim=2),
function raymarching_train_backword (line 129) | def raymarching_train_backword(segments: ti.types.ndarray(ndim=2),
class RayMarcherTaichi (line 145) | class RayMarcherTaichi(torch.nn.Module):
method __init__ (line 147) | def __init__(self, batch_size=8192):
method forward (line 221) | def forward(self, rays_o, rays_d, hits_t, density_bitfield, cascades,
function raymarching_test_kernel (line 230) | def raymarching_test_kernel(
function raymarching_test (line 306) | def raymarching_test(rays_o, rays_d, hits_t, alive_indices, density_bitf...
FILE: taichi_modules/utils.py
function scalbn (line 18) | def scalbn(x, exponent):
function calc_dt (line 23) | def calc_dt(t, exp_step_factor, grid_size, scale):
function frexp_bit (line 29) | def frexp_bit(x):
function mip_from_pos (line 47) | def mip_from_pos(xyz, cascades):
function mip_from_dt (line 56) | def mip_from_dt(dt, grid_size, cascades):
function __expand_bits (line 64) | def __expand_bits(v):
function __morton3D (line 73) | def __morton3D(xyz):
function __morton3D_invert (line 79) | def __morton3D_invert(x):
function morton3D_invert_kernel (line 89) | def morton3D_invert_kernel(indices: ti.types.ndarray(ndim=1),
function morton3D_invert (line 98) | def morton3D_invert(indices):
function morton3D_kernel (line 109) | def morton3D_kernel(xyzs: ti.types.ndarray(ndim=2),
function morton3D (line 116) | def morton3D(coords1):
function packbits (line 126) | def packbits(density_grid: ti.types.ndarray(ndim=1),
function torch2ti (line 141) | def torch2ti(field: ti.template(), data: ti.types.ndarray()):
function ti2torch (line 147) | def ti2torch(field: ti.template(), data: ti.types.ndarray()):
function ti2torch_grad (line 153) | def ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()):
function torch2ti_grad (line 159) | def torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()):
function torch2ti_vec (line 165) | def torch2ti_vec(field: ti.template(), data: ti.types.ndarray()):
function ti2torch_vec (line 171) | def ti2torch_vec(field: ti.template(), data: ti.types.ndarray()):
function ti2torch_grad_vec (line 178) | def ti2torch_grad_vec(field: ti.template(), grad: ti.types.ndarray()):
function torch2ti_grad_vec (line 185) | def torch2ti_grad_vec(field: ti.template(), grad: ti.types.ndarray()):
function extract_model_state_dict (line 191) | def extract_model_state_dict(ckpt_path,
function load_ckpt (line 210) | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]):
function depth2img (line 219) | def depth2img(depth):
FILE: taichi_modules/volume_render_test.py
function composite_test (line 5) | def composite_test(
FILE: taichi_modules/volume_train.py
function composite_train_fw_array (line 10) | def composite_train_fw_array(
function composite_train_fw (line 52) | def composite_train_fw(sigmas: ti.template(), rgbs: ti.template(),
function check_value (line 102) | def check_value(
class VolumeRendererTaichi (line 112) | class VolumeRendererTaichi(torch.nn.Module):
method __init__ (line 114) | def __init__(self, batch_size=8192, data_type=data_type):
method zero_grad (line 231) | def zero_grad(self):
method forward (line 237) | def forward(self, sigmas, rgbs, deltas, ts, rays_a, T_threshold):
FILE: tets/generate_tets.py
function generate_tetrahedron_grid_file (line 21) | def generate_tetrahedron_grid_file(res=32, root='..'):
function convert_from_quartet_to_npz (line 31) | def convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile...
Condensed preview — 132 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,079K chars).
[
{
"path": ".github/ISSUE_TEMPLATE/bug_report.yaml",
"chars": 1299,
"preview": "name: Bug Report\ndescription: File a bug report\ntitle: \"<title>\"\nlabels: [\"bug\"]\nbody:\n - type: markdown\n attributes"
},
{
"path": ".github/ISSUE_TEMPLATE/feature_request.md",
"chars": 604,
"preview": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: enhancement\nassignees: ''\n\n---\n\n**Is"
},
{
"path": ".gitignore",
"chars": 497,
"preview": "__pycache__/\nbuild/\n*.egg-info/\n*.so\nvenv_*/\n\ntmp*\n# data/\nldm/data/\ndata2\nscripts2\ntrial*/\n.vs/\n\nTOKEN\n*.ckpt\n\ndensegri"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "activation.py",
"chars": 526,
"preview": "import torch\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd \n\nclass _trunc_exp(Fu"
},
{
"path": "assets/advanced.md",
"chars": 3780,
"preview": "\n# Code organization & Advanced tips\n\nThis is a simple description of the most important implementation details.\nIf you "
},
{
"path": "assets/update_logs.md",
"chars": 1824,
"preview": "### 2023.4.19\n* Fix depth supervision, migrate depth estimation model to omnidata.\n* Add normal supervision (also by omn"
},
{
"path": "config/anya.csv",
"chars": 122,
"preview": "zero123_weight, radius, polar, azimuth, image\n1, 3, 90, 0, data/anya_front_rgba.png\n1, 3, 90, 180, data/anya_back_rgba.p"
},
{
"path": "config/car.csv",
"chars": 200,
"preview": "zero123_weight, radius, polar, azimuth, image\n4, 3.2, 90, 0, data/car_left_rgba.png\n1, 3, 90, 90, data/car_front_rgba.pn"
},
{
"path": "config/corgi.csv",
"chars": 105,
"preview": "zero123_weight, radius, polar, azimuth, image\n1, 3.2, 90, 0, data/corgi_puppy_sitting_looking_up_rgba.png"
},
{
"path": "docker/Dockerfile",
"chars": 1300,
"preview": "FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04\n\n# Remove any third-party apt sources to avoid issues with expiring key"
},
{
"path": "docker/README.md",
"chars": 2231,
"preview": "### Docker installation\n\n## Build image\nTo build the docker image on your own machine, which may take 15-30 mins:\n```\ndo"
},
{
"path": "dpt.py",
"chars": 27507,
"preview": "import math\nimport types\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport timm\n\nclass BaseMod"
},
{
"path": "encoding.py",
"chars": 3556,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass FreqEncoder_torch(nn.Module):\n def __init__"
},
{
"path": "evaluation/Prompt.py",
"chars": 3294,
"preview": "import textwrap\nfrom transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification\nfrom tran"
},
{
"path": "evaluation/mesh_to_video.py",
"chars": 2683,
"preview": "import os\nimport numpy as np\nimport trimesh\nimport argparse\nfrom pathlib import Path\nfrom tqdm import tqdm\nimport pyvist"
},
{
"path": "evaluation/r_precision.py",
"chars": 1166,
"preview": "from sentence_transformers import SentenceTransformer, util\nfrom PIL import Image\nimport argparse\nimport sys\n\n\nif __name"
},
{
"path": "evaluation/readme.md",
"chars": 1714,
"preview": "### Improvement:\n\n- Usage\n\n - r_precision.py <br>\n For prompt seperation <br>\n --text is for the prompt following the"
},
{
"path": "freqencoder/__init__.py",
"chars": 29,
"preview": "from .freq import FreqEncoder"
},
{
"path": "freqencoder/backend.py",
"chars": 1475,
"preview": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags"
},
{
"path": "freqencoder/freq.py",
"chars": 2232,
"preview": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function "
},
{
"path": "freqencoder/setup.py",
"chars": 1740,
"preview": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = "
},
{
"path": "freqencoder/src/bindings.cpp",
"chars": 275,
"preview": "#include <torch/extension.h>\n\n#include \"freqencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n m.def(\"freq_encod"
},
{
"path": "freqencoder/src/freqencoder.cu",
"chars": 3738,
"preview": "#include <stdint.h>\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext"
},
{
"path": "freqencoder/src/freqencoder.h",
"chars": 540,
"preview": "# pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// _backend.freq_encode_forward(inputs, B, input_dim, degre"
},
{
"path": "gridencoder/__init__.py",
"chars": 29,
"preview": "from .grid import GridEncoder"
},
{
"path": "gridencoder/backend.py",
"chars": 1454,
"preview": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags"
},
{
"path": "gridencoder/grid.py",
"chars": 8804,
"preview": "import math\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autogr"
},
{
"path": "gridencoder/setup.py",
"chars": 1719,
"preview": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = "
},
{
"path": "gridencoder/src/bindings.cpp",
"chars": 444,
"preview": "#include <torch/extension.h>\n\n#include \"gridencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n m.def(\"grid_encod"
},
{
"path": "gridencoder/src/gridencoder.cu",
"chars": 30774,
"preview": "#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/t"
},
{
"path": "gridencoder/src/gridencoder.h",
"chars": 1521,
"preview": "#ifndef _HASH_ENCODE_H\n#define _HASH_ENCODE_H\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// inputs: [B, D], float, i"
},
{
"path": "guidance/clip_utils.py",
"chars": 1761,
"preview": "import torch\nimport torch.nn as nn\n\nimport torchvision.transforms as T\nimport torchvision.transforms.functional as TF\n\ni"
},
{
"path": "guidance/if_utils.py",
"chars": 9323,
"preview": "from transformers import logging\nfrom diffusers import IFPipeline, DDPMScheduler\n\n# suppress partial model loading warni"
},
{
"path": "guidance/perpneg_utils.py",
"chars": 2058,
"preview": "import torch\n\n# Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm\ndef get_perpen"
},
{
"path": "guidance/sd_utils.py",
"chars": 15279,
"preview": "from transformers import CLIPTextModel, CLIPTokenizer, logging\nfrom diffusers import AutoencoderKL, UNet2DConditionModel"
},
{
"path": "guidance/zero123_utils.py",
"chars": 14204,
"preview": "import math\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom pathlib import Path\n\nimport torch\nimport torch.nn as"
},
{
"path": "ldm/extras.py",
"chars": 2560,
"preview": "from pathlib import Path\nfrom omegaconf import OmegaConf\nimport torch\nfrom ldm.util import instantiate_from_config\nimpor"
},
{
"path": "ldm/guidance.py",
"chars": 3322,
"preview": "from typing import List, Tuple\nfrom scipy import interpolate\nimport numpy as np\nimport torch\nimport matplotlib.pyplot as"
},
{
"path": "ldm/lr_scheduler.py",
"chars": 3882,
"preview": "import numpy as np\n\n\nclass LambdaWarmUpCosineScheduler:\n \"\"\"\n note: use with a base_lr of 1.0\n \"\"\"\n def __in"
},
{
"path": "ldm/models/autoencoder.py",
"chars": 17619,
"preview": "import torch\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom contextlib import contextmanager\n\nfrom "
},
{
"path": "ldm/models/diffusion/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/models/diffusion/classifier.py",
"chars": 10276,
"preview": "import os\nimport torch\nimport pytorch_lightning as pl\nfrom omegaconf import OmegaConf\nfrom torch.nn import functional as"
},
{
"path": "ldm/models/diffusion/ddim.py",
"chars": 16719,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\nfrom einops im"
},
{
"path": "ldm/models/diffusion/ddpm.py",
"chars": 96160,
"preview": "\"\"\"\nwild mixture of\nhttps://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e316"
},
{
"path": "ldm/models/diffusion/plms.py",
"chars": 13647,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\n\nfrom ldm.modu"
},
{
"path": "ldm/models/diffusion/sampling_util.py",
"chars": 1623,
"preview": "import torch\nimport numpy as np\n\n\ndef append_dims(x, target_dims):\n \"\"\"Appends dimensions to the end of a tensor unti"
},
{
"path": "ldm/modules/attention.py",
"chars": 8945,
"preview": "from inspect import isfunction\nimport math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfro"
},
{
"path": "ldm/modules/diffusionmodules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/modules/diffusionmodules/model.py",
"chars": 33409,
"preview": "# pytorch_diffusion + derived encoder decoder\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom ein"
},
{
"path": "ldm/modules/diffusionmodules/openaimodel.py",
"chars": 37219,
"preview": "from abc import abstractmethod\nfrom functools import partial\nimport math\nfrom typing import Iterable\n\nimport numpy as np"
},
{
"path": "ldm/modules/diffusionmodules/util.py",
"chars": 9561,
"preview": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n#"
},
{
"path": "ldm/modules/distributions/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/modules/distributions/distributions.py",
"chars": 2970,
"preview": "import torch\nimport numpy as np\n\n\nclass AbstractDistribution:\n def sample(self):\n raise NotImplementedError()\n"
},
{
"path": "ldm/modules/ema.py",
"chars": 2982,
"preview": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n def __init__(self, model, decay=0.9999, use_num_upates="
},
{
"path": "ldm/modules/encoders/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/modules/encoders/modules.py",
"chars": 20838,
"preview": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom functools import partial\nimport kornia\n\nfrom ldm.modules.x_tr"
},
{
"path": "ldm/modules/evaluate/adm_evaluator.py",
"chars": 25770,
"preview": "import argparse\nimport io\nimport os\nimport random\nimport warnings\nimport zipfile\nfrom abc import ABC, abstractmethod\nfro"
},
{
"path": "ldm/modules/evaluate/evaluate_perceptualsim.py",
"chars": 20269,
"preview": "import argparse\nimport glob\nimport os\nfrom tqdm import tqdm\nfrom collections import namedtuple\n\nimport numpy as np\nimpor"
},
{
"path": "ldm/modules/evaluate/frechet_video_distance.py",
"chars": 5241,
"preview": "# coding=utf-8\n# Copyright 2022 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
},
{
"path": "ldm/modules/evaluate/ssim.py",
"chars": 3367,
"preview": "# MIT Licence\n\n# Methods to predict the SSIM, taken from\n# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorc"
},
{
"path": "ldm/modules/evaluate/torch_frechet_video_distance.py",
"chars": 10549,
"preview": "# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks!\nimport os\nimport num"
},
{
"path": "ldm/modules/image_degradation/__init__.py",
"chars": 208,
"preview": "from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr\nfrom ldm.modules.image"
},
{
"path": "ldm/modules/image_degradation/bsrgan.py",
"chars": 25198,
"preview": "# -*- coding: utf-8 -*-\n\"\"\"\n# --------------------------------------------\n# Super-Resolution\n# ------------------------"
},
{
"path": "ldm/modules/image_degradation/bsrgan_light.py",
"chars": 22190,
"preview": "# -*- coding: utf-8 -*-\nimport numpy as np\nimport cv2\nimport torch\n\nfrom functools import partial\nimport random\nfrom sci"
},
{
"path": "ldm/modules/image_degradation/utils_image.py",
"chars": 29022,
"preview": "import os\nimport math\nimport random\nimport numpy as np\nimport torch\nimport cv2\nfrom torchvision.utils import make_grid\nf"
},
{
"path": "ldm/modules/losses/__init__.py",
"chars": 68,
"preview": "from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator"
},
{
"path": "ldm/modules/losses/contperceptual.py",
"chars": 5581,
"preview": "import torch\nimport torch.nn as nn\n\nfrom taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?\n"
},
{
"path": "ldm/modules/losses/vqperceptual.py",
"chars": 7941,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import repeat\n\nfrom taming.modules.discrim"
},
{
"path": "ldm/modules/x_transformer.py",
"chars": 20168,
"preview": "\"\"\"shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers\"\"\"\nimport torch\nfrom torch import "
},
{
"path": "ldm/thirdp/psp/helpers.py",
"chars": 3604,
"preview": "# https://github.com/eladrich/pixel2style2pixel\n\nfrom collections import namedtuple\nimport torch\nfrom torch.nn import Co"
},
{
"path": "ldm/thirdp/psp/id_loss.py",
"chars": 847,
"preview": "# https://github.com/eladrich/pixel2style2pixel\nimport torch\nfrom torch import nn\nfrom ldm.thirdp.psp.model_irse import "
},
{
"path": "ldm/thirdp/psp/model_irse.py",
"chars": 2883,
"preview": "# https://github.com/eladrich/pixel2style2pixel\n\nfrom torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, D"
},
{
"path": "ldm/util.py",
"chars": 7819,
"preview": "import importlib\n\nimport torchvision\nimport torch\nfrom torch import optim\nimport numpy as np\n\nfrom inspect import isfunc"
},
{
"path": "main.py",
"chars": 24270,
"preview": "import torch\nimport argparse\nimport pandas as pd\nimport sys\n\nfrom nerf.provider import NeRFDataset\nfrom nerf.utils impor"
},
{
"path": "meshutils.py",
"chars": 3880,
"preview": "import numpy as np\nimport pymeshlab as pml\n\ndef poisson_mesh_reconstruction(points, normals=None):\n # points/normals:"
},
{
"path": "nerf/gui.py",
"chars": 20302,
"preview": "import math\nimport torch\nimport numpy as np\nimport dearpygui.dearpygui as dpg\nfrom scipy.spatial.transform import Rotati"
},
{
"path": "nerf/network.py",
"chars": 8045,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom activation import trunc_exp\nfrom .renderer impo"
},
{
"path": "nerf/network_grid.py",
"chars": 6008,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom activation import trunc_exp, biased_softplus\nfr"
},
{
"path": "nerf/network_grid_taichi.py",
"chars": 6024,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom activation import trunc_exp\nfrom .renderer impo"
},
{
"path": "nerf/network_grid_tcnn.py",
"chars": 5747,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom activation import trunc_exp, biased_softplus\nfr"
},
{
"path": "nerf/provider.py",
"chars": 12284,
"preview": "import os\nimport cv2\nimport glob\nimport json\nimport tqdm\nimport random\nimport numpy as np\nfrom scipy.spatial.transform i"
},
{
"path": "nerf/renderer.py",
"chars": 51178,
"preview": "import os\nimport math\nimport cv2\nimport trimesh\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.f"
},
{
"path": "nerf/utils.py",
"chars": 56535,
"preview": "import os\nimport gc\nimport glob\nimport tqdm\nimport math\nimport imageio\nimport psutil\nfrom pathlib import Path\nimport ran"
},
{
"path": "optimizer.py",
"chars": 11804,
"preview": "# Copyright 2022 Garena Online Private Limited\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you"
},
{
"path": "preprocess_image.py",
"chars": 7439,
"preview": "import os\nimport sys\nimport cv2\nimport argparse\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport torch\nimport "
},
{
"path": "pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml",
"chars": 3006,
"preview": "model:\n base_learning_rate: 1.0e-04\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "raymarching/__init__.py",
"chars": 26,
"preview": "from .raymarching import *"
},
{
"path": "raymarching/backend.py",
"chars": 1454,
"preview": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags"
},
{
"path": "raymarching/raymarching.py",
"chars": 15284,
"preview": "import numpy as np\nimport time\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.cuda.a"
},
{
"path": "raymarching/setup.py",
"chars": 2155,
"preview": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = "
},
{
"path": "raymarching/src/bindings.cpp",
"chars": 968,
"preview": "#include <torch/extension.h>\n\n#include \"raymarching.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n // utils\n m.de"
},
{
"path": "raymarching/src/raymarching.cu",
"chars": 31876,
"preview": "#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/t"
},
{
"path": "raymarching/src/raymarching.h",
"chars": 2590,
"preview": "#pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n\nvoid near_far_from_aabb(const at::Tensor rays_o, const at::"
},
{
"path": "readme.md",
"chars": 17611,
"preview": "# Stable-Dreamfusion\n\nA pytorch implementation of the text-to-3D model **Dreamfusion**, powered by the [Stable Diffusion"
},
{
"path": "requirements.txt",
"chars": 685,
"preview": "tqdm\nrich\nninja\nnumpy\npandas\nscipy\nscikit-learn\nmatplotlib\nopencv-python\nimageio\nimageio-ffmpeg\n\ntorch\ntorch-ema\neinops\n"
},
{
"path": "scripts/install_ext.sh",
"chars": 101,
"preview": "pip install ./raymarching\npip install ./shencoder\npip install ./freqencoder\npip install ./gridencoder"
},
{
"path": "scripts/res64.args",
"chars": 25,
"preview": "-O --vram_O --w 64 --h 64"
},
{
"path": "scripts/run.sh",
"chars": 1669,
"preview": "#! /bin/bash\nCUDA_VISIBLE_DEVICES=1 python main.py -O --text \"a DSLR photo of a delicious hamburger\" --workspace trial_h"
},
{
"path": "scripts/run2.sh",
"chars": 988,
"preview": "#! /bin/bash\n\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"a DSLR photo of a shiba inu playing golf wearing tartan g"
},
{
"path": "scripts/run3.sh",
"chars": 1155,
"preview": "#! /bin/bash\n\nCUDA_VISIBLE_DEVICES=7 python main.py -O --text \"ironman, full body\" --workspace trial_ironman --iters 100"
},
{
"path": "scripts/run4.sh",
"chars": 1264,
"preview": "#! /bin/bash\n\nCUDA_VISIBLE_DEVICES=5 python main.py -O --text \"a rabbit, animated movie character, high detail 3d model\""
},
{
"path": "scripts/run5.sh",
"chars": 1076,
"preview": "#! /bin/bash\n\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"Perched blue jay bird\" --workspace trial_jay --iters 1000"
},
{
"path": "scripts/run6.sh",
"chars": 2101,
"preview": "#! /bin/bash\nCUDA_VISIBLE_DEVICES=4 python main.py -O --text \"a baby bunny sitting on top of a stack of pancakes\" --work"
},
{
"path": "scripts/run_if.sh",
"chars": 2185,
"preview": "#! /bin/bash\nCUDA_VISIBLE_DEVICES=2 python main.py -O --text \"a baby bunny sitting on top of a stack of pancakes\" --work"
},
{
"path": "scripts/run_if2.sh",
"chars": 1997,
"preview": "#! /bin/bash\nCUDA_VISIBLE_DEVICES=3 python main.py -O --text \"a corgi taking a selfie\" --workspace trial_if_corgi --iter"
},
{
"path": "scripts/run_if2_perpneg.sh",
"chars": 2308,
"preview": "#! /bin/bash\n# To avoid the Janus problem caused by the diffusion model's front view bias, utilize the Perp-Neg algorith"
},
{
"path": "scripts/run_image.sh",
"chars": 2408,
"preview": "# zero123 backend (single object, images like 3d model rendering)\n\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image data"
},
{
"path": "scripts/run_image_anya.sh",
"chars": 3058,
"preview": "# Phase 1 - barely fits in A100 40GB.\n# Conclusion: results in concave-ish face, no neck, excess hair in the back\nCUDA_V"
},
{
"path": "scripts/run_image_hard_examples.sh",
"chars": 1874,
"preview": "bash scripts/run_image_procedure.sh 0 30 90 anya_front \"A DSLR 3D photo of a cute anime schoolgirl stands proudly with h"
},
{
"path": "scripts/run_image_procedure.sh",
"chars": 3483,
"preview": "# Perform a 2D-to-3D reconstruction, similar to the Anya case study: https://github.com/ashawkey/stable-dreamfusion/issu"
},
{
"path": "scripts/run_image_text.sh",
"chars": 1543,
"preview": "# sd backend (realistic images)\n\nCUDA_VISIBLE_DEVICES=4 python main.py -O --image data/teddy_rgba.png --text \"a brown te"
},
{
"path": "scripts/run_images.sh",
"chars": 935,
"preview": "# zero123 backend (single object, images like 3d model rendering)\n\nCUDA_VISIBLE_DEVICES=6 python main.py -O --image_conf"
},
{
"path": "shencoder/__init__.py",
"chars": 39,
"preview": "from .sphere_harmonics import SHEncoder"
},
{
"path": "shencoder/backend.py",
"chars": 1451,
"preview": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags"
},
{
"path": "shencoder/setup.py",
"chars": 1713,
"preview": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = "
},
{
"path": "shencoder/sphere_harmonics.py",
"chars": 2697,
"preview": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function "
},
{
"path": "shencoder/src/bindings.cpp",
"chars": 261,
"preview": "#include <torch/extension.h>\n\n#include \"shencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n m.def(\"sh_encode_fo"
},
{
"path": "shencoder/src/shencoder.cu",
"chars": 37131,
"preview": "#include <stdint.h>\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext"
},
{
"path": "shencoder/src/shencoder.h",
"chars": 430,
"preview": "# pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// inputs: [B, D], float, in [-1, 1]\n// outputs: [B, F], fl"
},
{
"path": "taichi_modules/__init__.py",
"chars": 224,
"preview": "from .ray_march import RayMarcherTaichi, raymarching_test\nfrom .volume_train import VolumeRendererTaichi\nfrom .intersect"
},
{
"path": "taichi_modules/hash_encoder.py",
"chars": 10961,
"preview": "import numpy as np\nimport taichi as ti\nimport torch\nfrom taichi.math import uvec3\nfrom torch.cuda.amp import custom_bwd,"
},
{
"path": "taichi_modules/intersection.py",
"chars": 2278,
"preview": "import taichi as ti\nimport torch\nfrom taichi.math import vec3\nfrom torch.cuda.amp import custom_fwd\n\nfrom .utils import "
},
{
"path": "taichi_modules/ray_march.py",
"chars": 13205,
"preview": "import taichi as ti\nimport torch\nfrom taichi.math import vec3\nfrom torch.cuda.amp import custom_fwd\n\nfrom .utils import "
},
{
"path": "taichi_modules/utils.py",
"chars": 6369,
"preview": "import taichi as ti\nimport torch\nfrom taichi.math import uvec3\n\ntaichi_block_size = 128\n\ndata_type = ti.f32\ntorch_type ="
},
{
"path": "taichi_modules/volume_render_test.py",
"chars": 1460,
"preview": "import taichi as ti\n\n\n@ti.kernel\ndef composite_test(\n sigmas: ti.types.ndarray(ndim=2), rgbs: ti.types.ndarray(ndim=3"
},
{
"path": "taichi_modules/volume_train.py",
"chars": 9419,
"preview": "import taichi as ti\nimport torch\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\nfrom .utils import (data_type, ti2to"
},
{
"path": "tets/README.md",
"chars": 263,
"preview": "Place the tet grid files in this folder. \nWe provide a few example grids. See the main README.md for a download link.\n\nY"
},
{
"path": "tets/generate_tets.py",
"chars": 2327,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
}
]
// ... and 3 more files (download for full content)
About this extraction
This page contains the full source code of the ashawkey/stable-dreamfusion GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 132 files (1016.8 KB), approximately 275.2k tokens, and a symbol index with 1102 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.