Repository: ashawkey/stable-dreamfusion
Branch: main
Commit: 5550b91862a3
Files: 132
Total size: 1016.8 KB
Directory structure:
gitextract_l054hyr6/
├── .github/
│ └── ISSUE_TEMPLATE/
│ ├── bug_report.yaml
│ └── feature_request.md
├── .gitignore
├── LICENSE
├── activation.py
├── assets/
│ ├── advanced.md
│ └── update_logs.md
├── config/
│ ├── anya.csv
│ ├── car.csv
│ └── corgi.csv
├── docker/
│ ├── Dockerfile
│ └── README.md
├── dpt.py
├── encoding.py
├── evaluation/
│ ├── Prompt.py
│ ├── mesh_to_video.py
│ ├── r_precision.py
│ └── readme.md
├── freqencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── freq.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── freqencoder.cu
│ └── freqencoder.h
├── gridencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── grid.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── gridencoder.cu
│ └── gridencoder.h
├── guidance/
│ ├── clip_utils.py
│ ├── if_utils.py
│ ├── perpneg_utils.py
│ ├── sd_utils.py
│ └── zero123_utils.py
├── ldm/
│ ├── extras.py
│ ├── guidance.py
│ ├── lr_scheduler.py
│ ├── models/
│ │ ├── autoencoder.py
│ │ └── diffusion/
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── plms.py
│ │ └── sampling_util.py
│ ├── modules/
│ │ ├── attention.py
│ │ ├── diffusionmodules/
│ │ │ ├── __init__.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ └── util.py
│ │ ├── distributions/
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── encoders/
│ │ │ ├── __init__.py
│ │ │ └── modules.py
│ │ ├── evaluate/
│ │ │ ├── adm_evaluator.py
│ │ │ ├── evaluate_perceptualsim.py
│ │ │ ├── frechet_video_distance.py
│ │ │ ├── ssim.py
│ │ │ └── torch_frechet_video_distance.py
│ │ ├── image_degradation/
│ │ │ ├── __init__.py
│ │ │ ├── bsrgan.py
│ │ │ ├── bsrgan_light.py
│ │ │ └── utils_image.py
│ │ ├── losses/
│ │ │ ├── __init__.py
│ │ │ ├── contperceptual.py
│ │ │ └── vqperceptual.py
│ │ └── x_transformer.py
│ ├── thirdp/
│ │ └── psp/
│ │ ├── helpers.py
│ │ ├── id_loss.py
│ │ └── model_irse.py
│ └── util.py
├── main.py
├── meshutils.py
├── nerf/
│ ├── gui.py
│ ├── network.py
│ ├── network_grid.py
│ ├── network_grid_taichi.py
│ ├── network_grid_tcnn.py
│ ├── provider.py
│ ├── renderer.py
│ └── utils.py
├── optimizer.py
├── preprocess_image.py
├── pretrained/
│ └── zero123/
│ └── sd-objaverse-finetune-c_concat-256.yaml
├── raymarching/
│ ├── __init__.py
│ ├── backend.py
│ ├── raymarching.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── raymarching.cu
│ └── raymarching.h
├── readme.md
├── requirements.txt
├── scripts/
│ ├── install_ext.sh
│ ├── res64.args
│ ├── run.sh
│ ├── run2.sh
│ ├── run3.sh
│ ├── run4.sh
│ ├── run5.sh
│ ├── run6.sh
│ ├── run_if.sh
│ ├── run_if2.sh
│ ├── run_if2_perpneg.sh
│ ├── run_image.sh
│ ├── run_image_anya.sh
│ ├── run_image_hard_examples.sh
│ ├── run_image_procedure.sh
│ ├── run_image_text.sh
│ └── run_images.sh
├── shencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── setup.py
│ ├── sphere_harmonics.py
│ └── src/
│ ├── bindings.cpp
│ ├── shencoder.cu
│ └── shencoder.h
├── taichi_modules/
│ ├── __init__.py
│ ├── hash_encoder.py
│ ├── intersection.py
│ ├── ray_march.py
│ ├── utils.py
│ ├── volume_render_test.py
│ └── volume_train.py
└── tets/
├── 128_tets.npz
├── 32_tets.npz
├── 64_tets.npz
├── README.md
└── generate_tets.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.yaml
================================================
name: Bug Report
description: File a bug report
title: "
"
labels: ["bug"]
body:
- type: markdown
attributes:
value: |
Before filing a bug report, [search for an existing issue](https://github.com/ashawkey/stable-dreamfusion/issues).
Also, ensure you are running the latest version.
- type: textarea
id: description
attributes:
label: Description
description: Provide a clear and concise description of what the bug is.
placeholder: Description
validations:
required: true
- type: textarea
id: steps
attributes:
label: Steps to Reproduce
description: List the steps needed to reproduce the issue.
placeholder: |
1. Go to '...'
2. Click on '...'
validations:
required: true
- type: textarea
id: expected-behavior
attributes:
label: Expected Behavior
description: Describe what you expected to happen.
placeholder: |
The 'action' would do 'some amazing thing'.
validations:
required: true
- type: textarea
id: environment
attributes:
label: Environment
description: Describe your environment.
placeholder: |
Ubuntu 22.04, PyTorch 1.13, CUDA 11.6
validations:
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: enhancement
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
================================================
FILE: .gitignore
================================================
__pycache__/
build/
*.egg-info/
*.so
venv_*/
tmp*
# data/
ldm/data/
data2
scripts2
trial*/
.vs/
TOKEN
*.ckpt
densegridencoder
tets/256_tets.npz
.vscode/launch.json
data2
data/car*
data/chair*
data/warrior*
data/wd*
data/space*
data/corgi*
data/turtle*
# Only keep the original image, not the automatically-generated depth, normals, rgba
data/baby_phoenix_on_ice_*
data/bollywood_actress_*
data/beach_house_1_*
data/beach_house_2_*
data/mona_lisa_*
data/futuristic_car_*
data/church_ruins_*
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: activation.py
================================================
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
class _trunc_exp(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float)
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.exp(x)
@staticmethod
@custom_bwd
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(max=15))
trunc_exp = _trunc_exp.apply
def biased_softplus(x, bias=0):
return torch.nn.functional.softplus(x - bias)
================================================
FILE: assets/advanced.md
================================================
# Code organization & Advanced tips
This is a simple description of the most important implementation details.
If you are interested in improving this repo, this might be a starting point.
Any contribution would be greatly appreciated!
* The SDS loss is located at `./guidance/sd_utils.py > StableDiffusion > train_step`:
```python
## 1. we need to interpolate the NeRF rendering to 512x512, to feed it to SD's VAE.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
## 2. image (512x512) --- VAE --> latents (64x64), this is SD's difference from Imagen.
latents = self.encode_imgs(pred_rgb_512)
... # timestep sampling, noise adding and UNet noise predicting
## 3. the SDS loss
w = (1 - self.alphas[t])
grad = w * (noise_pred - noise)
# since UNet part is ignored and cannot simply audodiff, we have two ways to set the grad:
# 3.1. call backward and set the grad now (need to retain graph since we will call a second backward for the other losses later)
latents.backward(gradient=grad, retain_graph=True)
return 0 # dummy loss
# 3.2. use a custom function to set a hook in backward, so we only call backward once (credits to @elliottzheng)
class SpecifyGradient(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, input_tensor, gt_grad):
ctx.save_for_backward(gt_grad)
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
@staticmethod
@custom_bwd
def backward(ctx, grad_scale):
gt_grad, = ctx.saved_tensors
gt_grad = gt_grad * grad_scale
return gt_grad, None
loss = SpecifyGradient.apply(latents, grad)
return loss # functional loss
# 3.3. reparameterization (credits to @Xallt)
# d(loss)/d(latents) = grad, since grad is already detached, it's this simple.
loss = (grad * latents).sum()
return loss
# 3.4. reparameterization (credits to threestudio)
# this is the same as 3.3, but the loss value only reflects the magnitude of grad, which is more informative.
targets = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents, targets, reduction='sum')
return loss
```
* Other regularizations are in `./nerf/utils.py > Trainer > train_step`.
* The generation seems quite sensitive to regularizations on weights_sum (alphas for each ray). The original opacity loss tends to make NeRF disappear (zero density everywhere), so we use an entropy loss to replace it for now (encourages alpha to be either 0 or 1).
* NeRF Rendering core function: `./nerf/renderer.py > NeRFRenderer > run & run_cuda`.
* Shading & normal evaluation: `./nerf/network*.py > NeRFNetwork > forward`.
* light direction: current implementation use a plane light source, instead of a point light source.
* View-dependent prompting: `./nerf/provider.py > get_view_direction`.
* use `--angle_overhead, --angle_front` to set the border.
* Network backbone (`./nerf/network*.py`) can be chosen by the `--backbone` option.
* Spatial density bias (density blob): `./nerf/network*.py > NeRFNetwork > density_blob`.
# Debugging
`debugpy-run` is a convenient way to remotely debug this project. Simply replace a command like this one:
```bash
python main.py --text "a hamburger" --workspace trial -O --vram_O
```
... with:
```bash
debugpy-run main.py -- --text "a hamburger" --workspace trial -O --vram_O
```
For more details: https://github.com/bulletmark/debugpy-run
# Axes and directions of polar, azimuth, etc. in NeRF and Zero123
This code refers to theta for polar, phi for azimuth.
================================================
FILE: assets/update_logs.md
================================================
### 2023.4.19
* Fix depth supervision, migrate depth estimation model to omnidata.
* Add normal supervision (also by omnidata).
https://user-images.githubusercontent.com/25863658/232403294-b77409bf-ddc7-4bb8-af32-ee0cc123825a.mp4
### 2023.4.7
Improvement on mesh quality & DMTet finetuning support.
https://user-images.githubusercontent.com/25863658/230535363-298c960e-bf9c-4906-8b96-cd60edcb24dd.mp4
### 2023.3.30
* adopt ideas from [Fantasia3D](https://fantasia3d.github.io/) to concatenate normal and mask as the latent code in a warm up stage, which shows faster convergence of shape.
https://user-images.githubusercontent.com/25863658/230535373-6ee28f16-bb21-4ec4-bc86-d46597361a04.mp4
### 2023.1.30
* Use an MLP to predict the surface normals as in Magic3D to avoid finite difference / second order gradient, generation quality is greatly improved.
* More efficient two-pass raymarching in training inspired by nerfacc.
https://user-images.githubusercontent.com/25863658/215996308-9fd959f5-b5c7-4a8e-a241-0fe63ec86a4a.mp4
### 2022.12.3
* Support Stable-diffusion 2.0 base.
### 2022.11.15
* Add the vanilla backbone that is pure-pytorch.
### 2022.10.9
* The shading (partially) starts to work, at least it won't make scene empty. For some prompts, it shows better results (less severe Janus problem). The textureless rendering mode is still disabled.
* Enable shading by default (--latent_iter_ratio 1000).
### 2022.10.5
* Basic reproduction finished.
* Non --cuda_ray, --tcnn are not working, need to fix.
* Shading is not working, disabled in utils.py for now. Surface normals are bad.
* Use an entropy loss to regularize weights_sum (alpha), the original L2 reg always leads to degenerated geometry...
https://user-images.githubusercontent.com/25863658/194241493-f3e68f78-aefe-479e-a4a8-001424a61b37.mp4
================================================
FILE: config/anya.csv
================================================
zero123_weight, radius, polar, azimuth, image
1, 3, 90, 0, data/anya_front_rgba.png
1, 3, 90, 180, data/anya_back_rgba.png
================================================
FILE: config/car.csv
================================================
zero123_weight, radius, polar, azimuth, image
4, 3.2, 90, 0, data/car_left_rgba.png
1, 3, 90, 90, data/car_front_rgba.png
4, 3.2, 90, 180, data/car_right_rgba.png
1, 3, 90, -90, data/car_back_rgba.png
================================================
FILE: config/corgi.csv
================================================
zero123_weight, radius, polar, azimuth, image
1, 3.2, 90, 0, data/corgi_puppy_sitting_looking_up_rgba.png
================================================
FILE: docker/Dockerfile
================================================
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
# Remove any third-party apt sources to avoid issues with expiring keys.
RUN rm -f /etc/apt/sources.list.d/*.list
RUN apt-get update
RUN DEBIAN_FRONTEND=noninteractive TZ=Europe/MADRID apt-get install -y tzdata
# Install some basic utilities
RUN apt-get install -y \
curl \
ca-certificates \
sudo \
git \
bzip2 \
libx11-6 \
python3 \
python3-pip \
libglfw3-dev \
libgles2-mesa-dev \
libglib2.0-0 \
&& rm -rf /var/lib/apt/lists/*
# Create a working directory
RUN mkdir /app
WORKDIR /app
RUN cd /app
RUN git clone https://github.com/ashawkey/stable-dreamfusion.git
RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
WORKDIR /app/stable-dreamfusion
RUN pip3 install -r requirements.txt
RUN pip3 install git+https://github.com/NVlabs/nvdiffrast/
# Needs nvidia runtime, if you have "No CUDA runtime is found" error: https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime, first answer
RUN pip3 install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
RUN pip3 install git+https://github.com/openai/CLIP.git
RUN bash scripts/install_ext.sh
# Set the default command to python3
#CMD ["python3"]
================================================
FILE: docker/README.md
================================================
### Docker installation
## Build image
To build the docker image on your own machine, which may take 15-30 mins:
```
docker build -t stable-dreamfusion:latest .
```
If you have the error **No CUDA runtime is found** when building the wheels for tiny-cuda-nn you need to setup the nvidia-runtime for docker.
```
sudo apt-get install nvidia-container-runtime
```
Then edit `/etc/docker/daemon.json` and add the default-runtime:
```
{
"runtimes": {
"nvidia": {
"path": "nvidia-container-runtime",
"runtimeArgs": []
}
},
"default-runtime": "nvidia"
}
```
And restart docker:
```
sudo systemctl restart docker
```
Now you can build tiny-cuda-nn inside docker.
## Download image
To download the image (~6GB) instead:
```
docker pull supercabb/stable-dreamfusion:3080_0.0.1
docker tag supercabb/stable-dreamfusion:3080_0.0.1 stable-dreamfusion
```
## Use image
You can launch an interactive shell inside the container:
```
docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash
```
From this shell, all the code in the repo should work.
To run any single command `` inside the docker container:
```
docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c ""
```
To train:
```
export TOKEN="#HUGGING FACE ACCESS TOKEN#"
docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "echo ${TOKEN} > TOKEN \
&& python3 main.py --text \"a hamburger\" --workspace trial -O"
```
Run test without gui:
```
export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
main.py --workspace trial -O --test"
```
Run test with gui:
```
export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
xhost +
docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
main.py --workspace trial -O --test --gui"
xhost -
```
================================================
FILE: dpt.py
================================================
import math
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'))
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)
def unflatten_with_named_tensor(input, dim, sizes):
"""Workaround for unflattening with named tensor."""
# tracer acts up with unflatten. See https://github.com/pytorch/pytorch/issues/49538
new_shape = list(input.shape)[:dim] + list(sizes) + list(input.shape)[dim+1:]
return input.view(*new_shape)
class Slice(nn.Module):
def __init__(self, start_index=1):
super(Slice, self).__init__()
self.start_index = start_index
def forward(self, x):
return x[:, self.start_index :]
class AddReadout(nn.Module):
def __init__(self, start_index=1):
super(AddReadout, self).__init__()
self.start_index = start_index
def forward(self, x):
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index :] + readout.unsqueeze(1)
class ProjectReadout(nn.Module):
def __init__(self, in_features, start_index=1):
super(ProjectReadout, self).__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
features = torch.cat((x[:, self.start_index :], readout), -1)
return self.project(features)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x
def forward_vit(pretrained, x):
b, c, h, w = x.shape
glob = pretrained.model.forward_flex(x)
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_4 = pretrained.activations["4"]
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
unflattened_dim = 2
unflattened_size = (
int(torch.div(h, pretrained.model.patch_size[1], rounding_mode='floor')),
int(torch.div(w, pretrained.model.patch_size[0], rounding_mode='floor')),
)
unflatten = nn.Sequential(nn.Unflatten(unflattened_dim, unflattened_size))
if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten_with_named_tensor(layer_3, unflattened_dim, unflattened_size)
if layer_4.ndim == 3:
layer_4 = unflatten_with_named_tensor(layer_4, unflattened_dim, unflattened_size)
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
return layer_1, layer_2, layer_3, layer_4
def _resize_pos_embed(self, posemb, gs_h, gs_w):
posemb_tok, posemb_grid = (
posemb[:, : self.start_index],
posemb[0, self.start_index :],
)
gs_old = int(math.sqrt(posemb_grid.shape[0]))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def forward_flex(self, x):
b, c, h, w = x.shape
pos_embed = self._resize_pos_embed(
self.pos_embed, torch.div(h, self.patch_size[1], rounding_mode='floor'), torch.div(w, self.patch_size[0], rounding_mode='floor')
)
B = x.shape[0]
if hasattr(self.patch_embed, "backbone"):
x = self.patch_embed.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
if getattr(self, "dist_token", None) is not None:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
activations = {}
def get_activation(name):
def hook(model, input, output):
activations[name] = output
return hook
def get_readout_oper(vit_features, features, use_readout, start_index=1):
if use_readout == "ignore":
readout_oper = [Slice(start_index)] * len(features)
elif use_readout == "add":
readout_oper = [AddReadout(start_index)] * len(features)
elif use_readout == "project":
readout_oper = [
ProjectReadout(vit_features, start_index) for out_feat in features
]
else:
assert (
False
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
return readout_oper
def _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
size=[384, 384],
hooks=[2, 5, 8, 11],
vit_features=768,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
# 32, 48, 136, 384
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
hooks = [5, 11, 17, 23] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[256, 512, 1024, 1024],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
)
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model(
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout,
start_index=2,
)
def _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=[0, 1, 8, 11],
vit_features=768,
use_vit_only=False,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
if use_vit_only == True:
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
else:
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
get_activation("1")
)
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
get_activation("2")
)
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
if use_vit_only == True:
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
else:
pretrained.act_postprocess1 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess2 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitb_rn50_384(
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
):
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
hooks = [0, 1, 8, 11] if hooks == None else hooks
return _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
if backbone == "vitl16_384":
pretrained = _make_pretrained_vitl16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[256, 512, 1024, 1024], features, groups=groups, expand=expand
) # ViT-L/16 - 85.0% Top1 (backbone)
elif backbone == "vitb_rn50_384":
pretrained = _make_pretrained_vitb_rn50_384(
use_pretrained,
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
scratch = _make_scratch(
[256, 512, 768, 768], features, groups=groups, expand=expand
) # ViT-H/16 - 85.0% Top1 (backbone)
elif backbone == "vitb16_384":
pretrained = _make_pretrained_vitb16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[96, 192, 384, 768], features, groups=groups, expand=expand
) # ViT-B/16 - 84.6% Top1 (backbone)
elif backbone == "resnext101_wsl":
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
elif backbone == "efficientnet_lite3":
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
else:
print(f"Backbone '{backbone}' not implemented")
assert False
return pretrained, scratch
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand==True:
out_shape1 = out_shape
out_shape2 = out_shape*2
out_shape3 = out_shape*4
out_shape4 = out_shape*8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
efficientnet = torch.hub.load(
"rwightman/gen-efficientnet-pytorch",
"tf_efficientnet_lite3",
pretrained=use_pretrained,
exportable=exportable
)
return _make_efficientnet_backbone(efficientnet)
def _make_efficientnet_backbone(effnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
)
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
return pretrained
def _make_resnet_backbone(resnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
)
pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4
return pretrained
def _make_pretrained_resnext101_wsl(use_pretrained):
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
return _make_resnet_backbone(resnet)
class Interpolate(nn.Module):
"""Interpolation module.
"""
def __init__(self, scale_factor, mode, align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
)
return x
class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=True
)
return output
class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups=1
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
if self.bn==True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn==True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn==True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
# return out + x
class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand==True:
out_features = features//2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
# output += res
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output
def _make_fusion_block(features, use_bn):
return FeatureFusionBlock_custom(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
)
class DPT(BaseModel):
def __init__(
self,
head,
features=256,
backbone="vitb_rn50_384",
readout="project",
channels_last=False,
use_bn=False,
):
super(DPT, self).__init__()
self.channels_last = channels_last
hooks = {
"vitb_rn50_384": [0, 1, 8, 11],
"vitb16_384": [2, 5, 8, 11],
"vitl16_384": [5, 11, 17, 23],
}
# Instantiate backbone and reassemble blocks
self.pretrained, self.scratch = _make_encoder(
backbone,
features,
True, # Set to true of you want to train from scratch, uses ImageNet weights
groups=1,
expand=False,
exportable=False,
hooks=hooks[backbone],
use_readout=readout,
)
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
self.scratch.output_conv = head
def forward(self, x):
if self.channels_last == True:
x.contiguous(memory_format=torch.channels_last)
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return out
class DPTDepthModel(DPT):
def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs):
features = kwargs["features"] if "features" in kwargs else 256
head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
super().__init__(head, **kwargs)
if path is not None:
self.load(path)
def forward(self, x):
return super().forward(x).squeeze(dim=1)
================================================
FILE: encoding.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class FreqEncoder_torch(nn.Module):
def __init__(self, input_dim, max_freq_log2, N_freqs,
log_sampling=True, include_input=True,
periodic_fns=(torch.sin, torch.cos)):
super().__init__()
self.input_dim = input_dim
self.include_input = include_input
self.periodic_fns = periodic_fns
self.N_freqs = N_freqs
self.output_dim = 0
if self.include_input:
self.output_dim += self.input_dim
self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
if log_sampling:
self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs)
else:
self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs)
self.freq_bands = self.freq_bands.numpy().tolist()
def forward(self, input, max_level=None, **kwargs):
if max_level is None:
max_level = self.N_freqs
else:
max_level = int(max_level * self.N_freqs)
out = []
if self.include_input:
out.append(input)
for i in range(max_level):
freq = self.freq_bands[i]
for p_fn in self.periodic_fns:
out.append(p_fn(input * freq))
# append 0
if self.N_freqs - max_level > 0:
out.append(torch.zeros(*input.shape[:-1], (self.N_freqs - max_level) * 2 * input.shape[-1], device=input.device, dtype=input.dtype))
out = torch.cat(out, dim=-1)
return out
def get_encoder(encoding, input_dim=3,
multires=6,
degree=4,
num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, interpolation='linear',
**kwargs):
if encoding == 'None':
return lambda x, **kwargs: x, input_dim
elif encoding == 'frequency_torch':
encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)
elif encoding == 'frequency': # CUDA implementation, faster than torch.
from freqencoder import FreqEncoder
encoder = FreqEncoder(input_dim=input_dim, degree=multires)
elif encoding == 'sphere_harmonics':
from shencoder import SHEncoder
encoder = SHEncoder(input_dim=input_dim, degree=degree)
elif encoding == 'hashgrid':
from gridencoder import GridEncoder
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation)
elif encoding == 'tiledgrid':
from gridencoder import GridEncoder
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation)
elif encoding == 'hashgrid_taichi':
from taichi_modules.hash_encoder import HashEncoderTaichi
encoder = HashEncoderTaichi(batch_size=4096) #TODO: hard encoded batch size
else:
raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
return encoder, encoder.output_dim
================================================
FILE: evaluation/Prompt.py
================================================
import textwrap
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification
from transformers import pipeline
import argparse
import sys
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
#python Prompt.py --text "a dog is in front of a rabbit" --model vlt5
if __name__ == '__main__':
# Mimic the calling part of the main, using
parser = argparse.ArgumentParser()
parser.add_argument('--text', default="", type=str, help="text prompt")
#parser.add_argument('--workspace', default="trial", type=str, help="workspace")
parser.add_argument('--model', default='vlt5', type=str, help="model choices - vlt5, bert, XLNet")
opt = parser.parse_args()
if opt.model == "vlt5":
tokenizer = AutoTokenizer.from_pretrained("Voicelab/vlt5-base-keywords")
model = AutoModelForSeq2SeqLM.from_pretrained("Voicelab/vlt5-base-keywords")
task_prefix = "Keywords: "
inputs = [
opt.text
]
for sample in inputs:
input_sequences = [task_prefix + sample]
input_ids = tokenizer(
input_sequences, return_tensors="pt", truncation=True
).input_ids
output = model.generate(input_ids, no_repeat_ngram_size=3, num_beams=4)
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
#print(sample, "\n --->", output_text)
elif opt.model == "bert":
tokenizer = AutoTokenizer.from_pretrained("yanekyuk/bert-uncased-keyword-extractor")
model = AutoModelForTokenClassification.from_pretrained("yanekyuk/bert-uncased-keyword-extractor")
text = opt.text
input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
# Classify tokens
outputs = model(input_ids)
predictions = outputs.logits.detach().numpy()[0]
labels = predictions.argmax(axis=1)
labels = labels[1:-1]
print(labels)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
tokens = tokens[1:-1]
output_tokens = [tokens[i] for i in range(len(tokens)) if labels[i] != 0]
output_text = tokenizer.convert_tokens_to_string(output_tokens)
#print(output_text)
elif opt.model == "XLNet":
tokenizer = AutoTokenizer.from_pretrained("jasminejwebb/KeywordIdentifier")
model = AutoModelForTokenClassification.from_pretrained("jasminejwebb/KeywordIdentifier")
text = opt.text
input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
# Classify tokens
outputs = model(input_ids)
predictions = outputs.logits.detach().numpy()[0]
labels = predictions.argmax(axis=1)
labels = labels[1:-1]
print(labels)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
tokens = tokens[1:-1]
output_tokens = [tokens[i] for i in range(len(tokens)) if labels[i] != 0]
output_text = tokenizer.convert_tokens_to_string(output_tokens)
#print(output_text)
wrapped_text = textwrap.fill(output_text, width=50)
print('+' + '-'*52 + '+')
for line in wrapped_text.split('\n'):
print('| {} |'.format(line.ljust(50)))
print('+' + '-'*52 + '+')
#print(result)
================================================
FILE: evaluation/mesh_to_video.py
================================================
import os
import numpy as np
import trimesh
import argparse
from pathlib import Path
from tqdm import tqdm
import pyvista as pv
def render_video(anim_mesh):
center = anim_mesh.center_mass
plotter = pv.Plotter(off_screen=True)
plotter.add_mesh(anim_mesh)
radius = 10
n_frames = 360
angle_step = 2 * np.pi / n_frames
for i in tqdm(range(n_frames)):
camera_pos = [center[0] + radius * np.cos(i*angle_step),center[1] + radius *np.sin(i*angle_step),center[2]]
plotter.camera_position = (camera_pos, center, (0, 0, 1))
plotter.show(screenshot=f'frame_{i}.png', auto_close=False)
plotter.close()
os.system('ffmpeg -r 30 -f image2 -s 1920x1080 -i "result/frame_%d.png" -vcodec libx264 -crf 25 -pix_fmt yuv420p result/output.mp4')
def generate_mesh(obj1,obj2,transform_vector):
# Read 2 objects
filename1 = obj1 # Central Object
filename2 = obj2 # Surrounding Object
mesh1 = trimesh.load_mesh(filename1)
mesh2 = trimesh.load_mesh(filename2)
extents1 = mesh1.extents
extents2 = mesh1.extents
radius1 = sum(extents1) / 3.0
radius2 = sum(extents2) / 3.0
center1 = mesh1.center_mass
center2 = mesh2.center_mass
# Move
T1 = -center1
new =[]
for i in transform_vector:
try:
new.append(float(i))*radius1
except:
pass
transform_vector = new
print(T1, transform_vector, radius1)
T2 = -center2 + transform_vector
# Transform
mesh1.apply_translation(T1)
mesh2.apply_translation(T2)
# merge mesh
merged_mesh = trimesh.util.concatenate((mesh1, mesh2))
# save mesh
merged_mesh.export('merged_mesh.obj')
print("----> merge mesh done")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate rotating mesh animation.')
parser.add_argument('--center_obj', type=str, help='Input OBJ1 file.')
parser.add_argument('--surround_obj', type=str, help='Input OBJ2 file.')
parser.add_argument('--transform_vector', help='Transform_vector.')
parser.add_argument('--output_file', type=str, default="result/Demo.mp4", help='Output MP4 file.')
parser.add_argument('--num_frames', type=int, default=100, help='Number of frames to render.')
args = parser.parse_args()
#mesh = obj.Obj("wr.obj")
generate_mesh(args.center_obj,args.surround_obj,args.transform_vector)
input_file = Path("merged_mesh.obj")
output_file = Path(args.output_file)
out_dir = output_file.parent.joinpath('frames')
out_dir.mkdir(parents=True, exist_ok=True)
anim_mesh = trimesh.load_mesh(str(input_file))
render_video(anim_mesh)
================================================
FILE: evaluation/r_precision.py
================================================
from sentence_transformers import SentenceTransformer, util
from PIL import Image
import argparse
import sys
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--text', default="", type=str, help="text prompt")
parser.add_argument('--workspace', default="trial", type=str, help="text prompt")
parser.add_argument('--latest', default='ep0001', type=str, help="which epoch result you want to use for image path")
parser.add_argument('--mode', default='rgb', type=str, help="mode of result, color(rgb) or textureless()")
parser.add_argument('--clip', default="clip-ViT-B-32", type=str, help="CLIP model to encode the img and prompt")
opt = parser.parse_args()
#Load CLIP model
model = SentenceTransformer(f'{opt.clip}')
#Encode an image:
img_emb = model.encode(Image.open(f'../results/{opt.workspace}/validation/df_{opt.latest}_0005_{opt.mode}.png'))
#Encode text descriptions
text_emb = model.encode([f'{opt.text}'])
#Compute cosine similarities
cos_scores = util.cos_sim(img_emb, text_emb)
print("The final CLIP R-Precision is:", cos_scores[0][0].cpu().numpy())
================================================
FILE: evaluation/readme.md
================================================
### Improvement:
- Usage
- r_precision.py
For prompt seperation
--text is for the prompt following the author of stable dream fusion
--workspace is the workspace folder which will be created for every prompt fed into stable dreamfusion
--latest is which ckpt is used. Stable dream fusion record every epoch data. Normally is ep0100 unless the training is not finished or we further extend the training
--mode has choices of rgb and depth which is correspondent to color and texture result as original paper Figure 5: Qualitative comparison with baselines.
--clip has choices of clip-ViT-B-32, CLIP B/16, CLIP L/14, same as original paper
```bash
python Prompt.py --text "matte painting of a castle made of cheesecake surrounded by a moat made of ice cream" --workspace ../castle --latest ep0100 --mode rgb --clip clip-ViT-B-32
```
- Prompt.py (model name case sensitive)
For prompt seperation
--text is for the prompt following the author of stable dream fusion
--model is for choose the pretrain models
```bash
python Prompt.py --text "a dog is in front of a rabbit" --model vlt5
python Prompt.py --text "a dog is in front of a rabbit" --model bert
python Prompt.py --text "a dog is in front of a rabbit" --model XLNet
```
- mesh_to_video.py
--center_obj IS THE CENTER OBJECT
--surround_obj IS THE SURROUNDING OBJECT SUBJECT TO CHANGE
--transform_vector THE X Y Z 3d vector for transform
```bash
python mesh_to_video.py --center_obj 'mesh_whiterabbit/mesh.obj' --surround_obj 'mesh_snake/mesh.obj' --transform_vector [1,0,0]
```
================================================
FILE: freqencoder/__init__.py
================================================
from .freq import FreqEncoder
================================================
FILE: freqencoder/backend.py
================================================
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
'-use_fast_math'
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
_backend = load(name='_freqencoder',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'freqencoder.cu',
'bindings.cpp',
]],
)
__all__ = ['_backend']
================================================
FILE: freqencoder/freq.py
================================================
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _freqencoder as _backend
except ImportError:
from .backend import _backend
class _freq_encoder(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
def forward(ctx, inputs, degree, output_dim):
# inputs: [B, input_dim], float
# RETURN: [B, F], float
if not inputs.is_cuda: inputs = inputs.cuda()
inputs = inputs.contiguous()
B, input_dim = inputs.shape # batch size, coord dim
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
_backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
ctx.save_for_backward(inputs, outputs)
ctx.dims = [B, input_dim, degree, output_dim]
return outputs
@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
# grad: [B, C * C]
grad = grad.contiguous()
inputs, outputs = ctx.saved_tensors
B, input_dim, degree, output_dim = ctx.dims
grad_inputs = torch.zeros_like(inputs)
_backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
return grad_inputs, None, None
freq_encode = _freq_encoder.apply
class FreqEncoder(nn.Module):
def __init__(self, input_dim=3, degree=4):
super().__init__()
self.input_dim = input_dim
self.degree = degree
self.output_dim = input_dim + input_dim * 2 * degree
def __repr__(self):
return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
def forward(self, inputs, **kwargs):
# inputs: [..., input_dim]
# return: [..., ]
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.reshape(-1, self.input_dim)
outputs = freq_encode(inputs, self.degree, self.output_dim)
outputs = outputs.reshape(prefix_shape + [self.output_dim])
return outputs
================================================
FILE: freqencoder/setup.py
================================================
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
'-use_fast_math'
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
setup(
name='freqencoder', # package name, import this to use python API
ext_modules=[
CUDAExtension(
name='_freqencoder', # extension name, import this to use CUDA API
sources=[os.path.join(_src_path, 'src', f) for f in [
'freqencoder.cu',
'bindings.cpp',
]],
extra_compile_args={
'cxx': c_flags,
'nvcc': nvcc_flags,
}
),
],
cmdclass={
'build_ext': BuildExtension,
}
)
================================================
FILE: freqencoder/src/bindings.cpp
================================================
#include
#include "freqencoder.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
}
================================================
FILE: freqencoder/src/freqencoder.cu
================================================
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
inline constexpr __device__ float PI() { return 3.141592653589793f; }
template
__host__ __device__ T div_round_up(T val, T divisor) {
return (val + divisor - 1) / divisor;
}
// inputs: [B, D]
// outputs: [B, C], C = D + D * deg * 2
__global__ void kernel_freq(
const float * __restrict__ inputs,
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
float * outputs
) {
// parallel on per-element
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * C) return;
// get index
const uint32_t b = t / C;
const uint32_t c = t - b * C; // t % C;
// locate
inputs += b * D;
outputs += t;
// write self
if (c < D) {
outputs[0] = inputs[c];
// write freq
} else {
const uint32_t col = c / D - 1;
const uint32_t d = c % D;
const uint32_t freq = col / 2;
const float phase_shift = (col % 2) * (PI() / 2);
outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
}
}
// grad: [B, C], C = D + D * deg * 2
// outputs: [B, C]
// grad_inputs: [B, D]
__global__ void kernel_freq_backward(
const float * __restrict__ grad,
const float * __restrict__ outputs,
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
float * grad_inputs
) {
// parallel on per-element
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * D) return;
const uint32_t b = t / D;
const uint32_t d = t - b * D; // t % D;
// locate
grad += b * C;
outputs += b * C;
grad_inputs += t;
// register
float result = grad[d];
grad += D;
outputs += D;
for (uint32_t f = 0; f < deg; f++) {
result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
grad += 2 * D;
outputs += 2 * D;
}
// write
grad_inputs[0] = result;
}
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
CHECK_CUDA(inputs);
CHECK_CUDA(outputs);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(outputs);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(outputs);
static constexpr uint32_t N_THREADS = 128;
kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr());
}
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
CHECK_CUDA(grad);
CHECK_CUDA(outputs);
CHECK_CUDA(grad_inputs);
CHECK_CONTIGUOUS(grad);
CHECK_CONTIGUOUS(outputs);
CHECK_CONTIGUOUS(grad_inputs);
CHECK_IS_FLOATING(grad);
CHECK_IS_FLOATING(outputs);
CHECK_IS_FLOATING(grad_inputs);
static constexpr uint32_t N_THREADS = 128;
kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr());
}
================================================
FILE: freqencoder/src/freqencoder.h
================================================
# pragma once
#include
#include
// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
================================================
FILE: gridencoder/__init__.py
================================================
from .grid import GridEncoder
================================================
FILE: gridencoder/backend.py
================================================
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
_backend = load(name='_grid_encoder',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'gridencoder.cu',
'bindings.cpp',
]],
)
__all__ = ['_backend']
================================================
FILE: gridencoder/grid.py
================================================
import math
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _gridencoder as _backend
except ImportError:
from .backend import _backend
_gridtype_to_id = {
'hash': 0,
'tiled': 1,
}
_interp_to_id = {
'linear': 0,
'smoothstep': 1,
}
class _grid_encode(Function):
@staticmethod
@custom_fwd
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0, max_level=None):
# inputs: [B, D], float in [0, 1]
# embeddings: [sO, C], float
# offsets: [L + 1], int
# RETURN: [B, F], float
inputs = inputs.contiguous()
B, D = inputs.shape # batch size, coord dim
L = offsets.shape[0] - 1 # level
C = embeddings.shape[1] # embedding dim for each level
S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
H = base_resolution # base resolution
max_level = L if max_level is None else max(min(int(math.ceil(max_level * L)), L), 1)
# manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
# if C % 2 != 0, force float, since half for atomicAdd is very slow.
if torch.is_autocast_enabled() and C % 2 == 0:
embeddings = embeddings.to(torch.half)
# L first, optimize cache for cuda kernel, but needs an extra permute later
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
# zero init if we only calculate partial levels
if max_level < L: outputs.zero_()
if calc_grad_inputs:
dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
if max_level < L: dy_dx.zero_()
else:
dy_dx = None
_backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interpolation)
# permute back to [B, L * C]
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
ctx.dims = [B, D, C, L, S, H, gridtype, interpolation, max_level]
ctx.align_corners = align_corners
return outputs
@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
B, D, C, L, S, H, gridtype, interpolation, max_level = ctx.dims
align_corners = ctx.align_corners
# grad: [B, L * C] --> [L, B, C]
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
grad_embeddings = torch.zeros_like(embeddings)
if dy_dx is not None:
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
else:
grad_inputs = None
_backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation)
if dy_dx is not None:
grad_inputs = grad_inputs.to(inputs.dtype)
return grad_inputs, grad_embeddings, None, None, None, None, None, None, None, None
grid_encode = _grid_encode.apply
class GridEncoder(nn.Module):
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'):
super().__init__()
# the finest resolution desired at the last level, if provided, overridee per_level_scale
if desired_resolution is not None:
per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
self.input_dim = input_dim # coord dims, 2 or 3
self.num_levels = num_levels # num levels, each level multiply resolution by 2
self.level_dim = level_dim # encode channels per level
self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
self.log2_hashmap_size = log2_hashmap_size
self.base_resolution = base_resolution
self.output_dim = num_levels * level_dim
self.gridtype = gridtype
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
self.interpolation = interpolation
self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep"
self.align_corners = align_corners
# allocate parameters
offsets = []
offset = 0
self.max_params = 2 ** log2_hashmap_size
for i in range(num_levels):
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
params_in_level = min(self.max_params, (resolution) ** input_dim) # limit max number
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
offsets.append(offset)
offset += params_in_level
offsets.append(offset)
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
self.register_buffer('offsets', offsets)
self.n_params = offsets[-1] * level_dim
# parameters
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
self.reset_parameters()
def reset_parameters(self):
std = 1e-4
self.embeddings.data.uniform_(-std, std)
def __repr__(self):
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}"
def forward(self, inputs, bound=1, max_level=None):
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
# max_level: only calculate first max_level levels (None will use all levels)
# return: [..., num_levels * level_dim]
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.view(-1, self.input_dim)
outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id, max_level)
outputs = outputs.view(prefix_shape + [self.output_dim])
#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
return outputs
# always run in float precision!
@torch.cuda.amp.autocast(enabled=False)
def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000):
# inputs: [..., input_dim], float in [-b, b], location to calculate TV loss.
D = self.input_dim
C = self.embeddings.shape[1] # embedding dim for each level
L = self.offsets.shape[0] - 1 # level
S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
H = self.base_resolution # base resolution
if inputs is None:
# randomized in [0, 1]
inputs = torch.rand(B, self.input_dim, device=self.embeddings.device)
else:
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
inputs = inputs.view(-1, self.input_dim)
B = inputs.shape[0]
if self.embeddings.grad is None:
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
_backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners)
@torch.cuda.amp.autocast(enabled=False)
def grad_weight_decay(self, weight=0.1):
# level-wise meaned weight decay (ref: zip-nerf)
B = self.embeddings.shape[0] # size of embedding
C = self.embeddings.shape[1] # embedding dim for each level
L = self.offsets.shape[0] - 1 # level
if self.embeddings.grad is None:
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
_backend.grad_weight_decay(self.embeddings, self.embeddings.grad, self.offsets, weight, B, C, L)
================================================
FILE: gridencoder/setup.py
================================================
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
setup(
name='gridencoder', # package name, import this to use python API
ext_modules=[
CUDAExtension(
name='_gridencoder', # extension name, import this to use CUDA API
sources=[os.path.join(_src_path, 'src', f) for f in [
'gridencoder.cu',
'bindings.cpp',
]],
extra_compile_args={
'cxx': c_flags,
'nvcc': nvcc_flags,
}
),
],
cmdclass={
'build_ext': BuildExtension,
}
)
================================================
FILE: gridencoder/src/bindings.cpp
================================================
#include
#include "gridencoder.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)");
m.def("grad_weight_decay", &grad_weight_decay, "grad_weight_decay (CUDA)");
}
================================================
FILE: gridencoder/src/gridencoder.cu
================================================
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here!
__device__ inline at::Half atomicAdd(at::Half *address, at::Half val) {
// requires CUDA >= 10 and ARCH >= 70
// this is very slow compared to float or __half2, never use it.
//return atomicAdd(reinterpret_cast<__half*>(address), val);
}
template
__host__ __device__ inline T div_round_up(T val, T divisor) {
return (val + divisor - 1) / divisor;
}
template
__device__ inline T smoothstep(T val) {
return val*val*(3.0f - 2.0f * val);
}
template
__device__ inline T smoothstep_derivative(T val) {
return 6*val*(1.0f - val);
}
template
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
// coherent type of hashing
constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u };
uint32_t result = 0;
#pragma unroll
for (uint32_t i = 0; i < D; ++i) {
result ^= pos_grid[i] * primes[i];
}
return result;
}
template
__device__ uint32_t get_grid_index(const uint32_t gridtype, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
uint32_t stride = 1;
uint32_t index = 0;
#pragma unroll
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
index += pos_grid[d] * stride;
stride *= resolution;
}
// NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
// gridtype: 0 == hash, 1 == tiled
if (gridtype == 0 && stride > hashmap_size) {
index = fast_hash(pos_grid);
}
return (index % hashmap_size) * C + ch;
}
template
__global__ void kernel_grid(
const float * __restrict__ inputs,
const scalar_t * __restrict__ grid,
const int * __restrict__ offsets,
scalar_t * __restrict__ outputs,
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
scalar_t * __restrict__ dy_dx,
const uint32_t gridtype,
const bool align_corners,
const uint32_t interp
) {
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= B) return;
const uint32_t level = blockIdx.y;
// locate
grid += (uint32_t)offsets[level] * C;
inputs += b * D;
outputs += level * B * C + b * C;
// check input range (should be in [0, 1])
bool flag_oob = false;
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if (inputs[d] < 0 || inputs[d] > 1) {
flag_oob = true;
}
}
// if input out of bound, just set output to 0
if (flag_oob) {
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
outputs[ch] = 0;
}
if (dy_dx) {
dy_dx += b * D * L * C + level * D * C; // B L D C
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
dy_dx[d * C + ch] = 0;
}
}
}
return;
}
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
// calculate coordinate (always use float for precision!)
float pos[D];
float pos_deriv[D];
uint32_t pos_grid[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
// align_corners
if (align_corners) {
pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
} else {
pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
}
pos[d] -= (float)pos_grid[d];
// smoothstep instead of linear
if (interp == 1) {
pos_deriv[d] = smoothstep_derivative(pos[d]);
pos[d] = smoothstep(pos[d]);
} else {
pos_deriv[d] = 1.0f;
}
}
// verification of alignment
// if (level == L - 1 && b < 4) {
// printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
// }
// interpolate
scalar_t results[C] = {0}; // temp results in register
#pragma unroll
for (uint32_t idx = 0; idx < (1 << D); idx++) {
float w = 1;
uint32_t pos_grid_local[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if ((idx & (1 << d)) == 0) {
w *= 1 - pos[d];
pos_grid_local[d] = pos_grid[d];
} else {
w *= pos[d];
pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
}
}
uint32_t index = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid_local);
// writing to register (fast)
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
results[ch] += w * grid[index + ch];
}
//printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
}
// writing to global memory (slow)
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
outputs[ch] = results[ch];
}
// prepare dy_dx
// differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
if (dy_dx) {
dy_dx += b * D * L * C + level * D * C; // B L D C
#pragma unroll
for (uint32_t gd = 0; gd < D; gd++) {
scalar_t results_grad[C] = {0};
#pragma unroll
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
float w = (float)(align_corners ? resolution - 1 : resolution);
uint32_t pos_grid_local[D];
#pragma unroll
for (uint32_t nd = 0; nd < D - 1; nd++) {
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
if ((idx & (1 << nd)) == 0) {
w *= 1 - pos[d];
pos_grid_local[d] = pos_grid[d];
} else {
w *= pos[d];
pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
}
}
pos_grid_local[gd] = pos_grid[gd];
uint32_t index_left = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid_local);
pos_grid_local[gd] = min(pos_grid[gd] + 1, resolution - 1);
uint32_t index_right = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid_local);
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd];
}
}
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
dy_dx[gd * C + ch] = results_grad[ch];
}
}
}
}
template
__global__ void kernel_grid_backward(
const scalar_t * __restrict__ grad,
const float * __restrict__ inputs,
const scalar_t * __restrict__ grid,
const int * __restrict__ offsets,
scalar_t * __restrict__ grad_grid,
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
const uint32_t gridtype,
const bool align_corners,
const uint32_t interp
) {
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
if (b >= B) return;
const uint32_t level = blockIdx.y;
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
// locate
grad_grid += offsets[level] * C;
inputs += b * D;
grad += level * B * C + b * C + ch; // L, B, C
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
// check input range (should be in [0, 1])
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if (inputs[d] < 0 || inputs[d] > 1) {
return; // grad is init as 0, so we simply return.
}
}
// calculate coordinate
float pos[D];
uint32_t pos_grid[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
// align_corners
if (align_corners) {
pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
} else {
pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
}
pos[d] -= (float)pos_grid[d];
// smoothstep instead of linear
if (interp == 1) {
pos[d] = smoothstep(pos[d]);
}
}
scalar_t grad_cur[N_C] = {0}; // fetch to register
#pragma unroll
for (uint32_t c = 0; c < N_C; c++) {
grad_cur[c] = grad[c];
}
// interpolate
#pragma unroll
for (uint32_t idx = 0; idx < (1 << D); idx++) {
float w = 1;
uint32_t pos_grid_local[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if ((idx & (1 << d)) == 0) {
w *= 1 - pos[d];
pos_grid_local[d] = pos_grid[d];
} else {
w *= pos[d];
pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
}
}
uint32_t index = get_grid_index(gridtype, ch, hashmap_size, resolution, pos_grid_local);
// atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
// TODO: use float which is better than __half, if N_C % 2 != 0
if (std::is_same::value && N_C % 2 == 0) {
#pragma unroll
for (uint32_t c = 0; c < N_C; c += 2) {
// process two __half at once (by interpreting as a __half2)
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
atomicAdd((__half2*)&grad_grid[index + c], v);
}
// float, or __half when N_C % 2 != 0 (which means C == 1)
} else {
#pragma unroll
for (uint32_t c = 0; c < N_C; c++) {
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
}
}
}
}
template
__global__ void kernel_input_backward(
const scalar_t * __restrict__ grad,
const scalar_t * __restrict__ dy_dx,
scalar_t * __restrict__ grad_inputs,
uint32_t B, uint32_t L
) {
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * D) return;
const uint32_t b = t / D;
const uint32_t d = t - b * D;
dy_dx += b * L * D * C;
scalar_t result = 0;
# pragma unroll
for (int l = 0; l < L; l++) {
# pragma unroll
for (int ch = 0; ch < C; ch++) {
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
}
}
grad_inputs[t] = result;
}
template
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
static constexpr uint32_t N_THREAD = 512;
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), max_level, 1 };
switch (C) {
case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
case 16: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
case 32: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
}
}
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
// H: base resolution
// dy_dx: [B, L * D * C]
template
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
switch (D) {
case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4 or 5."};
}
}
template
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
static constexpr uint32_t N_THREAD = 256;
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), max_level, 1 };
switch (C) {
case 1:
kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
break;
case 2:
kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
break;
case 4:
kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
break;
case 8:
kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
break;
case 16:
kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
break;
case 32:
kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
}
}
// grad: [L, B, C], float
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// grad_embeddings: [sO, C]
// H: base resolution
template
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
switch (D) {
case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4 or 5."};
}
}
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
CHECK_CUDA(inputs);
CHECK_CUDA(embeddings);
CHECK_CUDA(offsets);
CHECK_CUDA(outputs);
// CHECK_CUDA(dy_dx);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(embeddings);
CHECK_CONTIGUOUS(offsets);
CHECK_CONTIGUOUS(outputs);
// CHECK_CONTIGUOUS(dy_dx);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(embeddings);
CHECK_IS_INT(offsets);
CHECK_IS_FLOATING(outputs);
// CHECK_IS_FLOATING(dy_dx);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
embeddings.scalar_type(), "grid_encode_forward", ([&] {
grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners, interp);
}));
}
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
CHECK_CUDA(grad);
CHECK_CUDA(inputs);
CHECK_CUDA(embeddings);
CHECK_CUDA(offsets);
CHECK_CUDA(grad_embeddings);
// CHECK_CUDA(dy_dx);
// CHECK_CUDA(grad_inputs);
CHECK_CONTIGUOUS(grad);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(embeddings);
CHECK_CONTIGUOUS(offsets);
CHECK_CONTIGUOUS(grad_embeddings);
// CHECK_CONTIGUOUS(dy_dx);
// CHECK_CONTIGUOUS(grad_inputs);
CHECK_IS_FLOATING(grad);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(embeddings);
CHECK_IS_INT(offsets);
CHECK_IS_FLOATING(grad_embeddings);
// CHECK_IS_FLOATING(dy_dx);
// CHECK_IS_FLOATING(grad_inputs);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "grid_encode_backward", ([&] {
grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners, interp);
}));
}
template
__global__ void kernel_grad_tv(
const scalar_t * __restrict__ inputs,
const scalar_t * __restrict__ grid,
scalar_t * __restrict__ grad,
const int * __restrict__ offsets,
const float weight,
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
const uint32_t gridtype,
const bool align_corners
) {
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= B) return;
const uint32_t level = blockIdx.y;
// locate
inputs += b * D;
grid += (uint32_t)offsets[level] * C;
grad += (uint32_t)offsets[level] * C;
// check input range (should be in [0, 1])
bool flag_oob = false;
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if (inputs[d] < 0 || inputs[d] > 1) {
flag_oob = true;
}
}
// if input out of bound, do nothing
if (flag_oob) return;
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
// calculate coordinate
float pos[D];
uint32_t pos_grid[D]; // [0, resolution]
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
// align_corners
if (align_corners) {
pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
} else {
pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
}
}
//printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
// total variation on pos_grid
scalar_t results[C] = {0}; // temp results in register
scalar_t idelta[C] = {0};
uint32_t index = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid);
scalar_t w = weight / (2 * D);
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
uint32_t cur_d = pos_grid[d];
scalar_t grad_val;
// right side
if (cur_d < resolution) {
pos_grid[d] = cur_d + 1;
uint32_t index_right = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid);
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
grad_val = (grid[index + ch] - grid[index_right + ch]);
results[ch] += grad_val;
idelta[ch] += grad_val * grad_val;
}
}
// left side
if (cur_d > 0) {
pos_grid[d] = cur_d - 1;
uint32_t index_left = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid);
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
grad_val = (grid[index + ch] - grid[index_left + ch]);
results[ch] += grad_val;
idelta[ch] += grad_val * grad_val;
}
}
// reset
pos_grid[d] = cur_d;
}
// writing to global memory (slow)
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
// index may collide, so use atomic!
atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f));
}
}
template
void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
static constexpr uint32_t N_THREAD = 512;
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
switch (C) {
case 1: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
case 2: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
case 4: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
case 8: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
case 16: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
case 32: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
}
}
template
void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
switch (D) {
case 2: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
case 3: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
case 4: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
case 5: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4, or 5."};
}
}
void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
embeddings.scalar_type(), "grad_total_variation", ([&] {
grad_total_variation_cuda(inputs.data_ptr(), embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, D, C, L, S, H, gridtype, align_corners);
}));
}
template
__global__ void kernel_grad_wd(
const scalar_t * __restrict__ grid,
scalar_t * __restrict__ grad,
const int * __restrict__ offsets,
const float weight,
const uint32_t B, const uint32_t L, const uint32_t C
) {
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= B * C) return;
// locate
grid += b;
grad += b;
// decide in which level is this thread...
uint32_t level = 0;
const uint32_t n = b / C;
// binary search b in offsets
uint32_t l = 0, r = L;
while (l < r) {
uint32_t m = (l + r) / 2;
if (offsets[m] <= n) {
level = m;
l = m + 1;
} else {
r = m;
}
}
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
grad[0] += 2 * weight * grid[0] / hashmap_size;
}
void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
embeddings.scalar_type(), "grad_weight_decay", ([&] {
static constexpr uint32_t N_THREAD = 1024;
const dim3 blocks_hashgrid = { div_round_up(B * C, N_THREAD), 1, 1 };
kernel_grad_wd<<>>(embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, L, C);
}));
}
================================================
FILE: gridencoder/src/gridencoder.h
================================================
#ifndef _HASH_ENCODE_H
#define _HASH_ENCODE_H
#include
#include
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// outputs: [B, L * C], float
// H: base resolution
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners);
void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L);
#endif
================================================
FILE: guidance/clip_utils.py
================================================
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import clip
class CLIP(nn.Module):
def __init__(self, device, **kwargs):
super().__init__()
self.device = device
self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False)
self.aug = T.Compose([
T.Resize((224, 224)),
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def get_text_embeds(self, prompt, **kwargs):
text = clip.tokenize(prompt).to(self.device)
text_z = self.clip_model.encode_text(text)
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
return text_z
def get_img_embeds(self, image, **kwargs):
image_z = self.clip_model.encode_image(self.aug(image))
image_z = image_z / image_z.norm(dim=-1, keepdim=True)
return image_z
def train_step(self, clip_z, pred_rgb, grad_scale=10, **kwargs):
"""
Args:
grad_scale: scalar or 1-tensor of size [B], i.e. 1 grad_scale per batch item.
"""
# TODO: resize the image from NeRF-rendered resolution (e.g. 128x128) to what CLIP expects (512x512), to prevent Pytorch warning about `antialias=None`.
image_z = self.clip_model.encode_image(self.aug(pred_rgb))
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
loss = 0
if 'image' in clip_z:
loss -= ((image_z * clip_z['image']).sum(-1) * grad_scale).mean()
if 'text' in clip_z:
loss -= ((image_z * clip_z['text']).sum(-1) * grad_scale).mean()
return loss
================================================
FILE: guidance/if_utils.py
================================================
from transformers import logging
from diffusers import IFPipeline, DDPMScheduler
# suppress partial model loading warning
logging.set_verbosity_error()
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from .perpneg_utils import weighted_perpendicular_aggregator
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = True
class IF(nn.Module):
def __init__(self, device, vram_O, t_range=[0.02, 0.98]):
super().__init__()
self.device = device
print(f'[INFO] loading DeepFloyd IF-I-XL...')
model_key = "DeepFloyd/IF-I-XL-v1.0"
is_torch2 = torch.__version__[0] == '2'
# Create model
pipe = IFPipeline.from_pretrained(model_key, variant="fp16", torch_dtype=torch.float16)
if not is_torch2:
pipe.enable_xformers_memory_efficient_attention()
if vram_O:
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
pipe.enable_model_cpu_offload()
else:
pipe.to(device)
self.unet = pipe.unet
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.unet = pipe.unet
self.scheduler = pipe.scheduler
self.pipe = pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
print(f'[INFO] loaded DeepFloyd IF-I-XL!')
@torch.no_grad()
def get_text_embeds(self, prompt):
# prompt: [str]
# TODO: should I add the preprocessing at https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py#LL486C10-L486C28
prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
inputs = self.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt')
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1):
# [0, 1] to [-1, 1] and make sure shape is [64, 64]
images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(images)
images_noisy = self.scheduler.add_noise(images, noise, t)
# pred noise
model_input = torch.cat([images_noisy] * 2)
model_input = self.scheduler.scale_model_input(model_input, t)
tt = torch.cat([t] * 2)
noise_pred = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# TODO: how to use the variance here?
# noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
# w(t), sigma_t^2
w = (1 - self.alphas[t])
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)
targets = (images - grad).detach()
loss = 0.5 * F.mse_loss(images.float(), targets, reduction='sum') / images.shape[0]
return loss
def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, grad_scale=1):
B = pred_rgb.shape[0]
K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts
# [0, 1] to [-1, 1] and make sure shape is [64, 64]
images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(images)
images_noisy = self.scheduler.add_noise(images, noise, t)
# pred noise
model_input = torch.cat([images_noisy] * (1 + K))
model_input = self.scheduler.scale_model_input(model_input, t)
tt = torch.cat([t] * (1 + K))
unet_output = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample
noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:]
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
# noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1)
noise_pred = noise_pred_uncond + guidance_scale * weighted_perpendicular_aggregator(delta_noise_preds, weights, B)
# w(t), sigma_t^2
w = (1 - self.alphas[t])
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)
targets = (images - grad).detach()
loss = 0.5 * F.mse_loss(images.float(), targets, reduction='sum') / images.shape[0]
return loss
@torch.no_grad()
def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps=50, guidance_scale=7.5):
images = torch.randn((1, 3, height, width), device=text_embeddings.device, dtype=text_embeddings.dtype)
images = images * self.scheduler.init_noise_sigma
self.scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(self.scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
model_input = torch.cat([images] * 2)
model_input = self.scheduler.scale_model_input(model_input, t)
# predict the noise residual
noise_pred = self.unet(model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
# compute the previous noisy sample x_t -> x_t-1
images = self.scheduler.step(noise_pred, t, images).prev_sample
images = (images + 1) / 2
return images
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
neg_embeds = self.get_text_embeds(negative_prompts)
text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
# Text embeds -> img
imgs = self.produce_imgs(text_embeds, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
# Img to Numpy
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
imgs = (imgs * 255).round().astype('uint8')
return imgs
if __name__ == '__main__':
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('prompt', type=str)
parser.add_argument('--negative', default='', type=str)
parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage")
parser.add_argument('-H', type=int, default=64)
parser.add_argument('-W', type=int, default=64)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--steps', type=int, default=50)
opt = parser.parse_args()
seed_everything(opt.seed)
device = torch.device('cuda')
sd = IF(device, opt.vram_O)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()
================================================
FILE: guidance/perpneg_utils.py
================================================
import torch
# Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm
def get_perpendicular_component(x, y):
assert x.shape == y.shape
return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y
def batch_get_perpendicular_component(x, y):
assert x.shape == y.shape
result = []
for i in range(x.shape[0]):
result.append(get_perpendicular_component(x[i], y[i]))
return torch.stack(result)
def weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_size):
"""
Notes:
- weights: an array with the weights for combining the noise predictions
- delta_noise_preds: [B x K, 4, 64, 64], K = max_prompts_per_dir
"""
delta_noise_preds = delta_noise_preds.split(batch_size, dim=0) # K x [B, 4, 64, 64]
weights = weights.split(batch_size, dim=0) # K x [B]
# print(f"{weights[0].shape = } {weights = }")
assert torch.all(weights[0] == 1.0)
main_positive = delta_noise_preds[0] # [B, 4, 64, 64]
accumulated_output = torch.zeros_like(main_positive)
for i, complementary_noise_pred in enumerate(delta_noise_preds[1:], start=1):
# print(f"\n{i = }, {weights[i] = }, {weights[i].shape = }\n")
idx_non_zero = torch.abs(weights[i]) > 1e-4
# print(f"{idx_non_zero.shape = }, {idx_non_zero = }")
# print(f"{weights[i][idx_non_zero].shape = }, {weights[i][idx_non_zero] = }")
# print(f"{complementary_noise_pred.shape = }, {complementary_noise_pred[idx_non_zero].shape = }")
# print(f"{main_positive.shape = }, {main_positive[idx_non_zero].shape = }")
if sum(idx_non_zero) == 0:
continue
accumulated_output[idx_non_zero] += weights[i][idx_non_zero].reshape(-1, 1, 1, 1) * batch_get_perpendicular_component(complementary_noise_pred[idx_non_zero], main_positive[idx_non_zero])
assert accumulated_output.shape == main_positive.shape, f"{accumulated_output.shape = }, {main_positive.shape = }"
return accumulated_output + main_positive
================================================
FILE: guidance/sd_utils.py
================================================
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline
from diffusers.utils.import_utils import is_xformers_available
from os.path import isfile
from pathlib import Path
# suppress partial model loading warning
logging.set_verbosity_error()
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image
from torch.cuda.amp import custom_bwd, custom_fwd
from .perpneg_utils import weighted_perpendicular_aggregator
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = True
class StableDiffusion(nn.Module):
def __init__(self, device, fp16, vram_O, sd_version='2.1', hf_key=None, t_range=[0.02, 0.98]):
super().__init__()
self.device = device
self.sd_version = sd_version
print(f'[INFO] loading stable diffusion...')
if hf_key is not None:
print(f'[INFO] using hugging face custom model key: {hf_key}')
model_key = hf_key
elif self.sd_version == '2.1':
model_key = "stabilityai/stable-diffusion-2-1-base"
elif self.sd_version == '2.0':
model_key = "stabilityai/stable-diffusion-2-base"
elif self.sd_version == '1.5':
model_key = "runwayml/stable-diffusion-v1-5"
else:
raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
self.precision_t = torch.float16 if fp16 else torch.float32
# Create model
pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.precision_t)
if vram_O:
pipe.enable_sequential_cpu_offload()
pipe.enable_vae_slicing()
pipe.unet.to(memory_format=torch.channels_last)
pipe.enable_attention_slicing(1)
# pipe.enable_model_cpu_offload()
else:
pipe.to(device)
self.vae = pipe.vae
self.tokenizer = pipe.tokenizer
self.text_encoder = pipe.text_encoder
self.unet = pipe.unet
self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler", torch_dtype=self.precision_t)
del pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
print(f'[INFO] loaded stable diffusion!')
@torch.no_grad()
def get_text_embeds(self, prompt):
# prompt: [str]
inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1,
save_guidance_path:Path=None):
if as_latent:
latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
else:
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
# encode image into latents with vae, requires grad!
latents = self.encode_imgs(pred_rgb_512)
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * 2)
tt = torch.cat([t] * 2)
noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample
# perform guidance (high scale from paper!)
noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
# import kiui
# latents_tmp = torch.randn((1, 4, 64, 64), device=self.device)
# latents_tmp = latents_tmp.detach()
# kiui.lo(latents_tmp)
# self.scheduler.set_timesteps(30)
# for i, t in enumerate(self.scheduler.timesteps):
# latent_model_input = torch.cat([latents_tmp] * 3)
# noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
# noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
# noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond)
# latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample']
# imgs = self.decode_latents(latents_tmp)
# kiui.vis.plot_image(imgs)
# w(t), sigma_t^2
w = (1 - self.alphas[t])
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)
if save_guidance_path:
with torch.no_grad():
if as_latent:
pred_rgb_512 = self.decode_latents(latents)
# visualize predicted denoised image
# The following block of code is equivalent to `predict_start_from_noise`...
# see zero123_utils.py's version for a simpler implementation.
alphas = self.scheduler.alphas.to(latents)
total_timesteps = self.max_step - self.min_step + 1
index = total_timesteps - t.to(latents.device) - 1
b = len(noise_pred)
a_t = alphas[index].reshape(b,1,1,1).to(self.device)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)
pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))
# visualize noisier image
result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))
# TODO: also denoise all-the-way
# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)
save_image(viz_images, save_guidance_path)
targets = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
return loss
def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1,
save_guidance_path:Path=None):
B = pred_rgb.shape[0]
K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts
if as_latent:
latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
else:
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
# encode image into latents with vae, requires grad!
latents = self.encode_imgs(pred_rgb_512)
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * (1 + K))
tt = torch.cat([t] * (1 + K))
unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample
# perform guidance (high scale from paper!)
noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:]
delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1)
noise_pred = noise_pred_uncond + guidance_scale * weighted_perpendicular_aggregator(delta_noise_preds, weights, B)
# import kiui
# latents_tmp = torch.randn((1, 4, 64, 64), device=self.device)
# latents_tmp = latents_tmp.detach()
# kiui.lo(latents_tmp)
# self.scheduler.set_timesteps(30)
# for i, t in enumerate(self.scheduler.timesteps):
# latent_model_input = torch.cat([latents_tmp] * 3)
# noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
# noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
# noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond)
# latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample']
# imgs = self.decode_latents(latents_tmp)
# kiui.vis.plot_image(imgs)
# w(t), sigma_t^2
w = (1 - self.alphas[t])
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)
if save_guidance_path:
with torch.no_grad():
if as_latent:
pred_rgb_512 = self.decode_latents(latents)
# visualize predicted denoised image
# The following block of code is equivalent to `predict_start_from_noise`...
# see zero123_utils.py's version for a simpler implementation.
alphas = self.scheduler.alphas.to(latents)
total_timesteps = self.max_step - self.min_step + 1
index = total_timesteps - t.to(latents.device) - 1
b = len(noise_pred)
a_t = alphas[index].reshape(b,1,1,1).to(self.device)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)
pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))
# visualize noisier image
result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))
# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)
save_image(viz_images, save_guidance_path)
targets = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
return loss
@torch.no_grad()
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
if latents is None:
latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
self.scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(self.scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
# perform guidance
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
return latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def encode_imgs(self, imgs):
# imgs: [B, 3, H, W]
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(negative_prompts, str):
negative_prompts = [negative_prompts]
# Prompts -> text embeds
pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
neg_embeds = self.get_text_embeds(negative_prompts)
text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
# Text embeds -> img latents
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
# Img latents -> imgs
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
# Img to Numpy
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
imgs = (imgs * 255).round().astype('uint8')
return imgs
if __name__ == '__main__':
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('prompt', type=str)
parser.add_argument('--negative', default='', type=str)
parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
parser.add_argument('--fp16', action='store_true', help="use float16 for training")
parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage")
parser.add_argument('-H', type=int, default=512)
parser.add_argument('-W', type=int, default=512)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--steps', type=int, default=50)
opt = parser.parse_args()
seed_everything(opt.seed)
device = torch.device('cuda')
sd = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key)
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
# visualize image
plt.imshow(imgs[0])
plt.show()
================================================
FILE: guidance/zero123_utils.py
================================================
import math
import numpy as np
from omegaconf import OmegaConf
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from torchvision.utils import save_image
from diffusers import DDIMScheduler
import sys
from os import path
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
from ldm.util import instantiate_from_config
# load model
def load_model_from_config(config, ckpt, device, vram_O=False, verbose=False):
pl_sd = torch.load(ckpt, map_location='cpu')
if 'global_step' in pl_sd and verbose:
print(f'[INFO] Global Step: {pl_sd["global_step"]}')
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print('[INFO] missing keys: \n', m)
if len(u) > 0 and verbose:
print('[INFO] unexpected keys: \n', u)
# manually load ema and delete it to save GPU memory
if model.use_ema:
if verbose:
print('[INFO] loading EMA...')
model.model_ema.copy_to(model.model)
del model.model_ema
if vram_O:
# we don't need decoder
del model.first_stage_model.decoder
torch.cuda.empty_cache()
model.eval().to(device)
return model
class Zero123(nn.Module):
def __init__(self, device, fp16,
config='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml',
ckpt='./pretrained/zero123/zero123-xl.ckpt', vram_O=False, t_range=[0.02, 0.98], opt=None):
super().__init__()
self.device = device
self.fp16 = fp16
self.vram_O = vram_O
self.t_range = t_range
self.opt = opt
self.config = OmegaConf.load(config)
# TODO: seems it cannot load into fp16...
self.model = load_model_from_config(self.config, ckpt, device=self.device, vram_O=vram_O)
# timesteps: use diffuser for convenience... hope it's alright.
self.num_train_timesteps = self.config.model.params.timesteps
self.scheduler = DDIMScheduler(
self.num_train_timesteps,
self.config.model.params.linear_start,
self.config.model.params.linear_end,
beta_schedule='scaled_linear',
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
@torch.no_grad()
def get_img_embeds(self, x):
# x: image tensor [B, 3, 256, 256] in [0, 1]
x = x * 2 - 1
c = [self.model.get_learned_conditioning(xx.unsqueeze(0)) for xx in x] #.tile(n_samples, 1, 1)
v = [self.model.encode_first_stage(xx.unsqueeze(0)).mode() for xx in x]
return c, v
def angle_between(self, sph_v1, sph_v2):
def sph2cart(sv):
r, theta, phi = sv[0], sv[1], sv[2]
return torch.tensor([r * torch.sin(theta) * torch.cos(phi), r * torch.sin(theta) * torch.sin(phi), r * torch.cos(theta)])
def unit_vector(v):
return v / torch.linalg.norm(v)
def angle_between_2_sph(sv1, sv2):
v1, v2 = sph2cart(sv1), sph2cart(sv2)
v1_u, v2_u = unit_vector(v1), unit_vector(v2)
return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0))
angles = torch.empty(len(sph_v1), len(sph_v2))
for i, sv1 in enumerate(sph_v1):
for j, sv2 in enumerate(sph_v2):
angles[i][j] = angle_between_2_sph(sv1, sv2)
return angles
def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scale=3, as_latent=False, grad_scale=1, save_guidance_path:Path=None):
# pred_rgb: tensor [1, 3, H, W] in [0, 1]
# adjust SDS scale based on how far the novel view is from the known view
ref_radii = embeddings['ref_radii']
ref_polars = embeddings['ref_polars']
ref_azimuths = embeddings['ref_azimuths']
v1 = torch.stack([radius + ref_radii[0], torch.deg2rad(polar + ref_polars[0]), torch.deg2rad(azimuth + ref_azimuths[0])], dim=-1) # polar,azimuth,radius are all actually delta wrt default
v2 = torch.stack([torch.tensor(ref_radii), torch.deg2rad(torch.tensor(ref_polars)), torch.deg2rad(torch.tensor(ref_azimuths))], dim=-1)
angles = torch.rad2deg(self.angle_between(v1, v2)).to(self.device)
if self.opt.zero123_grad_scale == 'angle':
grad_scale = (angles.min(dim=1)[0] / (180/len(ref_azimuths))) * grad_scale # rethink 180/len(ref_azimuths) # claforte: try inverting grad_scale or just fixing it to 1.0
elif self.opt.zero123_grad_scale == 'None':
grad_scale = 1.0 # claforte: I think this might converge faster...?
else:
assert False, f'Unrecognized `zero123_grad_scale`: {self.opt.zero123_grad_scale}'
if as_latent:
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
else:
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
latents = self.encode_imgs(pred_rgb_256)
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
# Set weights acc to closeness in angle
if len(ref_azimuths) > 1:
inv_angles = 1/angles
inv_angles[inv_angles > 100] = 100
inv_angles /= inv_angles.max(dim=-1, keepdim=True)[0]
inv_angles[inv_angles < 0.1] = 0
else:
inv_angles = torch.tensor([1.]).to(self.device)
# Multiply closeness-weight by user-given weights
zero123_ws = torch.tensor(embeddings['zero123_ws'])[None, :].to(self.device) * inv_angles
zero123_ws /= zero123_ws.max(dim=-1, keepdim=True)[0]
zero123_ws[zero123_ws < 0.1] = 0
with torch.no_grad():
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
x_in = torch.cat([latents_noisy] * 2)
t_in = torch.cat([t] * 2)
noise_preds = []
# Loop through each ref image
for (zero123_w, c_crossattn, c_concat, ref_polar, ref_azimuth, ref_radius) in zip(zero123_ws.T,
embeddings['c_crossattn'], embeddings['c_concat'],
ref_polars, ref_azimuths, ref_radii):
# polar,azimuth,radius are all actually delta wrt default
p = polar + ref_polars[0] - ref_polar
a = azimuth + ref_azimuths[0] - ref_azimuth
a[a > 180] -= 360 # range in [-180, 180]
r = radius + ref_radii[0] - ref_radius
# T = torch.tensor([math.radians(p), math.sin(math.radians(-a)), math.cos(math.radians(a)), r])
# T = T[None, None, :].to(self.device)
T = torch.stack([torch.deg2rad(p), torch.sin(torch.deg2rad(-a)), torch.cos(torch.deg2rad(a)), r], dim=-1)[:, None, :]
cond = {}
clip_emb = self.model.cc_projection(torch.cat([c_crossattn.repeat(len(T), 1, 1), T], dim=-1))
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).repeat(len(T), 1, 1, 1).to(self.device), c_concat.repeat(len(T), 1, 1, 1)], dim=0)]
noise_pred = self.model.apply_model(x_in, t_in, cond)
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
noise_preds.append(zero123_w[:, None, None, None] * noise_pred)
noise_pred = torch.stack(noise_preds).sum(dim=0) / zero123_ws.sum(dim=-1)[:, None, None, None]
w = (1 - self.alphas[t])
grad = (grad_scale * w)[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)
# import kiui
# if not as_latent:
# kiui.vis.plot_image(pred_rgb_256)
# kiui.vis.plot_matrix(latents)
# kiui.vis.plot_matrix(grad)
# import kiui
# latents = torch.randn((1, 4, 32, 32), device=self.device)
# kiui.lo(latents)
# self.scheduler.set_timesteps(30)
# with torch.no_grad():
# for i, t in enumerate(self.scheduler.timesteps):
# x_in = torch.cat([latents] * 2)
# t_in = torch.cat([t.view(1)] * 2).to(self.device)
# noise_pred = self.model.apply_model(x_in, t_in, cond)
# noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
# noise_pred = noise_pred_uncond + 3 * (noise_pred_cond - noise_pred_uncond)
# latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
# imgs = self.decode_latents(latents)
# print(polar, azimuth, radius)
# kiui.vis.plot_image(pred_rgb_256, imgs)
if save_guidance_path:
with torch.no_grad():
if as_latent:
pred_rgb_256 = self.decode_latents(latents) # claforte: test!
# visualize predicted denoised image
result_hopefully_less_noisy_image = self.decode_latents(self.model.predict_start_from_noise(latents_noisy, t, noise_pred))
# visualize noisier image
result_noisier_image = self.decode_latents(latents_noisy)
# TODO: also denoise all-the-way
# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
viz_images = torch.cat([pred_rgb_256, result_noisier_image, result_hopefully_less_noisy_image],dim=-1)
save_image(viz_images, save_guidance_path)
targets = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
return loss
# verification
@torch.no_grad()
def __call__(self,
image, # image tensor [1, 3, H, W] in [0, 1]
polar=0, azimuth=0, radius=0, # new view params
scale=3, ddim_steps=50, ddim_eta=1, h=256, w=256, # diffusion params
c_crossattn=None, c_concat=None, post_process=True,
):
if c_crossattn is None:
embeddings = self.get_img_embeds(image)
T = torch.tensor([math.radians(polar), math.sin(math.radians(azimuth)), math.cos(math.radians(azimuth)), radius])
T = T[None, None, :].to(self.device)
cond = {}
clip_emb = self.model.cc_projection(torch.cat([embeddings['c_crossattn'] if c_crossattn is None else c_crossattn, T], dim=-1))
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
cond['c_concat'] = [torch.cat([torch.zeros_like(embeddings['c_concat']).to(self.device), embeddings['c_concat']], dim=0)] if c_concat is None else [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)]
# produce latents loop
latents = torch.randn((1, 4, h // 8, w // 8), device=self.device)
self.scheduler.set_timesteps(ddim_steps)
for i, t in enumerate(self.scheduler.timesteps):
x_in = torch.cat([latents] * 2)
t_in = torch.cat([t.view(1)] * 2).to(self.device)
noise_pred = self.model.apply_model(x_in, t_in, cond)
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + scale * (noise_pred_cond - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)['prev_sample']
imgs = self.decode_latents(latents)
imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs
return imgs
def decode_latents(self, latents):
# zs: [B, 4, 32, 32] Latent space image
# with self.model.ema_scope():
imgs = self.model.decode_first_stage(latents)
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs # [B, 3, 256, 256] RGB space image
def encode_imgs(self, imgs):
# imgs: [B, 3, 256, 256] RGB space image
# with self.model.ema_scope():
imgs = imgs * 2 - 1
latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0)
return latents # [B, 4, 32, 32] Latent space image
if __name__ == '__main__':
import cv2
import argparse
import numpy as np
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('input', type=str)
parser.add_argument('--fp16', action='store_true', help="use float16 for training") # no use now, can only run in fp32
parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]')
parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]')
parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]')
opt = parser.parse_args()
device = torch.device('cuda')
print(f'[INFO] loading image from {opt.input} ...')
image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
image = image.astype(np.float32) / 255.0
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
print(f'[INFO] loading model ...')
zero123 = Zero123(device, opt.fp16, opt=opt)
print(f'[INFO] running model ...')
outputs = zero123(image, polar=opt.polar, azimuth=opt.azimuth, radius=opt.radius)
plt.imshow(outputs[0])
plt.show()
================================================
FILE: ldm/extras.py
================================================
from pathlib import Path
from omegaconf import OmegaConf
import torch
from ldm.util import instantiate_from_config
import logging
from contextlib import contextmanager
from contextlib import contextmanager
import logging
@contextmanager
def all_logging_disabled(highest_level=logging.CRITICAL):
"""
A context manager that will prevent any logging messages
triggered during the body from being processed.
:param highest_level: the maximum logging level in use.
This would only need to be changed if a custom level greater than CRITICAL
is defined.
https://gist.github.com/simon-weber/7853144
"""
# two kind-of hacks here:
# * can't get the highest logging level in effect => delegate to the user
# * can't get the current module-level override => use an undocumented
# (but non-private!) interface
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def load_training_dir(train_dir, device, epoch="last"):
"""Load a checkpoint and config from training directory"""
train_dir = Path(train_dir)
ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
config = list(train_dir.rglob(f"*-project.yaml"))
assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
if len(config) > 1:
print(f"found {len(config)} matching config files")
config = sorted(config)[-1]
print(f"selecting {config}")
else:
config = config[0]
config = OmegaConf.load(config)
return load_model_from_config(config, ckpt[0], device)
def load_model_from_config(config, ckpt, device="cpu", verbose=False):
"""Loads a model from config and a ckpt
if config is a path will use omegaconf to load
"""
if isinstance(config, (str, Path)):
config = OmegaConf.load(config)
with all_logging_disabled():
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
model.to(device)
model.eval()
model.cond_stage_model.device = device
return model
================================================
FILE: ldm/guidance.py
================================================
from typing import List, Tuple
from scipy import interpolate
import numpy as np
import torch
import matplotlib.pyplot as plt
from IPython.display import clear_output
import abc
class GuideModel(torch.nn.Module, abc.ABC):
def __init__(self) -> None:
super().__init__()
@abc.abstractmethod
def preprocess(self, x_img):
pass
@abc.abstractmethod
def compute_loss(self, inp):
pass
class Guider(torch.nn.Module):
def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
"""Apply classifier guidance
Specify a guidance scale as either a scalar
Or a schedule as a list of tuples t = 0->1 and scale, e.g.
[(0, 10), (0.5, 20), (1, 50)]
"""
super().__init__()
self.sampler = sampler
self.index = 0
self.show = verbose
self.guide_model = guide_model
self.history = []
if isinstance(scale, (Tuple, List)):
times = np.array([x[0] for x in scale])
values = np.array([x[1] for x in scale])
self.scale_schedule = {"times": times, "values": values}
else:
self.scale_schedule = float(scale)
self.ddim_timesteps = sampler.ddim_timesteps
self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
def get_scales(self):
if isinstance(self.scale_schedule, float):
return len(self.ddim_timesteps)*[self.scale_schedule]
interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
return interpolater(fractional_steps)
def modify_score(self, model, e_t, x, t, c):
# TODO look up index by t
scale = self.get_scales()[self.index]
if (scale == 0):
return e_t
sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
inp = self.guide_model.preprocess(x_img)
loss = self.guide_model.compute_loss(inp)
grads = torch.autograd.grad(loss.sum(), x_in)[0]
correction = grads * scale
if self.show:
clear_output(wait=True)
print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
plt.axis('off')
plt.show()
plt.imshow(correction[0][0].detach().cpu())
plt.axis('off')
plt.show()
e_t_mod = e_t - sqrt_1ma*correction
if self.show:
fig, axs = plt.subplots(1, 3)
axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
plt.show()
self.index += 1
return e_t_mod
================================================
FILE: ldm/lr_scheduler.py
================================================
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi))
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n,**kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi))
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
self.last_f = f
return f
================================================
FILE: ldm/models/autoencoder.py
================================================
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
]
return [opt_ae, opt_disc], scheduler
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
================================================
FILE: ldm/models/diffusion/__init__.py
================================================
================================================
FILE: ldm/models/diffusion/classifier.py
================================================
import os
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from copy import deepcopy
from einops import rearrange
from glob import glob
from natsort import natsorted
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
__models__ = {
'class_label': EncoderUNetModel,
'segmentation': UNetModel
}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(self,
diffusion_path,
num_classes,
ckpt_path=None,
pool='attention',
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.e-2,
log_steps=10,
monitor='val/loss',
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
# get latest config of diffusion model
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion()
self.monitor = monitor
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
self.log_steps = log_steps
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
else self.diffusion_model.cond_stage_key
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
if self.label_key not in __models__:
raise NotImplementedError()
self.load_classifier(ckpt_path, pool)
self.scheduler_config = scheduler_config
self.use_scheduler = self.scheduler_config is not None
self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config)
self.diffusion_model = model.eval()
self.diffusion_model.train = disabled_train
for param in self.diffusion_model.parameters():
param.requires_grad = False
def load_classifier(self, ckpt_path, pool):
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
model_config.out_channels = self.num_classes
if self.label_key == 'class_label':
model_config.pool = pool
self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None:
print('#####################################################################')
print(f'load from ckpt "{ckpt_path}"')
print('#####################################################################')
self.init_from_ckpt(ckpt_path)
@torch.no_grad()
def get_x_noisy(self, x, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
# todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t)
@torch.no_grad()
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
@torch.no_grad()
def get_conditioning(self, batch, k=None):
if k is None:
k = self.label_key
assert k is not None, 'Needs to provide label key'
targets = batch[k].to(self.device)
if self.label_key == 'segmentation':
targets = rearrange(targets, 'b h w c -> b c h w')
for down in range(self.numd):
h, w = targets.shape[-2:]
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
# targets = rearrange(targets,'b c h w -> b h w c')
return targets
def compute_top_k(self, logits, labels, k, reduction="mean"):
_, top_ks = torch.topk(logits, k, dim=1)
if reduction == "mean":
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
elif reduction == "none":
return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self):
# save some memory
self.diffusion_model.model.to('cpu')
@torch.no_grad()
def write_logs(self, loss, logits, targets):
log_prefix = 'train' if self.training else 'val'
log = {}
log[f"{log_prefix}/loss"] = loss.mean()
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
logits, targets, k=1, reduction="mean"
)
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
logits, targets, k=5, reduction="mean"
)
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
def shared_step(self, batch, t=None):
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
targets = self.get_conditioning(batch)
if targets.dim() == 4:
targets = targets.argmax(dim=1)
if t is None:
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
else:
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t)
loss = F.cross_entropy(logits, targets, reduction='none')
self.write_logs(loss.detach(), logits.detach(), targets.detach())
loss = loss.mean()
return loss, logits, x_noisy, targets
def training_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
return loss
def reset_noise_accs(self):
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
def on_validation_start(self):
self.reset_noise_accs()
@torch.no_grad()
def validation_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t)
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
return loss
def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
}]
return [optimizer], scheduler
return optimizer
@torch.no_grad()
def log_images(self, batch, N=8, *args, **kwargs):
log = dict()
x = self.get_input(batch, self.diffusion_model.first_stage_key)
log['inputs'] = x
y = self.get_conditioning(batch)
if self.label_key == 'class_label':
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
log['labels'] = y
if ismap(y):
log['labels'] = self.diffusion_model.to_rgb(y)
for step in range(self.log_steps):
current_time = step * self.log_time_interval
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
log[f'inputs@t{current_time}'] = x_noisy
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
pred = rearrange(pred, 'b h w c -> b c h w')
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
for key in log:
log[key] = log[key][:N]
return log
================================================
FILE: ldm/models/diffusion/ddim.py
================================================
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from einops import rearrange
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def to(self, device):
"""Same as to in torch module
Don't really underestand why this isn't a module in the first place"""
for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor):
new_v = getattr(self, k).to(device)
setattr(self, k, new_v)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
t_start=-1):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
timesteps = timesteps[:t_start]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback:
img = callback(i, img, pred_x0)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([
unconditional_conditioning[k][i],
c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([
unconditional_conditioning[k],
c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
print(t, sqrt_one_minus_at, a_t)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
return x_dec
================================================
FILE: ldm/models/diffusion/ddpm.py
================================================
"""
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
https://github.com/CompVis/taming-transformers
-- merci
"""
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from contextlib import contextmanager, nullcontext
from functools import partial
import itertools
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from omegaconf import ListConfig
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.modules.attention import CrossAttention
__conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn',
'adm': 'y'}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2
class DDPM(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
def __init__(self,
unet_config,
timesteps=1000,
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
use_ema=True,
first_stage_key="image",
image_size=256,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.,
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.,
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None,
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.,
make_it_fit=False,
ucg_training=None,
):
super().__init__()
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
self.parameterization = parameterization
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
self.cond_stage_model = None
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
self.first_stage_key = first_stage_key
self.image_size = image_size # try conv?
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.model = DiffusionWrapper(unet_config, conditioning_key)
count_params(self.model, verbose=True)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
self.scheduler_config = scheduler_config
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
if monitor is not None:
self.monitor = monitor
self.make_it_fit = make_it_fit
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
self.loss_type = loss_type
self.learn_logvar = learn_logvar
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.ucg_training = ucg_training or dict()
if self.ucg_training:
self.ucg_prng = np.random.RandomState()
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
self.register_buffer('posterior_mean_coef1', to_torch(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
self.register_buffer('posterior_mean_coef2', to_torch(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
else:
raise NotImplementedError("mu not supported")
# TODO how to choose this term
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
@torch.no_grad()
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
if self.make_it_fit:
n_params = len([name for name, _ in
itertools.chain(self.named_parameters(),
self.named_buffers())])
for name, param in tqdm(
itertools.chain(self.named_parameters(),
self.named_buffers()),
desc="Fitting old weights to new weights",
total=n_params
):
if not name in sd:
continue
old_shape = sd[name].shape
new_shape = param.shape
assert len(old_shape)==len(new_shape)
if len(new_shape) > 2:
# we only modify first two axes
assert new_shape[2:] == old_shape[2:]
# assumes first axis corresponds to output dim
if not new_shape == old_shape:
new_param = param.clone()
old_param = sd[name]
if len(new_shape) == 1:
for i in range(new_param.shape[0]):
new_param[i] = old_param[i % old_shape[0]]
elif len(new_shape) >= 2:
for i in range(new_param.shape[0]):
for j in range(new_param.shape[1]):
new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
n_used_old = torch.ones(old_shape[1])
for j in range(new_param.shape[1]):
n_used_old[j % old_shape[1]] += 1
n_used_new = torch.zeros(new_shape[1])
for j in range(new_param.shape[1]):
n_used_new[j] = n_used_old[j % old_shape[1]]
n_used_new = n_used_new[None, :]
while len(n_used_new.shape) < len(new_shape):
n_used_new = n_used_new.unsqueeze(-1)
new_param /= n_used_new
sd[name] = new_param
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, clip_denoised: bool):
model_out = self.model(x, t)
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised)
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
intermediates.append(img)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size),
return_intermediates=return_intermediates)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def get_loss(self, pred, target, mean=True):
if self.loss_type == 'l1':
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == 'l2':
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
def p_losses(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.model(x_noisy, t)
loss_dict = {}
if self.parameterization == "eps":
target = noise
elif self.parameterization == "x0":
target = x_start
else:
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
log_prefix = 'train' if self.training else 'val'
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
loss_simple = loss.mean() * self.l_simple_weight
loss_vlb = (self.lvlb_weights[t] * loss).mean()
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
loss = loss_simple + self.original_elbo_weight * loss_vlb
loss_dict.update({f'{log_prefix}/loss': loss})
return loss, loss_dict
def forward(self, x, *args, **kwargs):
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
return self.p_losses(x, t, *args, **kwargs)
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
def shared_step(self, batch):
x = self.get_input(batch, self.first_stage_key)
loss, loss_dict = self(x)
return loss, loss_dict
def training_step(self, batch, batch_idx):
for k in self.ucg_training:
p = self.ucg_training[k]["p"]
val = self.ucg_training[k]["val"]
if val is None:
val = ""
for i in range(len(batch[k])):
if self.ucg_prng.choice(2, p=[1-p, p]):
batch[k][i] = val
loss, loss_dict = self.shared_step(batch)
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
def _get_rows_from_list(self, samples):
n_imgs_per_row = len(samples)
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict()
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
x = x.to(self.device)[:N]
log["inputs"] = x
# get diffusion row
diffusion_row = list()
x_start = x[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
diffusion_row.append(x_noisy)
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
log["samples"] = samples
log["denoise_row"] = self._get_rows_from_list(denoise_row)
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.learn_logvar:
params = params + [self.logvar]
opt = torch.optim.AdamW(params, lr=lr)
return opt
class LatentDiffusion(DDPM):
"""main class"""
def __init__(self,
first_stage_config,
cond_stage_config,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
unet_trainable=True,
*args, **kwargs):
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs['timesteps']
# for backwards compatibility after implementation of DiffusionWrapper
if conditioning_key is None:
conditioning_key = 'concat' if concat_mode else 'crossattn'
if cond_stage_config == '__is_unconditional__':
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.unet_trainable = unet_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
else:
self.register_buffer('scale_factor', torch.tensor(scale_factor))
self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
# construct linear projection layer for concatenating image CLIP embedding and RT
self.cc_projection = nn.Linear(772, 768)
nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768])
nn.init.zeros_(list(self.cc_projection.parameters())[1])
self.cc_projection.requires_grad_(True)
self.clip_denoised = False
self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
def make_cond_schedule(self, ):
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
self.cond_ids[:self.num_timesteps_cond] = ids
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
x = super().get_input(batch, self.first_stage_key)
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
del self.scale_factor
self.register_buffer('scale_factor', 1. / z.flatten().std())
print(f"setting self.scale_factor to {self.scale_factor}")
print("### USING STD-RESCALING ###")
def register_schedule(self,
given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def instantiate_cond_stage(self, config):
if not self.cond_stage_trainable:
if config == "__is_first_stage__":
print("Using first stage also as cond stage.")
self.cond_stage_model = self.first_stage_model
elif config == "__is_unconditional__":
print(f"Training {self.__class__.__name__} as an unconditional model.")
self.cond_stage_model = None
# self.be_unconditional = True
else:
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
else:
assert config != '__is_first_stage__'
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
self.cond_stage_model = model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
denoise_row = []
for zd in tqdm(samples, desc=desc):
denoise_row.append(self.decode_first_stage(zd.to(self.device),
force_not_quantize=force_no_decoder_quantization))
n_imgs_per_row = len(denoise_row)
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
def meshgrid(self, h, w):
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
arr = torch.cat([y, x], dim=-1)
return arr
def delta_border(self, h, w):
"""
:param h: height
:param w: width
:return: normalized distance to image border,
wtith min distance = 0 at border and max dist = 0.5 at image center
"""
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
arr = self.meshgrid(h, w) / lower_right_corner
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
return edge_dist
def get_weighting(self, h, w, Ly, Lx, device):
weighting = self.delta_border(h, w)
weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
self.split_input_params["clip_max_weight"], )
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
if self.split_input_params["tie_braker"]:
L_weighting = self.delta_border(Ly, Lx)
L_weighting = torch.clip(L_weighting,
self.split_input_params["clip_min_tie_weight"],
self.split_input_params["clip_max_tie_weight"])
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
weighting = weighting * L_weighting
return weighting
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
"""
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs, nc, h, w = x.shape
# number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1
Lx = (w - kernel_size[1]) // stride[1] + 1
if uf == 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
elif uf > 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
dilation=1, padding=0,
stride=(stride[0] * uf, stride[1] * uf))
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
elif df > 1 and uf == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
dilation=1, padding=0,
stride=(stride[0] // df, stride[1] // df))
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
else:
raise NotImplementedError
return fold, unfold, normalization, weighting
@torch.no_grad()
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
x = super().get_input(batch, k)
T = batch['T'].to(memory_format=torch.contiguous_format).float()
if bs is not None:
x = x[:bs]
T = T[:bs].to(self.device)
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
cond_key = cond_key or self.cond_stage_key
xc = super().get_input(batch, cond_key).to(self.device)
if bs is not None:
xc = xc[:bs]
cond = {}
# To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
random = torch.rand(x.size(0), device=x.device)
prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
null_prompt = self.get_learned_conditioning([""])
# z.shape: [8, 4, 64, 64]; c.shape: [8, 1, 768]
# print('=========== xc shape ===========', xc.shape)
with torch.enable_grad():
clip_emb = self.get_learned_conditioning(xc).detach()
null_prompt = self.get_learned_conditioning([""]).detach()
cond["c_crossattn"] = [self.cc_projection(torch.cat([torch.where(prompt_mask, null_prompt, clip_emb), T[:, None, :]], dim=-1))]
cond["c_concat"] = [input_mask * self.encode_first_stage((xc.to(self.device))).mode().detach()]
out = [z, cond]
if return_first_stage_outputs:
xrec = self.decode_first_stage(z)
out.extend([x, xrec])
if return_original_cond:
out.append(xc)
return out
# @torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
uf = self.split_input_params["vqf"]
bs, nc, h, w = z.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
z = unfold(z) # (bn, nc * prod(**ks), L)
# 1. Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]
else:
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
for i in range(z.shape[-1])]
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
o = o * weighting
# Reverse 1. reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization # norm is shape (1, 1, h, w)
return decoded
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
# @torch.no_grad() # wasted two hours to find this bug... why no grad here!
def encode_first_stage(self, x):
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
self.split_input_params['original_image_size'] = x.shape[-2:]
bs, nc, h, w = x.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
z = unfold(x) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
for i in range(z.shape[-1])]
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization
return decoded
else:
return self.first_stage_model.encode(x)
else:
return self.first_stage_model.encode(x)
def shared_step(self, batch, **kwargs):
x, c = self.get_input(batch, self.first_stage_key)
loss = self(x, c)
return loss
def forward(self, x, c, *args, **kwargs):
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
if self.model.conditioning_key is not None:
assert c is not None
# if self.cond_stage_trainable:
# c = self.get_learned_conditioning(c)
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
def rescale_bbox(bbox):
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
return x0, y0, w, h
return [rescale_bbox(b) for b in bboxes]
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
# hybrid case, cond is exptected to be a dict
pass
else:
if not isinstance(cond, list):
cond = [cond]
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
cond = {key: cond}
if hasattr(self, "split_input_params"):
assert len(cond) == 1 # todo can only deal with one conditioning atm
assert not return_ids
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
h, w = x_noisy.shape[-2:]
fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
if self.cond_stage_key in ["image", "LR_image", "segmentation",
'bbox_img'] and self.model.conditioning_key: # todo check for completeness
c_key = next(iter(cond.keys())) # get key
c = next(iter(cond.values())) # get value
assert (len(c) == 1) # todo extend to list with more than one elem
c = c[0] # get element
c = unfold(c)
c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
elif self.cond_stage_key == 'coordinates_bbox':
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
# assuming padding of unfold is always 0 and its dilation is always 1
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
full_img_h, full_img_w = self.split_input_params['original_image_size']
# as we are operating on latents, we need the factor from the original image size to the
# spatial latent size to properly rescale the crops for regenerating the bbox annotations
num_downs = self.first_stage_model.encoder.num_resolutions - 1
rescale_latent = 2 ** (num_downs)
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
# need to rescale the tl patch coordinates to be in between (0,1)
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
for patch_nr in range(z.shape[-1])]
# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
patch_limits = [(x_tl, y_tl,
rescale_latent * ks[0] / full_img_w,
rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
# tokenize crop coordinates for the bounding boxes of the respective patches
patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
# cut tknzd crop position from conditioning
assert isinstance(cond, dict), 'cond must be dict to be fed into model'
cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
adapted_cond = self.get_learned_conditioning(adapted_cond)
adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
else:
cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
# apply model by loop over crops
output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
assert not isinstance(output_list[0],
tuple) # todo cant deal with multiple model outputs check this never happens
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
x_recon = fold(o) / normalization
else:
x_recon = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
return mean_flat(kl_prior) / np.log(2.0)
def p_losses(self, x_start, cond, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)
loss_dict = {}
prefix = 'train' if self.training else 'val'
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
else:
raise NotImplementedError()
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
if self.learn_logvar:
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
loss_dict.update({'logvar': self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
loss += (self.original_elbo_weight * loss_vlb)
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
return_x0=False, score_corrector=None, corrector_kwargs=None):
t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
if score_corrector is not None:
assert self.parameterization == "eps"
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
if return_codebook_ids:
model_out, logits = model_out
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
else:
raise NotImplementedError()
if clip_denoised:
x_recon.clamp_(-1., 1.)
if quantize_denoised:
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if return_codebook_ids:
return model_mean, posterior_variance, posterior_log_variance, logits
elif return_x0:
return model_mean, posterior_variance, posterior_log_variance, x_recon
else:
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
return_codebook_ids=return_codebook_ids,
quantize_denoised=quantize_denoised,
return_x0=return_x0,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
elif return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
model_mean, _, model_log_variance = outputs
noise = noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
if return_codebook_ids:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
if return_x0:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
else:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
timesteps = self.num_timesteps
if batch_size is not None:
b = batch_size if batch_size is not None else shape[0]
shape = [batch_size] + list(shape)
else:
b = batch_size = shape[0]
if x_T is None:
img = torch.randn(shape, device=self.device)
else:
img = x_T
intermediates = []
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
total=timesteps) if verbose else reversed(
range(0, timesteps))
if type(temperature) == float:
temperature = [temperature] * timesteps
for i in iterator:
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img, x0_partial = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised, return_x0=True,
temperature=temperature[i], noise_dropout=noise_dropout,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
if callback: callback(i)
if img_callback: img_callback(img, i)
return img, intermediates
@torch.no_grad()
def p_sample_loop(self, cond, shape, return_intermediates=False,
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
device = self.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
intermediates = [img]
if timesteps is None:
timesteps = self.num_timesteps
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
range(0, timesteps))
if mask is not None:
assert x0 is not None
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
for i in iterator:
ts = torch.full((b,), i, device=device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised)
if mask is not None:
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
if callback: callback(i)
if img_callback: img_callback(img, i)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
verbose=True, timesteps=None, quantize_denoised=False,
mask=None, x0=None, shape=None,**kwargs):
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
shape,
return_intermediates=return_intermediates, x_T=x_T,
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
mask=mask, x0=x0)
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
if ddim:
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
shape, cond, verbose=False, **kwargs)
else:
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
return_intermediates=True, **kwargs)
return samples, intermediates
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None, image_size=512):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
cond = {}
cond["c_crossattn"] = [c]
cond["c_concat"] = [torch.zeros([batch_size, 4, image_size // 8, image_size // 8]).to(self.device)]
return cond
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=N)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
log["conditioning"] = xc
elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
self.first_stage_model, IdentityFirstStage):
# also display when quantizing x0 while sampling
with ema_scope("Plotting Quantized Denoised"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta,
quantize_denoised=True)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
# quantize_denoised=True)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_x0_quantized"] = x_samples
if unconditional_guidance_scale > 1.0:
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label, image_size=x.shape[-1])
# uc = torch.zeros_like(c)
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
if inpaint:
# make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
mask = mask[:, None, ...]
with ema_scope("Plotting Inpaint"):
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_inpainting"] = x_samples
log["mask"] = mask
# outpaint
mask = 1. - mask
with ema_scope("Plotting Outpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples
if plot_progressive_rows:
with ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
log["progressive_row"] = prog_row
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = []
if self.unet_trainable == "attn":
print("Training only unet attention layers")
for n, m in self.model.named_modules():
if isinstance(m, CrossAttention) and n.endswith('attn2'):
params.extend(m.parameters())
if self.unet_trainable == "conv_in":
print("Training only unet input conv layers")
params = list(self.model.diffusion_model.input_blocks[0][0].parameters())
elif self.unet_trainable is True or self.unet_trainable == "all":
print("Training the full unet")
params = list(self.model.parameters())
else:
raise ValueError(f"Unrecognised setting for unet_trainable: {self.unet_trainable}")
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
print('Diffusion model optimizing logvar')
params.append(self.logvar)
if self.cc_projection is not None:
params = params + list(self.cc_projection.parameters())
print('========== optimizing for cc projection weight ==========')
opt = torch.optim.AdamW([{"params": self.model.parameters(), "lr": lr},
{"params": self.cc_projection.parameters(), "lr": 10. * lr}], lr=lr)
if self.use_scheduler:
assert 'target' in self.scheduler_config
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
}]
return [opt], scheduler
return opt
@torch.no_grad()
def to_rgb(self, x):
x = x.float()
if not hasattr(self, "colorize"):
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
x = nn.functional.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t)
elif self.conditioning_key == 'crossattn':
# c_crossattn dimension: torch.Size([8, 1, 768]) 1
# cc dimension: torch.Size([8, 1, 768]
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'hybrid-adm':
assert c_adm is not None
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc)
else:
raise NotImplementedError()
return out
class LatentUpscaleDiffusion(LatentDiffusion):
def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwargs):
super().__init__(*args, **kwargs)
# assumes that neither the cond_stage nor the low_scale_model contain trainable params
assert not self.cond_stage_trainable
self.instantiate_low_stage(low_scale_config)
self.low_scale_key = low_scale_key
def instantiate_low_stage(self, config):
model = instantiate_from_config(config)
self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters():
param.requires_grad = False
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
if not log_mode:
z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
else:
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
x_low = batch[self.low_scale_key][:bs]
x_low = rearrange(x_low, 'b h w c -> b c h w')
x_low = x_low.to(memory_format=torch.contiguous_format).float()
zx, noise_level = self.low_scale_model(x_low)
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
#import pudb; pu.db
if log_mode:
# TODO: maybe disable if too expensive
interpretability = False
if interpretability:
zx = zx[:, :, ::2, ::2]
x_low_rec = self.low_scale_model.decode(zx)
return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
return z, all_conds
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
log_mode=True)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
log["x_lr"] = x_low
log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
log["conditioning"] = xc
elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
# TODO explore better "unconditional" choices for the other keys
# maybe guide away from empty text label and highest noise level and maximally degraded zx?
uc = dict()
for k in c:
if k == "c_crossattn":
assert isinstance(c[k], list) and len(c[k]) == 1
uc[k] = [uc_tmp]
elif k == "c_adm": # todo: only run with text-based guidance?
assert isinstance(c[k], torch.Tensor)
uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
elif isinstance(c[k], list):
uc[k] = [c[k][i] for i in range(len(c[k]))]
else:
uc[k] = c[k]
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
if plot_progressive_rows:
with ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
log["progressive_row"] = prog_row
return log
class LatentInpaintDiffusion(LatentDiffusion):
"""
can either run as pure inpainting model (only concat mode) or with mixed conditionings,
e.g. mask as concat and text via cross-attn.
To disable finetuning mode, set finetune_keys to None
"""
def __init__(self,
finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
"model_ema.diffusion_modelinput_blocks00weight"
),
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels
c_concat_log_start=None, # to log reconstruction of c_concat codes
c_concat_log_end=None,
*args, **kwargs
):
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", list())
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.finetune_keys = finetune_keys
self.concat_keys = concat_keys
self.keep_dims = keep_finetune_dims
self.c_concat_log_start = c_concat_log_start
self.c_concat_log_end = c_concat_log_end
if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
if exists(ckpt_path):
self.init_from_ckpt(ckpt_path, ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
# make it explicit, finetune by including extra input channels
if exists(self.finetune_keys) and k in self.finetune_keys:
new_entry = None
for name, param in self.named_parameters():
if name in self.finetune_keys:
print(f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
new_entry = torch.zeros_like(param) # zero init
assert exists(new_entry), 'did not find matching parameter to modify'
new_entry[:, :self.keep_dims, ...] = sd[k]
sd[k] = new_entry
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
# note: restricted to non-trainable encoders currently
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
assert exists(self.concat_keys)
c_cat = list()
for ck in self.concat_keys:
cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
bchw = z.shape
if ck != self.masked_image_key:
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
log["c_concat_decoded"] = self.decode_first_stage(c_cat[:,self.c_concat_log_start:self.c_concat_log_end])
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
uc_cat = c_cat
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_full,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
log["masked_image"] = rearrange(batch["masked_image"],
'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
return log
class Layout2ImgDiffusion(LatentDiffusion):
# TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs):
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key]
mapper = dset.conditional_builders[self.cond_stage_key]
bbox_imgs = []
map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
for tknzd_bbox in batch[self.cond_stage_key][:N]:
bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
bbox_imgs.append(bboximg)
cond_img = torch.stack(bbox_imgs, dim=0)
logs['bbox_image'] = cond_img
return logs
class SimpleUpscaleDiffusion(LatentDiffusion):
def __init__(self, *args, low_scale_key="LR", **kwargs):
super().__init__(*args, **kwargs)
# assumes that neither the cond_stage nor the low_scale_model contain trainable params
assert not self.cond_stage_trainable
self.low_scale_key = low_scale_key
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
if not log_mode:
z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
else:
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
x_low = batch[self.low_scale_key][:bs]
x_low = rearrange(x_low, 'b h w c -> b c h w')
x_low = x_low.to(memory_format=torch.contiguous_format).float()
encoder_posterior = self.encode_first_stage(x_low)
zx = self.get_first_stage_encoding(encoder_posterior).detach()
all_conds = {"c_concat": [zx], "c_crossattn": [c]}
if log_mode:
# TODO: maybe disable if too expensive
interpretability = False
if interpretability:
zx = zx[:, :, ::2, ::2]
return z, all_conds, x, xrec, xc, x_low
return z, all_conds
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
log["x_lr"] = x_low
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
log["conditioning"] = xc
elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if unconditional_guidance_scale > 1.0:
uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
uc = dict()
for k in c:
if k == "c_crossattn":
assert isinstance(c[k], list) and len(c[k]) == 1
uc[k] = [uc_tmp]
elif isinstance(c[k], list):
uc[k] = [c[k][i] for i in range(len(c[k]))]
else:
uc[k] = c[k]
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log
class MultiCatFrameDiffusion(LatentDiffusion):
def __init__(self, *args, low_scale_key="LR", **kwargs):
super().__init__(*args, **kwargs)
# assumes that neither the cond_stage nor the low_scale_model contain trainable params
assert not self.cond_stage_trainable
self.low_scale_key = low_scale_key
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
n = 2
if not log_mode:
z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
else:
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
cat_conds = batch[self.low_scale_key][:bs]
cats = []
for i in range(n):
x_low = cat_conds[:,:,:,3*i:3*(i+1)]
x_low = rearrange(x_low, 'b h w c -> b c h w')
x_low = x_low.to(memory_format=torch.contiguous_format).float()
encoder_posterior = self.encode_first_stage(x_low)
zx = self.get_first_stage_encoding(encoder_posterior).detach()
cats.append(zx)
all_conds = {"c_concat": [torch.cat(cats, dim=1)], "c_crossattn": [c]}
if log_mode:
# TODO: maybe disable if too expensive
interpretability = False
if interpretability:
zx = zx[:, :, ::2, ::2]
return z, all_conds, x, xrec, xc, x_low
return z, all_conds
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
log["x_lr"] = x_low
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
log["conditioning"] = xc
elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if unconditional_guidance_scale > 1.0:
uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
uc = dict()
for k in c:
if k == "c_crossattn":
assert isinstance(c[k], list) and len(c[k]) == 1
uc[k] = [uc_tmp]
elif isinstance(c[k], list):
uc[k] = [c[k][i] for i in range(len(c[k]))]
else:
uc[k] = c[k]
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log
================================================
FILE: ldm/models/diffusion/plms.py
================================================
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next,
dynamic_threshold=dynamic_threshold)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([
unconditional_conditioning[k][i],
c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([
unconditional_conditioning[k],
c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
================================================
FILE: ldm/models/diffusion/sampling_util.py
================================================
import torch
import numpy as np
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def renorm_thresholding(x0, value):
# renorm
pred_max = x0.max()
pred_min = x0.min()
pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
pred_x0 = 2 * pred_x0 - 1. # -1 ... 1
s = torch.quantile(
rearrange(pred_x0, 'b ... -> b (...)').abs(),
value,
dim=-1
)
s.clamp_(min=1.0)
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
# clip by threshold
# pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
# temporary hack: numpy on cpu
pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
# re.renorm
pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
return pred_x0
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)
================================================
FILE: ldm/modules/attention.py
================================================
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from ldm.modules.diffusionmodules.util import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
disable_self_attn=disable_self_attn)
for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
x = self.proj_out(x)
return x + x_in
================================================
FILE: ldm/modules/diffusionmodules/__init__.py
================================================
================================================
FILE: ldm/modules/diffusionmodules/model.py
================================================
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from ldm.util import instantiate_from_config
from ldm.modules.attention import LinearAttention
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x+h
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b,c,h,w)
h_ = self.proj_out(h_)
return x+h_
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
return AttnBlock(in_channels)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)
class Model(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch*4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch,
self.temb_ch),
torch.nn.Linear(self.temb_ch,
self.temb_ch),
])
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
skip_in = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
if i_block == self.num_res_blocks:
skip_in = ch*in_ch_mult[i_level]
block.append(ResnetBlock(in_channels=block_in+skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x, t=None, context=None):
#assert x.shape[2] == x.shape[3] == self.resolution
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.weight
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(in_channels=in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=2 * in_channels,
out_channels=4 * in_channels,
temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=4 * in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
nn.Conv2d(2*in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True)])
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
for i, layer in enumerate(self.model):
if i in [1,2,3]:
x = layer(x, None)
else:
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
x = self.conv_out(h)
return x
class UpsampleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
ch_mult=(2,2), dropout=0.0):
super().__init__()
# upsampling
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
res_block = []
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
self.upsample_blocks.append(Upsample(block_in, True))
curr_res = curr_res * 2
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# upsampling
h = x
for k, i_level in enumerate(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.res_blocks[i_level][i_block](h, None)
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class LatentRescaler(nn.Module):
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
super().__init__()
# residual block, interpolate, residual block
self.factor = factor
self.conv_in = nn.Conv2d(in_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1)
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)])
self.attn = AttnBlock(mid_channels)
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)])
self.conv_out = nn.Conv2d(mid_channels,
out_channels,
kernel_size=1,
)
def forward(self, x):
x = self.conv_in(x)
for block in self.res_block1:
x = block(x, None)
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
x = self.attn(x)
for block in self.res_block2:
x = block(x, None)
x = self.conv_out(x)
return x
class MergedRescaleEncoder(nn.Module):
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True,
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
z_channels=intermediate_chn, double_z=False, resolution=resolution,
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
out_ch=None)
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
def forward(self, x):
x = self.encoder(x)
x = self.rescaler(x)
return x
class MergedRescaleDecoder(nn.Module):
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
super().__init__()
tmp_chn = z_channels*ch_mult[-1]
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
ch_mult=ch_mult, resolution=resolution, ch=ch)
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
out_channels=tmp_chn, depth=rescale_module_depth)
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Upsampler(nn.Module):
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
super().__init__()
assert out_size >= in_size
num_blocks = int(np.log2(out_size//in_size))+1
factor_up = 1.+ (out_size % in_size)
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
out_channels=in_channels)
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
attn_resolutions=[], in_channels=None, ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)])
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Resize(nn.Module):
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
super().__init__()
self.with_conv = learned
self.mode = mode
if self.with_conv:
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=4,
stride=2,
padding=1)
def forward(self, x, scale_factor=1.0):
if scale_factor==1.0:
return x
else:
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
return x
class FirstStagePostProcessor(nn.Module):
def __init__(self, ch_mult:list, in_channels,
pretrained_model:nn.Module=None,
reshape=False,
n_channels=None,
dropout=0.,
pretrained_config=None):
super().__init__()
if pretrained_config is None:
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
self.pretrained_model = pretrained_model
else:
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
self.instantiate_pretrained(pretrained_config)
self.do_reshape = reshape
if n_channels is None:
n_channels = self.pretrained_model.encoder.ch
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
stride=1,padding=1)
blocks = []
downs = []
ch_in = n_channels
for m in ch_mult:
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
ch_in = m * n_channels
downs.append(Downsample(ch_in, with_conv=False))
self.model = nn.ModuleList(blocks)
self.downsampler = nn.ModuleList(downs)
def instantiate_pretrained(self, config):
model = instantiate_from_config(config)
self.pretrained_model = model.eval()
# self.pretrained_model.train = False
for param in self.pretrained_model.parameters():
param.requires_grad = False
@torch.no_grad()
def encode_with_pretrained(self,x):
c = self.pretrained_model.encode(x)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
return c
def forward(self,x):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
z = nonlinearity(z)
for submodel, downmodel in zip(self.model,self.downsampler):
z = submodel(z,temb=None)
z = downmodel(z)
if self.do_reshape:
z = rearrange(z,'b c h w -> b (h w) c')
return z
================================================
FILE: ldm/modules/diffusionmodules/openaimodel.py
================================================
from abc import abstractmethod
from functools import partial
import math
from typing import Iterable
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from ldm.modules.diffusionmodules.util import (
checkpoint,
conv_nd,
linear,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
)
from ldm.modules.attention import SpatialTransformer
from ldm.util import exists
# dummy replace
def convert_module_to_f16(x):
pass
def convert_module_to_f32(x):
pass
## go
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding'
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
def forward(self,x):
return self.up(x)
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
#self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.") # todo: convert to warning
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa
)
)
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
class EncoderUNetModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool="adaptive",
*args,
**kwargs
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.pool = pool
if pool == "adaptive":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.AdaptiveAvgPool2d((1, 1)),
zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(),
)
elif pool == "attention":
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
AttentionPool2d(
(image_size // ds), ch, num_head_channels, out_channels
),
)
elif pool == "spatial":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
)
elif pool == "spatial_v2":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
normalization(2048),
nn.SiLU(),
nn.Linear(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
def forward(self, x, timesteps):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
return self.out(h)
else:
h = h.type(x.dtype)
return self.out(h)
================================================
FILE: ldm/modules/diffusionmodules/util.py
================================================
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import os
import math
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
return sigmas, alphas, alphas_prev
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
================================================
FILE: ldm/modules/distributions/__init__.py
================================================
================================================
FILE: ldm/modules/distributions/distributions.py
================================================
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
================================================
FILE: ldm/modules/ema.py
================================================
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
else torch.tensor(-1,dtype=torch.int))
for name, p in model.named_parameters():
if p.requires_grad:
#remove as '.'-character is not allowed in buffers
s_name = name.replace('.','')
self.m_name2s_name.update({name:s_name})
self.register_buffer(s_name,p.clone().detach().data)
self.collected_params = []
def forward(self,model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
================================================
FILE: ldm/modules/encoders/__init__.py
================================================
================================================
FILE: ldm/modules/encoders/modules.py
================================================
import torch
import torch.nn as nn
import numpy as np
from functools import partial
import kornia
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from ldm.util import default
import clip
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class FaceClipEncoder(AbstractEncoder):
def __init__(self, augment=True, retreival_key=None):
super().__init__()
self.encoder = FrozenCLIPImageEmbedder()
self.augment = augment
self.retreival_key = retreival_key
def forward(self, img):
encodings = []
with torch.no_grad():
x_offset = 125
if self.retreival_key:
# Assumes retrieved image are packed into the second half of channels
face = img[:,3:,190:440,x_offset:(512-x_offset)]
other = img[:,:3,...].clone()
else:
face = img[:,:,190:440,x_offset:(512-x_offset)]
other = img.clone()
if self.augment:
face = K.RandomHorizontalFlip()(face)
other[:,:,190:440,x_offset:(512-x_offset)] *= 0
encodings = [
self.encoder.encode(face),
self.encoder.encode(other),
]
return torch.cat(encodings, dim=1)
def encode(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
return self(img)
class FaceIdClipEncoder(AbstractEncoder):
def __init__(self):
super().__init__()
self.encoder = FrozenCLIPImageEmbedder()
for p in self.encoder.parameters():
p.requires_grad = False
self.id = FrozenFaceEncoder("/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True)
def forward(self, img):
encodings = []
with torch.no_grad():
face = kornia.geometry.resize(img, (256, 256),
interpolation='bilinear', align_corners=True)
other = img.clone()
other[:,:,184:452,122:396] *= 0
encodings = [
self.id.encode(face),
self.encoder.encode(other),
]
return torch.cat(encodings, dim=1)
def encode(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
return self(img)
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
def forward(self, batch, key=None):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
c = self.embedding(c)
return c
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, x):
return self(x)
class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
return tokens
@torch.no_grad()
def encode(self, text):
tokens = self(text)
if not self.vq_interface:
return tokens
return None, None, [None, None, tokens]
def decode(self, text):
return text
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
def forward(self, text):
if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, text):
# output of length 77
return self(text)
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
from ldm.thirdp.psp.id_loss import IDFeatures
import kornia.augmentation as K
class FrozenFaceEncoder(AbstractEncoder):
def __init__(self, model_path, augment=False):
super().__init__()
self.loss_fn = IDFeatures(model_path)
# face encoder is frozen
for p in self.loss_fn.parameters():
p.requires_grad = False
# Mapper is trainable
self.mapper = torch.nn.Linear(512, 768)
p = 0.25
if augment:
self.augment = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomEqualize(p=p),
# K.RandomPlanckianJitter(p=p),
# K.RandomPlasmaBrightness(p=p),
# K.RandomPlasmaContrast(p=p),
# K.ColorJiggle(0.02, 0.2, 0.2, p=p),
)
else:
self.augment = False
def forward(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros((1, 1, 768), device=self.mapper.weight.device)
if self.augment is not None:
# Transforms require 0-1
img = self.augment((img + 1)/2)
img = 2*img - 1
feat = self.loss_fn(img, crop=True)
feat = self.mapper(feat.unsqueeze(1))
return feat
def encode(self, img):
return self(img)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
import torch.nn.functional as F
from transformers import CLIPVisionModel
class ClipImageProjector(AbstractEncoder):
"""
Uses the CLIP image encoder.
"""
def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # clip-vit-base-patch32
super().__init__()
self.model = CLIPVisionModel.from_pretrained(version)
self.model.train()
self.max_length = max_length # TODO: typical value?
self.antialias = True
self.mapper = torch.nn.Linear(1024, 768)
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
null_cond = self.get_null_cond(version, max_length)
self.register_buffer('null_cond', null_cond)
@torch.no_grad()
def get_null_cond(self, version, max_length):
device = self.mean.device
embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
null_cond = embedder([""])
return null_cond
def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
if isinstance(x, list):
return self.null_cond
# x is assumed to be in range [-1,1]
x = self.preprocess(x)
outputs = self.model(pixel_values=x)
last_hidden_state = outputs.last_hidden_state
last_hidden_state = self.mapper(last_hidden_state)
return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0])
def encode(self, im):
return self(im)
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
super().__init__()
self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
self.projection = torch.nn.Linear(768, 768)
def forward(self, text):
z = self.embedder(text)
return self.projection(z)
def encode(self, text):
return self(text)
class FrozenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model='ViT-L/14',
jit=False,
device='cpu',
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
if isinstance(x, list):
# [""] denotes condition dropout for ucg
device = self.model.visual.conv1.weight.device
return torch.zeros(1, 768, device=device)
return self.model.encode_image(self.preprocess(x)).float()
def encode(self, im):
return self(im).unsqueeze(1)
from torchvision import transforms
import random
class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model='ViT-L/14',
jit=False,
device='cpu',
antialias=True,
max_crops=5,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.max_crops = max_crops
def preprocess(self, x):
# Expects inputs in the range -1, 1
randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1))
max_crops = self.max_crops
patches = []
crops = [randcrop(x) for _ in range(max_crops)]
patches.extend(crops)
x = torch.cat(patches, dim=0)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
if isinstance(x, list):
# [""] denotes condition dropout for ucg
device = self.model.visual.conv1.weight.device
return torch.zeros(1, self.max_crops, 768, device=device)
batch_tokens = []
for im in x:
patches = self.preprocess(im.unsqueeze(0))
tokens = self.model.encode_image(patches).float()
for t in tokens:
if random.random() < 0.1:
t *= 0
batch_tokens.append(tokens.unsqueeze(0))
return torch.cat(batch_tokens, dim=0)
def encode(self, im):
return self(im)
class SpatialRescaler(nn.Module):
def __init__(self,
n_stages=1,
method='bilinear',
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None
if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
def forward(self,x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output:
x = self.channel_mapper(x)
return x
def encode(self, x):
return self(x)
from ldm.util import instantiate_from_config
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
class LowScaleEncoder(nn.Module):
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64,
scale_factor=1.0):
super().__init__()
self.max_noise_level = max_noise_level
self.model = instantiate_from_config(model_config)
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
linear_end=linear_end)
self.out_size = output_size
self.scale_factor = scale_factor
def register_schedule(self, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def forward(self, x):
z = self.model.encode(x).sample()
z = z * self.scale_factor
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
z = self.q_sample(z, noise_level)
if self.out_size is not None:
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level
def decode(self, z):
z = z / self.scale_factor
return self.model.decode(z)
if __name__ == "__main__":
from ldm.util import count_params
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
count_params(model, True)
z = model(sentences)
print(z.shape)
model = FrozenCLIPEmbedder().cuda()
count_params(model, True)
z = model(sentences)
print(z.shape)
print("done.")
================================================
FILE: ldm/modules/evaluate/adm_evaluator.py
================================================
import argparse
import io
import os
import random
import warnings
import zipfile
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import partial
from multiprocessing import cpu_count
from multiprocessing.pool import ThreadPool
from typing import Iterable, Optional, Tuple
import yaml
import numpy as np
import requests
import tensorflow.compat.v1 as tf
from scipy import linalg
from tqdm.auto import tqdm
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
FID_POOL_NAME = "pool_3:0"
FID_SPATIAL_NAME = "mixed_6/conv:0"
REQUIREMENTS = f"This script has the following requirements: \n" \
'tensorflow-gpu>=2.0' + "\n" + 'scipy' + "\n" + "requests" + "\n" + "tqdm"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--ref_batch", help="path to reference batch npz file")
parser.add_argument("--sample_batch", help="path to sample batch npz file")
args = parser.parse_args()
config = tf.ConfigProto(
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
)
config.gpu_options.allow_growth = True
evaluator = Evaluator(tf.Session(config=config))
print("warming up TensorFlow...")
# This will cause TF to print a bunch of verbose stuff now rather
# than after the next print(), to help prevent confusion.
evaluator.warmup()
print("computing reference batch activations...")
ref_acts = evaluator.read_activations(args.ref_batch)
print("computing/reading reference batch statistics...")
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
print("computing sample batch activations...")
sample_acts = evaluator.read_activations(args.sample_batch)
print("computing/reading sample batch statistics...")
sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
print("Computing evaluations...")
is_ = evaluator.compute_inception_score(sample_acts[0])
print("Inception Score:", is_)
fid = sample_stats.frechet_distance(ref_stats)
print("FID:", fid)
sfid = sample_stats_spatial.frechet_distance(ref_stats_spatial)
print("sFID:", sfid)
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
print("Precision:", prec)
print("Recall:", recall)
savepath = '/'.join(args.sample_batch.split('/')[:-1])
results_file = os.path.join(savepath,'evaluation_metrics.yaml')
print(f'Saving evaluation results to "{results_file}"')
results = {
'IS': is_,
'FID': fid,
'sFID': sfid,
'Precision:':prec,
'Recall': recall
}
with open(results_file, 'w') as f:
yaml.dump(results, f, default_flow_style=False)
class InvalidFIDException(Exception):
pass
class FIDStatistics:
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
self.mu = mu
self.sigma = sigma
def frechet_distance(self, other, eps=1e-6):
"""
Compute the Frechet distance between two sets of statistics.
"""
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
mu1, sigma1 = self.mu, self.sigma
mu2, sigma2 = other.mu, other.sigma
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert (
mu1.shape == mu2.shape
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
assert (
sigma1.shape == sigma2.shape
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = (
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
% eps
)
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
class Evaluator:
def __init__(
self,
session,
batch_size=64,
softmax_batch_size=512,
):
self.sess = session
self.batch_size = batch_size
self.softmax_batch_size = softmax_batch_size
self.manifold_estimator = ManifoldEstimator(session)
with self.sess.graph.as_default():
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
self.softmax = _create_softmax_graph(self.softmax_input)
def warmup(self):
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
with open_npz_array(npz_path, "arr_0") as reader:
return self.compute_activations(reader.read_batches(self.batch_size))
def compute_activations(self, batches: Iterable[np.ndarray],silent=False) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute image features for downstream evals.
:param batches: a iterator over NHWC numpy arrays in [0, 255].
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
dimension. The tuple is (pool_3, spatial).
"""
preds = []
spatial_preds = []
it = batches if silent else tqdm(batches)
for batch in it:
batch = batch.astype(np.float32)
pred, spatial_pred = self.sess.run(
[self.pool_features, self.spatial_features], {self.image_input: batch}
)
preds.append(pred.reshape([pred.shape[0], -1]))
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
return (
np.concatenate(preds, axis=0),
np.concatenate(spatial_preds, axis=0),
)
def read_statistics(
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
) -> Tuple[FIDStatistics, FIDStatistics]:
obj = np.load(npz_path)
if "mu" in list(obj.keys()):
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
obj["mu_s"], obj["sigma_s"]
)
return tuple(self.compute_statistics(x) for x in activations)
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
mu = np.mean(activations, axis=0)
sigma = np.cov(activations, rowvar=False)
return FIDStatistics(mu, sigma)
def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
softmax_out = []
for i in range(0, len(activations), self.softmax_batch_size):
acts = activations[i : i + self.softmax_batch_size]
softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
preds = np.concatenate(softmax_out, axis=0)
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
scores = []
for i in range(0, len(preds), split_size):
part = preds[i : i + split_size]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return float(np.mean(scores))
def compute_prec_recall(
self, activations_ref: np.ndarray, activations_sample: np.ndarray
) -> Tuple[float, float]:
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
pr = self.manifold_estimator.evaluate_pr(
activations_ref, radii_1, activations_sample, radii_2
)
return (float(pr[0][0]), float(pr[1][0]))
class ManifoldEstimator:
"""
A helper for comparing manifolds of feature vectors.
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
"""
def __init__(
self,
session,
row_batch_size=10000,
col_batch_size=10000,
nhood_sizes=(3,),
clamp_to_percentile=None,
eps=1e-5,
):
"""
Estimate the manifold of given feature vectors.
:param session: the TensorFlow session.
:param row_batch_size: row batch size to compute pairwise distances
(parameter to trade-off between memory usage and performance).
:param col_batch_size: column batch size to compute pairwise distances.
:param nhood_sizes: number of neighbors used to estimate the manifold.
:param clamp_to_percentile: prune hyperspheres that have radius larger than
the given percentile.
:param eps: small number for numerical stability.
"""
self.distance_block = DistanceBlock(session)
self.row_batch_size = row_batch_size
self.col_batch_size = col_batch_size
self.nhood_sizes = nhood_sizes
self.num_nhoods = len(nhood_sizes)
self.clamp_to_percentile = clamp_to_percentile
self.eps = eps
def warmup(self):
feats, radii = (
np.zeros([1, 2048], dtype=np.float32),
np.zeros([1, 1], dtype=np.float32),
)
self.evaluate_pr(feats, radii, feats, radii)
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
num_images = len(features)
# Estimate manifold of features by calculating distances to k-NN of each sample.
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
for begin1 in range(0, num_images, self.row_batch_size):
end1 = min(begin1 + self.row_batch_size, num_images)
row_batch = features[begin1:end1]
for begin2 in range(0, num_images, self.col_batch_size):
end2 = min(begin2 + self.col_batch_size, num_images)
col_batch = features[begin2:end2]
# Compute distances between batches.
distance_batch[
0 : end1 - begin1, begin2:end2
] = self.distance_block.pairwise_distances(row_batch, col_batch)
# Find the k-nearest neighbor from the current batch.
radii[begin1:end1, :] = np.concatenate(
[
x[:, self.nhood_sizes]
for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
],
axis=0,
)
if self.clamp_to_percentile is not None:
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
radii[radii > max_distances] = 0
return radii
def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
"""
Evaluate if new feature vectors are at the manifold.
"""
num_eval_images = eval_features.shape[0]
num_ref_images = radii.shape[0]
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
for begin1 in range(0, num_eval_images, self.row_batch_size):
end1 = min(begin1 + self.row_batch_size, num_eval_images)
feature_batch = eval_features[begin1:end1]
for begin2 in range(0, num_ref_images, self.col_batch_size):
end2 = min(begin2 + self.col_batch_size, num_ref_images)
ref_batch = features[begin2:end2]
distance_batch[
0 : end1 - begin1, begin2:end2
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
# If a feature vector is inside a hypersphere of some reference sample, then
# the new sample lies at the estimated manifold.
# The radii of the hyperspheres are determined from distances of neighborhood size k.
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
max_realism_score[begin1:end1] = np.max(
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
)
nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
return {
"fraction": float(np.mean(batch_predictions)),
"batch_predictions": batch_predictions,
"max_realisim_score": max_realism_score,
"nearest_indices": nearest_indices,
}
def evaluate_pr(
self,
features_1: np.ndarray,
radii_1: np.ndarray,
features_2: np.ndarray,
radii_2: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Evaluate precision and recall efficiently.
:param features_1: [N1 x D] feature vectors for reference batch.
:param radii_1: [N1 x K1] radii for reference vectors.
:param features_2: [N2 x D] feature vectors for the other batch.
:param radii_2: [N x K2] radii for other vectors.
:return: a tuple of arrays for (precision, recall):
- precision: an np.ndarray of length K1
- recall: an np.ndarray of length K2
"""
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
for begin_1 in range(0, len(features_1), self.row_batch_size):
end_1 = begin_1 + self.row_batch_size
batch_1 = features_1[begin_1:end_1]
for begin_2 in range(0, len(features_2), self.col_batch_size):
end_2 = begin_2 + self.col_batch_size
batch_2 = features_2[begin_2:end_2]
batch_1_in, batch_2_in = self.distance_block.less_thans(
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
)
features_1_status[begin_1:end_1] |= batch_1_in
features_2_status[begin_2:end_2] |= batch_2_in
return (
np.mean(features_2_status.astype(np.float64), axis=0),
np.mean(features_1_status.astype(np.float64), axis=0),
)
class DistanceBlock:
"""
Calculate pairwise distances between vectors.
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
"""
def __init__(self, session):
self.session = session
# Initialize TF graph to calculate pairwise distances.
with session.graph.as_default():
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
distance_block_16 = _batch_pairwise_distances(
tf.cast(self._features_batch1, tf.float16),
tf.cast(self._features_batch2, tf.float16),
)
self.distance_block = tf.cond(
tf.reduce_all(tf.math.is_finite(distance_block_16)),
lambda: tf.cast(distance_block_16, tf.float32),
lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
)
# Extra logic for less thans.
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
def pairwise_distances(self, U, V):
"""
Evaluate pairwise distances between two batches of feature vectors.
"""
return self.session.run(
self.distance_block,
feed_dict={self._features_batch1: U, self._features_batch2: V},
)
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
return self.session.run(
[self._batch_1_in, self._batch_2_in],
feed_dict={
self._features_batch1: batch_1,
self._features_batch2: batch_2,
self._radii1: radii_1,
self._radii2: radii_2,
},
)
def _batch_pairwise_distances(U, V):
"""
Compute pairwise distances between two batches of feature vectors.
"""
with tf.variable_scope("pairwise_dist_block"):
# Squared norms of each row in U and V.
norm_u = tf.reduce_sum(tf.square(U), 1)
norm_v = tf.reduce_sum(tf.square(V), 1)
# norm_u as a column and norm_v as a row vectors.
norm_u = tf.reshape(norm_u, [-1, 1])
norm_v = tf.reshape(norm_v, [1, -1])
# Pairwise squared Euclidean distances.
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
return D
class NpzArrayReader(ABC):
@abstractmethod
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
pass
@abstractmethod
def remaining(self) -> int:
pass
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
def gen_fn():
while True:
batch = self.read_batch(batch_size)
if batch is None:
break
yield batch
rem = self.remaining()
num_batches = rem // batch_size + int(rem % batch_size != 0)
return BatchIterator(gen_fn, num_batches)
class BatchIterator:
def __init__(self, gen_fn, length):
self.gen_fn = gen_fn
self.length = length
def __len__(self):
return self.length
def __iter__(self):
return self.gen_fn()
class StreamingNpzArrayReader(NpzArrayReader):
def __init__(self, arr_f, shape, dtype):
self.arr_f = arr_f
self.shape = shape
self.dtype = dtype
self.idx = 0
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
if self.idx >= self.shape[0]:
return None
bs = min(batch_size, self.shape[0] - self.idx)
self.idx += bs
if self.dtype.itemsize == 0:
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
read_count = bs * np.prod(self.shape[1:])
read_size = int(read_count * self.dtype.itemsize)
data = _read_bytes(self.arr_f, read_size, "array data")
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
def remaining(self) -> int:
return max(0, self.shape[0] - self.idx)
class MemoryNpzArrayReader(NpzArrayReader):
def __init__(self, arr):
self.arr = arr
self.idx = 0
@classmethod
def load(cls, path: str, arr_name: str):
with open(path, "rb") as f:
arr = np.load(f)[arr_name]
return cls(arr)
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
if self.idx >= self.arr.shape[0]:
return None
res = self.arr[self.idx : self.idx + batch_size]
self.idx += batch_size
return res
def remaining(self) -> int:
return max(0, self.arr.shape[0] - self.idx)
@contextmanager
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
with _open_npy_file(path, arr_name) as arr_f:
version = np.lib.format.read_magic(arr_f)
if version == (1, 0):
header = np.lib.format.read_array_header_1_0(arr_f)
elif version == (2, 0):
header = np.lib.format.read_array_header_2_0(arr_f)
else:
yield MemoryNpzArrayReader.load(path, arr_name)
return
shape, fortran, dtype = header
if fortran or dtype.hasobject:
yield MemoryNpzArrayReader.load(path, arr_name)
else:
yield StreamingNpzArrayReader(arr_f, shape, dtype)
def _read_bytes(fp, size, error_template="ran out of data"):
"""
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
Read from file-like object until size bytes are read.
Raises ValueError if not EOF is encountered before size bytes are read.
Non-blocking objects only supported if they derive from io objects.
Required as e.g. ZipExtFile in python 2.6 can return less data than
requested.
"""
data = bytes()
while True:
# io files (default in python3) return None or raise on
# would-block, python2 file will truncate, probably nothing can be
# done about that. note that regular files can't be non-blocking
try:
r = fp.read(size - len(data))
data += r
if len(r) == 0 or len(data) == size:
break
except io.BlockingIOError:
pass
if len(data) != size:
msg = "EOF: reading %s, expected %d bytes got %d"
raise ValueError(msg % (error_template, size, len(data)))
else:
return data
@contextmanager
def _open_npy_file(path: str, arr_name: str):
with open(path, "rb") as f:
with zipfile.ZipFile(f, "r") as zip_f:
if f"{arr_name}.npy" not in zip_f.namelist():
raise ValueError(f"missing {arr_name} in npz file")
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
yield arr_f
def _download_inception_model():
if os.path.exists(INCEPTION_V3_PATH):
return
print("downloading InceptionV3 model...")
with requests.get(INCEPTION_V3_URL, stream=True) as r:
r.raise_for_status()
tmp_path = INCEPTION_V3_PATH + ".tmp"
with open(tmp_path, "wb") as f:
for chunk in tqdm(r.iter_content(chunk_size=8192)):
f.write(chunk)
os.rename(tmp_path, INCEPTION_V3_PATH)
def _create_feature_graph(input_batch):
_download_inception_model()
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
with open(INCEPTION_V3_PATH, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
pool3, spatial = tf.import_graph_def(
graph_def,
input_map={f"ExpandDims:0": input_batch},
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
name=prefix,
)
_update_shapes(pool3)
spatial = spatial[..., :7]
return pool3, spatial
def _create_softmax_graph(input_batch):
_download_inception_model()
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
with open(INCEPTION_V3_PATH, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
(matmul,) = tf.import_graph_def(
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
)
w = matmul.inputs[1]
logits = tf.matmul(input_batch, w)
return tf.nn.softmax(logits)
def _update_shapes(pool3):
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
ops = pool3.graph.get_operations()
for op in ops:
for o in op.outputs:
shape = o.get_shape()
if shape._dims is not None: # pylint: disable=protected-access
# shape = [s.value for s in shape] TF 1.x
shape = [s for s in shape] # TF 2.x
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
return pool3
def _numpy_partition(arr, kth, **kwargs):
num_workers = min(cpu_count(), len(arr))
chunk_size = len(arr) // num_workers
extra = len(arr) % num_workers
start_idx = 0
batches = []
for i in range(num_workers):
size = chunk_size + (1 if i < extra else 0)
batches.append(arr[start_idx : start_idx + size])
start_idx += size
with ThreadPool(num_workers) as pool:
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
if __name__ == "__main__":
print(REQUIREMENTS)
main()
================================================
FILE: ldm/modules/evaluate/evaluate_perceptualsim.py
================================================
import argparse
import glob
import os
from tqdm import tqdm
from collections import namedtuple
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
from ldm.modules.evaluate.ssim import ssim
transform = transforms.Compose([transforms.ToTensor()])
def normalize_tensor(in_feat, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1)).view(
in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3]
)
return in_feat / (norm_factor.expand_as(in_feat) + eps)
def cos_sim(in0, in1):
in0_norm = normalize_tensor(in0)
in1_norm = normalize_tensor(in1)
N = in0.size()[0]
X = in0.size()[2]
Y = in0.size()[3]
return torch.mean(
torch.mean(
torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2
).view(N, 1, 1, Y),
dim=3,
).view(N)
class squeezenet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(squeezenet, self).__init__()
pretrained_features = models.squeezenet1_1(
pretrained=pretrained
).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.slice6 = torch.nn.Sequential()
self.slice7 = torch.nn.Sequential()
self.N_slices = 7
for x in range(2):
self.slice1.add_module(str(x), pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), pretrained_features[x])
for x in range(10, 11):
self.slice5.add_module(str(x), pretrained_features[x])
for x in range(11, 12):
self.slice6.add_module(str(x), pretrained_features[x])
for x in range(12, 13):
self.slice7.add_module(str(x), pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
h = self.slice6(h)
h_relu6 = h
h = self.slice7(h)
h_relu7 = h
vgg_outputs = namedtuple(
"SqueezeOutputs",
["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"],
)
out = vgg_outputs(
h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7
)
return out
class alexnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(alexnet, self).__init__()
alexnet_pretrained_features = models.alexnet(
pretrained=pretrained
).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(2):
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
for x in range(10, 12):
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
alexnet_outputs = namedtuple(
"AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]
)
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
return out
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs",
["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"],
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
class resnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True, num=18):
super(resnet, self).__init__()
if num == 18:
self.net = models.resnet18(pretrained=pretrained)
elif num == 34:
self.net = models.resnet34(pretrained=pretrained)
elif num == 50:
self.net = models.resnet50(pretrained=pretrained)
elif num == 101:
self.net = models.resnet101(pretrained=pretrained)
elif num == 152:
self.net = models.resnet152(pretrained=pretrained)
self.N_slices = 5
self.conv1 = self.net.conv1
self.bn1 = self.net.bn1
self.relu = self.net.relu
self.maxpool = self.net.maxpool
self.layer1 = self.net.layer1
self.layer2 = self.net.layer2
self.layer3 = self.net.layer3
self.layer4 = self.net.layer4
def forward(self, X):
h = self.conv1(X)
h = self.bn1(h)
h = self.relu(h)
h_relu1 = h
h = self.maxpool(h)
h = self.layer1(h)
h_conv2 = h
h = self.layer2(h)
h_conv3 = h
h = self.layer3(h)
h_conv4 = h
h = self.layer4(h)
h_conv5 = h
outputs = namedtuple(
"Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"]
)
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
return out
# Off-the-shelf deep network
class PNet(torch.nn.Module):
"""Pre-trained network with all channels equally weighted by default"""
def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True):
super(PNet, self).__init__()
self.use_gpu = use_gpu
self.pnet_type = pnet_type
self.pnet_rand = pnet_rand
self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1)
self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1)
if self.pnet_type in ["vgg", "vgg16"]:
self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False)
elif self.pnet_type == "alex":
self.net = alexnet(
pretrained=not self.pnet_rand, requires_grad=False
)
elif self.pnet_type[:-2] == "resnet":
self.net = resnet(
pretrained=not self.pnet_rand,
requires_grad=False,
num=int(self.pnet_type[-2:]),
)
elif self.pnet_type == "squeeze":
self.net = squeezenet(
pretrained=not self.pnet_rand, requires_grad=False
)
self.L = self.net.N_slices
if use_gpu:
self.net.cuda()
self.shift = self.shift.cuda()
self.scale = self.scale.cuda()
def forward(self, in0, in1, retPerLayer=False):
in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
outs0 = self.net.forward(in0_sc)
outs1 = self.net.forward(in1_sc)
if retPerLayer:
all_scores = []
for (kk, out0) in enumerate(outs0):
cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk])
if kk == 0:
val = 1.0 * cur_score
else:
val = val + cur_score
if retPerLayer:
all_scores += [cur_score]
if retPerLayer:
return (val, all_scores)
else:
return val
# The SSIM metric
def ssim_metric(img1, img2, mask=None):
return ssim(img1, img2, mask=mask, size_average=False)
# The PSNR metric
def psnr(img1, img2, mask=None,reshape=False):
b = img1.size(0)
if not (mask is None):
b = img1.size(0)
mse_err = (img1 - img2).pow(2) * mask
if reshape:
mse_err = mse_err.reshape(b, -1).sum(dim=1) / (
3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1)
)
else:
mse_err = mse_err.view(b, -1).sum(dim=1) / (
3 * mask.view(b, -1).sum(dim=1).clamp(min=1)
)
else:
if reshape:
mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1)
else:
mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1)
psnr = 10 * (1 / mse_err).log10()
return psnr
# The perceptual similarity metric
def perceptual_sim(img1, img2, vgg16):
# First extract features
dist = vgg16(img1 * 2 - 1, img2 * 2 - 1)
return dist
def load_img(img_name, size=None):
try:
img = Image.open(img_name)
if type(size) == int:
img = img.resize((size, size))
elif size is not None:
img = img.resize((size[1], size[0]))
img = transform(img).cuda()
img = img.unsqueeze(0)
except Exception as e:
print("Failed at loading %s " % img_name)
print(e)
img = torch.zeros(1, 3, 256, 256).cuda()
raise
return img
def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other):
# Load VGG16 for feature similarity
vgg16 = PNet().to("cuda")
vgg16.eval()
vgg16.cuda()
values_percsim = []
values_ssim = []
values_psnr = []
folders = os.listdir(folder)
for i, f in tqdm(enumerate(sorted(folders))):
pred_imgs = glob.glob(folder + f + "/" + pred_img)
tgt_imgs = glob.glob(folder + f + "/" + tgt_img)
assert len(tgt_imgs) == 1
perc_sim = 10000
ssim_sim = -10
psnr_sim = -10
for p_img in pred_imgs:
t_img = load_img(tgt_imgs[0])
p_img = load_img(p_img, size=t_img.shape[2:])
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
perc_sim = min(perc_sim, t_perc_sim)
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
values_percsim += [perc_sim]
values_ssim += [ssim_sim]
values_psnr += [psnr_sim]
if take_every_other:
n_valuespercsim = []
n_valuesssim = []
n_valuespsnr = []
for i in range(0, len(values_percsim) // 2):
n_valuespercsim += [
min(values_percsim[2 * i], values_percsim[2 * i + 1])
]
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
values_percsim = n_valuespercsim
values_ssim = n_valuesssim
values_psnr = n_valuespsnr
avg_percsim = np.mean(np.array(values_percsim))
std_percsim = np.std(np.array(values_percsim))
avg_psnr = np.mean(np.array(values_psnr))
std_psnr = np.std(np.array(values_psnr))
avg_ssim = np.mean(np.array(values_ssim))
std_ssim = np.std(np.array(values_ssim))
return {
"Perceptual similarity": (avg_percsim, std_percsim),
"PSNR": (avg_psnr, std_psnr),
"SSIM": (avg_ssim, std_ssim),
}
def compute_perceptual_similarity_from_list(pred_imgs_list, tgt_imgs_list,
take_every_other,
simple_format=True):
# Load VGG16 for feature similarity
vgg16 = PNet().to("cuda")
vgg16.eval()
vgg16.cuda()
values_percsim = []
values_ssim = []
values_psnr = []
equal_count = 0
ambig_count = 0
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
pred_imgs = pred_imgs_list[i]
tgt_imgs = [tgt_img]
assert len(tgt_imgs) == 1
if type(pred_imgs) != list:
pred_imgs = [pred_imgs]
perc_sim = 10000
ssim_sim = -10
psnr_sim = -10
assert len(pred_imgs)>0
for p_img in pred_imgs:
t_img = load_img(tgt_imgs[0])
p_img = load_img(p_img, size=t_img.shape[2:])
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
perc_sim = min(perc_sim, t_perc_sim)
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
values_percsim += [perc_sim]
values_ssim += [ssim_sim]
if psnr_sim != np.float("inf"):
values_psnr += [psnr_sim]
else:
if torch.allclose(p_img, t_img):
equal_count += 1
print("{} equal src and wrp images.".format(equal_count))
else:
ambig_count += 1
print("{} ambiguous src and wrp images.".format(ambig_count))
if take_every_other:
n_valuespercsim = []
n_valuesssim = []
n_valuespsnr = []
for i in range(0, len(values_percsim) // 2):
n_valuespercsim += [
min(values_percsim[2 * i], values_percsim[2 * i + 1])
]
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
values_percsim = n_valuespercsim
values_ssim = n_valuesssim
values_psnr = n_valuespsnr
avg_percsim = np.mean(np.array(values_percsim))
std_percsim = np.std(np.array(values_percsim))
avg_psnr = np.mean(np.array(values_psnr))
std_psnr = np.std(np.array(values_psnr))
avg_ssim = np.mean(np.array(values_ssim))
std_ssim = np.std(np.array(values_ssim))
if simple_format:
# just to make yaml formatting readable
return {
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
"PSNR": [float(avg_psnr), float(std_psnr)],
"SSIM": [float(avg_ssim), float(std_ssim)],
}
else:
return {
"Perceptual similarity": (avg_percsim, std_percsim),
"PSNR": (avg_psnr, std_psnr),
"SSIM": (avg_ssim, std_ssim),
}
def compute_perceptual_similarity_from_list_topk(pred_imgs_list, tgt_imgs_list,
take_every_other, resize=False):
# Load VGG16 for feature similarity
vgg16 = PNet().to("cuda")
vgg16.eval()
vgg16.cuda()
values_percsim = []
values_ssim = []
values_psnr = []
individual_percsim = []
individual_ssim = []
individual_psnr = []
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
pred_imgs = pred_imgs_list[i]
tgt_imgs = [tgt_img]
assert len(tgt_imgs) == 1
if type(pred_imgs) != list:
assert False
pred_imgs = [pred_imgs]
perc_sim = 10000
ssim_sim = -10
psnr_sim = -10
sample_percsim = list()
sample_ssim = list()
sample_psnr = list()
for p_img in pred_imgs:
if resize:
t_img = load_img(tgt_imgs[0], size=(256,256))
else:
t_img = load_img(tgt_imgs[0])
p_img = load_img(p_img, size=t_img.shape[2:])
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
sample_percsim.append(t_perc_sim)
perc_sim = min(perc_sim, t_perc_sim)
t_ssim = ssim_metric(p_img, t_img).item()
sample_ssim.append(t_ssim)
ssim_sim = max(ssim_sim, t_ssim)
t_psnr = psnr(p_img, t_img).item()
sample_psnr.append(t_psnr)
psnr_sim = max(psnr_sim, t_psnr)
values_percsim += [perc_sim]
values_ssim += [ssim_sim]
values_psnr += [psnr_sim]
individual_percsim.append(sample_percsim)
individual_ssim.append(sample_ssim)
individual_psnr.append(sample_psnr)
if take_every_other:
assert False, "Do this later, after specifying topk to get proper results"
n_valuespercsim = []
n_valuesssim = []
n_valuespsnr = []
for i in range(0, len(values_percsim) // 2):
n_valuespercsim += [
min(values_percsim[2 * i], values_percsim[2 * i + 1])
]
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
values_percsim = n_valuespercsim
values_ssim = n_valuesssim
values_psnr = n_valuespsnr
avg_percsim = np.mean(np.array(values_percsim))
std_percsim = np.std(np.array(values_percsim))
avg_psnr = np.mean(np.array(values_psnr))
std_psnr = np.std(np.array(values_psnr))
avg_ssim = np.mean(np.array(values_ssim))
std_ssim = np.std(np.array(values_ssim))
individual_percsim = np.array(individual_percsim)
individual_psnr = np.array(individual_psnr)
individual_ssim = np.array(individual_ssim)
return {
"avg_of_best": {
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
"PSNR": [float(avg_psnr), float(std_psnr)],
"SSIM": [float(avg_ssim), float(std_ssim)],
},
"individual": {
"PSIM": individual_percsim,
"PSNR": individual_psnr,
"SSIM": individual_ssim,
}
}
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--folder", type=str, default="")
args.add_argument("--pred_image", type=str, default="")
args.add_argument("--target_image", type=str, default="")
args.add_argument("--take_every_other", action="store_true", default=False)
args.add_argument("--output_file", type=str, default="")
opts = args.parse_args()
folder = opts.folder
pred_img = opts.pred_image
tgt_img = opts.target_image
results = compute_perceptual_similarity(
folder, pred_img, tgt_img, opts.take_every_other
)
f = open(opts.output_file, 'w')
for key in results:
print("%s for %s: \n" % (key, opts.folder))
print(
"\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1])
)
f.write("%s for %s: \n" % (key, opts.folder))
f.write(
"\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1])
)
f.close()
================================================
FILE: ldm/modules/evaluate/frechet_video_distance.py
================================================
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python2, python3
"""Minimal Reference implementation for the Frechet Video Distance (FVD).
FVD is a metric for the quality of video generation models. It is inspired by
the FID (Frechet Inception Distance) used for images, but uses a different
embedding to be better suitable for videos.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow.compat.v1 as tf
import tensorflow_gan as tfgan
import tensorflow_hub as hub
def preprocess(videos, target_resolution):
"""Runs some preprocessing on the videos for I3D model.
Args:
videos: [batch_size, num_frames, height, width, depth] The videos to be
preprocessed. We don't care about the specific dtype of the videos, it can
be anything that tf.image.resize_bilinear accepts. Values are expected to
be in the range 0-255.
target_resolution: (width, height): target video resolution
Returns:
videos: [batch_size, num_frames, height, width, depth]
"""
videos_shape = list(videos.shape)
all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])
resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)
target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]
output_videos = tf.reshape(resized_videos, target_shape)
scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1
return scaled_videos
def _is_in_graph(tensor_name):
"""Checks whether a given tensor does exists in the graph."""
try:
tf.get_default_graph().get_tensor_by_name(tensor_name)
except KeyError:
return False
return True
def create_id3_embedding(videos,warmup=False,batch_size=16):
"""Embeds the given videos using the Inflated 3D Convolution ne twork.
Downloads the graph of the I3D from tf.hub and adds it to the graph on the
first call.
Args:
videos: [batch_size, num_frames, height=224, width=224, depth=3].
Expected range is [-1, 1].
Returns:
embedding: [batch_size, embedding_size]. embedding_size depends
on the model used.
Raises:
ValueError: when a provided embedding_layer is not supported.
"""
# batch_size = 16
module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"
# Making sure that we import the graph separately for
# each different input video tensor.
module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(
videos.name).replace(":", "_")
assert_ops = [
tf.Assert(
tf.reduce_max(videos) <= 1.001,
["max value in frame is > 1", videos]),
tf.Assert(
tf.reduce_min(videos) >= -1.001,
["min value in frame is < -1", videos]),
tf.assert_equal(
tf.shape(videos)[0],
batch_size, ["invalid frame batch size: ",
tf.shape(videos)],
summarize=6),
]
with tf.control_dependencies(assert_ops):
videos = tf.identity(videos)
module_scope = "%s_apply_default/" % module_name
# To check whether the module has already been loaded into the graph, we look
# for a given tensor name. If this tensor name exists, we assume the function
# has been called before and the graph was imported. Otherwise we import it.
# Note: in theory, the tensor could exist, but have wrong shapes.
# This will happen if create_id3_embedding is called with a frames_placehoder
# of wrong size/batch size, because even though that will throw a tf.Assert
# on graph-execution time, it will insert the tensor (with wrong shape) into
# the graph. This is why we need the following assert.
if warmup:
video_batch_size = int(videos.shape[0])
assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}"
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
if not _is_in_graph(tensor_name):
i3d_model = hub.Module(module_spec, name=module_name)
i3d_model(videos)
# gets the kinetics-i3d-400-logits layer
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)
return tensor
def calculate_fvd(real_activations,
generated_activations):
"""Returns a list of ops that compute metrics as funcs of activations.
Args:
real_activations: [num_samples, embedding_size]
generated_activations: [num_samples, embedding_size]
Returns:
A scalar that contains the requested FVD.
"""
return tfgan.eval.frechet_classifier_distance_from_activations(
real_activations, generated_activations)
================================================
FILE: ldm/modules/evaluate/ssim.py
================================================
# MIT Licence
# Methods to predict the SSIM, taken from
# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
from math import exp
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def gaussian(window_size, sigma):
gauss = torch.Tensor(
[
exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2))
for x in range(window_size)
]
)
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(
_2D_window.expand(channel, 1, window_size, window_size).contiguous()
)
return window
def _ssim(
img1, img2, window, window_size, channel, mask=None, size_average=True
):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = (
F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel)
- mu1_sq
)
sigma2_sq = (
F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel)
- mu2_sq
)
sigma12 = (
F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
- mu1_mu2
)
C1 = (0.01) ** 2
C2 = (0.03) ** 2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
)
if not (mask is None):
b = mask.size(0)
ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask
ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(
dim=1
).clamp(min=1)
return ssim_map
import pdb
pdb.set_trace
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2, mask=None):
(_, channel, _, _) = img1.size()
if (
channel == self.channel
and self.window.data.type() == img1.data.type()
):
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(
img1,
img2,
window,
self.window_size,
channel,
mask,
self.size_average,
)
def ssim(img1, img2, window_size=11, mask=None, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, mask, size_average)
================================================
FILE: ldm/modules/evaluate/torch_frechet_video_distance.py
================================================
# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks!
import os
import numpy as np
import io
import re
import requests
import html
import hashlib
import urllib
import urllib.request
import scipy.linalg
import multiprocessing as mp
import glob
from tqdm import tqdm
from typing import Any, List, Tuple, Union, Dict, Callable
from torchvision.io import read_video
import torch; torch.set_grad_enabled(False)
from einops import rearrange
from nitro.util import isvideo
def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float:
print('Calculate frechet distance...')
m = np.square(mu_sample - mu_ref).sum()
s, _ = scipy.linalg.sqrtm(np.dot(sigma_sample, sigma_ref), disp=False) # pylint: disable=no-member
fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2))
return float(fid)
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
mu = feats.mean(axis=0) # [d]
sigma = np.cov(feats, rowvar=False) # [d, d]
return mu, sigma
def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any:
"""Download the given URL and return a binary-mode file object to access the data."""
assert num_attempts >= 1
# Doesn't look like an URL scheme so interpret it as a local filename.
if not re.match('^[a-z]+://', url):
return url if return_filename else open(url, "rb")
# Handle file URLs. This code handles unusual file:// patterns that
# arise on Windows:
#
# file:///c:/foo.txt
#
# which would translate to a local '/c:/foo.txt' filename that's
# invalid. Drop the forward slash for such pathnames.
#
# If you touch this code path, you should test it on both Linux and
# Windows.
#
# Some internet resources suggest using urllib.request.url2pathname() but
# but that converts forward slashes to backslashes and this causes
# its own set of problems.
if url.startswith('file://'):
filename = urllib.parse.urlparse(url).path
if re.match(r'^/[a-zA-Z]:', filename):
filename = filename[1:]
return filename if return_filename else open(filename, "rb")
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
# Download.
url_name = None
url_data = None
with requests.Session() as session:
if verbose:
print("Downloading %s ..." % url, end="", flush=True)
for attempts_left in reversed(range(num_attempts)):
try:
with session.get(url) as res:
res.raise_for_status()
if len(res.content) == 0:
raise IOError("No data received")
if len(res.content) < 8192:
content_str = res.content.decode("utf-8")
if "download_warning" in res.headers.get("Set-Cookie", ""):
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
if len(links) == 1:
url = requests.compat.urljoin(url, links[0])
raise IOError("Google Drive virus checker nag")
if "Google Drive - Quota exceeded" in content_str:
raise IOError("Google Drive download quota exceeded -- please try again later")
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
url_name = match[1] if match else url
url_data = res.content
if verbose:
print(" done")
break
except KeyboardInterrupt:
raise
except:
if not attempts_left:
if verbose:
print(" failed")
raise
if verbose:
print(".", end="", flush=True)
# Return data as file object.
assert not return_filename
return io.BytesIO(url_data)
def load_video(ip):
vid, *_ = read_video(ip)
vid = rearrange(vid, 't h w c -> t c h w').to(torch.uint8)
return vid
def get_data_from_str(input_str,nprc = None):
assert os.path.isdir(input_str), f'Specified input folder "{input_str}" is not a directory'
vid_filelist = glob.glob(os.path.join(input_str,'*.mp4'))
print(f'Found {len(vid_filelist)} videos in dir {input_str}')
if nprc is None:
try:
nprc = mp.cpu_count()
except NotImplementedError:
print('WARNING: cpu_count() not avlailable, using only 1 cpu for video loading')
nprc = 1
pool = mp.Pool(processes=nprc)
vids = []
for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'):
vids.append(v)
vids = torch.stack(vids,dim=0).float()
return vids
def get_stats(stats):
assert os.path.isfile(stats) and stats.endswith('.npz'), f'no stats found under {stats}'
print(f'Using precomputed statistics under {stats}')
stats = np.load(stats)
stats = {key: stats[key] for key in stats.files}
return stats
@torch.no_grad()
def compute_fvd(ref_input, sample_input, bs=32,
ref_stats=None,
sample_stats=None,
nprc_load=None):
calc_stats = ref_stats is None or sample_stats is None
if calc_stats:
only_ref = sample_stats is not None
only_sample = ref_stats is not None
if isinstance(ref_input,str) and not only_sample:
ref_input = get_data_from_str(ref_input,nprc_load)
if isinstance(sample_input, str) and not only_ref:
sample_input = get_data_from_str(sample_input, nprc_load)
stats = compute_statistics(sample_input,ref_input,
device='cuda' if torch.cuda.is_available() else 'cpu',
bs=bs,
only_ref=only_ref,
only_sample=only_sample)
if only_ref:
stats.update(get_stats(sample_stats))
elif only_sample:
stats.update(get_stats(ref_stats))
else:
stats = get_stats(sample_stats)
stats.update(get_stats(ref_stats))
fvd = compute_frechet_distance(**stats)
return {'FVD' : fvd,}
@torch.no_grad()
def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only_ref=False,only_sample=False) -> Dict:
detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1'
detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer.
with open_url(detector_url, verbose=False) as f:
detector = torch.jit.load(f).eval().to(device)
assert not (only_sample and only_ref), 'only_ref and only_sample arguments are mutually exclusive'
ref_embed, sample_embed = [], []
info = f'Computing I3D activations for FVD score with batch size {bs}'
if only_ref:
if not isvideo(videos_real):
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
print(videos_real.shape)
if videos_real.shape[0] % bs == 0:
n_secs = videos_real.shape[0] // bs
else:
n_secs = videos_real.shape[0] // bs + 1
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
for ref_v in tqdm(videos_real, total=len(videos_real),desc=info):
feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
ref_embed.append(feats_ref)
elif only_sample:
if not isvideo(videos_fake):
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
print(videos_fake.shape)
if videos_fake.shape[0] % bs == 0:
n_secs = videos_fake.shape[0] // bs
else:
n_secs = videos_fake.shape[0] // bs + 1
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
for sample_v in tqdm(videos_fake, total=len(videos_real),desc=info):
feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
sample_embed.append(feats_sample)
else:
if not isvideo(videos_real):
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
if not isvideo(videos_fake):
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
if videos_fake.shape[0] % bs == 0:
n_secs = videos_fake.shape[0] // bs
else:
n_secs = videos_fake.shape[0] // bs + 1
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0)
for ref_v, sample_v in tqdm(zip(videos_real,videos_fake),total=len(videos_fake),desc=info):
# print(ref_v.shape)
# ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
# sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
sample_embed.append(feats_sample)
ref_embed.append(feats_ref)
out = dict()
if len(sample_embed) > 0:
sample_embed = np.concatenate(sample_embed,axis=0)
mu_sample, sigma_sample = compute_stats(sample_embed)
out.update({'mu_sample': mu_sample,
'sigma_sample': sigma_sample})
if len(ref_embed) > 0:
ref_embed = np.concatenate(ref_embed,axis=0)
mu_ref, sigma_ref = compute_stats(ref_embed)
out.update({'mu_ref': mu_ref,
'sigma_ref': sigma_ref})
return out
================================================
FILE: ldm/modules/image_degradation/__init__.py
================================================
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
================================================
FILE: ldm/modules/image_degradation/bsrgan.py
================================================
# -*- coding: utf-8 -*-
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
import numpy as np
import cv2
import torch
from functools import partial
import random
from scipy import ndimage
import scipy
import scipy.stats as ss
from scipy.interpolate import interp2d
from scipy.linalg import orth
import albumentations
import ldm.modules.image_degradation.utils_image as util
def modcrop_np(img, sf):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w, h = img.shape[:2]
im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def analytic_kernel(k):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size = k.shape[0]
# Calculate the big kernels size
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
# Normalize to 1
return cropped_big_k / cropped_big_k.sum()
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
return k
def gm_blur_kernel(mean, cov, size=15):
center = size / 2.0 + 0.5
k = np.zeros([size, size])
for y in range(size):
for x in range(size):
cy = y - center + 1
cx = x - center + 1
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
k = k / np.sum(k)
return k
def shift_pixel(x, sf, upper_left=True):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h, w = x.shape[:2]
shift = (sf - 1) * 0.5
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
if upper_left:
x1 = xv + shift
y1 = yv + shift
else:
x1 = xv - shift
y1 = yv - shift
x1 = np.clip(x1, 0, w - 1)
y1 = np.clip(y1, 0, h - 1)
if x.ndim == 2:
x = interp2d(xv, yv, x)(x1, y1)
if x.ndim == 3:
for i in range(x.shape[-1]):
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
return x
def blur(x, k):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
theta = np.random.rand() * np.pi # random theta
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
Z = np.stack([X, Y], 2)[:, :, :, None]
# Calcualte Gaussian for every pixel of the kernel
ZZ = Z - MU
ZZ_t = ZZ.transpose(0, 1, 3, 2)
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel = raw_kernel / np.sum(raw_kernel)
return kernel
def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
sumh = h.sum()
if sumh != 0:
h = h / sumh
return h
def fspecial_laplacian(alpha):
alpha = max([0, min([alpha, 1])])
h1 = alpha / (alpha + 1)
h2 = (1 - alpha) / (alpha + 1)
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
h = np.array(h)
return h
def fspecial(filter_type, *args, **kwargs):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian':
return fspecial_laplacian(*args, **kwargs)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def bicubic_degradation(x, sf=3):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def classical_degradation(x, k, sf=3):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
return x[st::sf, st::sf, ...]
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype('float32')
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
K = img + weight * residual
K = np.clip(K, 0, 1)
return soft_mask * K + (1 - soft_mask) * img
def add_blur(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
else:
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
return img
def add_resize(img, sf=4):
rnum = np.random.rand()
if rnum > 0.8: # up
sf1 = random.uniform(1, 2)
elif rnum < 0.7: # down
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
img = np.clip(img, 0.0, 1.0)
return img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else: # add noise
L = noise_level2 / 255.
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else:
L = noise_level2 / 255.
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
def add_JPEG_noise(img):
quality_factor = random.randint(30, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
return lq, hq
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
raise ValueError(f'img size ({h1}X{w1}) is too small!')
hq = img.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
for i in shuffle_order:
if i == 0:
img = add_blur(img, sf=sf)
elif i == 1:
img = add_blur(img, sf=sf)
elif i == 2:
a, b = img.shape[1], img.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
img = np.clip(img, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
img = add_JPEG_noise(img)
elif i == 6:
# add processed camera sensor noise
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
img = add_JPEG_noise(img)
# random crop
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
return img, hq
# todo no isp_model?
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image = util.uint2single(image)
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
for i in shuffle_order:
if i == 0:
image = add_blur(image, sf=sf)
elif i == 1:
image = add_blur(image, sf=sf)
elif i == 2:
a, b = image.shape[1], image.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
image = np.clip(image, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
image = add_JPEG_noise(image)
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image":image}
return example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
"""
This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
use_shuffle: the degradation shuffle
use_sharp: sharpening the img
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
raise ValueError(f'img size ({h1}X{w1}) is too small!')
if use_sharp:
img = add_sharpening(img)
hq = img.copy()
if random.random() < shuffle_prob:
shuffle_order = random.sample(range(13), 13)
else:
shuffle_order = list(range(13))
# local shuffle for noise, JPEG is always the last one
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
for i in shuffle_order:
if i == 0:
img = add_blur(img, sf=sf)
elif i == 1:
img = add_resize(img, sf=sf)
elif i == 2:
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
elif i == 3:
if random.random() < poisson_prob:
img = add_Poisson_noise(img)
elif i == 4:
if random.random() < speckle_prob:
img = add_speckle_noise(img)
elif i == 5:
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
elif i == 6:
img = add_JPEG_noise(img)
elif i == 7:
img = add_blur(img, sf=sf)
elif i == 8:
img = add_resize(img, sf=sf)
elif i == 9:
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
elif i == 10:
if random.random() < poisson_prob:
img = add_Poisson_noise(img)
elif i == 11:
if random.random() < speckle_prob:
img = add_speckle_noise(img)
elif i == 12:
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
else:
print('check the shuffle!')
# resize to desired size
img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
interpolation=random.choice([1, 2, 3]))
# add final JPEG compression noise
img = add_JPEG_noise(img)
# random crop
img, hq = random_crop(img, hq, sf, lq_patchsize)
return img, hq
if __name__ == '__main__':
print("hey")
img = util.imread_uint('utils/test.png', 3)
print(img)
img = util.uint2single(img)
print(img)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_lq = deg_fn(img)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
util.imsave(img_concat, str(i) + '.png')
================================================
FILE: ldm/modules/image_degradation/bsrgan_light.py
================================================
# -*- coding: utf-8 -*-
import numpy as np
import cv2
import torch
from functools import partial
import random
from scipy import ndimage
import scipy
import scipy.stats as ss
from scipy.interpolate import interp2d
from scipy.linalg import orth
import albumentations
import ldm.modules.image_degradation.utils_image as util
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
def modcrop_np(img, sf):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w, h = img.shape[:2]
im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def analytic_kernel(k):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size = k.shape[0]
# Calculate the big kernels size
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
# Normalize to 1
return cropped_big_k / cropped_big_k.sum()
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
return k
def gm_blur_kernel(mean, cov, size=15):
center = size / 2.0 + 0.5
k = np.zeros([size, size])
for y in range(size):
for x in range(size):
cy = y - center + 1
cx = x - center + 1
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
k = k / np.sum(k)
return k
def shift_pixel(x, sf, upper_left=True):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h, w = x.shape[:2]
shift = (sf - 1) * 0.5
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
if upper_left:
x1 = xv + shift
y1 = yv + shift
else:
x1 = xv - shift
y1 = yv - shift
x1 = np.clip(x1, 0, w - 1)
y1 = np.clip(y1, 0, h - 1)
if x.ndim == 2:
x = interp2d(xv, yv, x)(x1, y1)
if x.ndim == 3:
for i in range(x.shape[-1]):
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
return x
def blur(x, k):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
theta = np.random.rand() * np.pi # random theta
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
Z = np.stack([X, Y], 2)[:, :, :, None]
# Calcualte Gaussian for every pixel of the kernel
ZZ = Z - MU
ZZ_t = ZZ.transpose(0, 1, 3, 2)
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel = raw_kernel / np.sum(raw_kernel)
return kernel
def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
sumh = h.sum()
if sumh != 0:
h = h / sumh
return h
def fspecial_laplacian(alpha):
alpha = max([0, min([alpha, 1])])
h1 = alpha / (alpha + 1)
h2 = (1 - alpha) / (alpha + 1)
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
h = np.array(h)
return h
def fspecial(filter_type, *args, **kwargs):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian':
return fspecial_laplacian(*args, **kwargs)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def bicubic_degradation(x, sf=3):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x = bicubic_degradation(x, sf=sf)
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def classical_degradation(x, k, sf=3):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
return x[st::sf, st::sf, ...]
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype('float32')
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
K = img + weight * residual
K = np.clip(K, 0, 1)
return soft_mask * K + (1 - soft_mask) * img
def add_blur(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
wd2 = wd2/4
wd = wd/4
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
else:
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
return img
def add_resize(img, sf=4):
rnum = np.random.rand()
if rnum > 0.8: # up
sf1 = random.uniform(1, 2)
elif rnum < 0.7: # down
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
img = np.clip(img, 0.0, 1.0)
return img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else: # add noise
L = noise_level2 / 255.
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else:
L = noise_level2 / 255.
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
def add_JPEG_noise(img):
quality_factor = random.randint(80, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
return lq, hq
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
raise ValueError(f'img size ({h1}X{w1}) is too small!')
hq = img.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
for i in shuffle_order:
if i == 0:
img = add_blur(img, sf=sf)
elif i == 1:
img = add_blur(img, sf=sf)
elif i == 2:
a, b = img.shape[1], img.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
img = np.clip(img, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
img = add_JPEG_noise(img)
elif i == 6:
# add processed camera sensor noise
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
img = add_JPEG_noise(img)
# random crop
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
return img, hq
# todo no isp_model?
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image = util.uint2single(image)
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
for i in shuffle_order:
if i == 0:
image = add_blur(image, sf=sf)
# elif i == 1:
# image = add_blur(image, sf=sf)
if i == 0:
pass
elif i == 2:
a, b = image.shape[1], image.shape[0]
# downsample2
if random.random() < 0.8:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
image = np.clip(image, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
image = add_JPEG_noise(image)
#
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image": image}
return example
if __name__ == '__main__':
print("hey")
img = util.imread_uint('utils/test.png', 3)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_hq = img
img_lq = deg_fn(img)["image"]
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
util.imsave(img_concat, str(i) + '.png')
================================================
FILE: ldm/modules/image_degradation/utils_image.py
================================================
import os
import math
import random
import numpy as np
import torch
import cv2
from torchvision.utils import make_grid
from datetime import datetime
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
'''
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
# https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def get_timestamp():
return datetime.now().strftime('%y%m%d-%H%M%S')
def imshow(x, title=None, cbar=False, figsize=None):
plt.figure(figsize=figsize)
plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
if title:
plt.title(title)
if cbar:
plt.colorbar()
plt.show()
def surf(Z, cmap='rainbow', figsize=None):
plt.figure(figsize=figsize)
ax3 = plt.axes(projection='3d')
w, h = Z.shape[:2]
xx = np.arange(0,w,1)
yy = np.arange(0,h,1)
X, Y = np.meshgrid(xx, yy)
ax3.plot_surface(X,Y,Z,cmap=cmap)
#ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
plt.show()
'''
# --------------------------------------------
# get image pathes
# --------------------------------------------
'''
def get_image_paths(dataroot):
paths = None # return None if dataroot is None
if dataroot is not None:
paths = sorted(_get_paths_from_images(dataroot))
return paths
def _get_paths_from_images(path):
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)
images.append(img_path)
assert images, '{:s} has no valid image file'.format(path)
return images
'''
# --------------------------------------------
# split large images into small images
# --------------------------------------------
'''
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
w, h = img.shape[:2]
patches = []
if w > p_max and h > p_max:
w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
w1.append(w-p_size)
h1.append(h-p_size)
# print(w1)
# print(h1)
for i in w1:
for j in h1:
patches.append(img[i:i+p_size, j:j+p_size,:])
else:
patches.append(img)
return patches
def imssave(imgs, img_path):
"""
imgs: list, N images of size WxHxC
"""
img_name, ext = os.path.splitext(os.path.basename(img_path))
for i, img in enumerate(imgs):
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
cv2.imwrite(new_path, img)
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
"""
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
will be splitted.
Args:
original_dataroot:
taget_dataroot:
p_size: size of small images
p_overlap: patch size in training is a good choice
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
"""
paths = get_image_paths(original_dataroot)
for img_path in paths:
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img = imread_uint(img_path, n_channels=n_channels)
patches = patches_from_image(img, p_size, p_overlap, p_max)
imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
#if original_dataroot == taget_dataroot:
#del img_path
'''
# --------------------------------------------
# makedir
# --------------------------------------------
'''
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def mkdirs(paths):
if isinstance(paths, str):
mkdir(paths)
else:
for path in paths:
mkdir(path)
def mkdir_and_rename(path):
if os.path.exists(path):
new_name = path + '_archived_' + get_timestamp()
print('Path already exists. Rename it to [{:s}]'.format(new_name))
os.rename(path, new_name)
os.makedirs(path)
'''
# --------------------------------------------
# read image from path
# opencv is fast, but read BGR numpy image
# --------------------------------------------
'''
# --------------------------------------------
# get uint8 image of size HxWxn_channles (RGB)
# --------------------------------------------
def imread_uint(path, n_channels=3):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if n_channels == 1:
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
img = np.expand_dims(img, axis=2) # HxWx1
elif n_channels == 3:
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
return img
# --------------------------------------------
# matlab's imwrite
# --------------------------------------------
def imsave(img, img_path):
img = np.squeeze(img)
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
cv2.imwrite(img_path, img)
def imwrite(img, img_path):
img = np.squeeze(img)
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
cv2.imwrite(img_path, img)
# --------------------------------------------
# get single image of size HxWxn_channles (BGR)
# --------------------------------------------
def read_img(path):
# read image by cv2
# return: Numpy float32, HWC, BGR, [0,1]
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
'''
# --------------------------------------------
# image format conversion
# --------------------------------------------
# numpy(single) <---> numpy(unit)
# numpy(single) <---> tensor
# numpy(unit) <---> tensor
# --------------------------------------------
'''
# --------------------------------------------
# numpy(single) [0, 1] <---> numpy(unit)
# --------------------------------------------
def uint2single(img):
return np.float32(img/255.)
def single2uint(img):
return np.uint8((img.clip(0, 1)*255.).round())
def uint162single(img):
return np.float32(img/65535.)
def single2uint16(img):
return np.uint16((img.clip(0, 1)*65535.).round())
# --------------------------------------------
# numpy(unit) (HxWxC or HxW) <---> tensor
# --------------------------------------------
# convert uint to 4-dimensional torch tensor
def uint2tensor4(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
# convert uint to 3-dimensional torch tensor
def uint2tensor3(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
# convert 2/3/4-dimensional torch tensor to uint
def tensor2uint(img):
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
return np.uint8((img*255.0).round())
# --------------------------------------------
# numpy(single) (HxWxC) <---> tensor
# --------------------------------------------
# convert single (HxWxC) to 3-dimensional torch tensor
def single2tensor3(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
# convert single (HxWxC) to 4-dimensional torch tensor
def single2tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
# convert torch tensor to single
def tensor2single(img):
img = img.data.squeeze().float().cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
return img
# convert torch tensor to single
def tensor2single3(img):
img = img.data.squeeze().float().cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
elif img.ndim == 2:
img = np.expand_dims(img, axis=2)
return img
def single2tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
def single32tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
def single42tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
# from skimage.io import imread, imsave
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
'''
Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
'''
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
n_dim = tensor.dim()
if n_dim == 4:
n_img = len(tensor)
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 3:
img_np = tensor.numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 2:
img_np = tensor.numpy()
else:
raise TypeError(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
if out_type == np.uint8:
img_np = (img_np * 255.0).round()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return img_np.astype(out_type)
'''
# --------------------------------------------
# Augmentation, flipe and/or rotate
# --------------------------------------------
# The following two are enough.
# (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH
# --------------------------------------------
'''
def augment_img(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn)
'''
if mode == 0:
return img
elif mode == 1:
return np.flipud(np.rot90(img))
elif mode == 2:
return np.flipud(img)
elif mode == 3:
return np.rot90(img, k=3)
elif mode == 4:
return np.flipud(np.rot90(img, k=2))
elif mode == 5:
return np.rot90(img)
elif mode == 6:
return np.rot90(img, k=2)
elif mode == 7:
return np.flipud(np.rot90(img, k=3))
def augment_img_tensor4(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn)
'''
if mode == 0:
return img
elif mode == 1:
return img.rot90(1, [2, 3]).flip([2])
elif mode == 2:
return img.flip([2])
elif mode == 3:
return img.rot90(3, [2, 3])
elif mode == 4:
return img.rot90(2, [2, 3]).flip([2])
elif mode == 5:
return img.rot90(1, [2, 3])
elif mode == 6:
return img.rot90(2, [2, 3])
elif mode == 7:
return img.rot90(3, [2, 3]).flip([2])
def augment_img_tensor(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn)
'''
img_size = img.size()
img_np = img.data.cpu().numpy()
if len(img_size) == 3:
img_np = np.transpose(img_np, (1, 2, 0))
elif len(img_size) == 4:
img_np = np.transpose(img_np, (2, 3, 1, 0))
img_np = augment_img(img_np, mode=mode)
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
if len(img_size) == 3:
img_tensor = img_tensor.permute(2, 0, 1)
elif len(img_size) == 4:
img_tensor = img_tensor.permute(3, 2, 0, 1)
return img_tensor.type_as(img)
def augment_img_np3(img, mode=0):
if mode == 0:
return img
elif mode == 1:
return img.transpose(1, 0, 2)
elif mode == 2:
return img[::-1, :, :]
elif mode == 3:
img = img[::-1, :, :]
img = img.transpose(1, 0, 2)
return img
elif mode == 4:
return img[:, ::-1, :]
elif mode == 5:
img = img[:, ::-1, :]
img = img.transpose(1, 0, 2)
return img
elif mode == 6:
img = img[:, ::-1, :]
img = img[::-1, :, :]
return img
elif mode == 7:
img = img[:, ::-1, :]
img = img[::-1, :, :]
img = img.transpose(1, 0, 2)
return img
def augment_imgs(img_list, hflip=True, rot=True):
# horizontal flip OR rotate
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
'''
# --------------------------------------------
# modcrop and shave
# --------------------------------------------
'''
def modcrop(img_in, scale):
# img_in: Numpy, HWC or HW
img = np.copy(img_in)
if img.ndim == 2:
H, W = img.shape
H_r, W_r = H % scale, W % scale
img = img[:H - H_r, :W - W_r]
elif img.ndim == 3:
H, W, C = img.shape
H_r, W_r = H % scale, W % scale
img = img[:H - H_r, :W - W_r, :]
else:
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
return img
def shave(img_in, border=0):
# img_in: Numpy, HWC or HW
img = np.copy(img_in)
h, w = img.shape[:2]
img = img[border:h-border, border:w-border]
return img
'''
# --------------------------------------------
# image processing process on numpy image
# channel_convert(in_c, tar_type, img_list):
# rgb2ycbcr(img, only_y=True):
# bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img):
# --------------------------------------------
'''
def rgb2ycbcr(img, only_y=True):
'''same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def ycbcr2rgb(img):
'''same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def bgr2ycbcr(img, only_y=True):
'''bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def channel_convert(in_c, tar_type, img_list):
# conversion among BGR, gray and y
if in_c == 3 and tar_type == 'gray': # BGR to gray
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
return [np.expand_dims(img, axis=2) for img in gray_list]
elif in_c == 3 and tar_type == 'y': # BGR to y
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
return [np.expand_dims(img, axis=2) for img in y_list]
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
else:
return img_list
'''
# --------------------------------------------
# metric, PSNR and SSIM
# --------------------------------------------
'''
# --------------------------------------------
# PSNR
# --------------------------------------------
def calculate_psnr(img1, img2, border=0):
# img1 and img2 have range [0, 255]
#img1 = img1.squeeze()
#img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1 = img1[border:h-border, border:w-border]
img2 = img2[border:h-border, border:w-border]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
# --------------------------------------------
# SSIM
# --------------------------------------------
def calculate_ssim(img1, img2, border=0):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
#img1 = img1.squeeze()
#img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1 = img1[border:h-border, border:w-border]
img2 = img2[border:h-border, border:w-border]
if img1.ndim == 2:
return ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')
def ssim(img1, img2):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
'''
# --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1]
# --------------------------------------------
'''
# matlab 'imresize' function, now only support 'bicubic'
def cubic(x):
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
1, P).expand(out_length, P)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
# apply cubic kernel
if (scale < 1) and (antialiasing):
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, P)
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, P - 2)
weights = weights.narrow(1, 1, P - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, P - 2)
weights = weights.narrow(1, 0, P - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
# --------------------------------------------
# imresize for tensor image [0, 1]
# --------------------------------------------
def imresize(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: pytorch tensor, CHW or HW [0,1]
# output: CHW or HW [0,1] w/o round
need_squeeze = True if img.dim() == 2 else False
if need_squeeze:
img.unsqueeze_(0)
in_C, in_H, in_W = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
kernel_width = 4
kernel = 'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
sym_patch = img[:, :sym_len_Hs, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
sym_patch = img[:, -sym_len_He:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(in_C, out_H, in_W)
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
sym_patch = out_1[:, :, :sym_len_Ws]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
sym_patch = out_1[:, :, -sym_len_We:]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(in_C, out_H, out_W)
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
if need_squeeze:
out_2.squeeze_()
return out_2
# --------------------------------------------
# imresize for numpy image [0, 1]
# --------------------------------------------
def imresize_np(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC or HW [0,1]
# output: HWC or HW [0,1] w/o round
img = torch.from_numpy(img)
need_squeeze = True if img.dim() == 2 else False
if need_squeeze:
img.unsqueeze_(2)
in_H, in_W, in_C = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
kernel_width = 4
kernel = 'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
sym_patch = img[:sym_len_Hs, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
sym_patch = img[-sym_len_He:, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(out_H, in_W, in_C)
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
sym_patch = out_1[:, :sym_len_Ws, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
sym_patch = out_1[:, -sym_len_We:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(out_H, out_W, in_C)
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
if need_squeeze:
out_2.squeeze_()
return out_2.numpy()
if __name__ == '__main__':
print('---')
# img = imread_uint('test.bmp', 3)
# img = uint2single(img)
# img_bicubic = imresize_np(img, 1/4)
================================================
FILE: ldm/modules/losses/__init__.py
================================================
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
================================================
FILE: ldm/modules/losses/contperceptual.py
================================================
import torch
import torch.nn as nn
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
class LPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_loss="hinge"):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
global_step, last_layer=None, cond=None, split="train",
weights=None):
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights*nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
return d_loss, log
================================================
FILE: ldm/modules/losses/vqperceptual.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.):
if global_step < threshold:
weight = value
return weight
def measure_perplexity(predicted_indices, n_embed):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
def l1(x, y):
return torch.abs(x-y)
def l2(x, y):
return torch.pow((x-y), 2)
class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
pixel_loss="l1"):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert perceptual_loss in ["lpips", "clips", "dists"]
assert pixel_loss in ["l1", "l2"]
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
if perceptual_loss == "lpips":
print(f"{self.__class__.__name__}: Running with LPIPS.")
self.perceptual_loss = LPIPS().eval()
else:
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
self.perceptual_weight = perceptual_weight
if pixel_loss == "l1":
self.pixel_loss = l1
else:
self.pixel_loss = l2
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf
).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.n_classes = n_classes
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
if not exists(codebook_loss):
codebook_loss = torch.tensor([0.]).to(inputs.device)
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
nll_loss = rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
if predicted_indices is not None:
assert self.n_classes is not None
with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
log[f"{split}/perplexity"] = perplexity
log[f"{split}/cluster_usage"] = cluster_usage
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
return d_loss, log
================================================
FILE: ldm/modules/x_transformer.py
================================================
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
import torch
from torch import nn, einsum
import torch.nn.functional as F
from functools import partial
from inspect import isfunction
from collections import namedtuple
from einops import rearrange, repeat, reduce
# constants
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [
'pre_softmax_attn',
'post_softmax_attn'
])
LayerIntermediates = namedtuple('Intermediates', [
'hiddens',
'attn_intermediates'
])
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.emb = nn.Embedding(max_seq_len, dim)
self.init_()
def init_(self):
nn.init.normal_(self.emb.weight, std=0.02)
def forward(self, x):
n = torch.arange(x.shape[1], device=x.device)
return self.emb(n)[None, :, :]
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]
# helpers
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def always(val):
def inner(*args, **kwargs):
return val
return inner
def not_equals(val):
def inner(x):
return x != val
return inner
def equals(val):
def inner(x):
return x == val
return inner
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# classes
class Scale(nn.Module):
def __init__(self, value, fn):
super().__init__()
self.value = value
self.fn = fn
def forward(self, x, **kwargs):
x, *rest = self.fn(x, **kwargs)
return (x * self.value, *rest)
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.zeros(1))
def forward(self, x, **kwargs):
x, *rest = self.fn(x, **kwargs)
return (x * self.g, *rest)
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class Residual(nn.Module):
def forward(self, x, residual):
return x + residual
class GRUGating(nn.Module):
def __init__(self, dim):
super().__init__()
self.gru = nn.GRUCell(dim, dim)
def forward(self, x, residual):
gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d')
)
return gated_output.reshape_as(x)
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
# attention.
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
mask=None,
talking_heads=False,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.,
on_attn=False
):
super().__init__()
if use_entmax15:
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
self.scale = dim_head ** -0.5
self.heads = heads
self.causal = causal
self.mask = mask
inner_dim = dim_head * heads
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_k = nn.Linear(dim, inner_dim, bias=False)
self.to_v = nn.Linear(dim, inner_dim, bias=False)
self.dropout = nn.Dropout(dropout)
# talking heads
self.talking_heads = talking_heads
if talking_heads:
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
# explicit topk sparse attention
self.sparse_topk = sparse_topk
# entmax
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
self.attn_fn = F.softmax
# add memory key / values
self.num_mem_kv = num_mem_kv
if num_mem_kv > 0:
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
# attention on attention
self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
rel_pos=None,
sinusoidal_emb=None,
prev_attn=None,
mem=None
):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
kv_input = default(context, x)
q_input = x
k_input = kv_input
v_input = kv_input
if exists(mem):
k_input = torch.cat((mem, k_input), dim=-2)
v_input = torch.cat((mem, v_input), dim=-2)
if exists(sinusoidal_emb):
# in shortformer, the query would start at a position offset depending on the past cached memory
offset = k_input.shape[-2] - q_input.shape[-2]
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
k_input = k_input + sinusoidal_emb(k_input)
q = self.to_q(q_input)
k = self.to_k(k_input)
v = self.to_v(v_input)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
input_mask = None
if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
k_mask = q_mask if not exists(context) else context_mask
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j')
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = max_neg_value(dots)
if exists(prev_attn):
dots = dots + prev_attn
pre_softmax_attn = dots
if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
if exists(rel_pos):
dots = rel_pos(dots)
if exists(input_mask):
dots.masked_fill_(~input_mask, mask_value)
del input_mask
if self.causal:
i, j = dots.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
top, _ = dots.topk(self.sparse_topk, dim=-1)
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
mask = dots < vk
dots.masked_fill_(mask, mask_value)
del mask
attn = self.attn_fn(dots, dim=-1)
post_softmax_attn = attn
attn = self.dropout(attn)
if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn
)
return self.to_out(out), intermediates
class AttentionLayers(nn.Module):
def __init__(
self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
**kwargs
):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
self.dim = dim
self.depth = depth
self.layers = nn.ModuleList([])
self.has_pos_emb = position_infused_attn
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
self.rotary_pos_emb = always(None)
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = None
self.pre_norm = pre_norm
self.residual_attn = residual_attn
self.cross_residual_attn = cross_residual_attn
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
norm_class = RMSNorm if use_rmsnorm else norm_class
norm_fn = partial(norm_class, dim)
norm_fn = nn.Identity if use_rezero else norm_fn
branch_fn = Rezero if use_rezero else None
if cross_attend and not only_cross:
default_block = ('a', 'c', 'f')
elif cross_attend and only_cross:
default_block = ('c', 'f')
else:
default_block = ('a', 'f')
if macaron:
default_block = ('f',) + default_block
if exists(custom_layers):
layer_types = custom_layers
elif exists(par_ratio):
par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block))
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (par_width - len(default_block))
par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
else:
layer_types = default_block * depth
self.layer_types = layer_types
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
for layer_type in self.layer_types:
if layer_type == 'a':
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
elif layer_type == 'c':
layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f':
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
else:
raise Exception(f'invalid layer type {layer_type}')
if isinstance(layer, Attention) and exists(branch_fn):
layer = branch_fn(layer)
if gate_residual:
residual_fn = GRUGating(dim)
else:
residual_fn = Residual()
self.layers.append(nn.ModuleList([
norm_fn(),
layer,
residual_fn
]))
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
mems=None,
return_hiddens=False
):
hiddens = []
intermediates = []
prev_attn = None
prev_cross_attn = None
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
is_last = ind == (len(self.layers) - 1)
if layer_type == 'a':
hiddens.append(x)
layer_mem = mems.pop(0)
residual = x
if self.pre_norm:
x = norm(x)
if layer_type == 'a':
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
prev_attn=prev_attn, mem=layer_mem)
elif layer_type == 'c':
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
elif layer_type == 'f':
out = block(x)
x = residual_fn(out, residual)
if layer_type in ('a', 'c'):
intermediates.append(inter)
if layer_type == 'a' and self.residual_attn:
prev_attn = inter.pre_softmax_attn
elif layer_type == 'c' and self.cross_residual_attn:
prev_cross_attn = inter.pre_softmax_attn
if not self.pre_norm and not is_last:
x = norm(x)
if return_hiddens:
intermediates = LayerIntermediates(
hiddens=hiddens,
attn_intermediates=intermediates
)
return x, intermediates
return x
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder'
super().__init__(causal=False, **kwargs)
class TransformerWrapper(nn.Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.,
emb_dropout=0.,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True
):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.num_tokens = num_tokens
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
# let funnel encoder know number of memory tokens, if specified
if hasattr(attn_layers, 'num_memory_tokens'):
attn_layers.num_memory_tokens = num_memory_tokens
def init_(self):
nn.init.normal_(self.token_emb.weight, std=0.02)
def forward(
self,
x,
return_embeddings=False,
mask=None,
return_mems=False,
return_attn=False,
mems=None,
**kwargs
):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x)
x += self.pos_emb(x)
x = self.emb_dropout(x)
x = self.project_emb(x)
if num_mem > 0:
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
x = torch.cat((mem, x), dim=1)
# auto-handle masking after appending memory tokens
if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True)
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
out = self.to_logits(x) if not return_embeddings else x
if return_mems:
hiddens = intermediates.hiddens
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
return out, new_mems
if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
return out, attn_maps
return out
================================================
FILE: ldm/thirdp/psp/helpers.py
================================================
# https://github.com/eladrich/pixel2style2pixel
from collections import namedtuple
import torch
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
"""
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
""" A named tuple describing a ResNet block. """
def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
def get_blocks(num_layers):
if num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3)
]
else:
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
return blocks
class SEModule(Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class bottleneck_IR(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth)
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class bottleneck_IR_SE(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR_SE, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth)
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth),
SEModule(depth, 16)
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
================================================
FILE: ldm/thirdp/psp/id_loss.py
================================================
# https://github.com/eladrich/pixel2style2pixel
import torch
from torch import nn
from ldm.thirdp.psp.model_irse import Backbone
class IDFeatures(nn.Module):
def __init__(self, model_path):
super(IDFeatures, self).__init__()
print('Loading ResNet ArcFace')
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
self.facenet.load_state_dict(torch.load(model_path, map_location="cpu"))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
def forward(self, x, crop=False):
# Not sure of the image range here
if crop:
x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
x = x[:, :, 35:223, 32:220]
x = self.face_pool(x)
x_feats = self.facenet(x)
return x_feats
================================================
FILE: ldm/thirdp/psp/model_irse.py
================================================
# https://github.com/eladrich/pixel2style2pixel
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
"""
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
class Backbone(Module):
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
super(Backbone, self).__init__()
assert input_size in [112, 224], "input_size should be 112 or 224"
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
BatchNorm2d(64),
PReLU(64))
if input_size == 112:
self.output_layer = Sequential(BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 7 * 7, 512),
BatchNorm1d(512, affine=affine))
else:
self.output_layer = Sequential(BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 14 * 14, 512),
BatchNorm1d(512, affine=affine))
modules = []
for block in blocks:
for bottleneck in block:
modules.append(unit_module(bottleneck.in_channel,
bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return l2_norm(x)
def IR_50(input_size):
"""Constructs a ir-50 model."""
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_101(input_size):
"""Constructs a ir-101 model."""
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_152(input_size):
"""Constructs a ir-152 model."""
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_SE_50(input_size):
"""Constructs a ir_se-50 model."""
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
return model
def IR_SE_101(input_size):
"""Constructs a ir_se-101 model."""
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
return model
def IR_SE_152(input_size):
"""Constructs a ir_se-152 model."""
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
return model
================================================
FILE: ldm/util.py
================================================
import importlib
import torchvision
import torch
from torch import optim
import numpy as np
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import time
import cv2
import PIL
def pil_rectangle_crop(im):
width, height = im.size # Get dimensions
if width <= height:
left = 0
right = width
top = (height - width)/2
bottom = (height + width)/2
else:
top = 0
bottom = height
left = (width - height) / 2
bottom = (width + height) / 2
# Crop the center of the image
im = im.crop((left, top, right, bottom))
return im
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x,torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
class AdamWwithEMAandWings(optim.Optimizer):
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
ema_power=1., param_names=()):
"""AdamW that saves EMA versions of the parameters."""
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= ema_decay <= 1.0:
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
ema_power=ema_power, param_names=param_names)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
ema_params_with_grad = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
ema_decay = group['ema_decay']
ema_power = group['ema_power']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of parameter values
state['param_exp_avg'] = p.detach().float().clone()
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
ema_params_with_grad.append(state['param_exp_avg'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
optim._functional.adamw(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=False)
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
return loss
================================================
FILE: main.py
================================================
import torch
import argparse
import pandas as pd
import sys
from nerf.provider import NeRFDataset
from nerf.utils import *
# torch.autograd.set_detect_anomaly(True)
if __name__ == '__main__':
# See https://stackoverflow.com/questions/27433316/how-to-get-argparse-to-read-arguments-from-a-file-with-an-option-rather-than-pre
class LoadFromFile (argparse.Action):
def __call__ (self, parser, namespace, values, option_string = None):
with values as f:
# parse arguments in the file and store them in the target namespace
parser.parse_args(f.read().split(), namespace)
parser = argparse.ArgumentParser()
parser.add_argument('--file', type=open, action=LoadFromFile, help="specify a file filled with more arguments")
parser.add_argument('--text', default=None, help="text prompt")
parser.add_argument('--negative', default='', type=str, help="negative text prompt")
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray")
parser.add_argument('-O2', action='store_true', help="equals --backbone vanilla")
parser.add_argument('--test', action='store_true', help="test mode")
parser.add_argument('--six_views', action='store_true', help="six_views mode: save the images of the six views")
parser.add_argument('--eval_interval', type=int, default=1, help="evaluate on the valid set every interval epochs")
parser.add_argument('--test_interval', type=int, default=100, help="test on the test set every interval epochs")
parser.add_argument('--workspace', type=str, default='workspace')
parser.add_argument('--seed', default=None)
parser.add_argument('--image', default=None, help="image prompt")
parser.add_argument('--image_config', default=None, help="image config csv")
parser.add_argument('--known_view_interval', type=int, default=4, help="train default view with RGB loss every & iters, only valid if --image is not None.")
parser.add_argument('--IF', action='store_true', help="experimental: use DeepFloyd IF as the guidance model for nerf stage")
parser.add_argument('--guidance', type=str, nargs='*', default=['SD'], help='guidance model')
parser.add_argument('--guidance_scale', type=float, default=100, help="diffusion model classifier-free guidance scale")
parser.add_argument('--save_mesh', action='store_true', help="export an obj mesh with texture")
parser.add_argument('--mcubes_resolution', type=int, default=256, help="mcubes resolution for extracting mesh")
parser.add_argument('--decimate_target', type=int, default=5e4, help="target face number for mesh decimation")
parser.add_argument('--dmtet', action='store_true', help="use dmtet finetuning")
parser.add_argument('--tet_grid_size', type=int, default=128, help="tet grid size")
parser.add_argument('--init_with', type=str, default='', help="ckpt to init dmtet")
parser.add_argument('--lock_geo', action='store_true', help="disable dmtet to learn geometry")
## Perp-Neg options
parser.add_argument('--perpneg', action='store_true', help="use perp_neg")
parser.add_argument('--negative_w', type=float, default=-2, help="The scale of the weights of negative prompts. A larger value will help to avoid the Janus problem, but may cause flat faces. Vary between 0 to -4, depending on the prompt")
parser.add_argument('--front_decay_factor', type=float, default=2, help="decay factor for the front prompt")
parser.add_argument('--side_decay_factor', type=float, default=10, help="decay factor for the side prompt")
### training options
parser.add_argument('--iters', type=int, default=10000, help="training iters")
parser.add_argument('--lr', type=float, default=1e-3, help="max learning rate")
parser.add_argument('--ckpt', type=str, default='latest', help="possible options are ['latest', 'scratch', 'best', 'latest_model']")
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
parser.add_argument('--taichi_ray', action='store_true', help="use taichi raymarching")
parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=32, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
parser.add_argument('--latent_iter_ratio', type=float, default=0.2, help="training iters that only use albedo shading")
parser.add_argument('--albedo_iter_ratio', type=float, default=0, help="training iters that only use albedo shading")
parser.add_argument('--min_ambient_ratio', type=float, default=0.1, help="minimum ambient ratio to use in lambertian shading")
parser.add_argument('--textureless_ratio', type=float, default=0.2, help="ratio of textureless shading")
parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
parser.add_argument('--jitter_center', type=float, default=0.2, help="amount of jitter to add to sampled camera pose's center (camera location)")
parser.add_argument('--jitter_target', type=float, default=0.2, help="amount of jitter to add to sampled camera pose's target (i.e. 'look-at')")
parser.add_argument('--jitter_up', type=float, default=0.02, help="amount of jitter to add to sampled camera pose's up-axis (i.e. 'camera roll')")
parser.add_argument('--uniform_sphere_rate', type=float, default=0, help="likelihood of sampling camera location uniformly on the sphere surface area")
parser.add_argument('--grad_clip', type=float, default=-1, help="clip grad of all grad to this limit, negative value disables it")
parser.add_argument('--grad_clip_rgb', type=float, default=-1, help="clip grad of rgb space grad to this limit, negative value disables it")
# model options
parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)")
parser.add_argument('--density_activation', type=str, default='exp', choices=['softplus', 'exp'], help="density activation function")
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
parser.add_argument('--blob_density', type=float, default=5, help="max (center) density for the density blob")
parser.add_argument('--blob_radius', type=float, default=0.2, help="control the radius for the density blob")
# network backbone
parser.add_argument('--backbone', type=str, default='grid', choices=['grid_tcnn', 'grid', 'vanilla', 'grid_taichi'], help="nerf backbone")
parser.add_argument('--optim', type=str, default='adan', choices=['adan', 'adam'], help="optimizer")
parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
# try this if CUDA OOM
parser.add_argument('--fp16', action='store_true', help="use float16 for training")
parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage")
# rendering resolution in training, increase these for better quality / decrease these if CUDA OOM even if --vram_O enabled.
parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
parser.add_argument('--known_view_scale', type=float, default=1.5, help="multiply --h/w by this for known view rendering")
parser.add_argument('--known_view_noise_scale', type=float, default=2e-3, help="random camera noise added to rays_o and rays_d")
parser.add_argument('--dmtet_reso_scale', type=float, default=8, help="multiply --h/w by this for dmtet finetuning")
parser.add_argument('--batch_size', type=int, default=1, help="images to render per batch using NeRF")
### dataset options
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--min_near', type=float, default=0.01, help="minimum near distance for camera")
parser.add_argument('--radius_range', type=float, nargs='*', default=[3.0, 3.5], help="training camera radius range")
parser.add_argument('--theta_range', type=float, nargs='*', default=[45, 105], help="training camera range along the polar angles (i.e. up and down). See advanced.md for details.")
parser.add_argument('--phi_range', type=float, nargs='*', default=[-180, 180], help="training camera range along the azimuth angles (i.e. left and right). See advanced.md for details.")
parser.add_argument('--fovy_range', type=float, nargs='*', default=[10, 30], help="training camera fovy range")
parser.add_argument('--default_radius', type=float, default=3.2, help="radius for the default view")
parser.add_argument('--default_polar', type=float, default=90, help="polar for the default view")
parser.add_argument('--default_azimuth', type=float, default=0, help="azimuth for the default view")
parser.add_argument('--default_fovy', type=float, default=20, help="fovy for the default view")
parser.add_argument('--progressive_view', action='store_true', help="progressively expand view sampling range from default to full")
parser.add_argument('--progressive_view_init_ratio', type=float, default=0.2, help="initial ratio of final range, used for progressive_view")
parser.add_argument('--progressive_level', action='store_true', help="progressively increase gridencoder's max_level")
parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region")
parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
parser.add_argument('--t_range', type=float, nargs='*', default=[0.02, 0.98], help="stable diffusion time steps range")
parser.add_argument('--dont_override_stuff',action='store_true', help="Don't override t_range, etc.")
### regularizations
parser.add_argument('--lambda_entropy', type=float, default=1e-3, help="loss scale for alpha entropy")
parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value")
parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
parser.add_argument('--lambda_tv', type=float, default=0, help="loss scale for total variation")
parser.add_argument('--lambda_wd', type=float, default=0, help="loss scale")
parser.add_argument('--lambda_mesh_normal', type=float, default=0.5, help="loss scale for mesh normal smoothness")
parser.add_argument('--lambda_mesh_laplacian', type=float, default=0.5, help="loss scale for mesh laplacian")
parser.add_argument('--lambda_guidance', type=float, default=1, help="loss scale for SDS")
parser.add_argument('--lambda_rgb', type=float, default=1000, help="loss scale for RGB")
parser.add_argument('--lambda_mask', type=float, default=500, help="loss scale for mask (alpha)")
parser.add_argument('--lambda_normal', type=float, default=0, help="loss scale for normal map")
parser.add_argument('--lambda_depth', type=float, default=10, help="loss scale for relative depth")
parser.add_argument('--lambda_2d_normal_smooth', type=float, default=0, help="loss scale for 2D normal image smoothness")
parser.add_argument('--lambda_3d_normal_smooth', type=float, default=0, help="loss scale for 3D normal image smoothness")
### debugging options
parser.add_argument('--save_guidance', action='store_true', help="save images of the per-iteration NeRF renders, added noise, denoised (i.e. guidance), fully-denoised. Useful for debugging, but VERY SLOW and takes lots of memory!")
parser.add_argument('--save_guidance_interval', type=int, default=10, help="save guidance every X step")
### GUI options
parser.add_argument('--gui', action='store_true', help="start a GUI")
parser.add_argument('--W', type=int, default=800, help="GUI width")
parser.add_argument('--H', type=int, default=800, help="GUI height")
parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center")
parser.add_argument('--fovy', type=float, default=20, help="default GUI camera fovy")
parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]")
parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth")
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
parser.add_argument('--zero123_config', type=str, default='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml', help="config file for zero123")
parser.add_argument('--zero123_ckpt', type=str, default='pretrained/zero123/zero123-xl.ckpt', help="ckpt for zero123")
parser.add_argument('--zero123_grad_scale', type=str, default='angle', help="whether to scale the gradients based on 'angle' or 'None'")
parser.add_argument('--dataset_size_train', type=int, default=100, help="Length of train dataset i.e. # of iterations per epoch")
parser.add_argument('--dataset_size_valid', type=int, default=8, help="# of frames to render in the turntable video in validation")
parser.add_argument('--dataset_size_test', type=int, default=100, help="# of frames to render in the turntable video at test time")
parser.add_argument('--exp_start_iter', type=int, default=None, help="start iter # for experiment, to calculate progressive_view and progressive_level")
parser.add_argument('--exp_end_iter', type=int, default=None, help="end iter # for experiment, to calculate progressive_view and progressive_level")
opt = parser.parse_args()
if opt.O:
opt.fp16 = True
opt.cuda_ray = True
elif opt.O2:
opt.fp16 = True
opt.backbone = 'vanilla'
opt.progressive_level = True
if opt.IF:
if 'SD' in opt.guidance:
opt.guidance.remove('SD')
opt.guidance.append('IF')
opt.latent_iter_ratio = 0 # must not do as_latent
opt.images, opt.ref_radii, opt.ref_polars, opt.ref_azimuths, opt.zero123_ws = [], [], [], [], []
opt.default_zero123_w = 1
opt.exp_start_iter = opt.exp_start_iter or 0
opt.exp_end_iter = opt.exp_end_iter or opt.iters
# parameters for image-conditioned generation
if opt.image is not None or opt.image_config is not None:
if opt.text is None:
# use zero123 guidance model when only providing image
opt.guidance = ['zero123']
if not opt.dont_override_stuff:
opt.fovy_range = [opt.default_fovy, opt.default_fovy] # fix fov as zero123 doesn't support changing fov
opt.guidance_scale = 5
opt.lambda_3d_normal_smooth = 10
else:
# use stable-diffusion when providing both text and image
opt.guidance = ['SD', 'clip']
if not opt.dont_override_stuff:
opt.guidance_scale = 10
opt.t_range = [0.2, 0.6]
opt.known_view_interval = 2
opt.lambda_3d_normal_smooth = 20
opt.bg_radius = -1
# smoothness
opt.lambda_entropy = 1
opt.lambda_orient = 1
# latent warmup is not needed
opt.latent_iter_ratio = 0
if not opt.dont_override_stuff:
opt.albedo_iter_ratio = 0
# make shape init more stable
opt.progressive_view = True
opt.progressive_level = True
if opt.image is not None:
opt.images += [opt.image]
opt.ref_radii += [opt.default_radius]
opt.ref_polars += [opt.default_polar]
opt.ref_azimuths += [opt.default_azimuth]
opt.zero123_ws += [opt.default_zero123_w]
if opt.image_config is not None:
# for multiview (zero123)
conf = pd.read_csv(opt.image_config, skipinitialspace=True)
opt.images += list(conf.image)
opt.ref_radii += list(conf.radius)
opt.ref_polars += list(conf.polar)
opt.ref_azimuths += list(conf.azimuth)
opt.zero123_ws += list(conf.zero123_weight)
if opt.image is None:
opt.default_radius = opt.ref_radii[0]
opt.default_polar = opt.ref_polars[0]
opt.default_azimuth = opt.ref_azimuths[0]
opt.default_zero123_w = opt.zero123_ws[0]
# reset to None
if len(opt.images) == 0:
opt.images = None
# default parameters for finetuning
if opt.dmtet:
opt.h = int(opt.h * opt.dmtet_reso_scale)
opt.w = int(opt.w * opt.dmtet_reso_scale)
opt.known_view_scale = 1
if not opt.dont_override_stuff:
opt.t_range = [0.02, 0.50] # ref: magic3D
if opt.images is not None:
opt.lambda_normal = 0
opt.lambda_depth = 0
if opt.text is not None and not opt.dont_override_stuff:
opt.t_range = [0.20, 0.50]
# assume finetuning
opt.latent_iter_ratio = 0
opt.albedo_iter_ratio = 0
opt.progressive_view = False
# opt.progressive_level = False
# record full range for progressive view expansion
if opt.progressive_view:
if not opt.dont_override_stuff:
# disable as they disturb progressive view
opt.jitter_pose = False
opt.uniform_sphere_rate = 0
# back up full range
opt.full_radius_range = opt.radius_range
opt.full_theta_range = opt.theta_range
opt.full_phi_range = opt.phi_range
opt.full_fovy_range = opt.fovy_range
if opt.backbone == 'vanilla':
from nerf.network import NeRFNetwork
elif opt.backbone == 'grid':
from nerf.network_grid import NeRFNetwork
elif opt.backbone == 'grid_tcnn':
from nerf.network_grid_tcnn import NeRFNetwork
elif opt.backbone == 'grid_taichi':
opt.cuda_ray = False
opt.taichi_ray = True
import taichi as ti
from nerf.network_grid_taichi import NeRFNetwork
taichi_half2_opt = True
taichi_init_args = {"arch": ti.cuda, "device_memory_GB": 4.0}
if taichi_half2_opt:
taichi_init_args["half2_vectorization"] = True
ti.init(**taichi_init_args)
else:
raise NotImplementedError(f'--backbone {opt.backbone} is not implemented!')
print(opt)
if opt.seed is not None:
seed_everything(int(opt.seed))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeRFNetwork(opt).to(device)
if opt.dmtet and opt.init_with != '':
if opt.init_with.endswith('.pth'):
# load pretrained weights to init dmtet
state_dict = torch.load(opt.init_with, map_location=device)
model.load_state_dict(state_dict['model'], strict=False)
if opt.cuda_ray:
model.mean_density = state_dict['mean_density']
model.init_tet()
else:
# assume a mesh to init dmtet (experimental, not working well now!)
import trimesh
mesh = trimesh.load(opt.init_with, force='mesh', skip_material=True, process=False)
model.init_tet(mesh=mesh)
print(model)
if opt.six_views:
guidance = None # no need to load guidance model at test
trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset(opt, device=device, type='six_views', H=opt.H, W=opt.W, size=6).dataloader(batch_size=1)
trainer.test(test_loader, write_video=False)
if opt.save_mesh:
trainer.save_mesh()
elif opt.test:
guidance = None # no need to load guidance model at test
trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)
if opt.gui:
from nerf.gui import NeRFGUI
gui = NeRFGUI(opt, trainer)
gui.render()
else:
test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader(batch_size=1)
trainer.test(test_loader)
if opt.save_mesh:
trainer.save_mesh()
else:
train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=opt.dataset_size_train * opt.batch_size).dataloader()
if opt.optim == 'adan':
from optimizer import Adan
# Adan usually requires a larger LR
optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False)
else: # adam
optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
if opt.backbone == 'vanilla':
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
else:
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1) # fixed
# scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
guidance = nn.ModuleDict()
if 'SD' in opt.guidance:
from guidance.sd_utils import StableDiffusion
guidance['SD'] = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key, opt.t_range)
if 'IF' in opt.guidance:
from guidance.if_utils import IF
guidance['IF'] = IF(device, opt.vram_O, opt.t_range)
if 'zero123' in opt.guidance:
from guidance.zero123_utils import Zero123
guidance['zero123'] = Zero123(device=device, fp16=opt.fp16, config=opt.zero123_config, ckpt=opt.zero123_ckpt, vram_O=opt.vram_O, t_range=opt.t_range, opt=opt)
if 'clip' in opt.guidance:
from guidance.clip_utils import CLIP
guidance['clip'] = CLIP(device)
trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, scheduler_update_every_step=True)
trainer.default_view_data = train_loader._data.get_default_view_data()
if opt.gui:
from nerf.gui import NeRFGUI
gui = NeRFGUI(opt, trainer, train_loader)
gui.render()
else:
valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=opt.dataset_size_valid).dataloader(batch_size=1)
test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader(batch_size=1)
max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
trainer.train(train_loader, valid_loader, test_loader, max_epoch)
if opt.save_mesh:
trainer.save_mesh()
================================================
FILE: meshutils.py
================================================
import numpy as np
import pymeshlab as pml
def poisson_mesh_reconstruction(points, normals=None):
# points/normals: [N, 3] np.ndarray
import open3d as o3d
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
# outlier removal
pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)
# normals
if normals is None:
pcd.estimate_normals()
else:
pcd.normals = o3d.utility.Vector3dVector(normals[ind])
# visualize
o3d.visualization.draw_geometries([pcd], point_show_normal=False)
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9)
vertices_to_remove = densities < np.quantile(densities, 0.1)
mesh.remove_vertices_by_mask(vertices_to_remove)
# visualize
o3d.visualization.draw_geometries([mesh])
vertices = np.asarray(mesh.vertices)
triangles = np.asarray(mesh.triangles)
print(f'[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}')
return vertices, triangles
def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True):
# optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.
_ori_vert_shape = verts.shape
_ori_face_shape = faces.shape
if backend == 'pyfqmr':
import pyfqmr
solver = pyfqmr.Simplify()
solver.setMesh(verts, faces)
solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
verts, faces, normals = solver.getMesh()
else:
m = pml.Mesh(verts, faces)
ms = pml.MeshSet()
ms.add_mesh(m, 'mesh') # will copy!
# filters
# ms.meshing_decimation_clustering(threshold=pml.Percentage(1))
ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement)
if remesh:
# ms.apply_coord_taubin_smoothing()
ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1))
# extract mesh
m = ms.current_mesh()
verts = m.vertex_matrix()
faces = m.face_matrix()
print(f'[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
return verts, faces
def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=True, remesh_size=0.01):
# verts: [N, 3]
# faces: [N, 3]
_ori_vert_shape = verts.shape
_ori_face_shape = faces.shape
m = pml.Mesh(verts, faces)
ms = pml.MeshSet()
ms.add_mesh(m, 'mesh') # will copy!
# filters
ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
if v_pct > 0:
ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal
ms.meshing_remove_duplicate_faces() # faces defined by the same verts
ms.meshing_remove_null_faces() # faces with area == 0
if min_d > 0:
ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d))
if min_f > 0:
ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
if repair:
# ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
ms.meshing_repair_non_manifold_edges(method=0)
ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
if remesh:
# ms.apply_coord_taubin_smoothing()
ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.AbsoluteValue(remesh_size))
# extract mesh
m = ms.current_mesh()
verts = m.vertex_matrix()
faces = m.face_matrix()
print(f'[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
return verts, faces
================================================
FILE: nerf/gui.py
================================================
import math
import torch
import numpy as np
import dearpygui.dearpygui as dpg
from scipy.spatial.transform import Rotation as R
from nerf.utils import *
class OrbitCamera:
def __init__(self, W, H, r=2, fovy=60):
self.W = W
self.H = H
self.radius = r # camera distance from center
self.fovy = fovy # in degree
self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
self.rot = R.from_matrix(np.eye(3))
self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
self.near = 0.001
self.far = 1000
# pose
@property
def pose(self):
# first move camera to radius
res = np.eye(4, dtype=np.float32)
res[2, 3] = self.radius
# rotate
rot = np.eye(4, dtype=np.float32)
rot[:3, :3] = self.rot.as_matrix()
res = rot @ res
# translate
res[:3, 3] -= self.center
return res
# intrinsics
@property
def intrinsics(self):
focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
return np.array([focal, focal, self.W // 2, self.H // 2])
@property
def mvp(self):
focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
projection = np.array([
[2*focal/self.W, 0, 0, 0],
[0, -2*focal/self.H, 0, 0],
[0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
[0, 0, -1, 0]
], dtype=np.float32)
return projection @ np.linalg.inv(self.pose) # [4, 4]
def orbit(self, dx, dy):
# rotate along camera up/side axis!
side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.
rotvec_x = self.up * np.deg2rad(-0.1 * dx)
rotvec_y = side * np.deg2rad(-0.1 * dy)
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
def scale(self, delta):
self.radius *= 1.1 ** (-delta)
def pan(self, dx, dy, dz=0):
# pan in camera coordinate system (careful on the sensitivity!)
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, -dy, dz])
class NeRFGUI:
def __init__(self, opt, trainer, loader=None, debug=True):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
self.H = opt.H
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
self.debug = debug
self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg
self.training = False
self.step = 0 # training step
self.trainer = trainer
self.loader = loader
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # camera moved, should reset accumulation
self.spp = 1 # sample per pixel
self.light_dir = np.array([opt.light_theta, opt.light_phi])
self.ambient_ratio = 1.0
self.mode = 'image' # choose from ['image', 'depth']
self.shading = 'albedo'
self.dynamic_resolution = True if not self.opt.dmtet else False
self.downscale = 1
self.train_steps = 16
dpg.create_context()
self.register_dpg()
self.test_step()
def __del__(self):
dpg.destroy_context()
def train_step(self):
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
outputs = self.trainer.train_gui(self.loader, step=self.train_steps)
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
self.step += self.train_steps
self.need_update = True
dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')
# dynamic train steps
# max allowed train time per-frame is 500 ms
full_t = t / self.train_steps * 16
train_steps = min(16, max(4, int(16 * 500 / full_t)))
if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
self.train_steps = train_steps
def prepare_buffer(self, outputs):
if self.mode == 'image':
return outputs['image'].astype(np.float32)
else:
depth = outputs['depth'].astype(np.float32)
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6)
return np.expand_dims(depth, -1).repeat(3, -1)
def test_step(self):
if self.need_update or self.spp < self.opt.max_spp:
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.cam.mvp, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading)
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
# update dynamic resolution
if self.dynamic_resolution:
# max allowed infer time per-frame is 200 ms
full_t = t / (self.downscale ** 2)
downscale = min(1, max(1/4, math.sqrt(200 / full_t)))
if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:
self.downscale = downscale
if self.need_update:
self.render_buffer = self.prepare_buffer(outputs)
self.spp = 1
self.need_update = False
else:
self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
self.spp += 1
dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
dpg.set_value("_log_spp", self.spp)
dpg.set_value("_texture", self.render_buffer)
def register_dpg(self):
### register texture
with dpg.texture_registry(show=False):
dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
### register window
# the rendered image, as the primary window
with dpg.window(tag="_primary_window", width=self.W, height=self.H):
# add the texture
dpg.add_image("_texture")
dpg.set_primary_window("_primary_window", True)
# control window
with dpg.window(label="Control", tag="_control_window", width=400, height=300):
# text prompt
if self.opt.text is not None:
dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text")
if self.opt.negative != '':
dpg.add_text("negative text: " + self.opt.negative, tag="_log_prompt_negative_text")
# button theme
with dpg.theme() as theme_button:
with dpg.theme_component(dpg.mvButton):
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
# time
if not self.opt.test:
with dpg.group(horizontal=True):
dpg.add_text("Train time: ")
dpg.add_text("no data", tag="_log_train_time")
with dpg.group(horizontal=True):
dpg.add_text("Infer time: ")
dpg.add_text("no data", tag="_log_infer_time")
with dpg.group(horizontal=True):
dpg.add_text("SPP: ")
dpg.add_text("1", tag="_log_spp")
# train button
if not self.opt.test:
with dpg.collapsing_header(label="Train", default_open=True):
with dpg.group(horizontal=True):
dpg.add_text("Train: ")
def callback_train(sender, app_data):
if self.training:
self.training = False
dpg.configure_item("_button_train", label="start")
else:
self.training = True
dpg.configure_item("_button_train", label="stop")
dpg.add_button(label="start", tag="_button_train", callback=callback_train)
dpg.bind_item_theme("_button_train", theme_button)
def callback_reset(sender, app_data):
@torch.no_grad()
def weight_reset(m: nn.Module):
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()
self.trainer.model.apply(fn=weight_reset)
self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter
self.need_update = True
dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset)
dpg.bind_item_theme("_button_reset", theme_button)
with dpg.group(horizontal=True):
dpg.add_text("Checkpoint: ")
def callback_save(sender, app_data):
self.trainer.save_checkpoint(full=True, best=False)
dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1]))
self.trainer.epoch += 1 # use epoch to indicate different calls.
dpg.add_button(label="save", tag="_button_save", callback=callback_save)
dpg.bind_item_theme("_button_save", theme_button)
dpg.add_text("", tag="_log_ckpt")
# save mesh
with dpg.group(horizontal=True):
dpg.add_text("Marching Cubes: ")
def callback_mesh(sender, app_data):
self.trainer.save_mesh()
dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply')
self.trainer.epoch += 1 # use epoch to indicate different calls.
dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh)
dpg.bind_item_theme("_button_mesh", theme_button)
dpg.add_text("", tag="_log_mesh")
with dpg.group(horizontal=True):
dpg.add_text("", tag="_log_train_log")
# rendering options
with dpg.collapsing_header(label="Options", default_open=True):
# dynamic rendering resolution
with dpg.group(horizontal=True):
def callback_set_dynamic_resolution(sender, app_data):
if self.dynamic_resolution:
self.dynamic_resolution = False
self.downscale = 1
else:
self.dynamic_resolution = True
self.need_update = True
dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")
# mode combo
def callback_change_mode(sender, app_data):
self.mode = app_data
self.need_update = True
dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)
# bg_color picker
def callback_change_bg(sender, app_data):
self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
self.need_update = True
dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg)
# fov slider
def callback_set_fovy(sender, app_data):
self.cam.fovy = app_data
self.need_update = True
dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy)
# dt_gamma slider
def callback_set_dt_gamma(sender, app_data):
self.opt.dt_gamma = app_data
self.need_update = True
dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)
# max_steps slider
def callback_set_max_steps(sender, app_data):
self.opt.max_steps = app_data
self.need_update = True
dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)
# aabb slider
def callback_set_aabb(sender, app_data, user_data):
# user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
self.trainer.model.aabb_infer[user_data] = app_data
# also change train aabb ? [better not...]
#self.trainer.model.aabb_train[user_data] = app_data
self.need_update = True
dpg.add_separator()
dpg.add_text("Axis-aligned bounding box:")
with dpg.group(horizontal=True):
dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)
with dpg.group(horizontal=True):
dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)
with dpg.group(horizontal=True):
dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)
# light dir
def callback_set_light_dir(sender, app_data, user_data):
self.light_dir[user_data] = app_data
self.need_update = True
dpg.add_separator()
dpg.add_text("Plane Light Direction:")
with dpg.group(horizontal=True):
dpg.add_slider_float(label="theta", min_value=0, max_value=180, format="%.2f", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0)
with dpg.group(horizontal=True):
dpg.add_slider_float(label="phi", min_value=0, max_value=360, format="%.2f", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1)
# ambient ratio
def callback_set_abm_ratio(sender, app_data):
self.ambient_ratio = app_data
self.need_update = True
dpg.add_slider_float(label="ambient", min_value=0, max_value=1.0, format="%.5f", default_value=self.ambient_ratio, callback=callback_set_abm_ratio)
# shading mode
def callback_change_shading(sender, app_data):
self.shading = app_data
self.need_update = True
dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading)
# debug info
if self.debug:
with dpg.collapsing_header(label="Debug"):
# pose
dpg.add_separator()
dpg.add_text("Camera Pose:")
dpg.add_text(str(self.cam.pose), tag="_log_pose")
### register camera handler
def callback_camera_drag_rotate(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.orbit(dx, dy)
self.need_update = True
if self.debug:
dpg.set_value("_log_pose", str(self.cam.pose))
def callback_camera_wheel_scale(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
delta = app_data
self.cam.scale(delta)
self.need_update = True
if self.debug:
dpg.set_value("_log_pose", str(self.cam.pose))
def callback_camera_drag_pan(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.pan(dx, dy)
self.need_update = True
if self.debug:
dpg.set_value("_log_pose", str(self.cam.pose))
with dpg.handler_registry():
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Right, callback=callback_camera_drag_pan)
dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False)
# TODO: seems dearpygui doesn't support resizing texture...
# def callback_resize(sender, app_data):
# self.W = app_data[0]
# self.H = app_data[1]
# # how to reload texture ???
# dpg.set_viewport_resize_callback(callback_resize)
### global theme
with dpg.theme() as theme_no_padding:
with dpg.theme_component(dpg.mvAll):
# set all padding to 0 to avoid scroll bar
dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
dpg.bind_item_theme("_primary_window", theme_no_padding)
dpg.setup_dearpygui()
#dpg.show_metrics()
dpg.show_viewport()
def render(self):
while dpg.is_dearpygui_running():
# update texture every frame
if self.training:
self.train_step()
self.test_step()
dpg.render_dearpygui_frame()
================================================
FILE: nerf/network.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from activation import trunc_exp
from .renderer import NeRFRenderer
import numpy as np
from encoding import get_encoder
from .utils import safe_normalize
# TODO: not sure about the details...
class ResBlock(nn.Module):
def __init__(self, dim_in, dim_out, bias=True):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias)
self.norm = nn.LayerNorm(self.dim_out)
self.activation = nn.SiLU(inplace=True)
if self.dim_in != self.dim_out:
self.skip = nn.Linear(self.dim_in, self.dim_out, bias=False)
else:
self.skip = None
def forward(self, x):
# x: [B, C]
identity = x
out = self.dense(x)
out = self.norm(out)
if self.skip is not None:
identity = self.skip(identity)
out += identity
out = self.activation(out)
return out
class BasicBlock(nn.Module):
def __init__(self, dim_in, dim_out, bias=True):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
# x: [B, C]
out = self.dense(x)
out = self.activation(out)
return out
class MLP(nn.Module):
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True, block=BasicBlock):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dim_hidden = dim_hidden
self.num_layers = num_layers
net = []
for l in range(num_layers):
if l == 0:
net.append(BasicBlock(self.dim_in, self.dim_hidden, bias=bias))
elif l != num_layers - 1:
net.append(block(self.dim_hidden, self.dim_hidden, bias=bias))
else:
net.append(nn.Linear(self.dim_hidden, self.dim_out, bias=bias))
self.net = nn.ModuleList(net)
def forward(self, x):
for l in range(self.num_layers):
x = self.net[l](x)
return x
class NeRFNetwork(NeRFRenderer):
def __init__(self,
opt,
num_layers=5, # 5 in paper
hidden_dim=64, # 128 in paper
num_layers_bg=2, # 3 in paper
hidden_dim_bg=32, # 64 in paper
encoding='frequency_torch', # pure pytorch
):
super().__init__(opt)
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.encoder, self.in_dim = get_encoder(encoding, input_dim=3, multires=12)
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True, block=ResBlock)
self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else F.softplus
# background network
if self.opt.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
self.encoder_bg, self.in_dim_bg = get_encoder(encoding, input_dim=3, multires=4)
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
else:
self.bg_net = None
def common_forward(self, x):
# x: [N, 3], in [-bound, bound]
# sigma
enc = self.encoder(x, bound=self.bound, max_level=self.max_level)
h = self.sigma_net(enc)
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
albedo = torch.sigmoid(h[..., 1:])
return sigma, albedo
# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
def finite_difference_normal(self, x, epsilon=1e-2):
# x: [N, 3]
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
normal = torch.stack([
0.5 * (dx_pos - dx_neg) / epsilon,
0.5 * (dy_pos - dy_neg) / epsilon,
0.5 * (dz_pos - dz_neg) / epsilon
], dim=-1)
return -normal
def normal(self, x):
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x.requires_grad_(True)
sigma, albedo = self.common_forward(x)
# query gradient
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
# normal = self.finite_difference_normal(x)
normal = safe_normalize(normal)
normal = torch.nan_to_num(normal)
return normal
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
# x: [N, 3], in [-bound, bound]
# d: [N, 3], view direction, nomalized in [-1, 1]
# l: [3], plane light direction, nomalized in [-1, 1]
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
if shading == 'albedo':
# no need to query normal
sigma, color = self.common_forward(x)
normal = None
else:
# query normal
# sigma, albedo = self.common_forward(x)
# normal = self.normal(x)
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x.requires_grad_(True)
sigma, albedo = self.common_forward(x)
# query gradient
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
normal = safe_normalize(normal)
normal = torch.nan_to_num(normal)
# lambertian shading
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
if shading == 'textureless':
color = lambertian.unsqueeze(-1).repeat(1, 3)
elif shading == 'normal':
color = (normal + 1) / 2
else: # 'lambertian'
color = albedo * lambertian.unsqueeze(-1)
return sigma, color, normal
def density(self, x):
# x: [N, 3], in [-bound, bound]
sigma, albedo = self.common_forward(x)
return {
'sigma': sigma,
'albedo': albedo,
}
def background(self, d):
h = self.encoder_bg(d) # [N, C]
h = self.bg_net(h)
# sigmoid activation for rgb
rgbs = torch.sigmoid(h)
return rgbs
# optimizer utils
def get_params(self, lr):
params = [
# {'params': self.encoder.parameters(), 'lr': lr * 10},
{'params': self.sigma_net.parameters(), 'lr': lr},
]
if self.opt.bg_radius > 0:
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
params.append({'params': self.bg_net.parameters(), 'lr': lr})
if self.opt.dmtet and not self.opt.lock_geo:
params.append({'params': self.sdf, 'lr': lr})
params.append({'params': self.deform, 'lr': lr})
return params
================================================
FILE: nerf/network_grid.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from activation import trunc_exp, biased_softplus
from .renderer import NeRFRenderer
import numpy as np
from encoding import get_encoder
from .utils import safe_normalize
class MLP(nn.Module):
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dim_hidden = dim_hidden
self.num_layers = num_layers
net = []
for l in range(num_layers):
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
self.net = nn.ModuleList(net)
def forward(self, x):
for l in range(self.num_layers):
x = self.net[l](x)
if l != self.num_layers - 1:
x = F.relu(x, inplace=True)
return x
class NeRFNetwork(NeRFRenderer):
def __init__(self,
opt,
num_layers=3,
hidden_dim=64,
num_layers_bg=2,
hidden_dim_bg=32,
):
super().__init__(opt)
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
# self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)
self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus
# background network
if self.opt.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
# use a very simple network to avoid it learning the prompt...
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
else:
self.bg_net = None
def common_forward(self, x):
# sigma
enc = self.encoder(x, bound=self.bound, max_level=self.max_level)
h = self.sigma_net(enc)
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
albedo = torch.sigmoid(h[..., 1:])
return sigma, albedo
# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
def finite_difference_normal(self, x, epsilon=1e-2):
# x: [N, 3]
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
normal = torch.stack([
0.5 * (dx_pos - dx_neg) / epsilon,
0.5 * (dy_pos - dy_neg) / epsilon,
0.5 * (dz_pos - dz_neg) / epsilon
], dim=-1)
return -normal
def normal(self, x):
normal = self.finite_difference_normal(x)
normal = safe_normalize(normal)
normal = torch.nan_to_num(normal)
return normal
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
# x: [N, 3], in [-bound, bound]
# d: [N, 3], view direction, nomalized in [-1, 1]
# l: [3], plane light direction, nomalized in [-1, 1]
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
sigma, albedo = self.common_forward(x)
if shading == 'albedo':
normal = None
color = albedo
else: # lambertian shading
# normal = self.normal_net(enc)
normal = self.normal(x)
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
if shading == 'textureless':
color = lambertian.unsqueeze(-1).repeat(1, 3)
elif shading == 'normal':
color = (normal + 1) / 2
else: # 'lambertian'
color = albedo * lambertian.unsqueeze(-1)
return sigma, color, normal
def density(self, x):
# x: [N, 3], in [-bound, bound]
sigma, albedo = self.common_forward(x)
return {
'sigma': sigma,
'albedo': albedo,
}
def background(self, d):
h = self.encoder_bg(d) # [N, C]
h = self.bg_net(h)
# sigmoid activation for rgb
rgbs = torch.sigmoid(h)
return rgbs
# optimizer utils
def get_params(self, lr):
params = [
{'params': self.encoder.parameters(), 'lr': lr * 10},
{'params': self.sigma_net.parameters(), 'lr': lr},
# {'params': self.normal_net.parameters(), 'lr': lr},
]
if self.opt.bg_radius > 0:
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
params.append({'params': self.bg_net.parameters(), 'lr': lr})
if self.opt.dmtet and not self.opt.lock_geo:
params.append({'params': self.sdf, 'lr': lr})
params.append({'params': self.deform, 'lr': lr})
return params
================================================
FILE: nerf/network_grid_taichi.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from activation import trunc_exp
from .renderer import NeRFRenderer
import numpy as np
from encoding import get_encoder
from .utils import safe_normalize
class MLP(nn.Module):
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dim_hidden = dim_hidden
self.num_layers = num_layers
net = []
for l in range(num_layers):
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
self.net = nn.ModuleList(net)
def forward(self, x):
for l in range(self.num_layers):
x = self.net[l](x)
if l != self.num_layers - 1:
x = F.relu(x, inplace=True)
return x
class NeRFNetwork(NeRFRenderer):
def __init__(self,
opt,
num_layers=2,
hidden_dim=32,
num_layers_bg=2,
hidden_dim_bg=16,
):
super().__init__(opt)
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.encoder, self.in_dim = get_encoder('hashgrid_taichi', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
# self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)
self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else F.softplus
# background network
if self.opt.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
# use a very simple network to avoid it learning the prompt...
self.encoder_bg, self.in_dim_bg = get_encoder('frequency_torch', input_dim=3, multires=4) # TODO: freq encoder can be replaced by a Taichi implementation
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
else:
self.bg_net = None
def common_forward(self, x):
# sigma
enc = self.encoder(x, bound=self.bound)
h = self.sigma_net(enc)
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
albedo = torch.sigmoid(h[..., 1:])
return sigma, albedo
# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
def finite_difference_normal(self, x, epsilon=1e-2):
# x: [N, 3]
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
normal = torch.stack([
0.5 * (dx_pos - dx_neg) / epsilon,
0.5 * (dy_pos - dy_neg) / epsilon,
0.5 * (dz_pos - dz_neg) / epsilon
], dim=-1)
return -normal
def normal(self, x):
normal = self.finite_difference_normal(x)
normal = safe_normalize(normal)
normal = torch.nan_to_num(normal)
return normal
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
# x: [N, 3], in [-bound, bound]
# d: [N, 3], view direction, nomalized in [-1, 1]
# l: [3], plane light direction, nomalized in [-1, 1]
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
sigma, albedo = self.common_forward(x)
if shading == 'albedo':
normal = None
color = albedo
else: # lambertian shading
# normal = self.normal_net(enc)
normal = self.normal(x)
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
if shading == 'textureless':
color = lambertian.unsqueeze(-1).repeat(1, 3)
elif shading == 'normal':
color = (normal + 1) / 2
else: # 'lambertian'
color = albedo * lambertian.unsqueeze(-1)
return sigma, color, normal
def density(self, x):
# x: [N, 3], in [-bound, bound]
sigma, albedo = self.common_forward(x)
return {
'sigma': sigma,
'albedo': albedo,
}
def background(self, d):
h = self.encoder_bg(d) # [N, C]
h = self.bg_net(h)
# sigmoid activation for rgb
rgbs = torch.sigmoid(h)
return rgbs
# optimizer utils
def get_params(self, lr):
params = [
{'params': self.encoder.parameters(), 'lr': lr * 10},
{'params': self.sigma_net.parameters(), 'lr': lr},
# {'params': self.normal_net.parameters(), 'lr': lr},
]
if self.opt.bg_radius > 0:
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
params.append({'params': self.bg_net.parameters(), 'lr': lr})
if self.opt.dmtet and not self.opt.lock_geo:
params.append({'params': self.sdf, 'lr': lr})
params.append({'params': self.deform, 'lr': lr})
return params
================================================
FILE: nerf/network_grid_tcnn.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from activation import trunc_exp, biased_softplus
from .renderer import NeRFRenderer
import numpy as np
from encoding import get_encoder
from .utils import safe_normalize
import tinycudann as tcnn
class MLP(nn.Module):
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dim_hidden = dim_hidden
self.num_layers = num_layers
net = []
for l in range(num_layers):
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
self.net = nn.ModuleList(net)
def forward(self, x):
for l in range(self.num_layers):
x = self.net[l](x)
if l != self.num_layers - 1:
x = F.relu(x, inplace=True)
return x
class NeRFNetwork(NeRFRenderer):
def __init__(self,
opt,
num_layers=3,
hidden_dim=64,
num_layers_bg=2,
hidden_dim_bg=32,
):
super().__init__(opt)
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.encoder = tcnn.Encoding(
n_input_dims=3,
encoding_config={
"otype": "HashGrid",
"n_levels": 16,
"n_features_per_level": 2,
"log2_hashmap_size": 19,
"base_resolution": 16,
"interpolation": "Smoothstep",
"per_level_scale": np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1)),
},
dtype=torch.float32, # ENHANCE: default float16 seems unstable...
)
self.in_dim = self.encoder.n_output_dims
# use torch MLP, as tcnn MLP doesn't impl second-order derivative
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus
# background network
if self.opt.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
# use a very simple network to avoid it learning the prompt...
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
else:
self.bg_net = None
def common_forward(self, x):
# sigma
enc = self.encoder((x + self.bound) / (2 * self.bound)).float()
h = self.sigma_net(enc)
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
albedo = torch.sigmoid(h[..., 1:])
return sigma, albedo
def normal(self, x):
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x.requires_grad_(True)
sigma, albedo = self.common_forward(x)
# query gradient
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
# normal = self.finite_difference_normal(x)
normal = safe_normalize(normal)
normal = torch.nan_to_num(normal)
return normal
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
# x: [N, 3], in [-bound, bound]
# d: [N, 3], view direction, nomalized in [-1, 1]
# l: [3], plane light direction, nomalized in [-1, 1]
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
if shading == 'albedo':
sigma, albedo = self.common_forward(x)
normal = None
color = albedo
else: # lambertian shading
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x.requires_grad_(True)
sigma, albedo = self.common_forward(x)
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
normal = safe_normalize(normal)
normal = torch.nan_to_num(normal)
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
if shading == 'textureless':
color = lambertian.unsqueeze(-1).repeat(1, 3)
elif shading == 'normal':
color = (normal + 1) / 2
else: # 'lambertian'
color = albedo * lambertian.unsqueeze(-1)
return sigma, color, normal
def density(self, x):
# x: [N, 3], in [-bound, bound]
sigma, albedo = self.common_forward(x)
return {
'sigma': sigma,
'albedo': albedo,
}
def background(self, d):
h = self.encoder_bg(d) # [N, C]
h = self.bg_net(h)
# sigmoid activation for rgb
rgbs = torch.sigmoid(h)
return rgbs
# optimizer utils
def get_params(self, lr):
params = [
{'params': self.encoder.parameters(), 'lr': lr * 10},
{'params': self.sigma_net.parameters(), 'lr': lr},
]
if self.opt.bg_radius > 0:
params.append({'params': self.bg_net.parameters(), 'lr': lr})
if self.opt.dmtet and not self.opt.lock_geo:
params.append({'params': self.sdf, 'lr': lr})
params.append({'params': self.deform, 'lr': lr})
return params
================================================
FILE: nerf/provider.py
================================================
import os
import cv2
import glob
import json
import tqdm
import random
import numpy as np
from scipy.spatial.transform import Slerp, Rotation
import trimesh
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from .utils import get_rays, safe_normalize
DIR_COLORS = np.array([
[255, 0, 0, 255], # front
[0, 255, 0, 255], # side
[0, 0, 255, 255], # back
[255, 255, 0, 255], # side
[255, 0, 255, 255], # overhead
[0, 255, 255, 255], # bottom
], dtype=np.uint8)
def visualize_poses(poses, dirs, size=0.1):
# poses: [B, 4, 4], dirs: [B]
axes = trimesh.creation.axis(axis_length=4)
sphere = trimesh.creation.icosphere(radius=1)
objects = [axes, sphere]
for pose, dir in zip(poses, dirs):
# a camera is visualized with 8 line segments.
pos = pose[:3, 3]
a = pos + size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2]
b = pos - size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2]
c = pos - size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2]
d = pos + size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2]
segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]])
segs = trimesh.load_path(segs)
# different color for different dirs
segs.colors = DIR_COLORS[[dir]].repeat(len(segs.entities), 0)
objects.append(segs)
trimesh.Scene(objects).show()
def get_view_direction(thetas, phis, overhead, front):
# phis: [B,]; thetas: [B,]
# front = 0 [-front/2, front/2)
# side (cam left) = 1 [front/2, 180-front/2)
# back = 2 [180-front/2, 180+front/2)
# side (cam right) = 3 [180+front/2, 360-front/2)
# top = 4 [0, overhead]
# bottom = 5 [180-overhead, 180]
res = torch.zeros(thetas.shape[0], dtype=torch.long)
# first determine by phis
phis = phis % (2 * np.pi)
res[(phis < front / 2) | (phis >= 2 * np.pi - front / 2)] = 0
res[(phis >= front / 2) & (phis < np.pi - front / 2)] = 1
res[(phis >= np.pi - front / 2) & (phis < np.pi + front / 2)] = 2
res[(phis >= np.pi + front / 2) & (phis < 2 * np.pi - front / 2)] = 3
# override by thetas
res[thetas <= overhead] = 4
res[thetas >= (np.pi - overhead)] = 5
return res
def rand_poses(size, device, opt, radius_range=[1, 1.5], theta_range=[0, 120], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, uniform_sphere_rate=0.5):
''' generate random poses from an orbit camera
Args:
size: batch size of generated poses.
device: where to allocate the output.
radius: camera radius
theta_range: [min, max], should be in [0, pi]
phi_range: [min, max], should be in [0, 2 * pi]
Return:
poses: [size, 4, 4]
'''
theta_range = np.array(theta_range) / 180 * np.pi
phi_range = np.array(phi_range) / 180 * np.pi
angle_overhead = angle_overhead / 180 * np.pi
angle_front = angle_front / 180 * np.pi
radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
if random.random() < uniform_sphere_rate:
unit_centers = F.normalize(
torch.stack([
torch.randn(size, device=device),
torch.abs(torch.randn(size, device=device)),
torch.randn(size, device=device),
], dim=-1), p=2, dim=1
)
thetas = torch.acos(unit_centers[:,1])
phis = torch.atan2(unit_centers[:,0], unit_centers[:,2])
phis[phis < 0] += 2 * np.pi
centers = unit_centers * radius.unsqueeze(-1)
else:
thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
phis[phis < 0] += 2 * np.pi
centers = torch.stack([
radius * torch.sin(thetas) * torch.sin(phis),
radius * torch.cos(thetas),
radius * torch.sin(thetas) * torch.cos(phis),
], dim=-1) # [B, 3]
targets = 0
# jitters
if opt.jitter_pose:
jit_center = opt.jitter_center # 0.015 # was 0.2
jit_target = opt.jitter_target
centers += torch.rand_like(centers) * jit_center - jit_center/2.0
targets += torch.randn_like(centers) * jit_target
# lookat
forward_vector = safe_normalize(centers - targets)
up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
if opt.jitter_pose:
up_noise = torch.randn_like(up_vector) * opt.jitter_up
else:
up_noise = 0
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
poses[:, :3, 3] = centers
if return_dirs:
dirs = get_view_direction(thetas, phis, angle_overhead, angle_front)
else:
dirs = None
# back to degree
thetas = thetas / np.pi * 180
phis = phis / np.pi * 180
return poses, dirs, thetas, phis, radius
def circle_poses(device, radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), return_dirs=False, angle_overhead=30, angle_front=60):
theta = theta / 180 * np.pi
phi = phi / 180 * np.pi
angle_overhead = angle_overhead / 180 * np.pi
angle_front = angle_front / 180 * np.pi
centers = torch.stack([
radius * torch.sin(theta) * torch.sin(phi),
radius * torch.cos(theta),
radius * torch.sin(theta) * torch.cos(phi),
], dim=-1) # [B, 3]
# lookat
forward_vector = safe_normalize(centers)
up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(len(centers), 1)
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(len(centers), 1, 1)
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
poses[:, :3, 3] = centers
if return_dirs:
dirs = get_view_direction(theta, phi, angle_overhead, angle_front)
else:
dirs = None
return poses, dirs
class NeRFDataset:
def __init__(self, opt, device, type='train', H=256, W=256, size=100):
super().__init__()
self.opt = opt
self.device = device
self.type = type # train, val, test
self.H = H
self.W = W
self.size = size
self.training = self.type in ['train', 'all']
self.cx = self.H / 2
self.cy = self.W / 2
self.near = self.opt.min_near
self.far = 1000 # infinite
# [debug] visualize poses
# poses, dirs, _, _, _ = rand_poses(100, self.device, opt, radius_range=self.opt.radius_range, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose, uniform_sphere_rate=1)
# visualize_poses(poses.detach().cpu().numpy(), dirs.detach().cpu().numpy())
def get_default_view_data(self):
H = int(self.opt.known_view_scale * self.H)
W = int(self.opt.known_view_scale * self.W)
cx = H / 2
cy = W / 2
radii = torch.FloatTensor(self.opt.ref_radii).to(self.device)
thetas = torch.FloatTensor(self.opt.ref_polars).to(self.device)
phis = torch.FloatTensor(self.opt.ref_azimuths).to(self.device)
poses, dirs = circle_poses(self.device, radius=radii, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
fov = self.opt.default_fovy
focal = H / (2 * np.tan(np.deg2rad(fov) / 2))
intrinsics = np.array([focal, focal, cx, cy])
projection = torch.tensor([
[2*focal/W, 0, 0, 0],
[0, -2*focal/H, 0, 0],
[0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
[0, 0, -1, 0]
], dtype=torch.float32, device=self.device).unsqueeze(0).repeat(len(radii), 1, 1)
mvp = projection @ torch.inverse(poses) # [B, 4, 4]
# sample a low-resolution but full image
rays = get_rays(poses, intrinsics, H, W, -1)
data = {
'H': H,
'W': W,
'rays_o': rays['rays_o'],
'rays_d': rays['rays_d'],
'dir': dirs,
'mvp': mvp,
'polar': self.opt.ref_polars,
'azimuth': self.opt.ref_azimuths,
'radius': self.opt.ref_radii,
}
return data
def collate(self, index):
B = len(index)
if self.training:
# random pose on the fly
poses, dirs, thetas, phis, radius = rand_poses(B, self.device, self.opt, radius_range=self.opt.radius_range, theta_range=self.opt.theta_range, phi_range=self.opt.phi_range, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, uniform_sphere_rate=self.opt.uniform_sphere_rate)
# random focal
fov = random.random() * (self.opt.fovy_range[1] - self.opt.fovy_range[0]) + self.opt.fovy_range[0]
elif self.type == 'six_views':
# six views
thetas_six = [90, 90, 90, 90, 1e-3, 179.999]
phis_six = [ 0, 90, 180, -90, 0, 0]
thetas = torch.FloatTensor([thetas_six[index[0]]]).to(self.device)
phis = torch.FloatTensor([phis_six[index[0]]]).to(self.device)
radius = torch.FloatTensor([self.opt.default_radius]).to(self.device)
poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
# fixed focal
fov = self.opt.default_fovy
else:
# circle pose
thetas = torch.FloatTensor([self.opt.default_polar]).to(self.device)
phis = torch.FloatTensor([(index[0] / self.size) * 360]).to(self.device)
radius = torch.FloatTensor([self.opt.default_radius]).to(self.device)
poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
# fixed focal
fov = self.opt.default_fovy
focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2))
intrinsics = np.array([focal, focal, self.cx, self.cy])
projection = torch.tensor([
[2*focal/self.W, 0, 0, 0],
[0, -2*focal/self.H, 0, 0],
[0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
[0, 0, -1, 0]
], dtype=torch.float32, device=self.device).unsqueeze(0)
mvp = projection @ torch.inverse(poses) # [1, 4, 4]
# sample a low-resolution but full image
rays = get_rays(poses, intrinsics, self.H, self.W, -1)
# delta polar/azimuth/radius to default view
delta_polar = thetas - self.opt.default_polar
delta_azimuth = phis - self.opt.default_azimuth
delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]
delta_radius = radius - self.opt.default_radius
data = {
'H': self.H,
'W': self.W,
'rays_o': rays['rays_o'],
'rays_d': rays['rays_d'],
'dir': dirs,
'mvp': mvp,
'polar': delta_polar,
'azimuth': delta_azimuth,
'radius': delta_radius,
}
return data
def dataloader(self, batch_size=None):
batch_size = batch_size or self.opt.batch_size
loader = DataLoader(list(range(self.size)), batch_size=batch_size, collate_fn=self.collate, shuffle=self.training, num_workers=0)
loader._data = self
return loader
================================================
FILE: nerf/renderer.py
================================================
import os
import math
import cv2
import trimesh
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import nvdiffrast.torch as dr
import mcubes
import raymarching
from meshutils import decimate_mesh, clean_mesh, poisson_mesh_reconstruction
from .utils import custom_meshgrid, safe_normalize
def sample_pdf(bins, weights, n_samples, det=False):
# This implementation is from NeRF
# bins: [B, T], old_z_vals
# weights: [B, T - 1], bin weights.
# return: [B, n_samples], new_z_vals
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
# Take uniform samples
if det:
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
# Invert CDF
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples
@torch.cuda.amp.autocast(enabled=False)
def near_far_from_bound(rays_o, rays_d, bound, type='cube', min_near=0.05):
# rays: [B, N, 3], [B, N, 3]
# bound: int, radius for ball or half-edge-length for cube
# return near [B, N, 1], far [B, N, 1]
radius = rays_o.norm(dim=-1, keepdim=True)
if type == 'sphere':
near = radius - bound # [B, N, 1]
far = radius + bound
elif type == 'cube':
tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3]
tmax = (bound - rays_o) / (rays_d + 1e-15)
near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0]
far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0]
# if far < near, means no intersection, set both near and far to inf (1e9 here)
mask = far < near
near[mask] = 1e9
far[mask] = 1e9
# restrict near to a minimal value
near = torch.clamp(near, min=min_near)
return near, far
def plot_pointcloud(pc, color=None):
# pc: [N, 3]
# color: [N, 3/4]
print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
pc = trimesh.PointCloud(pc, color)
# axis
axes = trimesh.creation.axis(axis_length=4)
# sphere
sphere = trimesh.creation.icosphere(radius=1)
trimesh.Scene([pc, axes, sphere]).show()
class DMTet():
def __init__(self, device):
self.device = device
self.triangle_table = torch.tensor([
[-1, -1, -1, -1, -1, -1],
[ 1, 0, 2, -1, -1, -1],
[ 4, 0, 3, -1, -1, -1],
[ 1, 4, 2, 1, 3, 4],
[ 3, 1, 5, -1, -1, -1],
[ 2, 3, 0, 2, 5, 3],
[ 1, 4, 0, 1, 5, 4],
[ 4, 2, 5, -1, -1, -1],
[ 4, 5, 2, -1, -1, -1],
[ 4, 1, 0, 4, 5, 1],
[ 3, 2, 0, 3, 5, 2],
[ 1, 3, 5, -1, -1, -1],
[ 4, 1, 2, 4, 3, 1],
[ 3, 0, 4, -1, -1, -1],
[ 2, 0, 1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]
], dtype=torch.long, device=device)
self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device=device)
self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device=device)
def sort_edges(self, edges_ex2):
with torch.no_grad():
order = (edges_ex2[:,0] > edges_ex2[:,1]).long()
order = order.unsqueeze(dim=1)
a = torch.gather(input=edges_ex2, index=order, dim=1)
b = torch.gather(input=edges_ex2, index=1-order, dim=1)
return torch.stack([a, b],-1)
def __call__(self, pos_nx3, sdf_n, tet_fx4):
# pos_nx3: [N, 3]
# sdf_n: [N]
# tet_fx4: [F, 4]
with torch.no_grad():
occ_n = sdf_n > 0
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)
occ_sum = torch.sum(occ_fx4, -1) # [F,]
valid_tets = (occ_sum>0) & (occ_sum<4)
occ_sum = occ_sum[valid_tets]
# find all vertices
all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)
all_edges = self.sort_edges(all_edges)
unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True)
unique_edges = unique_edges.long()
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device=self.device)
idx_map = mapping[idx_map] # map edges to verts
interp_v = unique_edges[mask_edges]
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)
edges_to_interp_sdf[:,-1] *= -1
denominator = edges_to_interp_sdf.sum(1,keepdim = True)
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
idx_map = idx_map.reshape(-1,6)
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=self.device))
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
num_triangles = self.num_triangles_table[tetindex]
# Generate triangle indices
faces = torch.cat((
torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3),
torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3),
), dim=0)
return verts, faces
def compute_edge_to_face_mapping(attr_idx):
with torch.no_grad():
# Get unique edges
# Create all edges, packed by triangle
all_edges = torch.cat((
torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1),
torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1),
torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1),
), dim=-1).view(-1, 2)
# Swap edge order so min index is always first
order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
sorted_edges = torch.cat((
torch.gather(all_edges, 1, order),
torch.gather(all_edges, 1, 1 - order)
), dim=-1)
# Elliminate duplicates and return inverse mapping
unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True)
tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda()
tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda()
# Compute edge to face table
mask0 = order[:,0] == 0
mask1 = order[:,0] == 1
tris_per_edge[idx_map[mask0], 0] = tris[mask0]
tris_per_edge[idx_map[mask1], 1] = tris[mask1]
return tris_per_edge
@torch.cuda.amp.autocast(enabled=False)
def normal_consistency(face_normals, t_pos_idx):
tris_per_edge = compute_edge_to_face_mapping(t_pos_idx)
# Fetch normals for both faces sharind an edge
n0 = face_normals[tris_per_edge[:, 0], :]
n1 = face_normals[tris_per_edge[:, 1], :]
# Compute error metric based on normal difference
term = torch.clamp(torch.sum(n0 * n1, -1, keepdim=True), min=-1.0, max=1.0)
term = (1.0 - term)
return torch.mean(torch.abs(term))
def laplacian_uniform(verts, faces):
V = verts.shape[0]
F = faces.shape[0]
# Neighbor indices
ii = faces[:, [1, 2, 0]].flatten()
jj = faces[:, [2, 0, 1]].flatten()
adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(dim=1)
adj_values = torch.ones(adj.shape[1], device=verts.device, dtype=torch.float)
# Diagonal indices
diag_idx = adj[0]
# Build the sparse matrix
idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
values = torch.cat((-adj_values, adj_values))
# The coalesce operation sums the duplicate indices, resulting in the
# correct diagonal
return torch.sparse_coo_tensor(idx, values, (V,V)).coalesce()
@torch.cuda.amp.autocast(enabled=False)
def laplacian_smooth_loss(verts, faces):
with torch.no_grad():
L = laplacian_uniform(verts, faces.long())
loss = L.mm(verts)
loss = loss.norm(dim=1)
loss = loss.mean()
return loss
class NeRFRenderer(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.bound = opt.bound
self.cascade = 1 + math.ceil(math.log2(opt.bound))
self.grid_size = 128
self.max_level = None
self.dmtet = opt.dmtet
self.cuda_ray = opt.cuda_ray
self.taichi_ray = opt.taichi_ray
self.min_near = opt.min_near
self.density_thresh = opt.density_thresh
# prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
# NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
aabb_train = torch.FloatTensor([-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound])
aabb_infer = aabb_train.clone()
self.register_buffer('aabb_train', aabb_train)
self.register_buffer('aabb_infer', aabb_infer)
self.glctx = None
# extra state for cuda raymarching
if self.cuda_ray:
# density grid
density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
self.register_buffer('density_grid', density_grid)
self.register_buffer('density_bitfield', density_bitfield)
self.mean_density = 0
self.iter_density = 0
if self.dmtet:
# load dmtet vertices
tets = np.load('tets/{}_tets.npz'.format(self.opt.tet_grid_size))
self.verts = - torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * 2 # covers [-1, 1]
self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda')
self.tet_scale = torch.tensor([1, 1, 1], dtype=torch.float32, device='cuda')
self.dmtet_model = DMTet('cuda')
# vert sdf and deform
sdf = torch.nn.Parameter(torch.zeros_like(self.verts[..., 0]), requires_grad=True)
self.register_parameter('sdf', sdf)
deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True)
self.register_parameter('deform', deform)
edges = torch.tensor([0,1, 0,2, 0,3, 1,2, 1,3, 2,3], dtype=torch.long, device="cuda") # six edges for each tetrahedron.
all_edges = self.indices[:,edges].reshape(-1,2) # [M * 6, 2]
all_edges_sorted = torch.sort(all_edges, dim=1)[0]
self.all_edges = torch.unique(all_edges_sorted, dim=0)
if self.opt.h <= 2048 and self.opt.w <= 2048:
self.glctx = dr.RasterizeCudaContext()
else:
self.glctx = dr.RasterizeGLContext()
if self.taichi_ray:
from einops import rearrange
from taichi_modules import RayMarcherTaichi
from taichi_modules import VolumeRendererTaichi
from taichi_modules import RayAABBIntersector as RayAABBIntersectorTaichi
from taichi_modules import raymarching_test as raymarching_test_taichi
from taichi_modules import composite_test as composite_test_fw
from taichi_modules import packbits as packbits_taichi
self.rearrange = rearrange
self.packbits_taichi = packbits_taichi
self.ray_aabb_intersector = RayAABBIntersectorTaichi
self.raymarching_test_taichi = raymarching_test_taichi
self.composite_test_fw = composite_test_fw
self.ray_marching = RayMarcherTaichi(batch_size=4096) # TODO: hard encoded batch size
self.volume_render = VolumeRendererTaichi(batch_size=4096) # TODO: hard encoded batch size
# density grid
density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
self.register_buffer('density_grid', density_grid)
self.register_buffer('density_bitfield', density_bitfield)
self.mean_density = 0
self.iter_density = 0
@torch.no_grad()
def density_blob(self, x):
# x: [B, N, 3]
d = (x ** 2).sum(-1)
if self.opt.density_activation == 'exp':
g = self.opt.blob_density * torch.exp(- d / (2 * self.opt.blob_radius ** 2))
else:
g = self.opt.blob_density * (1 - torch.sqrt(d) / self.opt.blob_radius)
return g
def forward(self, x, d):
raise NotImplementedError()
def density(self, x):
raise NotImplementedError()
def reset_extra_state(self):
if not (self.cuda_ray or self.taichi_ray):
return
# density grid
self.density_grid.zero_()
self.mean_density = 0
self.iter_density = 0
@torch.no_grad()
def export_mesh(self, path, resolution=None, decimate_target=-1, S=128):
if self.opt.dmtet:
sdf = self.sdf
deform = torch.tanh(self.deform) / self.opt.tet_grid_size
vertices, triangles = self.dmtet_model(self.verts + deform, sdf, self.indices)
vertices = vertices.detach().cpu().numpy()
triangles = triangles.detach().cpu().numpy()
else:
if resolution is None:
resolution = self.grid_size
if self.cuda_ray:
density_thresh = min(self.mean_density, self.density_thresh) \
if np.greater(self.mean_density, 0) else self.density_thresh
else:
density_thresh = self.density_thresh
# TODO: use a larger thresh to extract a surface mesh from the density field, but this value is very empirical...
if self.opt.density_activation == 'softplus':
density_thresh = density_thresh * 25
sigmas = np.zeros([resolution, resolution, resolution], dtype=np.float32)
# query
X = torch.linspace(-1, 1, resolution).split(S)
Y = torch.linspace(-1, 1, resolution).split(S)
Z = torch.linspace(-1, 1, resolution).split(S)
for xi, xs in enumerate(X):
for yi, ys in enumerate(Y):
for zi, zs in enumerate(Z):
xx, yy, zz = custom_meshgrid(xs, ys, zs)
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
val = self.density(pts.to(self.aabb_train.device))
sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')
vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
vertices = vertices / (resolution - 1.0) * 2 - 1
# clean
vertices = vertices.astype(np.float32)
triangles = triangles.astype(np.int32)
vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
# decimation
if decimate_target > 0 and triangles.shape[0] > decimate_target:
vertices, triangles = decimate_mesh(vertices, triangles, decimate_target)
v = torch.from_numpy(vertices).contiguous().float().to(self.aabb_train.device)
f = torch.from_numpy(triangles).contiguous().int().to(self.aabb_train.device)
# mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
# mesh.export(os.path.join(path, f'mesh.ply'))
def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''):
# v, f: torch Tensor
device = v.device
v_np = v.cpu().numpy() # [N, 3]
f_np = f.cpu().numpy() # [M, 3]
print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
# unwrap uvs
import xatlas
import nvdiffrast.torch as dr
from sklearn.neighbors import NearestNeighbors
from scipy.ndimage import binary_dilation, binary_erosion
atlas = xatlas.Atlas()
atlas.add_mesh(v_np, f_np)
chart_options = xatlas.ChartOptions()
chart_options.max_iterations = 4 # for faster unwrap...
atlas.generate(chart_options=chart_options)
vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
# vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2]
vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
# render uv maps
uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
if ssaa > 1:
h = int(h0 * ssaa)
w = int(w0 * ssaa)
else:
h, w = h0, w0
if self.glctx is None:
if h <= 2048 and w <= 2048:
self.glctx = dr.RasterizeCudaContext()
else:
self.glctx = dr.RasterizeGLContext()
rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
# masked query
xyzs = xyzs.view(-1, 3)
mask = (mask > 0).view(-1)
feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32)
if mask.any():
xyzs = xyzs[mask] # [M, 3]
# batched inference to avoid OOM
all_feats = []
head = 0
while head < xyzs.shape[0]:
tail = min(head + 640000, xyzs.shape[0])
results_ = self.density(xyzs[head:tail])
all_feats.append(results_['albedo'].float())
head += 640000
feats[mask] = torch.cat(all_feats, dim=0)
feats = feats.view(h, w, -1)
mask = mask.view(h, w)
# quantize [0.0, 1.0] to [0, 255]
feats = feats.cpu().numpy()
feats = (feats * 255).astype(np.uint8)
### NN search as an antialiasing ...
mask = mask.cpu().numpy()
inpaint_region = binary_dilation(mask, iterations=3)
inpaint_region[mask] = 0
search_region = mask.copy()
not_search_region = binary_erosion(search_region, iterations=2)
search_region[not_search_region] = 0
search_coords = np.stack(np.nonzero(search_region), axis=-1)
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
_, indices = knn.kneighbors(inpaint_coords)
feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR)
# do ssaa after the NN search, in numpy
if ssaa > 1:
feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR)
cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats)
# save obj (v, vt, f /)
obj_file = os.path.join(path, f'{name}mesh.obj')
mtl_file = os.path.join(path, f'{name}mesh.mtl')
print(f'[INFO] writing obj mesh to {obj_file}')
with open(obj_file, "w") as fp:
fp.write(f'mtllib {name}mesh.mtl \n')
print(f'[INFO] writing vertices {v_np.shape}')
for v in v_np:
fp.write(f'v {v[0]} {v[1]} {v[2]} \n')
print(f'[INFO] writing vertices texture coords {vt_np.shape}')
for v in vt_np:
fp.write(f'vt {v[0]} {1 - v[1]} \n')
print(f'[INFO] writing faces {f_np.shape}')
fp.write(f'usemtl mat0 \n')
for i in range(len(f_np)):
fp.write(f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n")
with open(mtl_file, "w") as fp:
fp.write(f'newmtl mat0 \n')
fp.write(f'Ka 1.000000 1.000000 1.000000 \n')
fp.write(f'Kd 1.000000 1.000000 1.000000 \n')
fp.write(f'Ks 0.000000 0.000000 0.000000 \n')
fp.write(f'Tr 1.000000 \n')
fp.write(f'illum 1 \n')
fp.write(f'Ns 0.000000 \n')
fp.write(f'map_Kd {name}albedo.png \n')
_export(v, f)
def run(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs):
# rays_o, rays_d: [B, N, 3]
# bg_color: [BN, 3] in range [0, 1]
# return: image: [B, N, 3], depth: [B, N]
prefix = rays_o.shape[:-1]
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # N = B * N, in fact
device = rays_o.device
results = {}
# choose aabb
aabb = self.aabb_train if self.training else self.aabb_infer
# sample steps
# nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
# nears.unsqueeze_(-1)
# fars.unsqueeze_(-1)
nears, fars = near_far_from_bound(rays_o, rays_d, self.bound, type='sphere', min_near=self.min_near)
# random sample light_d if not provided
if light_d is None:
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
light_d = safe_normalize(rays_o + torch.randn(3, device=rays_o.device)) # [N, 3]
#print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
z_vals = torch.linspace(0.0, 1.0, self.opt.num_steps, device=device).unsqueeze(0) # [1, T]
z_vals = z_vals.expand((N, self.opt.num_steps)) # [N, T]
z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
# perturb z_vals
sample_dist = (fars - nears) / self.opt.num_steps
if perturb:
z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
#z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
# generate xyzs
xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
# query SDF and RGB
density_outputs = self.density(xyzs.reshape(-1, 3))
#sigmas = density_outputs['sigma'].view(N, self.opt.num_steps) # [N, T]
for k, v in density_outputs.items():
density_outputs[k] = v.view(N, self.opt.num_steps, -1)
# upsample z_vals (nerf-like)
if self.opt.upsample_steps > 0:
with torch.no_grad():
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T]
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
# sample new z_vals
z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], self.opt.upsample_steps, det=not self.training).detach() # [N, t]
new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
# only forward new points to save computation
new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
#new_sigmas = new_density_outputs['sigma'].view(N, self.opt.upsample_steps) # [N, t]
for k, v in new_density_outputs.items():
new_density_outputs[k] = v.view(N, self.opt.upsample_steps, -1)
# re-order
z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
z_vals, z_index = torch.sort(z_vals, dim=1)
xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
for k in density_outputs:
tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
light_d = light_d.view(-1, 1, 3).expand_as(xyzs)
for k, v in density_outputs.items():
density_outputs[k] = v.view(-1, v.shape[-1])
dirs = safe_normalize(dirs)
sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d.reshape(-1, 3), ratio=ambient_ratio, shading=shading)
rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
if normals is not None:
normals = normals.view(N, -1, 3)
# calculate weight_sum (mask)
weights_sum = weights.sum(dim=-1) # [N]
# calculate depth
depth = torch.sum(weights * z_vals, dim=-1)
# calculate color
image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
# mix background color
if bg_color is None:
if self.opt.bg_radius > 0:
# use the bg model to calculate bg_color
bg_color = self.background(rays_d) # [N, 3]
else:
bg_color = 1
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
image = image.view(*prefix, 3)
depth = depth.view(*prefix)
weights_sum = weights_sum.reshape(*prefix)
if self.training:
if self.opt.lambda_orient > 0 and normals is not None:
# orientation loss
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
results['loss_orient'] = loss_orient.sum(-1).mean()
if self.opt.lambda_3d_normal_smooth > 0 and normals is not None:
normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)
results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean()
if (self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0) and normals is not None:
normal_image = torch.sum(weights.unsqueeze(-1) * (normals + 1) / 2, dim=-2) # [N, 3], in [0, 1]
results['normal_image'] = normal_image
results['image'] = image
results['depth'] = depth
results['weights'] = weights
results['weights_sum'] = weights_sum
return results
def run_cuda(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, T_thresh=1e-4, binarize=False, **kwargs):
# rays_o, rays_d: [B, N, 3]
# return: image: [B, N, 3], depth: [B, N]
prefix = rays_o.shape[:-1]
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # B * N, in fact
device = rays_o.device
# pre-calculate near far
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer)
# random sample light_d if not provided
if light_d is None:
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
light_d = safe_normalize(rays_o + torch.randn(3, device=rays_o.device)) # [N, 3]
results = {}
if self.training:
xyzs, dirs, ts, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb, self.opt.dt_gamma, self.opt.max_steps)
dirs = safe_normalize(dirs)
if light_d.shape[0] > 1:
flatten_rays = raymarching.flatten_rays(rays, xyzs.shape[0]).long()
light_d = light_d[flatten_rays]
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
weights, weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ts, rays, T_thresh, binarize)
# normals related regularizations
if self.opt.lambda_orient > 0 and normals is not None:
# orientation loss
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
results['loss_orient'] = loss_orient.mean()
if self.opt.lambda_3d_normal_smooth > 0 and normals is not None:
normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)
results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean()
if (self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0) and normals is not None:
_, _, _, normal_image = raymarching.composite_rays_train(sigmas.detach(), (normals + 1) / 2, ts, rays, T_thresh, binarize)
results['normal_image'] = normal_image
# weights normalization
results['weights'] = weights
else:
# allocate outputs
dtype = torch.float32
weights_sum = torch.zeros(N, dtype=dtype, device=device)
depth = torch.zeros(N, dtype=dtype, device=device)
image = torch.zeros(N, 3, dtype=dtype, device=device)
n_alive = N
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
rays_t = nears.clone() # [N]
step = 0
while step < self.opt.max_steps: # hard coded max step
# count alive rays
n_alive = rays_alive.shape[0]
# exit loop
if n_alive <= 0:
break
# decide compact_steps
n_step = max(min(N // n_alive, 8), 1)
xyzs, dirs, ts = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb if step == 0 else False, self.opt.dt_gamma, self.opt.max_steps)
dirs = safe_normalize(dirs)
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh, binarize)
rays_alive = rays_alive[rays_alive >= 0]
#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
step += n_step
# mix background color
if bg_color is None:
if self.opt.bg_radius > 0:
# use the bg model to calculate bg_color
bg_color = self.background(rays_d) # [N, 3]
else:
bg_color = 1
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
image = image.view(*prefix, 3)
depth = depth.view(*prefix)
weights_sum = weights_sum.reshape(*prefix)
results['image'] = image
results['depth'] = depth
results['weights_sum'] = weights_sum
return results
@torch.no_grad()
def init_tet(self, mesh=None):
if mesh is not None:
# normalize mesh
scale = 0.8 / np.array(mesh.bounds[1] - mesh.bounds[0]).max()
center = np.array(mesh.bounds[1] + mesh.bounds[0]) / 2
mesh.vertices = (mesh.vertices - center) * scale
# init scale
# self.tet_scale = torch.from_numpy(np.abs(mesh.vertices).max(axis=0) + 1e-1).to(self.verts.dtype).cuda()
self.tet_scale = torch.from_numpy(np.array([np.abs(mesh.vertices).max()]) + 1e-1).to(self.verts.dtype).cuda()
self.verts = self.verts * self.tet_scale
# init sdf
import cubvh
BVH = cubvh.cuBVH(mesh.vertices, mesh.faces)
sdf, _, _ = BVH.signed_distance(self.verts, return_uvw=False, mode='watertight')
sdf *= -10 # INNER is POSITIVE, also make it stronger
self.sdf.data += sdf.to(self.sdf.data.dtype).clamp(-1, 1)
else:
if self.cuda_ray:
density_thresh = min(self.mean_density, self.density_thresh)
else:
density_thresh = self.density_thresh
if self.opt.density_activation == 'softplus':
density_thresh = density_thresh * 25
# init scale
sigma = self.density(self.verts)['sigma'] # verts covers [-1, 1] now
mask = sigma > density_thresh
valid_verts = self.verts[mask]
self.tet_scale = valid_verts.abs().amax(dim=0) + 1e-1
self.verts = self.verts * self.tet_scale
# init sigma
sigma = self.density(self.verts)['sigma'] # new verts
self.sdf.data += (sigma - density_thresh).clamp(-1, 1)
print(f'[INFO] init dmtet: scale = {self.tet_scale}')
def run_dmtet(self, rays_o, rays_d, mvp, h, w, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, **kwargs):
# mvp: [B, 4, 4]
device = mvp.device
campos = rays_o[:, 0, :] # only need one ray per batch
# random sample light_d if not provided
if light_d is None:
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
light_d = safe_normalize(campos + torch.randn_like(campos)).view(-1, 1, 1, 3) # [B, 1, 1, 3]
results = {}
# get mesh
sdf = self.sdf
deform = torch.tanh(self.deform) / self.opt.tet_grid_size
verts, faces = self.dmtet_model(self.verts + deform, sdf, self.indices)
# get normals
i0, i1, i2 = faces[:, 0], faces[:, 1], faces[:, 2]
v0, v1, v2 = verts[i0, :], verts[i1, :], verts[i2, :]
faces = faces.int()
face_normals = torch.cross(v1 - v0, v2 - v0)
face_normals = safe_normalize(face_normals)
vn = torch.zeros_like(verts)
vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
# rasterization
verts_clip = torch.bmm(F.pad(verts, pad=(0, 1), mode='constant', value=1.0).unsqueeze(0).repeat(mvp.shape[0], 1, 1),
mvp.permute(0,2,1)).float() # [B, N, 4]
rast, rast_db = dr.rasterize(self.glctx, verts_clip, faces, (h, w))
alpha = (rast[..., 3:] > 0).float()
xyzs, _ = dr.interpolate(verts.unsqueeze(0), rast, faces) # [B, H, W, 3]
normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, faces)
normal = safe_normalize(normal)
xyzs = xyzs.view(-1, 3)
mask = (rast[..., 3:] > 0).view(-1).detach()
# do the lighting here since we have normal from mesh now.
albedo = torch.zeros_like(xyzs, dtype=torch.float32)
if mask.any():
masked_albedo = self.density(xyzs[mask])['albedo']
albedo[mask] = masked_albedo.float()
albedo = albedo.view(-1, h, w, 3)
# these two modes lead to no parameters to optimize if using --lock_geo.
if self.opt.lock_geo and shading in ['textureless', 'normal']:
shading = 'lambertian'
if shading == 'albedo':
color = albedo
elif shading == 'textureless':
lambertian = ambient_ratio + (1 - ambient_ratio) * (normal * light_d).sum(-1).float().clamp(min=0)
color = lambertian.unsqueeze(-1).repeat(1, 1, 1, 3)
elif shading == 'normal':
color = (normal + 1) / 2
else: # 'lambertian'
lambertian = ambient_ratio + (1 - ambient_ratio) * (normal * light_d).sum(-1).float().clamp(min=0)
color = albedo * lambertian.unsqueeze(-1)
color = dr.antialias(color, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3]
alpha = dr.antialias(alpha, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 1]
# mix background color
if bg_color is None:
if self.opt.bg_radius > 0:
# use the bg model to calculate bg_color
bg_color = self.background(rays_d) # [N, 3]
else:
bg_color = 1
if torch.is_tensor(bg_color) and len(bg_color.shape) > 1:
bg_color = bg_color.view(-1, h, w, 3)
depth = rast[:, :, :, [2]] # [B, H, W]
color = color + (1 - alpha) * bg_color
results['depth'] = depth
results['image'] = color
results['weights_sum'] = alpha.squeeze(-1)
if self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0:
normal_image = dr.antialias((normal + 1) / 2, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3]
results['normal_image'] = normal_image
# regularizations
if self.training:
if self.opt.lambda_mesh_normal > 0:
results['normal_loss'] = normal_consistency(face_normals, faces)
if self.opt.lambda_mesh_laplacian > 0:
results['lap_loss'] = laplacian_smooth_loss(verts, faces)
return results
def run_taichi(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, T_thresh=1e-4, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# return: image: [B, N, 3], depth: [B, N]
prefix = rays_o.shape[:-1]
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # N = B * N, in fact
device = rays_o.device
# pre-calculate near far
exp_step_factor = kwargs.get('exp_step_factor', 0.)
MAX_SAMPLES = 1024
NEAR_DISTANCE = 0.01
center = torch.zeros(1, 3)
half_size = torch.ones(1, 3)
_, hits_t, _ = self.ray_aabb_intersector.apply(rays_o, rays_d, center, half_size, 1)
hits_t[(hits_t[:, 0, 0] >= 0) & (hits_t[:, 0, 0] < NEAR_DISTANCE), 0, 0] = NEAR_DISTANCE
# TODO: should sample different light_d for each batch... but taichi end doesn't have a flatten_ray implemented currently...
# random sample light_d if not provided
if light_d is None:
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float))
light_d = safe_normalize(light_d)
results = {}
if self.training:
rays_a, xyzs, dirs, deltas, ts, _ = self.ray_marching(rays_o, rays_d, hits_t[:, 0], self.density_bitfield, self.cascade, self.bound, exp_step_factor, self.grid_size, MAX_SAMPLES)
dirs = safe_normalize(dirs)
# plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
_, weights_sum, depth, image, weights = self.volume_render(sigmas, rgbs, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4))
# normals related regularizations
if self.opt.lambda_orient > 0 and normals is not None:
# orientation loss
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
results['loss_orient'] = loss_orient.mean()
if self.opt.lambda_3d_normal_smooth > 0 and normals is not None:
normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)
results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean()
if (self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0) and normals is not None:
_, _, _, normal_image, _ = self.volume_render(sigmas.detach(), (normals + 1) / 2, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4))
results['normal_image'] = normal_image
# weights normalization
results['weights'] = weights
else:
# allocate outputs
dtype = torch.float32
weights_sum = torch.zeros(N, dtype=dtype, device=device)
depth = torch.zeros(N, dtype=dtype, device=device)
image = torch.zeros(N, 3, dtype=dtype, device=device)
n_alive = N
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
rays_t = hits_t[:, 0, 0]
step = 0
min_samples = 1 if exp_step_factor == 0 else 4
while step < self.opt.max_steps: # hard coded max step
# count alive rays
n_alive = rays_alive.shape[0]
# exit loop
if n_alive <= 0:
break
# decide compact_steps
# n_step = max(min(N // n_alive, 8), 1)
n_step = max(min(N // n_alive, 64), min_samples)
xyzs, dirs, deltas, ts, N_eff_samples = \
self.raymarching_test_taichi(rays_o, rays_d, hits_t[:, 0], rays_alive,
self.density_bitfield, self.cascade,
self.bound, exp_step_factor,
self.grid_size, MAX_SAMPLES, n_step)
xyzs = self.rearrange(xyzs, 'n1 n2 c -> (n1 n2) c')
dirs = self.rearrange(dirs, 'n1 n2 c -> (n1 n2) c')
dirs = safe_normalize(dirs)
valid_mask = ~torch.all(dirs == 0, dim=1)
if valid_mask.sum() == 0:
break
sigmas = torch.zeros(len(xyzs), device=device)
rgbs = torch.zeros(len(xyzs), 3, device=device)
normals = torch.zeros(len(xyzs), 3, device=device)
sigmas[valid_mask], _rgbs, normals = self(xyzs[valid_mask], dirs[valid_mask], light_d, ratio=ambient_ratio, shading=shading)
rgbs[valid_mask] = _rgbs.float()
sigmas = self.rearrange(sigmas, '(n1 n2) -> n1 n2', n2=n_step)
rgbs = self.rearrange(rgbs, '(n1 n2) c -> n1 n2 c', n2=n_step)
if normals is not None:
normals = self.rearrange(normals, '(n1 n2) c -> n1 n2 c', n2=n_step)
self.composite_test_fw(sigmas, rgbs, deltas, ts, hits_t[:,0], rays_alive,
kwargs.get('T_threshold', 1e-4), N_eff_samples,
weights_sum, depth, image)
rays_alive = rays_alive[rays_alive >= 0]
step += n_step
# mix background color
if bg_color is None:
if self.opt.bg_radius > 0:
# use the bg model to calculate bg_color
bg_color = self.background(rays_d) # [N, 3]
else:
bg_color = 1
image = image + self.rearrange(1 - weights_sum, 'n -> n 1') * bg_color
image = image.view(*prefix, 3)
depth = depth.view(*prefix)
weights_sum = weights_sum.reshape(*prefix)
results['image'] = image
results['depth'] = depth
results['weights_sum'] = weights_sum
return results
@torch.no_grad()
def update_extra_state(self, decay=0.95, S=128):
# call before each epoch to update extra states.
if not (self.cuda_ray or self.taichi_ray):
return
### update density grid
tmp_grid = - torch.ones_like(self.density_grid)
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S)
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S)
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S)
for xs in X:
for ys in Y:
for zs in Z:
# construct points
xx, yy, zz = custom_meshgrid(xs, ys, zs)
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
indices = raymarching.morton3D(coords).long() # [N]
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
# cascading
for cas in range(self.cascade):
bound = min(2 ** cas, self.bound)
half_grid_size = bound / self.grid_size
# scale to current cascade's resolution
cas_xyzs = xyzs * (bound - half_grid_size)
# add noise in [-hgs, hgs]
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
# query density
sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
# assign
tmp_grid[cas, indices] = sigmas
# ema update
valid_mask = self.density_grid >= 0
self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
self.mean_density = torch.mean(self.density_grid[valid_mask]).item()
self.iter_density += 1
# convert to bitfield
density_thresh = min(self.mean_density, self.density_thresh)
if self.cuda_ray:
self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
elif self.taichi_ray:
self.packbits_taichi(self.density_grid.reshape(-1).contiguous(), density_thresh, self.density_bitfield)
# print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f}')
def render(self, rays_o, rays_d, mvp, h, w, staged=False, max_ray_batch=4096, **kwargs):
# rays_o, rays_d: [B, N, 3]
# return: pred_rgb: [B, N, 3]
B, N = rays_o.shape[:2]
device = rays_o.device
if self.dmtet:
results = self.run_dmtet(rays_o, rays_d, mvp, h, w, **kwargs)
elif self.cuda_ray:
results = self.run_cuda(rays_o, rays_d, **kwargs)
elif self.taichi_ray:
results = self.run_taichi(rays_o, rays_d, **kwargs)
else:
if staged:
depth = torch.empty((B, N), device=device)
image = torch.empty((B, N, 3), device=device)
weights_sum = torch.empty((B, N), device=device)
for b in range(B):
head = 0
while head < N:
tail = min(head + max_ray_batch, N)
results_ = self.run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs)
depth[b:b+1, head:tail] = results_['depth']
weights_sum[b:b+1, head:tail] = results_['weights_sum']
image[b:b+1, head:tail] = results_['image']
head += max_ray_batch
results = {}
results['depth'] = depth
results['image'] = image
results['weights_sum'] = weights_sum
else:
results = self.run(rays_o, rays_d, **kwargs)
return results
================================================
FILE: nerf/utils.py
================================================
import os
import gc
import glob
import tqdm
import math
import imageio
import psutil
from pathlib import Path
import random
import shutil
import warnings
import tensorboardX
import numpy as np
import time
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
import torchvision.transforms.functional as TF
from torchmetrics import PearsonCorrCoef
from rich.console import Console
from torch_ema import ExponentialMovingAverage
from packaging import version as pver
def adjust_text_embeddings(embeddings, azimuth, opt):
text_z_list = []
weights_list = []
K = 0
for b in range(azimuth.shape[0]):
text_z_, weights_ = get_pos_neg_text_embeddings(embeddings, azimuth[b], opt)
K = max(K, weights_.shape[0])
text_z_list.append(text_z_)
weights_list.append(weights_)
# Interleave text_embeddings from different dirs to form a batch
text_embeddings = []
for i in range(K):
for text_z in text_z_list:
# if uneven length, pad with the first embedding
text_embeddings.append(text_z[i] if i < len(text_z) else text_z[0])
text_embeddings = torch.stack(text_embeddings, dim=0) # [B * K, 77, 768]
# Interleave weights from different dirs to form a batch
weights = []
for i in range(K):
for weights_ in weights_list:
weights.append(weights_[i] if i < len(weights_) else torch.zeros_like(weights_[0]))
weights = torch.stack(weights, dim=0) # [B * K]
return text_embeddings, weights
def get_pos_neg_text_embeddings(embeddings, azimuth_val, opt):
if azimuth_val >= -90 and azimuth_val < 90:
if azimuth_val >= 0:
r = 1 - azimuth_val / 90
else:
r = 1 + azimuth_val / 90
start_z = embeddings['front']
end_z = embeddings['side']
# if random.random() < 0.3:
# r = r + random.gauss(0, 0.08)
pos_z = r * start_z + (1 - r) * end_z
text_z = torch.cat([pos_z, embeddings['front'], embeddings['side']], dim=0)
if r > 0.8:
front_neg_w = 0.0
else:
front_neg_w = math.exp(-r * opt.front_decay_factor) * opt.negative_w
if r < 0.2:
side_neg_w = 0.0
else:
side_neg_w = math.exp(-(1-r) * opt.side_decay_factor) * opt.negative_w
weights = torch.tensor([1.0, front_neg_w, side_neg_w])
else:
if azimuth_val >= 0:
r = 1 - (azimuth_val - 90) / 90
else:
r = 1 + (azimuth_val + 90) / 90
start_z = embeddings['side']
end_z = embeddings['back']
# if random.random() < 0.3:
# r = r + random.gauss(0, 0.08)
pos_z = r * start_z + (1 - r) * end_z
text_z = torch.cat([pos_z, embeddings['side'], embeddings['front']], dim=0)
front_neg_w = opt.negative_w
if r > 0.8:
side_neg_w = 0.0
else:
side_neg_w = math.exp(-r * opt.side_decay_factor) * opt.negative_w / 2
weights = torch.tensor([1.0, side_neg_w, front_neg_w])
return text_z, weights.to(text_z.device)
def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse('1.10'):
return torch.meshgrid(*args)
else:
return torch.meshgrid(*args, indexing='ij')
def safe_normalize(x, eps=1e-20):
return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
@torch.cuda.amp.autocast(enabled=False)
def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
''' get rays
Args:
poses: [B, 4, 4], cam2world
intrinsics: [4]
H, W, N: int
error_map: [B, 128 * 128], sample probability based on training error
Returns:
rays_o, rays_d: [B, N, 3]
inds: [B, N]
'''
device = poses.device
B = poses.shape[0]
fx, fy, cx, cy = intrinsics
i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device))
i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
results = {}
if N > 0:
N = min(N, H*W)
if error_map is None:
inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate
inds = inds.expand([B, N])
else:
# weighted sample on a low-reso grid
inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128)
# map to the original resolution with random perturb.
inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway.
sx, sy = H / 128, W / 128
inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1)
inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1)
inds = inds_x * W + inds_y
results['inds_coarse'] = inds_coarse # need this when updating error_map
i = torch.gather(i, -1, inds)
j = torch.gather(j, -1, inds)
results['inds'] = inds
else:
inds = torch.arange(H*W, device=device).expand([B, H*W])
zs = - torch.ones_like(i)
xs = - (i - cx) / fx * zs
ys = (j - cy) / fy * zs
directions = torch.stack((xs, ys, zs), dim=-1)
# directions = safe_normalize(directions)
rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
rays_o = poses[..., :3, 3] # [B, 3]
rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
results['rays_o'] = rays_o
results['rays_d'] = rays_d
return results
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = True
@torch.jit.script
def linear_to_srgb(x):
return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
@torch.jit.script
def srgb_to_linear(x):
return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
class Trainer(object):
def __init__(self,
argv, # command line args
name, # name of this experiment
opt, # extra conf
model, # network
guidance, # guidance network
criterion=None, # loss function, if None, assume inline implementation in train_step
optimizer=None, # optimizer
ema_decay=None, # if use EMA, set the decay
lr_scheduler=None, # scheduler
metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
local_rank=0, # which GPU am I
world_size=1, # total num of GPUs
device=None, # device to use, usually setting to None is OK. (auto choose device)
mute=False, # whether to mute all print
fp16=False, # amp optimize level
max_keep_ckpt=2, # max num of saved ckpts in disk
workspace='workspace', # workspace to save logs & ckpts
best_mode='min', # the smaller/larger result, the better
use_loss_as_metric=True, # use loss as the first metric
report_metric_at_train=False, # also report metrics at training
use_checkpoint="latest", # which ckpt to use at init time
use_tensorboardX=True, # whether to use tensorboard for logging
scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
):
self.argv = argv
self.name = name
self.opt = opt
self.mute = mute
self.metrics = metrics
self.local_rank = local_rank
self.world_size = world_size
self.workspace = workspace
self.ema_decay = ema_decay
self.fp16 = fp16
self.best_mode = best_mode
self.use_loss_as_metric = use_loss_as_metric
self.report_metric_at_train = report_metric_at_train
self.max_keep_ckpt = max_keep_ckpt
self.use_checkpoint = use_checkpoint
self.use_tensorboardX = use_tensorboardX
self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
self.scheduler_update_every_step = scheduler_update_every_step
self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
self.console = Console()
model.to(self.device)
if self.world_size > 1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
self.model = model
# guide model
self.guidance = guidance
self.embeddings = {}
# text prompt / images
if self.guidance is not None:
for key in self.guidance:
for p in self.guidance[key].parameters():
p.requires_grad = False
self.embeddings[key] = {}
self.prepare_embeddings()
if isinstance(criterion, nn.Module):
criterion.to(self.device)
self.criterion = criterion
if self.opt.images is not None:
self.pearson = PearsonCorrCoef().to(self.device)
if optimizer is None:
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam
else:
self.optimizer = optimizer(self.model)
if lr_scheduler is None:
self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler
else:
self.lr_scheduler = lr_scheduler(self.optimizer)
if ema_decay is not None:
self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay)
else:
self.ema = None
self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
# variable init
self.total_train_t = 0
self.epoch = 0
self.global_step = 0
self.local_step = 0
self.stats = {
"loss": [],
"valid_loss": [],
"results": [], # metrics[0], or valid_loss
"checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt
"best_result": None,
}
# auto fix
if len(metrics) == 0 or self.use_loss_as_metric:
self.best_mode = 'min'
# workspace prepare
self.log_ptr = None
if self.workspace is not None:
os.makedirs(self.workspace, exist_ok=True)
self.log_path = os.path.join(workspace, f"log_{self.name}.txt")
self.log_ptr = open(self.log_path, "a+")
self.ckpt_path = os.path.join(self.workspace, 'checkpoints')
self.best_path = f"{self.ckpt_path}/{self.name}.pth"
os.makedirs(self.ckpt_path, exist_ok=True)
# Save a copy of image_config in the experiment workspace
if opt.image_config is not None:
shutil.copyfile(opt.image_config, os.path.join(self.workspace, os.path.basename(opt.image_config)))
# Save a copy of images in the experiment workspace
if opt.images is not None:
for image_file in opt.images:
shutil.copyfile(image_file, os.path.join(self.workspace, os.path.basename(image_file)))
self.log(f'[INFO] Cmdline: {self.argv}')
self.log(f'[INFO] opt: {self.opt}')
self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}')
self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')
if self.workspace is not None:
if self.use_checkpoint == "scratch":
self.log("[INFO] Training from scratch ...")
elif self.use_checkpoint == "latest":
self.log("[INFO] Loading latest checkpoint ...")
self.load_checkpoint()
elif self.use_checkpoint == "latest_model":
self.log("[INFO] Loading latest checkpoint (model only)...")
self.load_checkpoint(model_only=True)
elif self.use_checkpoint == "best":
if os.path.exists(self.best_path):
self.log("[INFO] Loading best checkpoint ...")
self.load_checkpoint(self.best_path)
else:
self.log(f"[INFO] {self.best_path} not found, loading latest ...")
self.load_checkpoint()
else: # path to ckpt
self.log(f"[INFO] Loading {self.use_checkpoint} ...")
self.load_checkpoint(self.use_checkpoint)
# calculate the text embs.
@torch.no_grad()
def prepare_embeddings(self):
# text embeddings (stable-diffusion)
if self.opt.text is not None:
if 'SD' in self.guidance:
self.embeddings['SD']['default'] = self.guidance['SD'].get_text_embeds([self.opt.text])
self.embeddings['SD']['uncond'] = self.guidance['SD'].get_text_embeds([self.opt.negative])
for d in ['front', 'side', 'back']:
self.embeddings['SD'][d] = self.guidance['SD'].get_text_embeds([f"{self.opt.text}, {d} view"])
if 'IF' in self.guidance:
self.embeddings['IF']['default'] = self.guidance['IF'].get_text_embeds([self.opt.text])
self.embeddings['IF']['uncond'] = self.guidance['IF'].get_text_embeds([self.opt.negative])
for d in ['front', 'side', 'back']:
self.embeddings['IF'][d] = self.guidance['IF'].get_text_embeds([f"{self.opt.text}, {d} view"])
if 'clip' in self.guidance:
self.embeddings['clip']['text'] = self.guidance['clip'].get_text_embeds(self.opt.text)
if self.opt.images is not None:
h = int(self.opt.known_view_scale * self.opt.h)
w = int(self.opt.known_view_scale * self.opt.w)
# load processed image
for image in self.opt.images:
assert image.endswith('_rgba.png') # the rest of this code assumes that the _rgba image has been passed.
rgbas = [cv2.cvtColor(cv2.imread(image, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA) for image in self.opt.images]
rgba_hw = np.stack([cv2.resize(rgba, (w, h), interpolation=cv2.INTER_AREA).astype(np.float32) / 255 for rgba in rgbas])
rgb_hw = rgba_hw[..., :3] * rgba_hw[..., 3:] + (1 - rgba_hw[..., 3:])
self.rgb = torch.from_numpy(rgb_hw).permute(0,3,1,2).contiguous().to(self.device)
self.mask = torch.from_numpy(rgba_hw[..., 3] > 0.5).to(self.device)
print(f'[INFO] dataset: load image prompt {self.opt.images} {self.rgb.shape}')
# load depth
depth_paths = [image.replace('_rgba.png', '_depth.png') for image in self.opt.images]
depths = [cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) for depth_path in depth_paths]
depth = np.stack([cv2.resize(depth, (w, h), interpolation=cv2.INTER_AREA) for depth in depths])
self.depth = torch.from_numpy(depth.astype(np.float32) / 255).to(self.device) # TODO: this should be mapped to FP16
print(f'[INFO] dataset: load depth prompt {depth_paths} {self.depth.shape}')
# load normal # TODO: don't load if normal loss is 0
normal_paths = [image.replace('_rgba.png', '_normal.png') for image in self.opt.images]
normals = [cv2.imread(normal_path, cv2.IMREAD_UNCHANGED) for normal_path in normal_paths]
normal = np.stack([cv2.resize(normal, (w, h), interpolation=cv2.INTER_AREA) for normal in normals])
self.normal = torch.from_numpy(normal.astype(np.float32) / 255).to(self.device)
print(f'[INFO] dataset: load normal prompt {normal_paths} {self.normal.shape}')
# encode embeddings for zero123
if 'zero123' in self.guidance:
rgba_256 = np.stack([cv2.resize(rgba, (256, 256), interpolation=cv2.INTER_AREA).astype(np.float32) / 255 for rgba in rgbas])
rgbs_256 = rgba_256[..., :3] * rgba_256[..., 3:] + (1 - rgba_256[..., 3:])
rgb_256 = torch.from_numpy(rgbs_256).permute(0,3,1,2).contiguous().to(self.device)
guidance_embeds = self.guidance['zero123'].get_img_embeds(rgb_256)
self.embeddings['zero123']['default'] = {
'zero123_ws' : self.opt.zero123_ws,
'c_crossattn' : guidance_embeds[0],
'c_concat' : guidance_embeds[1],
'ref_polars' : self.opt.ref_polars,
'ref_azimuths' : self.opt.ref_azimuths,
'ref_radii' : self.opt.ref_radii,
}
if 'clip' in self.guidance:
self.embeddings['clip']['image'] = self.guidance['clip'].get_img_embeds(self.rgb)
def __del__(self):
if self.log_ptr:
self.log_ptr.close()
def log(self, *args, **kwargs):
if self.local_rank == 0:
if not self.mute:
#print(*args)
self.console.print(*args, **kwargs)
if self.log_ptr:
print(*args, file=self.log_ptr)
self.log_ptr.flush() # write immediately to file
### ------------------------------
def train_step(self, data, save_guidance_path:Path=None):
"""
Args:
save_guidance_path: an image that combines the NeRF render, the added latent noise,
the denoised result and optionally the fully-denoised image.
"""
# perform RGBD loss instead of SDS if is image-conditioned
do_rgbd_loss = self.opt.images is not None and \
(self.global_step % self.opt.known_view_interval == 0)
# override random camera with fixed known camera
if do_rgbd_loss:
data = self.default_view_data
# experiment iterations ratio
# i.e. what proportion of this experiment have we completed (in terms of iterations) so far?
exp_iter_ratio = (self.global_step - self.opt.exp_start_iter) / (self.opt.exp_end_iter - self.opt.exp_start_iter)
# progressively relaxing view range
if self.opt.progressive_view:
r = min(1.0, self.opt.progressive_view_init_ratio + 2.0*exp_iter_ratio)
self.opt.phi_range = [self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[0] * r,
self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[1] * r]
self.opt.theta_range = [self.opt.default_polar * (1 - r) + self.opt.full_theta_range[0] * r,
self.opt.default_polar * (1 - r) + self.opt.full_theta_range[1] * r]
self.opt.radius_range = [self.opt.default_radius * (1 - r) + self.opt.full_radius_range[0] * r,
self.opt.default_radius * (1 - r) + self.opt.full_radius_range[1] * r]
self.opt.fovy_range = [self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[0] * r,
self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[1] * r]
# progressively increase max_level
if self.opt.progressive_level:
self.model.max_level = min(1.0, 0.25 + 2.0*exp_iter_ratio)
rays_o = data['rays_o'] # [B, N, 3]
rays_d = data['rays_d'] # [B, N, 3]
mvp = data['mvp'] # [B, 4, 4]
B, N = rays_o.shape[:2]
H, W = data['H'], data['W']
# When ref_data has B images > opt.batch_size
if B > self.opt.batch_size:
# choose batch_size images out of those B images
choice = torch.randperm(B)[:self.opt.batch_size]
B = self.opt.batch_size
rays_o = rays_o[choice]
rays_d = rays_d[choice]
mvp = mvp[choice]
if do_rgbd_loss:
ambient_ratio = 1.0
shading = 'lambertian' # use lambertian instead of albedo to get normal
as_latent = False
binarize = False
bg_color = torch.rand((B * N, 3), device=rays_o.device)
# add camera noise to avoid grid-like artifact
if self.opt.known_view_noise_scale > 0:
noise_scale = self.opt.known_view_noise_scale #* (1 - self.global_step / self.opt.iters)
rays_o = rays_o + torch.randn(3, device=self.device) * noise_scale
rays_d = rays_d + torch.randn(3, device=self.device) * noise_scale
elif exp_iter_ratio <= self.opt.latent_iter_ratio:
ambient_ratio = 1.0
shading = 'normal'
as_latent = True
binarize = False
bg_color = None
else:
if exp_iter_ratio <= self.opt.albedo_iter_ratio:
ambient_ratio = 1.0
shading = 'albedo'
else:
# random shading
ambient_ratio = self.opt.min_ambient_ratio + (1.0-self.opt.min_ambient_ratio) * random.random()
rand = random.random()
if rand >= (1.0 - self.opt.textureless_ratio):
shading = 'textureless'
else:
shading = 'lambertian'
as_latent = False
# random weights binarization (like mobile-nerf) [NOT WORKING NOW]
# binarize_thresh = min(0.5, -0.5 + self.global_step / self.opt.iters)
# binarize = random.random() < binarize_thresh
binarize = False
# random background
rand = random.random()
if self.opt.bg_radius > 0 and rand > 0.5:
bg_color = None # use bg_net
else:
bg_color = torch.rand(3).to(self.device) # single color random bg
outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, binarize=binarize)
pred_depth = outputs['depth'].reshape(B, 1, H, W)
pred_mask = outputs['weights_sum'].reshape(B, 1, H, W)
if 'normal_image' in outputs:
pred_normal = outputs['normal_image'].reshape(B, H, W, 3)
if as_latent:
# abuse normal & mask as latent code for faster geometry initialization (ref: fantasia3D)
pred_rgb = torch.cat([outputs['image'], outputs['weights_sum'].unsqueeze(-1)], dim=-1).reshape(B, H, W, 4).permute(0, 3, 1, 2).contiguous() # [B, 4, H, W]
else:
pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [B, 3, H, W]
# known view loss
if do_rgbd_loss:
gt_mask = self.mask # [B, H, W]
gt_rgb = self.rgb # [B, 3, H, W]
gt_normal = self.normal # [B, H, W, 3]
gt_depth = self.depth # [B, H, W]
if len(gt_rgb) > self.opt.batch_size:
gt_mask = gt_mask[choice]
gt_rgb = gt_rgb[choice]
gt_normal = gt_normal[choice]
gt_depth = gt_depth[choice]
# color loss
gt_rgb = gt_rgb * gt_mask[:, None].float() + bg_color.reshape(B, H, W, 3).permute(0,3,1,2).contiguous() * (1 - gt_mask[:, None].float())
loss = self.opt.lambda_rgb * F.mse_loss(pred_rgb, gt_rgb)
# mask loss
loss = loss + self.opt.lambda_mask * F.mse_loss(pred_mask[:, 0], gt_mask.float())
# normal loss
if self.opt.lambda_normal > 0 and 'normal_image' in outputs:
valid_gt_normal = 1 - 2 * gt_normal[gt_mask] # [B, 3]
valid_pred_normal = 2 * pred_normal[gt_mask] - 1 # [B, 3]
lambda_normal = self.opt.lambda_normal * min(1, self.global_step / self.opt.iters)
loss = loss + lambda_normal * (1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean())
# relative depth loss
if self.opt.lambda_depth > 0:
valid_gt_depth = gt_depth[gt_mask] # [B,]
valid_pred_depth = pred_depth[:, 0][gt_mask] # [B,]
lambda_depth = self.opt.lambda_depth * min(1, self.global_step / self.opt.iters)
loss = loss + lambda_depth * (1 - self.pearson(valid_pred_depth, valid_gt_depth))
# # scale-invariant
# with torch.no_grad():
# A = torch.cat([valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1) # [B, 2]
# X = torch.linalg.lstsq(A, valid_pred_depth).solution # [2, 1]
# valid_gt_depth = A @ X # [B, 1]
# lambda_depth = self.opt.lambda_depth #* min(1, self.global_step / self.opt.iters)
# loss = loss + lambda_depth * F.mse_loss(valid_pred_depth, valid_gt_depth)
# novel view loss
else:
loss = 0
if 'SD' in self.guidance:
# interpolate text_z
azimuth = data['azimuth'] # [-180, 180]
# ENHANCE: remove loop to handle batch size > 1
text_z = [self.embeddings['SD']['uncond']] * azimuth.shape[0]
if self.opt.perpneg:
text_z_comp, weights = adjust_text_embeddings(self.embeddings['SD'], azimuth, self.opt)
text_z.append(text_z_comp)
else:
for b in range(azimuth.shape[0]):
if azimuth[b] >= -90 and azimuth[b] < 90:
if azimuth[b] >= 0:
r = 1 - azimuth[b] / 90
else:
r = 1 + azimuth[b] / 90
start_z = self.embeddings['SD']['front']
end_z = self.embeddings['SD']['side']
else:
if azimuth[b] >= 0:
r = 1 - (azimuth[b] - 90) / 90
else:
r = 1 + (azimuth[b] + 90) / 90
start_z = self.embeddings['SD']['side']
end_z = self.embeddings['SD']['back']
text_z.append(r * start_z + (1 - r) * end_z)
text_z = torch.cat(text_z, dim=0)
if self.opt.perpneg:
loss = loss + self.guidance['SD'].train_step_perpneg(text_z, weights, pred_rgb, as_latent=as_latent, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance,
save_guidance_path=save_guidance_path)
else:
loss = loss + self.guidance['SD'].train_step(text_z, pred_rgb, as_latent=as_latent, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance,
save_guidance_path=save_guidance_path)
if 'IF' in self.guidance:
# interpolate text_z
azimuth = data['azimuth'] # [-180, 180]
# ENHANCE: remove loop to handle batch size > 1
text_z = [self.embeddings['IF']['uncond']] * azimuth.shape[0]
if self.opt.perpneg:
text_z_comp, weights = adjust_text_embeddings(self.embeddings['IF'], azimuth, self.opt)
text_z.append(text_z_comp)
else:
for b in range(azimuth.shape[0]):
if azimuth[b] >= -90 and azimuth[b] < 90:
if azimuth[b] >= 0:
r = 1 - azimuth[b] / 90
else:
r = 1 + azimuth[b] / 90
start_z = self.embeddings['IF']['front']
end_z = self.embeddings['IF']['side']
else:
if azimuth[b] >= 0:
r = 1 - (azimuth[b] - 90) / 90
else:
r = 1 + (azimuth[b] + 90) / 90
start_z = self.embeddings['IF']['side']
end_z = self.embeddings['IF']['back']
text_z.append(r * start_z + (1 - r) * end_z)
text_z = torch.cat(text_z, dim=0)
if self.opt.perpneg:
loss = loss + self.guidance['IF'].train_step_perpneg(text_z, weights, pred_rgb, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance)
else:
loss = loss + self.guidance['IF'].train_step(text_z, pred_rgb, guidance_scale=self.opt.guidance_scale, grad_scale=self.opt.lambda_guidance)
if 'zero123' in self.guidance:
polar = data['polar']
azimuth = data['azimuth']
radius = data['radius']
loss = loss + self.guidance['zero123'].train_step(self.embeddings['zero123']['default'], pred_rgb, polar, azimuth, radius, guidance_scale=self.opt.guidance_scale,
as_latent=as_latent, grad_scale=self.opt.lambda_guidance, save_guidance_path=save_guidance_path)
if 'clip' in self.guidance:
# empirical, far view should apply smaller CLIP loss
lambda_guidance = 10 * (1 - abs(azimuth) / 180) * self.opt.lambda_guidance
loss = loss + self.guidance['clip'].train_step(self.embeddings['clip'], pred_rgb, grad_scale=lambda_guidance)
# regularizations
if not self.opt.dmtet:
if self.opt.lambda_opacity > 0:
loss_opacity = (outputs['weights_sum'] ** 2).mean()
loss = loss + self.opt.lambda_opacity * loss_opacity
if self.opt.lambda_entropy > 0:
alphas = outputs['weights'].clamp(1e-5, 1 - 1e-5)
# alphas = alphas ** 2 # skewed entropy, favors 0 over 1
loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
lambda_entropy = self.opt.lambda_entropy * min(1, 2 * self.global_step / self.opt.iters)
loss = loss + lambda_entropy * loss_entropy
if self.opt.lambda_2d_normal_smooth > 0 and 'normal_image' in outputs:
# pred_vals = outputs['normal_image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous()
# smoothed_vals = TF.gaussian_blur(pred_vals.detach(), kernel_size=9)
# loss_smooth = F.mse_loss(pred_vals, smoothed_vals)
# total-variation
loss_smooth = (pred_normal[:, 1:, :, :] - pred_normal[:, :-1, :, :]).square().mean() + \
(pred_normal[:, :, 1:, :] - pred_normal[:, :, :-1, :]).square().mean()
loss = loss + self.opt.lambda_2d_normal_smooth * loss_smooth
if self.opt.lambda_orient > 0 and 'loss_orient' in outputs:
loss_orient = outputs['loss_orient']
loss = loss + self.opt.lambda_orient * loss_orient
if self.opt.lambda_3d_normal_smooth > 0 and 'loss_normal_perturb' in outputs:
loss_normal_perturb = outputs['loss_normal_perturb']
loss = loss + self.opt.lambda_3d_normal_smooth * loss_normal_perturb
else:
if self.opt.lambda_mesh_normal > 0:
loss = loss + self.opt.lambda_mesh_normal * outputs['normal_loss']
if self.opt.lambda_mesh_laplacian > 0:
loss = loss + self.opt.lambda_mesh_laplacian * outputs['lap_loss']
return pred_rgb, pred_depth, loss
def post_train_step(self):
# unscale grad before modifying it!
# ref: https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping
self.scaler.unscale_(self.optimizer)
# clip grad
if self.opt.grad_clip >= 0:
torch.nn.utils.clip_grad_value_(self.model.parameters(), self.opt.grad_clip)
if not self.opt.dmtet and self.opt.backbone == 'grid':
if self.opt.lambda_tv > 0:
lambda_tv = min(1.0, self.global_step / (0.5 * self.opt.iters)) * self.opt.lambda_tv
self.model.encoder.grad_total_variation(lambda_tv, None, self.model.bound)
if self.opt.lambda_wd > 0:
self.model.encoder.grad_weight_decay(self.opt.lambda_wd)
def eval_step(self, data):
rays_o = data['rays_o'] # [B, N, 3]
rays_d = data['rays_d'] # [B, N, 3]
mvp = data['mvp']
B, N = rays_o.shape[:2]
H, W = data['H'], data['W']
shading = data['shading'] if 'shading' in data else 'albedo'
ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
light_d = data['light_d'] if 'light_d' in data else None
outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=True, perturb=False, bg_color=None, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading)
pred_rgb = outputs['image'].reshape(B, H, W, 3)
pred_depth = outputs['depth'].reshape(B, H, W)
# dummy
loss = torch.zeros([1], device=pred_rgb.device, dtype=pred_rgb.dtype)
return pred_rgb, pred_depth, loss
def test_step(self, data, bg_color=None, perturb=False):
rays_o = data['rays_o'] # [B, N, 3]
rays_d = data['rays_d'] # [B, N, 3]
mvp = data['mvp']
B, N = rays_o.shape[:2]
H, W = data['H'], data['W']
if bg_color is not None:
bg_color = bg_color.to(rays_o.device)
shading = data['shading'] if 'shading' in data else 'albedo'
ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
light_d = data['light_d'] if 'light_d' in data else None
outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=True, perturb=perturb, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, bg_color=bg_color)
pred_rgb = outputs['image'].reshape(B, H, W, 3)
pred_depth = outputs['depth'].reshape(B, H, W)
return pred_rgb, pred_depth, None
def save_mesh(self, loader=None, save_path=None):
if save_path is None:
save_path = os.path.join(self.workspace, 'mesh')
self.log(f"==> Saving mesh to {save_path}")
os.makedirs(save_path, exist_ok=True)
self.model.export_mesh(save_path, resolution=self.opt.mcubes_resolution, decimate_target=self.opt.decimate_target)
self.log(f"==> Finished saving mesh.")
### ------------------------------
def train(self, train_loader, valid_loader, test_loader, max_epochs):
if self.use_tensorboardX and self.local_rank == 0:
self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name))
start_t = time.time()
for epoch in range(self.epoch + 1, max_epochs + 1):
self.epoch = epoch
self.train_one_epoch(train_loader, max_epochs)
if self.workspace is not None and self.local_rank == 0:
self.save_checkpoint(full=True, best=False)
if self.epoch % self.opt.eval_interval == 0:
self.evaluate_one_epoch(valid_loader)
self.save_checkpoint(full=False, best=True)
if self.epoch % self.opt.test_interval == 0 or self.epoch == max_epochs:
self.test(test_loader)
end_t = time.time()
self.total_train_t = end_t - start_t + self.total_train_t
self.log(f"[INFO] training takes {(self.total_train_t)/ 60:.4f} minutes.")
if self.use_tensorboardX and self.local_rank == 0:
self.writer.close()
def evaluate(self, loader, name=None):
self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX
self.evaluate_one_epoch(loader, name)
self.use_tensorboardX = use_tensorboardX
def test(self, loader, save_path=None, name=None, write_video=True):
if save_path is None:
save_path = os.path.join(self.workspace, 'results')
if name is None:
name = f'{self.name}_ep{self.epoch:04d}'
os.makedirs(save_path, exist_ok=True)
self.log(f"==> Start Test, save results to {save_path}")
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
self.model.eval()
if write_video:
all_preds = []
all_preds_depth = []
with torch.no_grad():
for i, data in enumerate(loader):
with torch.cuda.amp.autocast(enabled=self.fp16):
preds, preds_depth, _ = self.test_step(data)
pred = preds[0].detach().cpu().numpy()
pred = (pred * 255).astype(np.uint8)
pred_depth = preds_depth[0].detach().cpu().numpy()
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min() + 1e-6)
pred_depth = (pred_depth * 255).astype(np.uint8)
if write_video:
all_preds.append(pred)
all_preds_depth.append(pred_depth)
else:
cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_rgb.png'), cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))
cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_depth.png'), pred_depth)
pbar.update(loader.batch_size)
if write_video:
all_preds = np.stack(all_preds, axis=0)
all_preds_depth = np.stack(all_preds_depth, axis=0)
imageio.mimwrite(os.path.join(save_path, f'{name}_rgb.mp4'), all_preds, fps=25, quality=8, macro_block_size=1)
imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, macro_block_size=1)
self.log(f"==> Finished Test.")
# [GUI] train text step.
def train_gui(self, train_loader, step=16):
self.model.train()
total_loss = torch.tensor([0], dtype=torch.float32, device=self.device)
loader = iter(train_loader)
for _ in range(step):
# mimic an infinite loop dataloader (in case the total dataset is smaller than step)
try:
data = next(loader)
except StopIteration:
loader = iter(train_loader)
data = next(loader)
# update grid every 16 steps
if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
with torch.cuda.amp.autocast(enabled=self.fp16):
self.model.update_extra_state()
self.global_step += 1
self.optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=self.fp16):
pred_rgbs, pred_depths, loss = self.train_step(data)
self.scaler.scale(loss).backward()
self.post_train_step()
self.scaler.step(self.optimizer)
self.scaler.update()
if self.scheduler_update_every_step:
self.lr_scheduler.step()
total_loss += loss.detach()
if self.ema is not None:
self.ema.update()
average_loss = total_loss.item() / step
if not self.scheduler_update_every_step:
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step(average_loss)
else:
self.lr_scheduler.step()
outputs = {
'loss': average_loss,
'lr': self.optimizer.param_groups[0]['lr'],
}
return outputs
# [GUI] test on a single image
def test_gui(self, pose, intrinsics, mvp, W, H, bg_color=None, spp=1, downscale=1, light_d=None, ambient_ratio=1.0, shading='albedo'):
# render resolution (may need downscale to for better frame rate)
rH = int(H * downscale)
rW = int(W * downscale)
intrinsics = intrinsics * downscale
pose = torch.from_numpy(pose).unsqueeze(0).to(self.device)
mvp = torch.from_numpy(mvp).unsqueeze(0).to(self.device)
rays = get_rays(pose, intrinsics, rH, rW, -1)
# from degree theta/phi to 3D normalized vec
light_d = np.deg2rad(light_d)
light_d = np.array([
np.sin(light_d[0]) * np.sin(light_d[1]),
np.cos(light_d[0]),
np.sin(light_d[0]) * np.cos(light_d[1]),
], dtype=np.float32)
light_d = torch.from_numpy(light_d).to(self.device)
data = {
'rays_o': rays['rays_o'],
'rays_d': rays['rays_d'],
'mvp': mvp,
'H': rH,
'W': rW,
'light_d': light_d,
'ambient_ratio': ambient_ratio,
'shading': shading,
}
self.model.eval()
if self.ema is not None:
self.ema.store()
self.ema.copy_to()
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=self.fp16):
# here spp is used as perturb random seed!
preds, preds_depth, _ = self.test_step(data, bg_color=bg_color, perturb=False if spp == 1 else spp)
if self.ema is not None:
self.ema.restore()
# interpolation to the original resolution
if downscale != 1:
# have to permute twice with torch...
preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='nearest').permute(0, 2, 3, 1).contiguous()
preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
outputs = {
'image': preds[0].detach().cpu().numpy(),
'depth': preds_depth[0].detach().cpu().numpy(),
}
return outputs
def train_one_epoch(self, loader, max_epochs):
self.log(f"==> [{time.strftime('%Y-%m-%d_%H-%M-%S')}] Start Training {self.workspace} Epoch {self.epoch}/{max_epochs}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...")
total_loss = 0
if self.local_rank == 0 and self.report_metric_at_train:
for metric in self.metrics:
metric.clear()
self.model.train()
# distributedSampler: must call set_epoch() to shuffle indices across multiple epochs
# ref: https://pytorch.org/docs/stable/data.html
if self.world_size > 1:
loader.sampler.set_epoch(self.epoch)
if self.local_rank == 0:
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
self.local_step = 0
if self.opt.save_guidance:
save_guidance_folder = Path(self.workspace) / 'guidance'
save_guidance_folder.mkdir(parents=True, exist_ok=True)
for data in loader:
# update grid every 16 steps
if (self.model.cuda_ray or self.model.taichi_ray) and self.global_step % self.opt.update_extra_interval == 0:
with torch.cuda.amp.autocast(enabled=self.fp16):
self.model.update_extra_state()
self.local_step += 1
self.global_step += 1
self.optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=self.fp16):
if self.opt.save_guidance and (self.global_step % self.opt.save_guidance_interval == 0):
save_guidance_path = save_guidance_folder / f'step_{self.global_step:07d}.png'
else:
save_guidance_path = None
pred_rgbs, pred_depths, loss = self.train_step(data, save_guidance_path=save_guidance_path)
# hooked grad clipping for RGB space
if self.opt.grad_clip_rgb >= 0:
def _hook(grad):
if self.opt.fp16:
# correctly handle the scale
grad_scale = self.scaler._get_scale_async()
return grad.clamp(grad_scale * -self.opt.grad_clip_rgb, grad_scale * self.opt.grad_clip_rgb)
else:
return grad.clamp(-self.opt.grad_clip_rgb, self.opt.grad_clip_rgb)
pred_rgbs.register_hook(_hook)
# pred_rgbs.retain_grad()
self.scaler.scale(loss).backward()
self.post_train_step()
self.scaler.step(self.optimizer)
self.scaler.update()
if self.scheduler_update_every_step:
self.lr_scheduler.step()
loss_val = loss.item()
total_loss += loss_val
if self.local_rank == 0:
# if self.report_metric_at_train:
# for metric in self.metrics:
# metric.update(preds, truths)
if self.use_tensorboardX:
self.writer.add_scalar("train/loss", loss_val, self.global_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step)
if self.scheduler_update_every_step:
pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}")
else:
pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
pbar.update(loader.batch_size)
if self.ema is not None:
self.ema.update()
average_loss = total_loss / self.local_step
self.stats["loss"].append(average_loss)
if self.local_rank == 0:
pbar.close()
if self.report_metric_at_train:
for metric in self.metrics:
self.log(metric.report(), style="red")
if self.use_tensorboardX:
metric.write(self.writer, self.epoch, prefix="train")
metric.clear()
if not self.scheduler_update_every_step:
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step(average_loss)
else:
self.lr_scheduler.step()
cpu_mem, gpu_mem = get_CPU_mem(), get_GPU_mem()[0]
self.log(f"==> [{time.strftime('%Y-%m-%d_%H-%M-%S')}] Finished Epoch {self.epoch}/{max_epochs}. CPU={cpu_mem:.1f}GB, GPU={gpu_mem:.1f}GB.")
def evaluate_one_epoch(self, loader, name=None):
self.log(f"++> Evaluate {self.workspace} at epoch {self.epoch} ...")
if name is None:
name = f'{self.name}_ep{self.epoch:04d}'
total_loss = 0
if self.local_rank == 0:
for metric in self.metrics:
metric.clear()
self.model.eval()
if self.ema is not None:
self.ema.store()
self.ema.copy_to()
if self.local_rank == 0:
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
with torch.no_grad():
self.local_step = 0
for data in loader:
self.local_step += 1
with torch.cuda.amp.autocast(enabled=self.fp16):
preds, preds_depth, loss = self.eval_step(data)
# all_gather/reduce the statistics (NCCL only support all_*)
if self.world_size > 1:
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss = loss / self.world_size
preds_list = [torch.zeros_like(preds).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...]
dist.all_gather(preds_list, preds)
preds = torch.cat(preds_list, dim=0)
preds_depth_list = [torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...]
dist.all_gather(preds_depth_list, preds_depth)
preds_depth = torch.cat(preds_depth_list, dim=0)
loss_val = loss.item()
total_loss += loss_val
# only rank = 0 will perform evaluation.
if self.local_rank == 0:
# save image
save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png')
save_path_depth = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_depth.png')
#self.log(f"==> Saving validation image to {save_path}")
os.makedirs(os.path.dirname(save_path), exist_ok=True)
pred = preds[0].detach().cpu().numpy()
pred = (pred * 255).astype(np.uint8)
pred_depth = preds_depth[0].detach().cpu().numpy()
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min() + 1e-6)
pred_depth = (pred_depth * 255).astype(np.uint8)
cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))
cv2.imwrite(save_path_depth, pred_depth)
pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
pbar.update(loader.batch_size)
average_loss = total_loss / self.local_step
self.stats["valid_loss"].append(average_loss)
if self.local_rank == 0:
pbar.close()
if not self.use_loss_as_metric and len(self.metrics) > 0:
result = self.metrics[0].measure()
self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result
else:
self.stats["results"].append(average_loss) # if no metric, choose best by min loss
for metric in self.metrics:
self.log(metric.report(), style="blue")
if self.use_tensorboardX:
metric.write(self.writer, self.epoch, prefix="evaluate")
metric.clear()
if self.ema is not None:
self.ema.restore()
self.log(f"++> Evaluate epoch {self.epoch} Finished.")
def save_checkpoint(self, name=None, full=False, best=False):
if name is None:
name = f'{self.name}_ep{self.epoch:04d}'
state = {
'epoch': self.epoch,
'global_step': self.global_step,
'stats': self.stats,
}
if self.model.cuda_ray:
state['mean_density'] = self.model.mean_density
if self.opt.dmtet:
state['tet_scale'] = self.model.tet_scale.cpu().numpy()
if full:
state['optimizer'] = self.optimizer.state_dict()
state['lr_scheduler'] = self.lr_scheduler.state_dict()
state['scaler'] = self.scaler.state_dict()
if self.ema is not None:
state['ema'] = self.ema.state_dict()
if not best:
state['model'] = self.model.state_dict()
file_path = f"{name}.pth"
self.stats["checkpoints"].append(file_path)
if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
old_ckpt = os.path.join(self.ckpt_path, self.stats["checkpoints"].pop(0))
if os.path.exists(old_ckpt):
os.remove(old_ckpt)
torch.save(state, os.path.join(self.ckpt_path, file_path))
else:
if len(self.stats["results"]) > 0:
# always save best since loss cannot reflect performance.
if True:
# self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}")
# self.stats["best_result"] = self.stats["results"][-1]
# save ema results
if self.ema is not None:
self.ema.store()
self.ema.copy_to()
state['model'] = self.model.state_dict()
if self.ema is not None:
self.ema.restore()
torch.save(state, self.best_path)
else:
self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.")
def load_checkpoint(self, checkpoint=None, model_only=False):
if checkpoint is None:
checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/*.pth'))
if checkpoint_list:
checkpoint = checkpoint_list[-1]
self.log(f"[INFO] Latest checkpoint is {checkpoint}")
else:
self.log("[WARN] No checkpoint found, model randomly initialized.")
return
checkpoint_dict = torch.load(checkpoint, map_location=self.device)
if 'model' not in checkpoint_dict:
self.model.load_state_dict(checkpoint_dict)
self.log("[INFO] loaded model.")
return
missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
self.log("[INFO] loaded model.")
if len(missing_keys) > 0:
self.log(f"[WARN] missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
self.log(f"[WARN] unexpected keys: {unexpected_keys}")
if self.ema is not None and 'ema' in checkpoint_dict:
try:
self.ema.load_state_dict(checkpoint_dict['ema'])
self.log("[INFO] loaded EMA.")
except:
self.log("[WARN] failed to loaded EMA.")
if self.model.cuda_ray:
if 'mean_density' in checkpoint_dict:
self.model.mean_density = checkpoint_dict['mean_density']
if self.opt.dmtet:
if 'tet_scale' in checkpoint_dict:
new_scale = torch.from_numpy(checkpoint_dict['tet_scale']).to(self.device)
self.model.verts *= new_scale / self.model.tet_scale
self.model.tet_scale = new_scale
if model_only:
return
self.stats = checkpoint_dict['stats']
self.epoch = checkpoint_dict['epoch']
self.global_step = checkpoint_dict['global_step']
self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
if self.optimizer and 'optimizer' in checkpoint_dict:
try:
self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
self.log("[INFO] loaded optimizer.")
except:
self.log("[WARN] Failed to load optimizer.")
if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
try:
self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
self.log("[INFO] loaded scheduler.")
except:
self.log("[WARN] Failed to load scheduler.")
if self.scaler and 'scaler' in checkpoint_dict:
try:
self.scaler.load_state_dict(checkpoint_dict['scaler'])
self.log("[INFO] loaded scaler.")
except:
self.log("[WARN] Failed to load scaler.")
def get_CPU_mem():
return psutil.Process(os.getpid()).memory_info().rss /1024**3
def get_GPU_mem():
num = torch.cuda.device_count()
mem, mems = 0, []
for i in range(num):
mem_free, mem_total = torch.cuda.mem_get_info(i)
mems.append(int(((mem_total - mem_free)/1024**3)*1000)/1000)
mem += mems[-1]
return mem, mems
================================================
FILE: optimizer.py
================================================
# Copyright 2022 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
class Adan(Optimizer):
"""
Implements a pytorch variant of Adan
Adan was proposed in
Adan: Adaptive Nesterov Momentum Algorithm for
Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022.
https://arxiv.org/abs/2208.06677
Arguments:
params (iterable): iterable of parameters to optimize or
dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float, flot], optional): coefficients used for
first- and second-order moments. (default: (0.98, 0.92, 0.99))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): decoupled weight decay
(L2 penalty) (default: 0)
max_grad_norm (float, optional): value used to clip
global grad norm (default: 0.0 no clip)
no_prox (bool): how to perform the decoupled weight decay
(default: False)
foreach (bool): if True would use torch._foreach implementation.
It's faster but uses slightly more memory. (default: True)
"""
def __init__(self,
params,
lr=1e-3,
betas=(0.98, 0.92, 0.99),
eps=1e-8,
weight_decay=0.0,
max_grad_norm=0.0,
no_prox=False,
foreach: bool = True):
if not 0.0 <= max_grad_norm:
raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm))
if not 0.0 <= lr:
raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= eps:
raise ValueError('Invalid epsilon value: {}'.format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError('Invalid beta parameter at index 0: {}'.format(
betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError('Invalid beta parameter at index 1: {}'.format(
betas[1]))
if not 0.0 <= betas[2] < 1.0:
raise ValueError('Invalid beta parameter at index 2: {}'.format(
betas[2]))
defaults = dict(lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm,
no_prox=no_prox,
foreach=foreach)
super().__init__(params, defaults)
def __setstate__(self, state):
super(Adan, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('no_prox', False)
@torch.no_grad()
def restart_opt(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
if p.requires_grad:
state = self.state[p]
# State initialization
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p)
# Exponential moving average of gradient difference
state['exp_avg_diff'] = torch.zeros_like(p)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
if self.defaults['max_grad_norm'] > 0:
device = self.param_groups[0]['params'][0].device
global_grad_norm = torch.zeros(1, device=device)
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'],
device=device)
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
grad = p.grad
global_grad_norm.add_(grad.pow(2).sum())
global_grad_norm = torch.sqrt(global_grad_norm)
clip_global_grad_norm = torch.clamp(
max_grad_norm / (global_grad_norm + group['eps']),
max=1.0).item()
else:
clip_global_grad_norm = 1.0
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
exp_avg_diffs = []
neg_pre_grads = []
beta1, beta2, beta3 = group['betas']
# assume same step across group now to simplify things
# per parameter step can be easily support
# by making it tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
bias_correction1 = 1.0 - beta1**group['step']
bias_correction2 = 1.0 - beta2**group['step']
bias_correction3 = 1.0 - beta3**group['step']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
grads.append(p.grad)
state = self.state[p]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
state['exp_avg_diff'] = torch.zeros_like(p)
if 'neg_pre_grad' not in state or group['step'] == 1:
state['neg_pre_grad'] = p.grad.clone().mul_(
-clip_global_grad_norm)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
exp_avg_diffs.append(state['exp_avg_diff'])
neg_pre_grads.append(state['neg_pre_grad'])
kwargs = dict(
params=params_with_grad,
grads=grads,
exp_avgs=exp_avgs,
exp_avg_sqs=exp_avg_sqs,
exp_avg_diffs=exp_avg_diffs,
neg_pre_grads=neg_pre_grads,
beta1=beta1,
beta2=beta2,
beta3=beta3,
bias_correction1=bias_correction1,
bias_correction2=bias_correction2,
bias_correction3_sqrt=math.sqrt(bias_correction3),
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
no_prox=group['no_prox'],
clip_global_grad_norm=clip_global_grad_norm,
)
if group['foreach']:
_multi_tensor_adan(**kwargs)
else:
_single_tensor_adan(**kwargs)
return loss
def _single_tensor_adan(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
exp_avg_diffs: List[Tensor],
neg_pre_grads: List[Tensor],
*,
beta1: float,
beta2: float,
beta3: float,
bias_correction1: float,
bias_correction2: float,
bias_correction3_sqrt: float,
lr: float,
weight_decay: float,
eps: float,
no_prox: bool,
clip_global_grad_norm: Tensor,
):
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
exp_avg_diff = exp_avg_diffs[i]
neg_grad_or_diff = neg_pre_grads[i]
grad.mul_(clip_global_grad_norm)
# for memory saving, we use `neg_grad_or_diff`
# to get some temp variable in a inplace way
neg_grad_or_diff.add_(grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t
exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff,
alpha=1 - beta2) # diff_t
neg_grad_or_diff.mul_(beta2).add_(grad)
exp_avg_sq.mul_(beta3).addcmul_(neg_grad_or_diff,
neg_grad_or_diff,
value=1 - beta3) # n_t
denom = ((exp_avg_sq).sqrt() / bias_correction3_sqrt).add_(eps)
step_size_diff = lr * beta2 / bias_correction2
step_size = lr / bias_correction1
if no_prox:
param.mul_(1 - lr * weight_decay)
param.addcdiv_(exp_avg, denom, value=-step_size)
param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff)
else:
param.addcdiv_(exp_avg, denom, value=-step_size)
param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff)
param.div_(1 + lr * weight_decay)
neg_grad_or_diff.zero_().add_(grad, alpha=-1.0)
def _multi_tensor_adan(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
exp_avg_diffs: List[Tensor],
neg_pre_grads: List[Tensor],
*,
beta1: float,
beta2: float,
beta3: float,
bias_correction1: float,
bias_correction2: float,
bias_correction3_sqrt: float,
lr: float,
weight_decay: float,
eps: float,
no_prox: bool,
clip_global_grad_norm: Tensor,
):
if len(params) == 0:
return
torch._foreach_mul_(grads, clip_global_grad_norm)
# for memory saving, we use `neg_pre_grads`
# to get some temp variable in a inplace way
torch._foreach_add_(neg_pre_grads, grads)
torch._foreach_mul_(exp_avgs, beta1)
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t
torch._foreach_mul_(exp_avg_diffs, beta2)
torch._foreach_add_(exp_avg_diffs, neg_pre_grads,
alpha=1 - beta2) # diff_t
torch._foreach_mul_(neg_pre_grads, beta2)
torch._foreach_add_(neg_pre_grads, grads)
torch._foreach_mul_(exp_avg_sqs, beta3)
torch._foreach_addcmul_(exp_avg_sqs,
neg_pre_grads,
neg_pre_grads,
value=1 - beta3) # n_t
denom = torch._foreach_sqrt(exp_avg_sqs)
torch._foreach_div_(denom, bias_correction3_sqrt)
torch._foreach_add_(denom, eps)
step_size_diff = lr * beta2 / bias_correction2
step_size = lr / bias_correction1
if no_prox:
torch._foreach_mul_(params, 1 - lr * weight_decay)
torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
torch._foreach_addcdiv_(params,
exp_avg_diffs,
denom,
value=-step_size_diff)
else:
torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
torch._foreach_addcdiv_(params,
exp_avg_diffs,
denom,
value=-step_size_diff)
torch._foreach_div_(params, 1 + lr * weight_decay)
torch._foreach_zero_(neg_pre_grads)
torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0)
================================================
FILE: preprocess_image.py
================================================
import os
import sys
import cv2
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
class BackgroundRemoval():
def __init__(self, device='cuda'):
from carvekit.api.high import HiInterface
self.interface = HiInterface(
object_type="object", # Can be "object" or "hairs-like".
batch_size_seg=5,
batch_size_matting=1,
device=device,
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=True,
)
@torch.no_grad()
def __call__(self, image):
# image: [H, W, 3] array in [0, 255].
image = Image.fromarray(image)
image = self.interface([image])[0]
image = np.array(image)
return image
class BLIP2():
def __init__(self, device='cuda'):
self.device = device
from transformers import AutoProcessor, Blip2ForConditionalGeneration
self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device)
@torch.no_grad()
def __call__(self, image):
image = Image.fromarray(image)
inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.model.generate(**inputs, max_new_tokens=20)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return generated_text
class DPT():
def __init__(self, task='depth', device='cuda'):
self.task = task
self.device = device
from dpt import DPTDepthModel
if task == 'depth':
path = 'pretrained/omnidata/omnidata_dpt_depth_v2.ckpt'
self.model = DPTDepthModel(backbone='vitb_rn50_384')
self.aug = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=0.5, std=0.5)
])
else: # normal
path = 'pretrained/omnidata/omnidata_dpt_normal_v2.ckpt'
self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3)
self.aug = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor()
])
# load model
checkpoint = torch.load(path, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = {}
for k, v in checkpoint['state_dict'].items():
state_dict[k[6:]] = v
else:
state_dict = checkpoint
self.model.load_state_dict(state_dict)
self.model.eval().to(device)
@torch.no_grad()
def __call__(self, image):
# image: np.ndarray, uint8, [H, W, 3]
H, W = image.shape[:2]
image = Image.fromarray(image)
image = self.aug(image).unsqueeze(0).to(self.device)
if self.task == 'depth':
depth = self.model(image).clamp(0, 1)
depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bicubic', align_corners=False)
depth = depth.squeeze(1).cpu().numpy()
return depth
else:
normal = self.model(image).clamp(0, 1)
normal = F.interpolate(normal, size=(H, W), mode='bicubic', align_corners=False)
normal = normal.cpu().numpy()
return normal
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str, help="path to image (png, jpeg, etc.)")
parser.add_argument('--size', default=256, type=int, help="output resolution")
parser.add_argument('--border_ratio', default=0.2, type=float, help="output border ratio")
parser.add_argument('--recenter', type=bool, default=True, help="recenter, potentially not helpful for multiview zero123")
parser.add_argument('--dont_recenter', dest='recenter', action='store_false')
opt = parser.parse_args()
out_dir = os.path.dirname(opt.path)
out_rgba = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_rgba.png')
out_depth = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_depth.png')
out_normal = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_normal.png')
out_caption = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_caption.txt')
# load image
print(f'[INFO] loading image...')
image = cv2.imread(opt.path, cv2.IMREAD_UNCHANGED)
if image.shape[-1] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
else:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# carve background
print(f'[INFO] background removal...')
carved_image = BackgroundRemoval()(image) # [H, W, 4]
mask = carved_image[..., -1] > 0
# predict depth
print(f'[INFO] depth estimation...')
dpt_depth_model = DPT(task='depth')
depth = dpt_depth_model(image)[0]
depth[mask] = (depth[mask] - depth[mask].min()) / (depth[mask].max() - depth[mask].min() + 1e-9)
depth[~mask] = 0
depth = (depth * 255).astype(np.uint8)
del dpt_depth_model
# predict normal
print(f'[INFO] normal estimation...')
dpt_normal_model = DPT(task='normal')
normal = dpt_normal_model(image)[0]
normal = (normal * 255).astype(np.uint8).transpose(1, 2, 0)
normal[~mask] = 0
del dpt_normal_model
# recenter
if opt.recenter:
print(f'[INFO] recenter...')
final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8)
final_depth = np.zeros((opt.size, opt.size), dtype=np.uint8)
final_normal = np.zeros((opt.size, opt.size, 3), dtype=np.uint8)
coords = np.nonzero(mask)
x_min, x_max = coords[0].min(), coords[0].max()
y_min, y_max = coords[1].min(), coords[1].max()
h = x_max - x_min
w = y_max - y_min
desired_size = int(opt.size * (1 - opt.border_ratio))
scale = desired_size / max(h, w)
h2 = int(h * scale)
w2 = int(w * scale)
x2_min = (opt.size - h2) // 2
x2_max = x2_min + h2
y2_min = (opt.size - w2) // 2
y2_max = y2_min + w2
final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
final_depth[x2_min:x2_max, y2_min:y2_max] = cv2.resize(depth[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
final_normal[x2_min:x2_max, y2_min:y2_max] = cv2.resize(normal[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
else:
final_rgba = carved_image
final_depth = depth
final_normal = normal
# write output
cv2.imwrite(out_rgba, cv2.cvtColor(final_rgba, cv2.COLOR_RGBA2BGRA))
cv2.imwrite(out_depth, final_depth)
cv2.imwrite(out_normal, final_normal)
# predict caption (it's too slow... use your brain instead)
# print(f'[INFO] captioning...')
# blip2 = BLIP2()
# caption = blip2(image)
# with open(out_caption, 'w') as f:
# f.write(caption)
================================================
FILE: pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml
================================================
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "image_target"
cond_stage_key: "image_cond"
image_size: 32
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 100 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
# data:
# target: ldm.data.simple.ObjaverseDataModuleFromConfig
# params:
# root_dir: 'views_whole_sphere'
# batch_size: 192
# num_workers: 16
# total_view: 4
# train:
# validation: False
# image_transforms:
# size: 256
# validation:
# validation: True
# image_transforms:
# size: 256
# lightning:
# find_unused_parameters: false
# metrics_over_trainsteps_checkpoint: True
# modelcheckpoint:
# params:
# every_n_train_steps: 5000
# callbacks:
# image_logger:
# target: main.ImageLogger
# params:
# batch_frequency: 500
# max_images: 32
# increase_log_steps: False
# log_first_step: True
# log_images_kwargs:
# use_ema_scope: False
# inpaint: False
# plot_progressive_rows: False
# plot_diffusion_rows: False
# N: 32
# unconditional_scale: 3.0
# unconditional_label: [""]
# trainer:
# benchmark: True
# val_check_interval: 5000000 # really sorry
# num_sanity_val_steps: 0
# accumulate_grad_batches: 1
================================================
FILE: raymarching/__init__.py
================================================
from .raymarching import *
================================================
FILE: raymarching/backend.py
================================================
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
_backend = load(name='_raymarching',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'raymarching.cu',
'bindings.cpp',
]],
)
__all__ = ['_backend']
================================================
FILE: raymarching/raymarching.py
================================================
import numpy as np
import time
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
# lazy building:
# `import raymarching` will not immediately build the extension, only if you actually call any functions.
BACKEND = None
def get_backend():
global BACKEND
if BACKEND is None:
try:
import _raymarching as _backend
except ImportError:
from .backend import _backend
BACKEND = _backend
return BACKEND
# ----------------------------------------
# utils
# ----------------------------------------
class _near_far_from_aabb(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
''' near_far_from_aabb, CUDA implementation
Calculate rays' intersection time (near and far) with aabb
Args:
rays_o: float, [N, 3]
rays_d: float, [N, 3]
aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
min_near: float, scalar
Returns:
nears: float, [N]
fars: float, [N]
'''
if not rays_o.is_cuda: rays_o = rays_o.cuda()
if not rays_d.is_cuda: rays_d = rays_d.cuda()
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # num rays
nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
get_backend().near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
return nears, fars
near_far_from_aabb = _near_far_from_aabb.apply
class _sph_from_ray(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, rays_o, rays_d, radius):
''' sph_from_ray, CUDA implementation
get spherical coordinate on the background sphere from rays.
Assume rays_o are inside the Sphere(radius).
Args:
rays_o: [N, 3]
rays_d: [N, 3]
radius: scalar, float
Return:
coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
'''
if not rays_o.is_cuda: rays_o = rays_o.cuda()
if not rays_d.is_cuda: rays_d = rays_d.cuda()
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # num rays
coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
get_backend().sph_from_ray(rays_o, rays_d, radius, N, coords)
return coords
sph_from_ray = _sph_from_ray.apply
class _morton3D(Function):
@staticmethod
def forward(ctx, coords):
''' morton3D, CUDA implementation
Args:
coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
TODO: check if the coord range is valid! (current 128 is safe)
Returns:
indices: [N], int32, in [0, 128^3)
'''
if not coords.is_cuda: coords = coords.cuda()
N = coords.shape[0]
indices = torch.empty(N, dtype=torch.int32, device=coords.device)
get_backend().morton3D(coords.int(), N, indices)
return indices
morton3D = _morton3D.apply
class _morton3D_invert(Function):
@staticmethod
def forward(ctx, indices):
''' morton3D_invert, CUDA implementation
Args:
indices: [N], int32, in [0, 128^3)
Returns:
coords: [N, 3], int32, in [0, 128)
'''
if not indices.is_cuda: indices = indices.cuda()
N = indices.shape[0]
coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
get_backend().morton3D_invert(indices.int(), N, coords)
return coords
morton3D_invert = _morton3D_invert.apply
class _packbits(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, grid, thresh, bitfield=None):
''' packbits, CUDA implementation
Pack up the density grid into a bit field to accelerate ray marching.
Args:
grid: float, [C, H * H * H], assume H % 2 == 0
thresh: float, threshold
Returns:
bitfield: uint8, [C, H * H * H / 8]
'''
if not grid.is_cuda: grid = grid.cuda()
grid = grid.contiguous()
C = grid.shape[0]
H3 = grid.shape[1]
N = C * H3 // 8
if bitfield is None:
bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
get_backend().packbits(grid, N, thresh, bitfield)
return bitfield
packbits = _packbits.apply
class _flatten_rays(Function):
@staticmethod
def forward(ctx, rays, M):
''' flatten rays
Args:
rays: [N, 2], all rays' (point_offset, point_count),
M: scalar, int, count of points (we cannot get this info from rays unfortunately...)
Returns:
res: [M], flattened ray index.
'''
if not rays.is_cuda: rays = rays.cuda()
rays = rays.contiguous()
N = rays.shape[0]
res = torch.zeros(M, dtype=torch.int, device=rays.device)
get_backend().flatten_rays(rays, N, M, res)
return res
flatten_rays = _flatten_rays.apply
# ----------------------------------------
# train functions
# ----------------------------------------
class _march_rays_train(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, perturb=False, dt_gamma=0, max_steps=1024, contract=False):
''' march rays to generate points (forward only)
Args:
rays_o/d: float, [N, 3]
bound: float, scalar
density_bitfield: uint8: [CHHH // 8]
C: int
H: int
nears/fars: float, [N]
step_counter: int32, (2), used to count the actual number of generated points.
mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
perturb: bool
align: int, pad output so its size is dividable by align, set to -1 to disable.
force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
Returns:
xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
dirs: float, [M, 3], all generated points' view dirs.
ts: float, [M, 2], all generated points' ts.
rays: int32, [N, 2], all rays' (point_offset, point_count), e.g., xyzs[rays[i, 0]:(rays[i, 0] + rays[i, 1])] --> points belonging to rays[i, 0]
'''
if not rays_o.is_cuda: rays_o = rays_o.cuda()
if not rays_d.is_cuda: rays_d = rays_d.cuda()
if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
rays_o = rays_o.float().contiguous().view(-1, 3)
rays_d = rays_d.float().contiguous().view(-1, 3)
density_bitfield = density_bitfield.contiguous()
N = rays_o.shape[0] # num rays
step_counter = torch.zeros(1, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
if perturb:
noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
else:
noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
# first pass: write rays, get total number of points M to render
rays = torch.empty(N, 2, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
get_backend().march_rays_train(rays_o, rays_d, density_bitfield, bound, contract, dt_gamma, max_steps, N, C, H, nears, fars, None, None, None, rays, step_counter, noises)
# allocate based on M
M = step_counter.item()
# print(M, N)
# print(rays[:, 0].max())
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
# second pass: write outputs
get_backend().march_rays_train(rays_o, rays_d, density_bitfield, bound, contract, dt_gamma, max_steps, N, C, H, nears, fars, xyzs, dirs, ts, rays, step_counter, noises)
return xyzs, dirs, ts, rays
march_rays_train = _march_rays_train.apply
class _composite_rays_train(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, sigmas, rgbs, ts, rays, T_thresh=1e-4, binarize=False):
''' composite rays' rgbs, according to the ray marching formula.
Args:
rgbs: float, [M, 3]
sigmas: float, [M,]
ts: float, [M, 2]
rays: int32, [N, 3]
Returns:
weights: float, [M]
weights_sum: float, [N,], the alpha channel
depth: float, [N, ], the Depth
image: float, [N, 3], the RGB channel (after multiplying alpha!)
'''
sigmas = sigmas.float().contiguous()
rgbs = rgbs.float().contiguous()
M = sigmas.shape[0]
N = rays.shape[0]
weights = torch.zeros(M, dtype=sigmas.dtype, device=sigmas.device) # may leave unmodified, so init with 0
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
get_backend().composite_rays_train_forward(sigmas, rgbs, ts, rays, M, N, T_thresh, binarize, weights, weights_sum, depth, image)
ctx.save_for_backward(sigmas, rgbs, ts, rays, weights_sum, depth, image)
ctx.dims = [M, N, T_thresh, binarize]
return weights, weights_sum, depth, image
@staticmethod
@custom_bwd
def backward(ctx, grad_weights, grad_weights_sum, grad_depth, grad_image):
grad_weights = grad_weights.contiguous()
grad_weights_sum = grad_weights_sum.contiguous()
grad_depth = grad_depth.contiguous()
grad_image = grad_image.contiguous()
sigmas, rgbs, ts, rays, weights_sum, depth, image = ctx.saved_tensors
M, N, T_thresh, binarize = ctx.dims
grad_sigmas = torch.zeros_like(sigmas)
grad_rgbs = torch.zeros_like(rgbs)
get_backend().composite_rays_train_backward(grad_weights, grad_weights_sum, grad_depth, grad_image, sigmas, rgbs, ts, rays, weights_sum, depth, image, M, N, T_thresh, binarize, grad_sigmas, grad_rgbs)
return grad_sigmas, grad_rgbs, None, None, None, None
composite_rays_train = _composite_rays_train.apply
# ----------------------------------------
# infer functions
# ----------------------------------------
class _march_rays(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, perturb=False, dt_gamma=0, max_steps=1024, contract=False):
''' march rays to generate points (forward only, for inference)
Args:
n_alive: int, number of alive rays
n_step: int, how many steps we march
rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
rays_t: float, [N], the alive rays' time, we only use the first n_alive.
rays_o/d: float, [N, 3]
bound: float, scalar
density_bitfield: uint8: [CHHH // 8]
C: int
H: int
nears/fars: float, [N]
align: int, pad output so its size is dividable by align, set to -1 to disable.
perturb: bool/int, int > 0 is used as the random seed.
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
Returns:
xyzs: float, [n_alive * n_step, 3], all generated points' coords
dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
ts: float, [n_alive * n_step, 2], all generated points' ts
'''
if not rays_o.is_cuda: rays_o = rays_o.cuda()
if not rays_d.is_cuda: rays_d = rays_d.cuda()
rays_o = rays_o.float().contiguous().view(-1, 3)
rays_d = rays_d.float().contiguous().view(-1, 3)
M = n_alive * n_step
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
if perturb:
# torch.manual_seed(perturb) # test_gui uses spp index as seed
noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
else:
noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
get_backend().march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, contract, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, ts, noises)
return xyzs, dirs, ts
march_rays = _march_rays.apply
class _composite_rays(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh=1e-2, binarize=False):
''' composite rays' rgbs, according to the ray marching formula. (for inference)
Args:
n_alive: int, number of alive rays
n_step: int, how many steps we march
rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
rays_t: float, [N], the alive rays' time
sigmas: float, [n_alive * n_step,]
rgbs: float, [n_alive * n_step, 3]
ts: float, [n_alive * n_step, 2]
In-place Outputs:
weights_sum: float, [N,], the alpha channel
depth: float, [N,], the depth value
image: float, [N, 3], the RGB channel (after multiplying alpha!)
'''
sigmas = sigmas.float().contiguous()
rgbs = rgbs.float().contiguous()
get_backend().composite_rays(n_alive, n_step, T_thresh, binarize, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image)
return tuple()
composite_rays = _composite_rays.apply
================================================
FILE: raymarching/setup.py
================================================
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]
if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']
# find cl.exe
def find_cl_path():
import glob
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path
'''
Usage:
python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
python setup.py install # build extensions and install (copy) to PATH.
pip install . # ditto but better (e.g., dependency & metadata handling)
python setup.py develop # build extensions and install (symbolic) to PATH.
pip install -e . # ditto but better (e.g., dependency & metadata handling)
'''
setup(
name='raymarching', # package name, import this to use python API
ext_modules=[
CUDAExtension(
name='_raymarching', # extension name, import this to use CUDA API
sources=[os.path.join(_src_path, 'src', f) for f in [
'raymarching.cu',
'bindings.cpp',
]],
extra_compile_args={
'cxx': c_flags,
'nvcc': nvcc_flags,
}
),
],
cmdclass={
'build_ext': BuildExtension,
}
)
================================================
FILE: raymarching/src/bindings.cpp
================================================
#include
#include "raymarching.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// utils
m.def("flatten_rays", &flatten_rays, "flatten_rays (CUDA)");
m.def("packbits", &packbits, "packbits (CUDA)");
m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
m.def("morton3D", &morton3D, "morton3D (CUDA)");
m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
// train
m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
// infer
m.def("march_rays", &march_rays, "march rays (CUDA)");
m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
}
================================================
FILE: raymarching/src/raymarching.cu
================================================
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }
inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }
inline constexpr __device__ float PI() { return 3.141592653589793f; }
inline constexpr __device__ float RPI() { return 0.3183098861837907f; }
template
inline __host__ __device__ T div_round_up(T val, T divisor) {
return (val + divisor - 1) / divisor;
}
inline __host__ __device__ float signf(const float x) {
return copysignf(1.0, x);
}
inline __host__ __device__ float clamp(const float x, const float min, const float max) {
return fminf(max, fmaxf(min, x));
}
inline __host__ __device__ void swapf(float& a, float& b) {
float c = a; a = b; b = c;
}
inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {
const float mx = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));
int exponent;
frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...
return fminf(max_cascade - 1, fmaxf(0, exponent));
}
inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {
const float mx = dt * H * 0.5;
int exponent;
frexpf(mx, &exponent);
return fminf(max_cascade - 1, fmaxf(0, exponent));
}
inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
{
v = (v * 0x00010001u) & 0xFF0000FFu;
v = (v * 0x00000101u) & 0x0F00F00Fu;
v = (v * 0x00000011u) & 0xC30C30C3u;
v = (v * 0x00000005u) & 0x49249249u;
return v;
}
inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
{
uint32_t xx = __expand_bits(x);
uint32_t yy = __expand_bits(y);
uint32_t zz = __expand_bits(z);
return xx | (yy << 1) | (zz << 2);
}
inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
{
x = x & 0x49249249;
x = (x | (x >> 2)) & 0xc30c30c3;
x = (x | (x >> 4)) & 0x0f00f00f;
x = (x | (x >> 8)) & 0xff0000ff;
x = (x | (x >> 16)) & 0x0000ffff;
return x;
}
////////////////////////////////////////////////////
///////////// utils /////////////
////////////////////////////////////////////////////
// rays_o/d: [N, 3]
// nears/fars: [N]
// scalar_t should always be float in use.
template
__global__ void kernel_near_far_from_aabb(
const scalar_t * __restrict__ rays_o,
const scalar_t * __restrict__ rays_d,
const scalar_t * __restrict__ aabb,
const uint32_t N,
const float min_near,
scalar_t * nears, scalar_t * fars
) {
// parallel per ray
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= N) return;
// locate
rays_o += n * 3;
rays_d += n * 3;
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
// get near far (assume cube scene)
float near = (aabb[0] - ox) * rdx;
float far = (aabb[3] - ox) * rdx;
if (near > far) swapf(near, far);
float near_y = (aabb[1] - oy) * rdy;
float far_y = (aabb[4] - oy) * rdy;
if (near_y > far_y) swapf(near_y, far_y);
if (near > far_y || near_y > far) {
nears[n] = fars[n] = std::numeric_limits::max();
return;
}
if (near_y > near) near = near_y;
if (far_y < far) far = far_y;
float near_z = (aabb[2] - oz) * rdz;
float far_z = (aabb[5] - oz) * rdz;
if (near_z > far_z) swapf(near_z, far_z);
if (near > far_z || near_z > far) {
nears[n] = fars[n] = std::numeric_limits::max();
return;
}
if (near_z > near) near = near_z;
if (far_z < far) far = far_z;
if (near < min_near) near = min_near;
nears[n] = near;
fars[n] = far;
}
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {
static constexpr uint32_t N_THREAD = 128;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_o.scalar_type(), "near_far_from_aabb", ([&] {
kernel_near_far_from_aabb<<>>(rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr());
}));
}
// rays_o/d: [N, 3]
// radius: float
// coords: [N, 2]
template
__global__ void kernel_sph_from_ray(
const scalar_t * __restrict__ rays_o,
const scalar_t * __restrict__ rays_d,
const float radius,
const uint32_t N,
scalar_t * coords
) {
// parallel per ray
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= N) return;
// locate
rays_o += n * 3;
rays_d += n * 3;
coords += n * 2;
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
// const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
// solve t from || o + td || = radius
const float A = dx * dx + dy * dy + dz * dz;
const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2
const float C = ox * ox + oy * oy + oz * oz - radius * radius;
const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)
// solve theta, phi (assume y is the up axis)
const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;
const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)
const float phi = atan2(z, x); // [-PI, PI)
// normalize to [-1, 1]
coords[0] = 2 * theta * RPI() - 1;
coords[1] = phi * RPI();
}
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {
static constexpr uint32_t N_THREAD = 128;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_o.scalar_type(), "sph_from_ray", ([&] {
kernel_sph_from_ray<<>>(rays_o.data_ptr(), rays_d.data_ptr(), radius, N, coords.data_ptr());
}));
}
// coords: int32, [N, 3]
// indices: int32, [N]
__global__ void kernel_morton3D(
const int * __restrict__ coords,
const uint32_t N,
int * indices
) {
// parallel
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= N) return;
// locate
coords += n * 3;
indices[n] = __morton3D(coords[0], coords[1], coords[2]);
}
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {
static constexpr uint32_t N_THREAD = 128;
kernel_morton3D<<>>(coords.data_ptr(), N, indices.data_ptr());
}
// indices: int32, [N]
// coords: int32, [N, 3]
__global__ void kernel_morton3D_invert(
const int * __restrict__ indices,
const uint32_t N,
int * coords
) {
// parallel
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= N) return;
// locate
coords += n * 3;
const int ind = indices[n];
coords[0] = __morton3D_invert(ind >> 0);
coords[1] = __morton3D_invert(ind >> 1);
coords[2] = __morton3D_invert(ind >> 2);
}
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {
static constexpr uint32_t N_THREAD = 128;
kernel_morton3D_invert<<>>(indices.data_ptr(), N, coords.data_ptr());
}
// grid: float, [C, H, H, H]
// N: int, C * H * H * H / 8
// density_thresh: float
// bitfield: uint8, [N]
template
__global__ void kernel_packbits(
const scalar_t * __restrict__ grid,
const uint32_t N,
const float density_thresh,
uint8_t * bitfield
) {
// parallel per byte
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= N) return;
// locate
grid += n * 8;
uint8_t bits = 0;
#pragma unroll
for (uint8_t i = 0; i < 8; i++) {
bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;
}
bitfield[n] = bits;
}
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {
static constexpr uint32_t N_THREAD = 128;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grid.scalar_type(), "packbits", ([&] {
kernel_packbits<<>>(grid.data_ptr(), N, density_thresh, bitfield.data_ptr());
}));
}
__global__ void kernel_flatten_rays(
const int * __restrict__ rays,
const uint32_t N, const uint32_t M,
int * res
) {
// parallel per ray
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= N) return;
// locate
uint32_t offset = rays[n * 2];
uint32_t num_steps = rays[n * 2 + 1];
// write to res
res += offset;
for (int i = 0; i < num_steps; i++) res[i] = n;
}
void flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res) {
static constexpr uint32_t N_THREAD = 128;
kernel_flatten_rays<<>>(rays.data_ptr(), N, M, res.data_ptr());
}
////////////////////////////////////////////////////
///////////// training /////////////
////////////////////////////////////////////////////
// rays_o/d: [N, 3]
// grid: [CHHH / 8]
// xyzs, dirs, ts: [M, 3], [M, 3], [M, 2]
// dirs: [M, 3]
// rays: [N, 3], idx, offset, num_steps
template
__global__ void kernel_march_rays_train(
const scalar_t * __restrict__ rays_o,
const scalar_t * __restrict__ rays_d,
const uint8_t * __restrict__ grid,
const float bound, const bool contract,
const float dt_gamma, const uint32_t max_steps,
const uint32_t N, const uint32_t C, const uint32_t H,
const scalar_t* __restrict__ nears,
const scalar_t* __restrict__ fars,
scalar_t * xyzs, scalar_t * dirs, scalar_t * ts,
int * rays,
int * counter,
const scalar_t* __restrict__ noises
) {
// parallel per ray
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= N) return;
// is first pass running.
const bool first_pass = (xyzs == nullptr);
// locate
rays_o += n * 3;
rays_d += n * 3;
rays += n * 2;
uint32_t num_steps = max_steps;
if (!first_pass) {
uint32_t point_index = rays[0];
num_steps = rays[1];
xyzs += point_index * 3;
dirs += point_index * 3;
ts += point_index * 2;
}
// ray marching
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float rH = 1 / (float)H;
const float H3 = H * H * H;
const float near = nears[n];
const float far = fars[n];
const float noise = noises[n];
const float dt_min = 2 * SQRT3() / max_steps;
const float dt_max = 2 * SQRT3() * bound / H;
// const float dt_max = 1e10f;
float t0 = near;
t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;
float t = t0;
uint32_t step = 0;
//if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far);
while (t < far && step < num_steps) {
// current point
const float x = clamp(ox + t * dx, -bound, bound);
const float y = clamp(oy + t * dy, -bound, bound);
const float z = clamp(oz + t * dz, -bound, bound);
float dt = clamp(t * dt_gamma, dt_min, dt_max);
// get mip level
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
const float mip_bound = fminf(scalbnf(1.0f, level), bound);
const float mip_rbound = 1 / mip_bound;
// contraction
float cx = x, cy = y, cz = z;
const float mag = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));
if (contract && mag > 1) {
// L-INF norm
const float Linf_scale = (2 - 1 / mag) / mag;
cx *= Linf_scale;
cy *= Linf_scale;
cz *= Linf_scale;
}
// convert to nearest grid position
const int nx = clamp(0.5 * (cx * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
const int ny = clamp(0.5 * (cy * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
const int nz = clamp(0.5 * (cz * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
const bool occ = grid[index / 8] & (1 << (index % 8));
// if occpuied, advance a small step, and write to output
//if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, step);
if (occ) {
step++;
t += dt;
if (!first_pass) {
xyzs[0] = cx; // write contracted coordinates!
xyzs[1] = cy;
xyzs[2] = cz;
dirs[0] = dx;
dirs[1] = dy;
dirs[2] = dz;
ts[0] = t;
ts[1] = dt;
xyzs += 3;
dirs += 3;
ts += 2;
}
// contraction case: cannot apply voxel skipping.
} else if (contract && mag > 1) {
t += dt;
// else, skip a large step (basically skip a voxel grid)
} else {
// calc distance to next voxel
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - cx) * rdx;
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - cy) * rdy;
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - cz) * rdz;
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
// step until next voxel
do {
dt = clamp(t * dt_gamma, dt_min, dt_max);
t += dt;
} while (t < tt);
}
}
//printf("[n=%d] step=%d, near=%f, far=%f, dt=%f, num_steps=%f\n", n, step, near, far, dt_min, (far - near) / dt_min);
// write rays
if (first_pass) {
uint32_t point_index = atomicAdd(counter, step);
rays[0] = point_index;
rays[1] = step;
}
}
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional xyzs, at::optional dirs, at::optional ts, at::Tensor rays, at::Tensor counter, at::Tensor noises) {
static constexpr uint32_t N_THREAD = 128;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_o.scalar_type(), "march_rays_train", ([&] {
kernel_march_rays_train<<>>(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), bound, contract, dt_gamma, max_steps, N, C, H, nears.data_ptr(), fars.data_ptr(),
xyzs.has_value() ? xyzs.value().data_ptr() : nullptr,
dirs.has_value() ? dirs.value().data_ptr() : nullptr,
ts.has_value() ? ts.value().data_ptr() : nullptr,
rays.data_ptr(), counter.data_ptr(), noises.data_ptr());
}));
}
// sigmas: [M]
// rgbs: [M, 3]
// ts: [M, 2]
// rays: [N, 2], offset, num_steps
// weights: [M]
// weights_sum: [N], final pixel alpha
// depth: [N,]
// image: [N, 3]
template
__global__ void kernel_composite_rays_train_forward(
const scalar_t * __restrict__ sigmas,
const scalar_t * __restrict__ rgbs,
const scalar_t * __restrict__ ts,
const int * __restrict__ rays,
const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize,
scalar_t * weights,
scalar_t * weights_sum,
scalar_t * depth,
scalar_t * image
) {
// parallel per ray
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= N) return;
// locate
uint32_t offset = rays[n * 2];
uint32_t num_steps = rays[n * 2 + 1];
// empty ray, or ray that exceed max step count.
if (num_steps == 0 || offset + num_steps > M) {
weights_sum[n] = 0;
depth[n] = 0;
image[n * 3] = 0;
image[n * 3 + 1] = 0;
image[n * 3 + 2] = 0;
return;
}
ts += offset * 2;
weights += offset;
sigmas += offset;
rgbs += offset * 3;
// accumulate
uint32_t step = 0;
float T = 1.0f;
float r = 0, g = 0, b = 0, ws = 0, d = 0;
while (step < num_steps) {
const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]);
const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha;
const float weight = alpha * T;
weights[0] = weight;
r += weight * rgbs[0];
g += weight * rgbs[1];
b += weight * rgbs[2];
ws += weight;
d += weight * ts[0];
T *= 1.0f - alpha;
// minimal remained transmittence
if (T < T_thresh) break;
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
// locate
weights++;
sigmas++;
rgbs += 3;
ts += 2;
step++;
}
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
// write
weights_sum[n] = ws; // weights_sum
depth[n] = d;
image[n * 3] = r;
image[n * 3 + 1] = g;
image[n * 3 + 2] = b;
}
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {
static constexpr uint32_t N_THREAD = 128;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(), "composite_rays_train_forward", ([&] {
kernel_composite_rays_train_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), ts.data_ptr(), rays.data_ptr(), M, N, T_thresh, binarize, weights.data_ptr(), weights_sum.data_ptr(), depth.data_ptr(), image.data_ptr());
}));
}
// grad_weights: [M,]
// grad_weights_sum: [N,]
// grad_image: [N, 3]
// grad_depth: [N,]
// sigmas: [M]
// rgbs: [M, 3]
// ts: [M, 2]
// rays: [N, 2], offset, num_steps
// weights_sum: [N,], weights_sum here
// image: [N, 3]
// grad_sigmas: [M]
// grad_rgbs: [M, 3]
template
__global__ void kernel_composite_rays_train_backward(
const scalar_t * __restrict__ grad_weights,
const scalar_t * __restrict__ grad_weights_sum,
const scalar_t * __restrict__ grad_depth,
const scalar_t * __restrict__ grad_image,
const scalar_t * __restrict__ sigmas,
const scalar_t * __restrict__ rgbs,
const scalar_t * __restrict__ ts,
const int * __restrict__ rays,
const scalar_t * __restrict__ weights_sum,
const scalar_t * __restrict__ depth,
const scalar_t * __restrict__ image,
const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize,
scalar_t * grad_sigmas,
scalar_t * grad_rgbs
) {
// parallel per ray
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= N) return;
// locate
uint32_t offset = rays[n * 2];
uint32_t num_steps = rays[n * 2 + 1];
if (num_steps == 0 || offset + num_steps > M) return;
grad_weights += offset;
grad_weights_sum += n;
grad_depth += n;
grad_image += n * 3;
weights_sum += n;
depth += n;
image += n * 3;
sigmas += offset;
rgbs += offset * 3;
ts += offset * 2;
grad_sigmas += offset;
grad_rgbs += offset * 3;
// accumulate
uint32_t step = 0;
float T = 1.0f;
const float r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0], d_final = depth[0];
float r = 0, g = 0, b = 0, ws = 0, d = 0;
while (step < num_steps) {
const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]);
const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha;
const float weight = alpha * T;
r += weight * rgbs[0];
g += weight * rgbs[1];
b += weight * rgbs[2];
ws += weight;
d += weight * ts[0];
T *= 1.0f - alpha;
// check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
// write grad_rgbs
grad_rgbs[0] = grad_image[0] * weight;
grad_rgbs[1] = grad_image[1] * weight;
grad_rgbs[2] = grad_image[2] * weight;
// write grad_sigmas
grad_sigmas[0] = ts[1] * (
grad_image[0] * (T * rgbs[0] - (r_final - r)) +
grad_image[1] * (T * rgbs[1] - (g_final - g)) +
grad_image[2] * (T * rgbs[2] - (b_final - b)) +
(grad_weights_sum[0] + grad_weights[0]) * (T - (ws_final - ws)) +
grad_depth[0] * (T * ts[0] - (d_final - d))
);
//printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
// minimal remained transmittence
if (T < T_thresh) break;
// locate
sigmas++;
rgbs += 3;
ts += 2;
grad_weights++;
grad_sigmas++;
grad_rgbs += 3;
step++;
}
}
void composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {
static constexpr uint32_t N_THREAD = 128;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_image.scalar_type(), "composite_rays_train_backward", ([&] {
kernel_composite_rays_train_backward<<>>(grad_weights.data_ptr(), grad_weights_sum.data_ptr(), grad_depth.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), ts.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), depth.data_ptr(), image.data_ptr(), M, N, T_thresh, binarize, grad_sigmas.data_ptr(), grad_rgbs.data_ptr());
}));
}
////////////////////////////////////////////////////
///////////// infernce /////////////
////////////////////////////////////////////////////
template
__global__ void kernel_march_rays(
const uint32_t n_alive,
const uint32_t n_step,
const int* __restrict__ rays_alive,
const scalar_t* __restrict__ rays_t,
const scalar_t* __restrict__ rays_o,
const scalar_t* __restrict__ rays_d,
const float bound, const bool contract,
const float dt_gamma, const uint32_t max_steps,
const uint32_t C, const uint32_t H,
const uint8_t * __restrict__ grid,
const scalar_t* __restrict__ nears,
const scalar_t* __restrict__ fars,
scalar_t* xyzs, scalar_t* dirs, scalar_t* ts,
const scalar_t* __restrict__ noises
) {
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= n_alive) return;
const int index = rays_alive[n]; // ray id
const float noise = noises[n];
// locate
rays_o += index * 3;
rays_d += index * 3;
xyzs += n * n_step * 3;
dirs += n * n_step * 3;
ts += n * n_step * 2;
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float rH = 1 / (float)H;
const float H3 = H * H * H;
const float near = nears[index], far = fars[index];
const float dt_min = 2 * SQRT3() / max_steps;
const float dt_max = 2 * SQRT3() * bound / H;
// const float dt_max = 1e10f;
// march for n_step steps, record points
float t = rays_t[index];
t += clamp(t * dt_gamma, dt_min, dt_max) * noise;
uint32_t step = 0;
while (t < far && step < n_step) {
// current point
const float x = clamp(ox + t * dx, -bound, bound);
const float y = clamp(oy + t * dy, -bound, bound);
const float z = clamp(oz + t * dz, -bound, bound);
float dt = clamp(t * dt_gamma, dt_min, dt_max);
// get mip level
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
const float mip_bound = fminf(scalbnf(1, level), bound);
const float mip_rbound = 1 / mip_bound;
// contraction
float cx = x, cy = y, cz = z;
const float mag = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));
if (contract && mag > 1) {
// L-INF norm
const float Linf_scale = (2 - 1 / mag) / mag;
cx *= Linf_scale;
cy *= Linf_scale;
cz *= Linf_scale;
}
// convert to nearest grid position
const int nx = clamp(0.5 * (cx * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
const int ny = clamp(0.5 * (cy * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
const int nz = clamp(0.5 * (cz * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
const bool occ = grid[index / 8] & (1 << (index % 8));
// if occpuied, advance a small step, and write to output
if (occ) {
// write step
xyzs[0] = cx;
xyzs[1] = cy;
xyzs[2] = cz;
dirs[0] = dx;
dirs[1] = dy;
dirs[2] = dz;
// calc dt
t += dt;
ts[0] = t;
ts[1] = dt;
// step
xyzs += 3;
dirs += 3;
ts += 2;
step++;
// contraction case
} else if (contract && mag > 1) {
t += dt;
// else, skip a large step (basically skip a voxel grid)
} else {
// calc distance to next voxel
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - cx) * rdx;
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - cy) * rdy;
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - cz) * rdz;
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
// step until next voxel
do {
dt = clamp(t * dt_gamma, dt_min, dt_max);
t += dt;
} while (t < tt);
}
}
}
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises) {
static constexpr uint32_t N_THREAD = 128;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_o.scalar_type(), "march_rays", ([&] {
kernel_march_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, contract, dt_gamma, max_steps, C, H, grid.data_ptr(), near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), ts.data_ptr(), noises.data_ptr());
}));
}
template
__global__ void kernel_composite_rays(
const uint32_t n_alive,
const uint32_t n_step,
const float T_thresh, const bool binarize,
int* rays_alive,
scalar_t* rays_t,
const scalar_t* __restrict__ sigmas,
const scalar_t* __restrict__ rgbs,
const scalar_t* __restrict__ ts,
scalar_t* weights_sum, scalar_t* depth, scalar_t* image
) {
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= n_alive) return;
const int index = rays_alive[n]; // ray id
// locate
sigmas += n * n_step;
rgbs += n * n_step * 3;
ts += n * n_step * 2;
rays_t += index;
weights_sum += index;
depth += index;
image += index * 3;
float t;
float d = depth[0], r = image[0], g = image[1], b = image[2], weight_sum = weights_sum[0];
// accumulate
uint32_t step = 0;
while (step < n_step) {
// ray is terminated if t == 0
if (ts[0] == 0) break;
const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]);
const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha;
/*
T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
w_i = alpha_i * T_i
-->
T_i = 1 - \sum_{j=0}^{i-1} w_j
*/
const float T = 1 - weight_sum;
const float weight = alpha * T;
weight_sum += weight;
t = ts[0];
d += weight * t; // real depth
r += weight * rgbs[0];
g += weight * rgbs[1];
b += weight * rgbs[2];
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
// ray is terminated if T is too small
// use a larger bound to further accelerate inference
if (T < T_thresh) break;
// locate
sigmas++;
rgbs += 3;
ts += 2;
step++;
}
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
// rays_alive = -1 means ray is terminated early.
if (step < n_step) {
rays_alive[n] = -1;
} else {
rays_t[0] = t;
}
weights_sum[0] = weight_sum; // this is the thing I needed!
depth[0] = d;
image[0] = r;
image[1] = g;
image[2] = b;
}
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, const bool binarize, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor ts, at::Tensor weights, at::Tensor depth, at::Tensor image) {
static constexpr uint32_t N_THREAD = 128;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
image.scalar_type(), "composite_rays", ([&] {
kernel_composite_rays<<>>(n_alive, n_step, T_thresh, binarize, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr