Repository: openai/shap-e
Branch: main
Commit: 50131012ee11
Files: 79
Total size: 17.7 MB
Directory structure:
gitextract__yazl4ub/
├── .gitignore
├── LICENSE
├── README.md
├── model-card.md
├── samples.md
├── setup.py
└── shap_e/
├── __init__.py
├── diffusion/
│ ├── __init__.py
│ ├── gaussian_diffusion.py
│ ├── k_diffusion.py
│ └── sample.py
├── examples/
│ ├── encode_model.ipynb
│ ├── example_data/
│ │ └── cactus/
│ │ ├── material.mtl
│ │ └── object.obj
│ ├── sample_image_to_3d.ipynb
│ └── sample_text_to_3d.ipynb
├── models/
│ ├── __init__.py
│ ├── configs.py
│ ├── download.py
│ ├── generation/
│ │ ├── __init__.py
│ │ ├── latent_diffusion.py
│ │ ├── perceiver.py
│ │ ├── pooled_mlp.py
│ │ ├── pretrained_clip.py
│ │ ├── transformer.py
│ │ └── util.py
│ ├── nerf/
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── ray.py
│ │ └── renderer.py
│ ├── nerstf/
│ │ ├── mlp.py
│ │ └── renderer.py
│ ├── nn/
│ │ ├── __init__.py
│ │ ├── camera.py
│ │ ├── checkpoint.py
│ │ ├── encoding.py
│ │ ├── meta.py
│ │ ├── ops.py
│ │ ├── pointnet2_utils.py
│ │ └── utils.py
│ ├── query.py
│ ├── renderer.py
│ ├── stf/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── mlp.py
│ │ └── renderer.py
│ ├── transmitter/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── bottleneck.py
│ │ ├── channels_encoder.py
│ │ ├── multiview_encoder.py
│ │ ├── params_proj.py
│ │ └── pc_encoder.py
│ └── volume.py
├── rendering/
│ ├── __init__.py
│ ├── _mc_table.py
│ ├── blender/
│ │ ├── __init__.py
│ │ ├── blender_script.py
│ │ ├── constants.py
│ │ ├── render.py
│ │ └── view_data.py
│ ├── mc.py
│ ├── mesh.py
│ ├── ply_util.py
│ ├── point_cloud.py
│ ├── pytorch3d_util.py
│ ├── raycast/
│ │ ├── __init__.py
│ │ ├── _utils.py
│ │ ├── cast.py
│ │ ├── render.py
│ │ └── types.py
│ ├── torch_mesh.py
│ └── view_data.py
└── util/
├── __init__.py
├── collections.py
├── data_util.py
├── image_util.py
├── io.py
└── notebooks.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/
.DS_Store
*.egg-info/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 OpenAI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Shap-E
This is the official code and model release for [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463).
* See [Usage](#usage) for guidance on how to use this repository.
* See [Samples](#samples) for examples of what our text-conditional model can generate.
# Samples
Here are some highlighted samples from our text-conditional model. For random samples on selected prompts, see [samples.md](samples.md).
|
|
|
A chair that looks like an avocado |
An airplane that looks like a banana |
A spaceship |
|
|
|
| A birthday cupcake |
A chair that looks like a tree |
A green boot |
|
|
|
| A penguin |
Ube ice cream cone |
A bowl of vegetables |
# Usage
Install with `pip install -e .`.
To get started with examples, see the following notebooks:
* [sample_text_to_3d.ipynb](shap_e/examples/sample_text_to_3d.ipynb) - sample a 3D model, conditioned on a text prompt.
* [sample_image_to_3d.ipynb](shap_e/examples/sample_image_to_3d.ipynb) - sample a 3D model, conditioned on a synthetic view image. To get the best result, you should remove background from the input image.
* [encode_model.ipynb](shap_e/examples/encode_model.ipynb) - loads a 3D model or a trimesh, creates a batch of multiview renders and a point cloud, encodes them into a latent, and renders it back. For this to work, install Blender version 3.3.1 or higher, and set the environment variable `BLENDER_PATH` to the path of the Blender executable.
================================================
FILE: model-card.md
================================================
# Model Card: Shap-E
This is the official codebase for running the latent diffusion models described in [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463). These models were trained and released by OpenAI. Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about how the models were trained and evaluated.
# Model Details
Shap-E includes two kinds of models: an encoder and a latent diffusion model.
1. **The encoder** converts 3D assets into the parameters of small neural networks which represent the 3D shape and texture as an implicit function. The resulting implicit function can be rendered from arbitrary viewpoints or imported into downstream applications as a mesh.
2. **The latent diffusion model** generates novel implicit functions conditioned on either images or text descriptions. As above, these samples can be rendered or exported as a mesh. Specifically, these models produce latents which must be linearly projected to get the final implicit function parameters. The final projection layer of the encoder is used for this purpose.
Like [Point-E](https://github.com/openai/point-e/blob/main/model-card.md), Shap-E can often generate coherent 3D objects when conditioned on a rendering from a single viewpoint. When conditioned on text prompts directly, Shap-E is also often capable of producing recognizable objects, although it sometimes struggles to combine multiple objects or concepts.
Samples from Shap-E are typically lower fidelity than professional 3D assets and often have rough edges, holes, or blurry surface textures.
# Model Date
April 2023
# Model Versions
The following model checkpoints are available in this repository:
* `transmitter` - the encoder and corresponding projection layers for converting encoder outputs into implicit neural representations.
* `decoder` - just the final projection layer component of `transmitter`. This is a smaller checkpoint than `transmitter` since it does not include parameters for encoding 3D assets. This is the minimum required model to convert diffusion outputs into implicit neural representations.
* `text300M` - the text-conditional latent diffusion model.
* `image300M` - the image-conditional latent diffusion model.
# Paper & Samples
[Paper link](https://arxiv.org/abs/2305.02463) / [Samples](samples.md)
# Training data
The encoder and image-conditional diffusion models are trained on the [same dataset as Point-E](https://github.com/openai/point-e/blob/main/model-card.md#training-data). However, a few changes to the post-processing were made:
* We rendered 60 views (instead of 20) of each model when computing point clouds, to avoid small cracks.
* We produced 16K points in each point cloud instead of 4K.
* We simplified the lighting and material setup to only include diffuse materials.
For our text-conditional diffusion model, we expanded our dataset with roughly a million more 3D assets. Additionally, we collected 120K captions from human annotators for a high-quality subset of our 3D assets.
# Evaluated Use
We release these models with the intention of furthering progress in the field of generative modeling. However, we acknowledge that our models have certain constraints and biases, which is why we advise against employing them for commercial purposes at this time. We are aware that the utilization of our models could extend to areas beyond our expectations, and defining specific criteria for what is considered suitable for "research" purposes presents a challenge. Specifically, we advise caution when using these models in contexts that demand high accuracy, where minor imperfections in the generated 3D assets could have adverse consequences.
Specifically, these models have been evaluated on the following tasks for research purposes:
* Generating 3D renderings or meshes conditioned on single, synthetic images
* Generating 3D renderings or meshes conditioned on text descriptions
# Performance & Limitations
Our image-conditional model has only been evaluated on a highly specific distribution of synthetic renderings. Even in these cases, the model still sometimes fails to infer the correct occluded parts of an object or produces geometry that is inconsistent with the given rendered images. These failure modes are similar to those of Point-E. The resulting 3D assets often have rough edges, holes, or blurry surface textures.
Our text-conditional model can also produce a somewhat large and diverse vocabulary of objects. This model is often capable of producing objects with requested colors and textures, and sometimes even combining multiple objects. However, it often fails for more complex prompts that require placing multiple objects in a scene or binding attributes to objects. It also typically fails to produce a desired number of objects when a certain quantity is requested.
We find that our text-conditional model can sometimes produce samples which reflect gender biases. For example, samples for "a nurse" typically have a different body shape than samples for "a doctor". When probing for potential misuses, we also found that our text-conditional model is capable of producing 3D assets related to violence, such as guns or tanks. However, the resulting quality of these samples is poor enough that they look unrealistic and toy like.
As with Point-E, our dataset consists of many simple, cartoonish 3D assets, and our generative models are prone to imitating this style.
We believe our models will have many potential use cases. For example, our text-conditional model could enable users to quickly produce many 3D assets, allowing for rapid prototyping for computer graphics applications or 3D printing.
The use of 3D printing in concert with our models could potentially be harmful, for example if used to create dangerous objects or fabricate tools or parts that are deployed without external validation.
Generative 3D models share many challenges and constraints with image generation models. This includes the tendency to generate content that may be biased or detrimental, as well as the potential for dual-use applications. As the capabilities of these models evolve, further investigation is required to gain a clearer understanding of how these risks manifest.
================================================
FILE: samples.md
================================================
# Samples
Here is a collection of prompts and four random text-conditional samples for each prompt. Samples are rendered at 128x128 resolution with NeRF.
| Prompt | | | | |
|---|
| a penguin |  |  |  |  |
| a campfire |  |  |  |  |
| an elephant |  |  |  |  |
| a donut with pink icing |  |  |  |  |
| a voxelized dog |  |  |  |  |
| ube ice cream cone |  |  |  |  |
| a birthday cupcake |  |  |  |  |
| shepherds pie |  |  |  |  |
| a bowl of vegetables |  |  |  |  |
| a cheeseburger |  |  |  |  |
| a plate of mushy green peas |  |  |  |  |
| a traffic cone |  |  |  |  |
| a car that looks like an avocado |  |  |  |  |
| an airplane that looks like a banana |  |  |  |  |
| a stop sign |  |  |  |  |
| a spaceship |  |  |  |  |
| a race car |  |  |  |  |
| a schoolbus |  |  |  |  |
| a firetruck |  |  |  |  |
| a rusty old car |  |  |  |  |
| a fast car |  |  |  |  |
| a chair that looks like an avocado |  |  |  |  |
| a chair that looks like fruit |  |  |  |  |
| a chair that looks like a tree |  |  |  |  |
| a chair that looks like a zebra |  |  |  |  |
| a chair that looks like a swimming pool |  |  |  |  |
| the person is running |  |  |  |  |
| the person is sitting |  |  |  |  |
| the person is lying down |  |  |  |  |
| a person that looks like a zebra |  |  |  |  |
| a person that looks like a leopard |  |  |  |  |
| a pair of shorts |  |  |  |  |
| a designer dress |  |  |  |  |
| banana shoes |  |  |  |  |
| a green boot |  |  |  |  |
| a pair of sunglasses |  |  |  |  |
================================================
FILE: setup.py
================================================
from setuptools import setup
setup(
name="shap-e",
packages=[
"shap_e",
"shap_e.diffusion",
"shap_e.models",
"shap_e.models.generation",
"shap_e.models.nerf",
"shap_e.models.nerstf",
"shap_e.models.nn",
"shap_e.models.stf",
"shap_e.models.transmitter",
"shap_e.rendering",
"shap_e.rendering.blender",
"shap_e.rendering.raycast",
"shap_e.util",
],
install_requires=[
"filelock",
"Pillow",
"torch",
"fire",
"humanize",
"requests",
"tqdm",
"matplotlib",
"scikit-image",
"scipy",
"numpy",
"blobfile",
"clip @ git+https://github.com/openai/CLIP.git",
],
author="OpenAI",
)
================================================
FILE: shap_e/__init__.py
================================================
================================================
FILE: shap_e/diffusion/__init__.py
================================================
================================================
FILE: shap_e/diffusion/gaussian_diffusion.py
================================================
"""
Based on https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
"""
import math
from typing import Any, Dict, Iterable, Optional, Sequence, Union
import blobfile as bf
import numpy as np
import torch as th
import yaml
def diffusion_from_config(config: Union[str, Dict[str, Any]]) -> "GaussianDiffusion":
if isinstance(config, str):
with bf.BlobFile(config, "rb") as f:
obj = yaml.load(f, Loader=yaml.SafeLoader)
return diffusion_from_config(obj)
schedule = config["schedule"]
steps = config["timesteps"]
respace = config.get("respacing", None)
mean_type = config.get("mean_type", "epsilon")
betas = get_named_beta_schedule(schedule, steps, **config.get("schedule_args", {}))
channel_scales = config.get("channel_scales", None)
channel_biases = config.get("channel_biases", None)
if channel_scales is not None:
channel_scales = np.array(channel_scales)
if channel_biases is not None:
channel_biases = np.array(channel_biases)
kwargs = dict(
betas=betas,
model_mean_type=mean_type,
model_var_type="learned_range",
loss_type="mse",
channel_scales=channel_scales,
channel_biases=channel_biases,
)
if respace is None:
return GaussianDiffusion(**kwargs)
else:
return SpacedDiffusion(use_timesteps=space_timesteps(steps, respace), **kwargs)
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
This is the deprecated API for creating beta schedules.
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "linear":
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, **extra_args: float):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
return get_beta_schedule(
"linear",
beta_start=scale * 0.0001,
beta_end=scale * 0.02,
num_diffusion_timesteps=num_diffusion_timesteps,
)
elif schedule_name == "cosine":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
elif schedule_name == "inv_parabola":
exponent = extra_args.get("power", 2.0)
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: 1 - t**exponent,
)
elif schedule_name == "translated_parabola":
exponent = extra_args.get("power", 2.0)
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: (1 - t) ** exponent,
)
elif schedule_name == "exp":
coefficient = extra_args.get("coefficient", -12.0)
return betas_for_alpha_bar(num_diffusion_timesteps, lambda t: math.exp(t * coefficient))
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
elif section_counts.startswith("exact"):
res = set(int(x) for x in section_counts[len("exact") :].split(","))
for x in res:
if x < 0 or x >= num_timesteps:
raise ValueError(f"timestep out of bounds: {x}")
return res
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(f"cannot divide section of {size} steps into {section_count}")
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class GaussianDiffusion:
"""
Utilities for training and sampling diffusion models.
Ported directly from here:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D array of betas for each diffusion timestep from T to 1.
:param model_mean_type: a string determining what the model outputs.
:param model_var_type: a string determining how variance is output.
:param loss_type: a string determining the loss function to use.
:param discretized_t0: if True, use discrete gaussian loss for t=0. Only
makes sense for images.
:param channel_scales: a multiplier to apply to x_start in training_losses
and sampling functions.
"""
def __init__(
self,
*,
betas: Sequence[float],
model_mean_type: str,
model_var_type: str,
loss_type: str,
discretized_t0: bool = False,
channel_scales: Optional[np.ndarray] = None,
channel_biases: Optional[np.ndarray] = None,
):
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
self.loss_type = loss_type
self.discretized_t0 = discretized_t0
self.channel_scales = channel_scales
self.channel_biases = channel_biases
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
self.betas = betas
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
self.num_timesteps = int(betas.shape[0])
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
)
def get_sigmas(self, t):
return _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, t.shape)
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(
self, model, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, t, **model_kwargs)
if isinstance(model_output, tuple):
model_output, extra = model_output
else:
extra = None
if self.model_var_type in ["learned", "learned_range"]:
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
if self.model_var_type == "learned":
model_log_variance = model_var_values
model_variance = th.exp(model_log_variance)
else:
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = th.exp(model_log_variance)
else:
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
"fixed_large": (
np.append(self.posterior_variance[1], self.betas[1:]),
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
),
"fixed_small": (
self.posterior_variance,
self.posterior_log_variance_clipped,
),
}[self.model_var_type]
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)
if clip_denoised:
return x.clamp(-1, 1)
return x
if self.model_mean_type == "x_prev":
pred_xstart = process_xstart(
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
)
model_mean = model_output
elif self.model_mean_type in ["x_start", "epsilon"]:
if self.model_mean_type == "x_start":
pred_xstart = process_xstart(model_output)
else:
pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
else:
raise NotImplementedError(self.model_mean_type)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
"extra": extra,
}
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_xstart_from_xprev(self, x_t, t, xprev):
assert x_t.shape == xprev.shape
return ( # (xprev - coef2*x_t) / coef1
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
- _extract_into_tensor(
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
)
* x_t
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient = cond_fn(x, t, **(model_kwargs or {}))
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
return new_mean
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute what the p_mean_variance output would have been, should the
model's score function be conditioned by cond_fn.
See condition_mean() for details on cond_fn.
Unlike condition_mean(), this instead uses the conditioning strategy
from Song et al (2020).
"""
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **(model_kwargs or {}))
out = p_mean_var.copy()
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
return out
def p_sample(
self,
model,
x,
t,
clip_denoised=False,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
noise = th.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if cond_fn is not None:
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def p_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=False,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
temp=1.0,
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final = None
for sample in self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
temp=temp,
):
final = sample
return final["sample"]
def p_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=False,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
temp=1.0,
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device) * temp
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
yield self.unscale_out_dict(out)
img = out["sample"]
def ddim_sample(
self,
model,
x,
t,
clip_denoised=False,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = (
eta
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12.
noise = th.randn_like(x)
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def ddim_reverse_sample(
self,
model,
x,
t,
clip_denoised=False,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""
assert eta == 0.0, "Reverse ODE only for deterministic path"
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
def ddim_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=False,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
temp=1.0,
):
"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final = None
for sample in self.ddim_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
eta=eta,
temp=temp,
):
final = sample
return final["sample"]
def ddim_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=False,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
temp=1.0,
):
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device) * temp
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.ddim_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
eta=eta,
)
yield self.unscale_out_dict(out)
img = out["sample"]
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=False, model_kwargs=None):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
if not self.discretized_t0:
decoder_nll = th.zeros_like(decoder_nll)
assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0), decoder_nll, kl)
return {
"output": output,
"pred_xstart": out["pred_xstart"],
"extra": out["extra"],
}
def training_losses(
self, model, x_start, t, model_kwargs=None, noise=None
) -> Dict[str, th.Tensor]:
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
x_start = self.scale_channels(x_start)
if model_kwargs is None:
model_kwargs = {}
if noise is None:
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
terms = {}
if self.loss_type == "kl" or self.loss_type == "rescaled_kl":
vb_terms = self._vb_terms_bpd(
model=model,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
model_kwargs=model_kwargs,
)
terms["loss"] = vb_terms["output"]
if self.loss_type == "rescaled_kl":
terms["loss"] *= self.num_timesteps
extra = vb_terms["extra"]
elif self.loss_type == "mse" or self.loss_type == "rescaled_mse":
model_output = model(x_t, t, **model_kwargs)
if isinstance(model_output, tuple):
model_output, extra = model_output
else:
extra = {}
if self.model_var_type in [
"learned",
"learned_range",
]:
B, C = x_t.shape[:2]
assert model_output.shape == (
B,
C * 2,
*x_t.shape[2:],
), f"{model_output.shape} != {(B, C * 2, *x_t.shape[2:])}"
model_output, model_var_values = th.split(model_output, C, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"] = self._vb_terms_bpd(
model=lambda *args, r=frozen_out: r,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
)["output"]
if self.loss_type == "rescaled_mse":
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
terms["vb"] *= self.num_timesteps / 1000.0
target = {
"x_prev": self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
"x_start": x_start,
"epsilon": noise,
}[self.model_mean_type]
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2)
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:
terms["loss"] = terms["mse"]
else:
raise NotImplementedError(self.loss_type)
if "losses" in extra:
terms.update({k: loss for k, (loss, _scale) in extra["losses"].items()})
for loss, scale in extra["losses"].values():
terms["loss"] = terms["loss"] + loss * scale
return terms
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
return mean_flat(kl_prior) / np.log(2.0)
def calc_bpd_loop(self, model, x_start, clip_denoised=False, model_kwargs=None):
"""
Compute the entire variational lower-bound, measured in bits-per-dim,
as well as other related quantities.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param clip_denoised: if True, clip denoised samples.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- total_bpd: the total variational lower-bound, per batch element.
- prior_bpd: the prior term in the lower-bound.
- vb: an [N x T] tensor of terms in the lower-bound.
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
"""
device = x_start.device
batch_size = x_start.shape[0]
vb = []
xstart_mse = []
mse = []
for t in list(range(self.num_timesteps))[::-1]:
t_batch = th.tensor([t] * batch_size, device=device)
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
# Calculate VLB term at the current timestep
with th.no_grad():
out = self._vb_terms_bpd(
model,
x_start=x_start,
x_t=x_t,
t=t_batch,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
)
vb.append(out["output"])
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
mse.append(mean_flat((eps - noise) ** 2))
vb = th.stack(vb, dim=1)
xstart_mse = th.stack(xstart_mse, dim=1)
mse = th.stack(mse, dim=1)
prior_bpd = self._prior_bpd(x_start)
total_bpd = vb.sum(dim=1) + prior_bpd
return {
"total_bpd": total_bpd,
"prior_bpd": prior_bpd,
"vb": vb,
"xstart_mse": xstart_mse,
"mse": mse,
}
def scale_channels(self, x: th.Tensor) -> th.Tensor:
if self.channel_scales is not None:
x = x * th.from_numpy(self.channel_scales).to(x).reshape(
[1, -1, *([1] * (len(x.shape) - 2))]
)
if self.channel_biases is not None:
x = x + th.from_numpy(self.channel_biases).to(x).reshape(
[1, -1, *([1] * (len(x.shape) - 2))]
)
return x
def unscale_channels(self, x: th.Tensor) -> th.Tensor:
if self.channel_biases is not None:
x = x - th.from_numpy(self.channel_biases).to(x).reshape(
[1, -1, *([1] * (len(x.shape) - 2))]
)
if self.channel_scales is not None:
x = x / th.from_numpy(self.channel_scales).to(x).reshape(
[1, -1, *([1] * (len(x.shape) - 2))]
)
return x
def unscale_out_dict(
self, out: Dict[str, Union[th.Tensor, Any]]
) -> Dict[str, Union[th.Tensor, Any]]:
return {
k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v) for k, v in out.items()
}
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: (unordered) timesteps from the original diffusion
process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps: Iterable[int], **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(self, model, *args, **kwargs):
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses(self, model, *args, **kwargs):
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(model, self.timestep_map, self.original_num_steps)
class _WrappedModel:
def __init__(self, model, timestep_map, original_num_steps):
self.model = model
self.timestep_map = timestep_map
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
return self.model(x, new_ts, **kwargs)
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, th.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.flatten(1).mean(1)
================================================
FILE: shap_e/diffusion/k_diffusion.py
================================================
"""
Based on: https://github.com/crowsonkb/k-diffusion
Copyright (c) 2022 Katherine Crowson
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import numpy as np
import torch as th
from .gaussian_diffusion import GaussianDiffusion, mean_flat
class KarrasDenoiser:
def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data
def get_snr(self, sigmas):
return sigmas**-2
def get_sigmas(self, sigmas):
return sigmas
def get_scalings(self, sigma):
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out, c_in
def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None):
if model_kwargs is None:
model_kwargs = {}
if noise is None:
noise = th.randn_like(x_start)
terms = {}
dims = x_start.ndim
x_t = x_start + noise * append_dims(sigmas, dims)
c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)]
model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs)
target = (x_start - c_skip * x_t) / c_out
terms["mse"] = mean_flat((model_output - target) ** 2)
terms["xs_mse"] = mean_flat((denoised - x_start) ** 2)
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:
terms["loss"] = terms["mse"]
return terms
def denoise(self, model, x_t, sigmas, **model_kwargs):
c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)]
rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
denoised = c_out * model_output + c_skip * x_t
return model_output, denoised
class GaussianToKarrasDenoiser:
def __init__(self, model, diffusion):
from scipy import interpolate
self.model = model
self.diffusion = diffusion
self.alpha_cumprod_to_t = interpolate.interp1d(
diffusion.alphas_cumprod, np.arange(0, diffusion.num_timesteps)
)
def sigma_to_t(self, sigma):
alpha_cumprod = 1.0 / (sigma**2 + 1)
if alpha_cumprod > self.diffusion.alphas_cumprod[0]:
return 0
elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]:
return self.diffusion.num_timesteps - 1
else:
return float(self.alpha_cumprod_to_t(alpha_cumprod))
def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None):
t = th.tensor(
[self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()],
dtype=th.long,
device=sigmas.device,
)
c_in = append_dims(1.0 / (sigmas**2 + 1) ** 0.5, x_t.ndim)
out = self.diffusion.p_mean_variance(
self.model, x_t * c_in, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
return None, out["pred_xstart"]
def karras_sample(*args, **kwargs):
last = None
for x in karras_sample_progressive(*args, **kwargs):
last = x["x"]
return last
def karras_sample_progressive(
diffusion,
model,
shape,
steps,
clip_denoised=True,
progress=False,
model_kwargs=None,
device=None,
sigma_min=0.002,
sigma_max=80, # higher for highres?
rho=7.0,
sampler="heun",
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
guidance_scale=0.0,
):
sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)
x_T = th.randn(*shape, device=device) * sigma_max
sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[
sampler
]
if sampler != "ancestral":
sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise)
else:
sampler_args = {}
if isinstance(diffusion, KarrasDenoiser):
def denoiser(x_t, sigma):
_, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs)
if clip_denoised:
denoised = denoised.clamp(-1, 1)
return denoised
elif isinstance(diffusion, GaussianDiffusion):
model = GaussianToKarrasDenoiser(model, diffusion)
def denoiser(x_t, sigma):
_, denoised = model.denoise(
x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
return denoised
else:
raise NotImplementedError
if guidance_scale != 0 and guidance_scale != 1:
def guided_denoiser(x_t, sigma):
x_t = th.cat([x_t, x_t], dim=0)
sigma = th.cat([sigma, sigma], dim=0)
x_0 = denoiser(x_t, sigma)
cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0)
x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0)
return x_0
else:
guided_denoiser = denoiser
for obj in sample_fn(
guided_denoiser,
x_T,
sigmas,
progress=progress,
**sampler_args,
):
if isinstance(diffusion, GaussianDiffusion):
yield diffusion.unscale_out_dict(obj)
else:
yield obj
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = th.linspace(0, 1, n)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim)
def get_ancestral_step(sigma_from, sigma_to):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
return sigma_down, sigma_up
@th.no_grad()
def sample_euler_ancestral(model, x, sigmas, progress=False):
"""Ancestral sampling with Euler method steps."""
s_in = x.new_ones([x.shape[0]])
indices = range(len(sigmas) - 1)
if progress:
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
denoised = model(x, sigmas[i] * s_in)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "pred_xstart": denoised}
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + th.randn_like(x) * sigma_up
yield {"x": x, "pred_xstart": x}
@th.no_grad()
def sample_heun(
denoiser,
x,
sigmas,
progress=False,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
s_in = x.new_ones([x.shape[0]])
indices = range(len(sigmas) - 1)
if progress:
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
gamma = (
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
)
eps = th.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = denoiser(x, sigma_hat * s_in)
d = to_d(x, sigma_hat, denoised)
yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "pred_xstart": denoised}
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
yield {"x": x, "pred_xstart": denoised}
@th.no_grad()
def sample_dpm(
denoiser,
x,
sigmas,
progress=False,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
s_in = x.new_ones([x.shape[0]])
indices = range(len(sigmas) - 1)
if progress:
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
gamma = (
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
)
eps = th.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = denoiser(x, sigma_hat * s_in)
d = to_d(x, sigma_hat, denoised)
yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised}
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigma_hat
dt_2 = sigmas[i + 1] - sigma_hat
x_2 = x + d * dt_1
denoised_2 = denoiser(x_2, sigma_mid * s_in)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
yield {"x": x, "pred_xstart": denoised}
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
def append_zero(x):
return th.cat([x, x.new_zeros([1])])
================================================
FILE: shap_e/diffusion/sample.py
================================================
from typing import Any, Callable, Dict, Optional
import torch
import torch.nn as nn
from .gaussian_diffusion import GaussianDiffusion
from .k_diffusion import karras_sample
DEFAULT_KARRAS_STEPS = 64
DEFAULT_KARRAS_SIGMA_MIN = 1e-3
DEFAULT_KARRAS_SIGMA_MAX = 160
DEFAULT_KARRAS_S_CHURN = 0.0
def uncond_guide_model(
model: Callable[..., torch.Tensor], scale: float
) -> Callable[..., torch.Tensor]:
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
return model_fn
def sample_latents(
*,
batch_size: int,
model: nn.Module,
diffusion: GaussianDiffusion,
model_kwargs: Dict[str, Any],
guidance_scale: float,
clip_denoised: bool,
use_fp16: bool,
use_karras: bool,
karras_steps: int,
sigma_min: float,
sigma_max: float,
s_churn: float,
device: Optional[torch.device] = None,
progress: bool = False,
) -> torch.Tensor:
sample_shape = (batch_size, model.d_latent)
if device is None:
device = next(model.parameters()).device
if hasattr(model, "cached_model_kwargs"):
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
if guidance_scale != 1.0 and guidance_scale != 0.0:
for k, v in model_kwargs.copy().items():
model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
sample_shape = (batch_size, model.d_latent)
with torch.autocast(device_type=device.type, enabled=use_fp16):
if use_karras:
samples = karras_sample(
diffusion=diffusion,
model=model,
shape=sample_shape,
steps=karras_steps,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
device=device,
sigma_min=sigma_min,
sigma_max=sigma_max,
s_churn=s_churn,
guidance_scale=guidance_scale,
progress=progress,
)
else:
internal_batch_size = batch_size
if guidance_scale != 1.0:
model = uncond_guide_model(model, guidance_scale)
internal_batch_size *= 2
samples = diffusion.p_sample_loop(
model,
shape=(internal_batch_size, *sample_shape[1:]),
model_kwargs=model_kwargs,
device=device,
clip_denoised=clip_denoised,
progress=progress,
)
return samples
================================================
FILE: shap_e/examples/encode_model.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"from shap_e.models.download import load_model\n",
"from shap_e.util.data_util import load_or_create_multimodal_batch\n",
"from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"xm = load_model('transmitter', device=device)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"model_path = \"example_data/cactus/object.obj\"\n",
"\n",
"# This may take a few minutes, since it requires rendering the model twice\n",
"# in two different modes.\n",
"batch = load_or_create_multimodal_batch(\n",
" device,\n",
" model_path=model_path,\n",
" mv_light_mode=\"basic\",\n",
" mv_image_size=256,\n",
" cache_dir=\"example_data/cactus/cached\",\n",
" verbose=True, # this will show Blender output during renders\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad():\n",
" latent = xm.encoder.encode_to_bottleneck(batch)\n",
"\n",
" render_mode = 'stf' # you can change this to 'nerf'\n",
" size = 128 # recommended that you lower resolution when using nerf\n",
"\n",
" cameras = create_pan_cameras(size, device)\n",
" images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
" display(gif_widget(images))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: shap_e/examples/example_data/cactus/material.mtl
================================================
newmtl mat0
Ka 0.0000 0.7000 0.0000
Kd 0.0000 0.7000 0.0000
Ks 0.0000 0.0000 0.0000
newmtl mat1
Ka 0.6600 0.4400 0.2000
Kd 0.6600 0.4400 0.2000
Ks 0.0000 0.0000 0.0000
newmtl mat2
Ka 0.3000 0.3000 0.3000
Kd 0.3000 0.3000 0.3000
Ks 0.0000 0.0000 0.0000
newmtl mat3
Ka 0.0000 0.5000 0.0000
Kd 0.0000 0.5000 0.0000
Ks 0.0000 0.0000 0.0000
newmtl mat4
Ka 0.0000 0.5667 0.0000
Kd 0.0000 0.5667 0.0000
Ks 0.0000 0.0000 0.0000
newmtl mat5
Ka 0.5400 0.3933 0.2333
Kd 0.5400 0.3933 0.2333
Ks 0.0000 0.0000 0.0000
newmtl mat6
Ka 0.0000 0.6333 0.0000
Kd 0.0000 0.6333 0.0000
Ks 0.0000 0.0000 0.0000
newmtl mat7
Ka 0.2000 0.3667 0.2000
Kd 0.2000 0.3667 0.2000
Ks 0.0000 0.0000 0.0000
newmtl mat8
Ka 0.4200 0.3467 0.2667
Kd 0.4200 0.3467 0.2667
Ks 0.0000 0.0000 0.0000
newmtl mat9
Ka 0.1000 0.4333 0.1000
Kd 0.1000 0.4333 0.1000
Ks 0.0000 0.0000 0.0000
newmtl mat10
Ka 0.1000 0.5667 0.1000
Kd 0.1000 0.5667 0.1000
Ks 0.0000 0.0000 0.0000
newmtl mat11
Ka 0.2000 0.4333 0.2000
Kd 0.2000 0.4333 0.2000
Ks 0.0000 0.0000 0.0000
newmtl mat12
Ka 0.1000 0.5000 0.1000
Kd 0.1000 0.5000 0.1000
Ks 0.0000 0.0000 0.0000
================================================
FILE: shap_e/examples/example_data/cactus/object.obj
================================================
[File too large to display: 17.2 MB]
================================================
FILE: shap_e/examples/sample_image_to_3d.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "964ccced",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"from shap_e.diffusion.sample import sample_latents\n",
"from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n",
"from shap_e.models.download import load_model, load_config\n",
"from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget\n",
"from shap_e.util.image_util import load_image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8eed3a76",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d922637",
"metadata": {},
"outputs": [],
"source": [
"xm = load_model('transmitter', device=device)\n",
"model = load_model('image300M', device=device)\n",
"diffusion = diffusion_from_config(load_config('diffusion'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53d329d0",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 4\n",
"guidance_scale = 3.0\n",
"\n",
"# To get the best result, you should remove the background and show only the object of interest to the model.\n",
"image = load_image(\"example_data/corgi.png\")\n",
"\n",
"latents = sample_latents(\n",
" batch_size=batch_size,\n",
" model=model,\n",
" diffusion=diffusion,\n",
" guidance_scale=guidance_scale,\n",
" model_kwargs=dict(images=[image] * batch_size),\n",
" progress=True,\n",
" clip_denoised=True,\n",
" use_fp16=True,\n",
" use_karras=True,\n",
" karras_steps=64,\n",
" sigma_min=1e-3,\n",
" sigma_max=160,\n",
" s_churn=0,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "633da2ec",
"metadata": {},
"outputs": [],
"source": [
"render_mode = 'nerf' # you can change this to 'stf' for mesh rendering\n",
"size = 64 # this is the size of the renders; higher values take longer to render.\n",
"\n",
"cameras = create_pan_cameras(size, device)\n",
"for i, latent in enumerate(latents):\n",
" images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
" display(gif_widget(images))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: shap_e/examples/sample_text_to_3d.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "964ccced",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"from shap_e.diffusion.sample import sample_latents\n",
"from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n",
"from shap_e.models.download import load_model, load_config\n",
"from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8eed3a76",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d922637",
"metadata": {},
"outputs": [],
"source": [
"xm = load_model('transmitter', device=device)\n",
"model = load_model('text300M', device=device)\n",
"diffusion = diffusion_from_config(load_config('diffusion'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53d329d0",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 4\n",
"guidance_scale = 15.0\n",
"prompt = \"a shark\"\n",
"\n",
"latents = sample_latents(\n",
" batch_size=batch_size,\n",
" model=model,\n",
" diffusion=diffusion,\n",
" guidance_scale=guidance_scale,\n",
" model_kwargs=dict(texts=[prompt] * batch_size),\n",
" progress=True,\n",
" clip_denoised=True,\n",
" use_fp16=True,\n",
" use_karras=True,\n",
" karras_steps=64,\n",
" sigma_min=1e-3,\n",
" sigma_max=160,\n",
" s_churn=0,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "633da2ec",
"metadata": {},
"outputs": [],
"source": [
"render_mode = 'nerf' # you can change this to 'stf'\n",
"size = 64 # this is the size of the renders; higher values take longer to render.\n",
"\n",
"cameras = create_pan_cameras(size, device)\n",
"for i, latent in enumerate(latents):\n",
" images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
" display(gif_widget(images))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "85a4dce4",
"metadata": {},
"outputs": [],
"source": [
"# Example of saving the latents as meshes.\n",
"from shap_e.util.notebooks import decode_latent_mesh\n",
"\n",
"for i, latent in enumerate(latents):\n",
" t = decode_latent_mesh(xm, latent).tri_mesh()\n",
" with open(f'example_mesh_{i}.ply', 'wb') as f:\n",
" t.write_ply(f)\n",
" with open(f'example_mesh_{i}.obj', 'w') as f:\n",
" t.write_obj(f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: shap_e/models/__init__.py
================================================
================================================
FILE: shap_e/models/configs.py
================================================
from typing import Any, Dict, Union
import blobfile as bf
import torch
import torch.nn as nn
import yaml
from shap_e.models.generation.latent_diffusion import SplitVectorDiffusion
from shap_e.models.generation.perceiver import PointDiffusionPerceiver
from shap_e.models.generation.pooled_mlp import PooledMLP
from shap_e.models.generation.transformer import (
CLIPImageGridPointDiffusionTransformer,
CLIPImageGridUpsamplePointDiffusionTransformer,
CLIPImagePointDiffusionTransformer,
PointDiffusionTransformer,
UpsamplePointDiffusionTransformer,
)
from shap_e.models.nerf.model import MLPNeRFModel, VoidNeRFModel
from shap_e.models.nerf.renderer import OneStepNeRFRenderer, TwoStepNeRFRenderer
from shap_e.models.nerstf.mlp import MLPDensitySDFModel, MLPNeRSTFModel
from shap_e.models.nerstf.renderer import NeRSTFRenderer
from shap_e.models.nn.meta import batch_meta_state_dict
from shap_e.models.stf.mlp import MLPSDFModel, MLPTextureFieldModel
from shap_e.models.stf.renderer import STFRenderer
from shap_e.models.transmitter.base import ChannelsDecoder, Transmitter, VectorDecoder
from shap_e.models.transmitter.channels_encoder import (
PointCloudPerceiverChannelsEncoder,
PointCloudTransformerChannelsEncoder,
)
from shap_e.models.transmitter.multiview_encoder import MultiviewTransformerEncoder
from shap_e.models.transmitter.pc_encoder import (
PointCloudPerceiverEncoder,
PointCloudTransformerEncoder,
)
from shap_e.models.volume import BoundingBoxVolume, SphericalVolume, UnboundedVolume
def model_from_config(config: Union[str, Dict[str, Any]], device: torch.device) -> nn.Module:
if isinstance(config, str):
with bf.BlobFile(config, "rb") as f:
obj = yaml.load(f, Loader=yaml.SafeLoader)
return model_from_config(obj, device=device)
config = config.copy()
name = config.pop("name")
if name == "PointCloudTransformerEncoder":
return PointCloudTransformerEncoder(device=device, dtype=torch.float32, **config)
elif name == "PointCloudPerceiverEncoder":
return PointCloudPerceiverEncoder(device=device, dtype=torch.float32, **config)
elif name == "PointCloudTransformerChannelsEncoder":
return PointCloudTransformerChannelsEncoder(device=device, dtype=torch.float32, **config)
elif name == "PointCloudPerceiverChannelsEncoder":
return PointCloudPerceiverChannelsEncoder(device=device, dtype=torch.float32, **config)
elif name == "MultiviewTransformerEncoder":
return MultiviewTransformerEncoder(device=device, dtype=torch.float32, **config)
elif name == "Transmitter":
renderer = model_from_config(config.pop("renderer"), device=device)
param_shapes = {
k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()
}
encoder_config = config.pop("encoder").copy()
encoder_config["param_shapes"] = param_shapes
encoder = model_from_config(encoder_config, device=device)
return Transmitter(encoder=encoder, renderer=renderer, **config)
elif name == "VectorDecoder":
renderer = model_from_config(config.pop("renderer"), device=device)
param_shapes = {
k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()
}
return VectorDecoder(param_shapes=param_shapes, renderer=renderer, device=device, **config)
elif name == "ChannelsDecoder":
renderer = model_from_config(config.pop("renderer"), device=device)
param_shapes = {
k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()
}
return ChannelsDecoder(
param_shapes=param_shapes, renderer=renderer, device=device, **config
)
elif name == "OneStepNeRFRenderer":
config = config.copy()
for field in [
# Required
"void_model",
"foreground_model",
"volume",
# Optional to use NeRF++
"background_model",
"outer_volume",
]:
if field in config:
config[field] = model_from_config(config.pop(field).copy(), device)
return OneStepNeRFRenderer(device=device, **config)
elif name == "TwoStepNeRFRenderer":
config = config.copy()
for field in [
# Required
"void_model",
"coarse_model",
"fine_model",
"volume",
# Optional to use NeRF++
"coarse_background_model",
"fine_background_model",
"outer_volume",
]:
if field in config:
config[field] = model_from_config(config.pop(field).copy(), device)
return TwoStepNeRFRenderer(device=device, **config)
elif name == "PooledMLP":
return PooledMLP(device, **config)
elif name == "PointDiffusionTransformer":
return PointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "PointDiffusionPerceiver":
return PointDiffusionPerceiver(device=device, dtype=torch.float32, **config)
elif name == "CLIPImagePointDiffusionTransformer":
return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "CLIPImageGridPointDiffusionTransformer":
return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "UpsamplePointDiffusionTransformer":
return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
elif name == "CLIPImageGridUpsamplePointDiffusionTransformer":
return CLIPImageGridUpsamplePointDiffusionTransformer(
device=device, dtype=torch.float32, **config
)
elif name == "SplitVectorDiffusion":
inner_config = config.pop("inner")
d_latent = config.pop("d_latent")
latent_ctx = config.pop("latent_ctx", 1)
inner_config["input_channels"] = d_latent // latent_ctx
inner_config["n_ctx"] = latent_ctx
inner_config["output_channels"] = d_latent // latent_ctx * 2
inner_model = model_from_config(inner_config, device)
return SplitVectorDiffusion(
device=device, wrapped=inner_model, n_ctx=latent_ctx, d_latent=d_latent
)
elif name == "STFRenderer":
config = config.copy()
for field in ["sdf", "tf", "volume"]:
config[field] = model_from_config(config.pop(field), device)
return STFRenderer(device=device, **config)
elif name == "NeRSTFRenderer":
config = config.copy()
for field in ["sdf", "tf", "nerstf", "void", "volume"]:
if field not in config:
continue
config[field] = model_from_config(config.pop(field), device)
config.setdefault("sdf", None)
config.setdefault("tf", None)
config.setdefault("nerstf", None)
return NeRSTFRenderer(device=device, **config)
model_cls = {
"MLPSDFModel": MLPSDFModel,
"MLPTextureFieldModel": MLPTextureFieldModel,
"MLPNeRFModel": MLPNeRFModel,
"MLPDensitySDFModel": MLPDensitySDFModel,
"MLPNeRSTFModel": MLPNeRSTFModel,
"VoidNeRFModel": VoidNeRFModel,
"BoundingBoxVolume": BoundingBoxVolume,
"SphericalVolume": SphericalVolume,
"UnboundedVolume": UnboundedVolume,
}[name]
return model_cls(device=device, **config)
================================================
FILE: shap_e/models/download.py
================================================
"""
Adapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py
"""
import hashlib
import os
from functools import lru_cache
from typing import Dict, Optional
import requests
import torch
import yaml
from filelock import FileLock
from tqdm.auto import tqdm
MODEL_PATHS = {
"transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt",
"decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt",
"text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt",
"image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt",
}
CONFIG_PATHS = {
"transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml",
"decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml",
"text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml",
"image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml",
"diffusion": "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml",
}
URL_HASHES = {
"https://openaipublic.azureedge.net/main/shap-e/transmitter.pt": "af02a0b85a8abdfb3919584b63c540ba175f6ad4790f574a7fef4617e5acdc3b",
"https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt": "d7e7ebbfe3780499ae89b2da5e7c1354012dba5a6abfe295bed42f25c3be1b98",
"https://openaipublic.azureedge.net/main/shap-e/text_cond.pt": "e6b4fa599a7b3c3b16c222d5f5fe56f9db9289ff0b6575fbe5c11bc97106aad4",
"https://openaipublic.azureedge.net/main/shap-e/image_cond.pt": "cb8072c64bbbcf6910488814d212227de5db291780d4ea99c6152f9346cf12aa",
"https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml": "ffe1bcb405104a37d9408391182ab118a4ef313c391e07689684f1f62071605e",
"https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml": "e6d373649f8e24d85925f4674b9ac41c57aba5f60e42cde6d10f87381326365c",
"https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml": "f290beeea3d3e9ff15db01bde5382b6e549e463060c0744f89c049505be246c1",
"https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml": "4e0745605a533c543c72add803a78d233e2a6401e0abfa0cad58afb4d74ad0b0",
"https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml": "efcb2cd7ee545b2d27223979d41857802448143990572a42645cd09c2942ed57",
}
@lru_cache()
def default_cache_dir() -> str:
return os.path.join(os.path.abspath(os.getcwd()), "shap_e_model_cache")
def fetch_file_cached(
url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
) -> str:
"""
Download the file at the given URL into a local file and return the path.
If cache_dir is specified, it will be used to download the files.
Otherwise, default_cache_dir() is used.
"""
expected_hash = URL_HASHES[url]
if cache_dir is None:
cache_dir = default_cache_dir()
os.makedirs(cache_dir, exist_ok=True)
local_path = os.path.join(cache_dir, url.split("/")[-1])
if os.path.exists(local_path):
check_hash(local_path, expected_hash)
return local_path
response = requests.get(url, stream=True)
size = int(response.headers.get("content-length", "0"))
with FileLock(local_path + ".lock"):
if progress:
pbar = tqdm(total=size, unit="iB", unit_scale=True)
tmp_path = local_path + ".tmp"
with open(tmp_path, "wb") as f:
for chunk in response.iter_content(chunk_size):
if progress:
pbar.update(len(chunk))
f.write(chunk)
os.rename(tmp_path, local_path)
if progress:
pbar.close()
check_hash(local_path, expected_hash)
return local_path
def check_hash(path: str, expected_hash: str):
actual_hash = hash_file(path)
if actual_hash != expected_hash:
raise RuntimeError(
f"The file {path} should have hash {expected_hash} but has {actual_hash}. "
"Try deleting it and running this call again."
)
def hash_file(path: str) -> str:
sha256_hash = hashlib.sha256()
with open(path, "rb") as file:
while True:
data = file.read(4096)
if not len(data):
break
sha256_hash.update(data)
return sha256_hash.hexdigest()
def load_config(
config_name: str,
progress: bool = False,
cache_dir: Optional[str] = None,
chunk_size: int = 4096,
):
if config_name not in CONFIG_PATHS:
raise ValueError(
f"Unknown config name {config_name}. Known names are: {CONFIG_PATHS.keys()}."
)
path = fetch_file_cached(
CONFIG_PATHS[config_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
)
with open(path, "r") as f:
return yaml.safe_load(f)
def load_checkpoint(
checkpoint_name: str,
device: torch.device,
progress: bool = True,
cache_dir: Optional[str] = None,
chunk_size: int = 4096,
) -> Dict[str, torch.Tensor]:
if checkpoint_name not in MODEL_PATHS:
raise ValueError(
f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
)
path = fetch_file_cached(
MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
)
return torch.load(path, map_location=device)
def load_model(
model_name: str,
device: torch.device,
**kwargs,
) -> Dict[str, torch.Tensor]:
from .configs import model_from_config
model = model_from_config(load_config(model_name, **kwargs), device=device)
model.load_state_dict(load_checkpoint(model_name, device=device, **kwargs))
model.eval()
return model
================================================
FILE: shap_e/models/generation/__init__.py
================================================
================================================
FILE: shap_e/models/generation/latent_diffusion.py
================================================
from typing import Any, Dict
import torch
import torch.nn as nn
class SplitVectorDiffusion(nn.Module):
def __init__(self, *, device: torch.device, wrapped: nn.Module, n_ctx: int, d_latent: int):
super().__init__()
self.device = device
self.n_ctx = n_ctx
self.d_latent = d_latent
self.wrapped = wrapped
if hasattr(self.wrapped, "cached_model_kwargs"):
self.cached_model_kwargs = self.wrapped.cached_model_kwargs
def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs):
h = x.reshape(x.shape[0], self.n_ctx, -1).permute(0, 2, 1)
pre_channels = h.shape[1]
h = self.wrapped(h, t, **kwargs)
assert (
h.shape[1] == pre_channels * 2
), "expected twice as many outputs for variance prediction"
eps, var = torch.chunk(h, 2, dim=1)
return torch.cat(
[
eps.permute(0, 2, 1).flatten(1),
var.permute(0, 2, 1).flatten(1),
],
dim=1,
)
================================================
FILE: shap_e/models/generation/perceiver.py
================================================
import math
from typing import Optional
import torch
import torch.nn as nn
from shap_e.models.nn.checkpoint import checkpoint
from .transformer import MLP, Transformer, init_linear
from .util import timestep_embedding
class MultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
n_data: int,
width: int,
heads: int,
init_scale: float,
data_width: Optional[int] = None,
):
super().__init__()
self.n_ctx = n_ctx
self.n_data = n_data
self.width = width
self.heads = heads
self.data_width = width if data_width is None else data_width
self.c_q = nn.Linear(width, width, device=device, dtype=dtype)
self.c_kv = nn.Linear(self.data_width, width * 2, device=device, dtype=dtype)
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
self.attention = QKVMultiheadCrossAttention(
device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, n_data=n_data
)
init_linear(self.c_q, init_scale)
init_linear(self.c_kv, init_scale)
init_linear(self.c_proj, init_scale)
def forward(self, x, data):
x = self.c_q(x)
data = self.c_kv(data)
x = checkpoint(self.attention, (x, data), (), True)
x = self.c_proj(x)
return x
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, n_data: int
):
super().__init__()
self.device = device
self.dtype = dtype
self.heads = heads
self.n_ctx = n_ctx
self.n_data = n_data
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
attn_ch = width // self.heads // 2
scale = 1 / math.sqrt(math.sqrt(attn_ch))
q = q.view(bs, n_ctx, self.heads, -1)
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
weight = torch.einsum(
"bthc,bshc->bhts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
wdtype = weight.dtype
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
n_data: int,
width: int,
heads: int,
data_width: Optional[int] = None,
init_scale: float = 1.0,
):
super().__init__()
if data_width is None:
data_width = width
self.attn = MultiheadCrossAttention(
device=device,
dtype=dtype,
n_ctx=n_ctx,
n_data=n_data,
width=width,
heads=heads,
data_width=data_width,
init_scale=init_scale,
)
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, data: torch.Tensor):
x = x + self.attn(self.ln_1(x), self.ln_2(data))
x = x + self.mlp(self.ln_3(x))
return x
class SimplePerceiver(nn.Module):
"""
Only does cross attention
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
n_data: int,
width: int,
layers: int,
heads: int,
init_scale: float = 0.25,
data_width: Optional[int] = None,
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
init_scale = init_scale * math.sqrt(1.0 / width)
self.resblocks = nn.ModuleList(
[
ResidualCrossAttentionBlock(
device=device,
dtype=dtype,
n_ctx=n_ctx,
n_data=n_data,
width=width,
heads=heads,
init_scale=init_scale,
data_width=data_width,
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor, data: torch.Tensor):
for block in self.resblocks:
x = block(x, data)
return x
class PointDiffusionPerceiver(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
input_channels: int = 3,
output_channels: int = 3,
n_ctx: int = 1024,
n_latent: int = 128,
width: int = 512,
encoder_layers: int = 12,
latent_layers: int = 12,
decoder_layers: int = 12,
heads: int = 8,
init_scale: float = 0.25,
):
super().__init__()
self.time_embed = MLP(
device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
)
self.latent_embed = MLP(
device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
)
self.n_latent = n_latent
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.encoder = SimplePerceiver(
device=device,
dtype=dtype,
n_ctx=n_latent,
n_data=n_ctx,
width=width,
layers=encoder_layers,
heads=heads,
init_scale=init_scale,
)
self.processor = Transformer(
device=device,
dtype=dtype,
n_ctx=n_latent,
width=width,
layers=latent_layers,
heads=heads,
init_scale=init_scale,
)
self.decoder = SimplePerceiver(
device=device,
dtype=dtype,
n_ctx=n_ctx,
n_data=n_latent,
width=width,
layers=decoder_layers,
heads=heads,
init_scale=init_scale,
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
with torch.no_grad():
self.output_proj.weight.zero_()
self.output_proj.bias.zero_()
def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
:param x: an [N x C x T] tensor.
:param t: an [N] tensor.
:return: an [N x C' x T] tensor.
"""
assert x.shape[-1] == self.decoder.n_ctx
t_embed = self.time_embed(timestep_embedding(t, self.encoder.width))
data = self.input_proj(x.permute(0, 2, 1)) + t_embed[:, None]
data = self.ln_pre(data)
l = torch.arange(self.n_latent).to(x.device)
h = self.latent_embed(timestep_embedding(l, self.decoder.width))
h = h.unsqueeze(0).repeat(x.shape[0], 1, 1)
h = self.encoder(h, data)
h = self.processor(h)
h = self.decoder(data, h)
h = self.ln_post(h)
h = self.output_proj(h)
return h.permute(0, 2, 1)
================================================
FILE: shap_e/models/generation/pooled_mlp.py
================================================
import torch
import torch.nn as nn
from .util import timestep_embedding
class PooledMLP(nn.Module):
def __init__(
self,
device: torch.device,
*,
input_channels: int = 3,
output_channels: int = 6,
hidden_size: int = 256,
resblocks: int = 4,
pool_op: str = "max",
):
super().__init__()
self.input_embed = nn.Conv1d(input_channels, hidden_size, kernel_size=1, device=device)
self.time_embed = nn.Linear(hidden_size, hidden_size, device=device)
blocks = []
for _ in range(resblocks):
blocks.append(ResBlock(hidden_size, pool_op, device=device))
self.sequence = nn.Sequential(*blocks)
self.out = nn.Conv1d(hidden_size, output_channels, kernel_size=1, device=device)
with torch.no_grad():
self.out.bias.zero_()
self.out.weight.zero_()
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
in_embed = self.input_embed(x)
t_embed = self.time_embed(timestep_embedding(t, in_embed.shape[1]))
h = in_embed + t_embed[..., None]
h = self.sequence(h)
h = self.out(h)
return h
class ResBlock(nn.Module):
def __init__(self, hidden_size: int, pool_op: str, device: torch.device):
super().__init__()
assert pool_op in ["mean", "max"]
self.pool_op = pool_op
self.body = nn.Sequential(
nn.SiLU(),
nn.LayerNorm((hidden_size,), device=device),
nn.Linear(hidden_size, hidden_size, device=device),
nn.SiLU(),
nn.LayerNorm((hidden_size,), device=device),
nn.Linear(hidden_size, hidden_size, device=device),
)
self.gate = nn.Sequential(
nn.Linear(hidden_size, hidden_size, device=device),
nn.Tanh(),
)
def forward(self, x: torch.Tensor):
N, C, T = x.shape
out = self.body(x.permute(0, 2, 1).reshape(N * T, C)).reshape([N, T, C]).permute(0, 2, 1)
pooled = pool(self.pool_op, x)
gate = self.gate(pooled)
return x + out * gate[..., None]
def pool(op_name: str, x: torch.Tensor) -> torch.Tensor:
if op_name == "max":
pooled, _ = torch.max(x, dim=-1)
elif op_name == "mean":
pooled, _ = torch.mean(x, dim=-1)
else:
raise ValueError(f"unknown pool op: {op_name}")
return pooled
================================================
FILE: shap_e/models/generation/pretrained_clip.py
================================================
from typing import Iterable, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from shap_e.models.download import default_cache_dir
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
class ImageCLIP(nn.Module):
"""
A wrapper around a pre-trained CLIP model that automatically handles
batches of texts, images, and embeddings.
"""
def __init__(
self,
device: torch.device,
dtype: Optional[torch.dtype] = torch.float32,
ensure_used_params: bool = True,
clip_name: str = "ViT-L/14",
cache_dir: Optional[str] = None,
):
super().__init__()
assert clip_name in ["ViT-L/14", "ViT-B/32"]
self.device = device
self.ensure_used_params = ensure_used_params
# Lazy import because of torchvision.
import clip
self.clip_model, self.preprocess = clip.load(
clip_name, device=device, download_root=cache_dir or default_cache_dir()
)
self.clip_name = clip_name
if dtype is not None:
self.clip_model.to(dtype)
self._tokenize = clip.tokenize
@property
def feature_dim(self) -> int:
if self.clip_name == "ViT-L/14":
return 768
else:
return 512
@property
def grid_size(self) -> int:
if self.clip_name == "ViT-L/14":
return 16
else:
return 7
@property
def grid_feature_dim(self) -> int:
if self.clip_name == "ViT-L/14":
return 1024
else:
return 768
def forward(
self,
batch_size: int,
images: Optional[Iterable[Optional[ImageType]]] = None,
texts: Optional[Iterable[Optional[str]]] = None,
embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
) -> torch.Tensor:
"""
Generate a batch of embeddings from a mixture of images, texts,
precomputed embeddings, and possibly empty values.
For each batch element, at most one of images, texts, and embeddings
should have a non-None value. Embeddings from multiple modalities
cannot be mixed for a single batch element. If no modality is provided,
a zero embedding will be used for the batch element.
"""
image_seq = [None] * batch_size if images is None else list(images)
text_seq = [None] * batch_size if texts is None else list(texts)
embedding_seq = [None] * batch_size if embeddings is None else list(embeddings)
assert len(image_seq) == batch_size, "number of images should match batch size"
assert len(text_seq) == batch_size, "number of texts should match batch size"
assert len(embedding_seq) == batch_size, "number of embeddings should match batch size"
if self.ensure_used_params:
return self._static_multimodal_embed(
images=image_seq, texts=text_seq, embeddings=embedding_seq
)
result = torch.zeros((batch_size, self.feature_dim), device=self.device)
index_images = []
index_texts = []
for i, (image, text, emb) in enumerate(zip(image_seq, text_seq, embedding_seq)):
assert (
sum([int(image is not None), int(text is not None), int(emb is not None)]) < 2
), "only one modality may be non-None per batch element"
if image is not None:
index_images.append((i, image))
elif text is not None:
index_texts.append((i, text))
elif emb is not None:
result[i] = emb.to(result)
if len(index_images):
embs = self.embed_images((img for _, img in index_images))
for (i, _), emb in zip(index_images, embs):
result[i] = emb.to(result)
if len(index_texts):
embs = self.embed_text((text for _, text in index_texts))
for (i, _), emb in zip(index_texts, embs):
result[i] = emb.to(result)
return result
def _static_multimodal_embed(
self,
images: List[Optional[ImageType]] = None,
texts: List[Optional[str]] = None,
embeddings: List[Optional[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Like forward(), but always runs all encoders to ensure that
the forward graph looks the same on every rank.
"""
image_emb = self.embed_images(images)
text_emb = self.embed_text(t if t else "" for t in texts)
joined_embs = torch.stack(
[
emb.to(device=self.device, dtype=torch.float32)
if emb is not None
else torch.zeros(self.feature_dim, device=self.device)
for emb in embeddings
],
dim=0,
)
image_flag = torch.tensor([x is not None for x in images], device=self.device)[
:, None
].expand_as(image_emb)
text_flag = torch.tensor([x is not None for x in texts], device=self.device)[
:, None
].expand_as(image_emb)
emb_flag = torch.tensor([x is not None for x in embeddings], device=self.device)[
:, None
].expand_as(image_emb)
return (
image_flag.float() * image_emb
+ text_flag.float() * text_emb
+ emb_flag.float() * joined_embs
+ self.clip_model.logit_scale * 0 # avoid unused parameters
)
def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
"""
:param xs: N images, stored as numpy arrays, tensors, or PIL images.
:return: an [N x D] tensor of features.
"""
clip_inputs = self.images_to_tensor(xs)
results = self.clip_model.encode_image(clip_inputs).float()
return results / torch.linalg.norm(results, dim=-1, keepdim=True)
def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
"""
Embed text prompts as an [N x D] tensor.
"""
enc = self.clip_model.encode_text(
self._tokenize(list(prompts), truncate=True).to(self.device)
).float()
return enc / torch.linalg.norm(enc, dim=-1, keepdim=True)
def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
"""
Embed images into latent grids.
:param xs: an iterable of images to embed.
:return: a tensor of shape [N x C x L], where L = self.grid_size**2.
"""
if self.ensure_used_params:
extra_value = 0.0
for p in self.parameters():
extra_value = extra_value + p.mean() * 0.0
else:
extra_value = 0.0
x = self.images_to_tensor(xs).to(self.clip_model.dtype)
# https://github.com/openai/CLIP/blob/4d120f3ec35b30bd0f992f5d8af2d793aad98d2a/clip/model.py#L225
vt = self.clip_model.visual
x = vt.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[
vt.class_embedding.to(x.dtype)
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + vt.positional_embedding.to(x.dtype)
x = vt.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = vt.transformer(x)
x = x.permute(1, 2, 0) # LND -> NDL
return x[..., 1:].contiguous().float() + extra_value
def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0).to(self.device)
class FrozenImageCLIP:
def __init__(self, device: torch.device, **kwargs):
self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **kwargs)
for parameter in self.model.parameters():
parameter.requires_grad_(False)
@property
def feature_dim(self) -> int:
return self.model.feature_dim
@property
def grid_size(self) -> int:
return self.model.grid_size
@property
def grid_feature_dim(self) -> int:
return self.model.grid_feature_dim
def __call__(
self,
batch_size: int,
images: Optional[Iterable[Optional[ImageType]]] = None,
texts: Optional[Iterable[Optional[str]]] = None,
embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
) -> torch.Tensor:
# We don't do a no_grad() here so that gradients could still
# flow to the input embeddings argument.
# This behavior is currently not used, but it could be.
return self.model(batch_size=batch_size, images=images, texts=texts, embeddings=embeddings)
def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
with torch.no_grad():
return self.model.embed_images(xs)
def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
with torch.no_grad():
return self.model.embed_text(prompts)
def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
with torch.no_grad():
return self.model.embed_images_grid(xs)
def _image_to_pil(obj: Optional[ImageType]) -> Image.Image:
if obj is None:
return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8))
if isinstance(obj, np.ndarray):
return Image.fromarray(obj.astype(np.uint8))
elif isinstance(obj, torch.Tensor):
return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8))
else:
return obj
================================================
FILE: shap_e/models/generation/transformer.py
================================================
import math
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
import torch
import torch.nn as nn
from shap_e.models.nn.checkpoint import checkpoint
from .pretrained_clip import FrozenImageCLIP, ImageCLIP, ImageType
from .util import timestep_embedding
def init_linear(l, stddev):
nn.init.normal_(l.weight, std=stddev)
if l.bias is not None:
nn.init.constant_(l.bias, 0.0)
class MultiheadAttention(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
width: int,
heads: int,
init_scale: float,
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.heads = heads
self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype)
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
init_linear(self.c_qkv, init_scale)
init_linear(self.c_proj, init_scale)
def forward(self, x):
x = self.c_qkv(x)
x = checkpoint(self.attention, (x,), (), True)
x = self.c_proj(x)
return x
class MLP(nn.Module):
def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float):
super().__init__()
self.width = width
self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
self.gelu = nn.GELU()
init_linear(self.c_fc, init_scale)
init_linear(self.c_proj, init_scale)
def forward(self, x):
return self.c_proj(self.gelu(self.c_fc(x)))
class QKVMultiheadAttention(nn.Module):
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
super().__init__()
self.device = device
self.dtype = dtype
self.heads = heads
self.n_ctx = n_ctx
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.heads // 3
scale = 1 / math.sqrt(math.sqrt(attn_ch))
qkv = qkv.view(bs, n_ctx, self.heads, -1)
q, k, v = torch.split(qkv, attn_ch, dim=-1)
weight = torch.einsum(
"bthc,bshc->bhts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
wdtype = weight.dtype
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
width: int,
heads: int,
init_scale: float = 1.0,
):
super().__init__()
self.attn = MultiheadAttention(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
)
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
def forward(self, x: torch.Tensor):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
width: int,
layers: int,
heads: int,
init_scale: float = 0.25,
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
init_scale = init_scale * math.sqrt(1.0 / width)
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor):
for block in self.resblocks:
x = block(x)
return x
class PointDiffusionTransformer(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
input_channels: int = 3,
output_channels: int = 3,
n_ctx: int = 1024,
width: int = 512,
layers: int = 12,
heads: int = 8,
init_scale: float = 0.25,
time_token_cond: bool = False,
use_pos_emb: bool = False,
pos_emb_init_scale: float = 1.0,
pos_emb_n_ctx: Optional[int] = None,
):
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.n_ctx = n_ctx
self.time_token_cond = time_token_cond
self.use_pos_emb = use_pos_emb
self.time_embed = MLP(
device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
)
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.backbone = Transformer(
device=device,
dtype=dtype,
n_ctx=n_ctx + int(time_token_cond),
width=width,
layers=layers,
heads=heads,
init_scale=init_scale,
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
with torch.no_grad():
self.output_proj.weight.zero_()
self.output_proj.bias.zero_()
if self.use_pos_emb:
self.register_parameter(
"pos_emb",
nn.Parameter(
pos_emb_init_scale
* torch.randn(pos_emb_n_ctx or self.n_ctx, width, device=device, dtype=dtype)
),
)
def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
:param x: an [N x C x T] tensor.
:param t: an [N] tensor.
:return: an [N x C' x T] tensor.
"""
assert x.shape[-1] == self.n_ctx
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
return self._forward_with_cond(x, [(t_embed, self.time_token_cond)])
def _forward_with_cond(
self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]]
) -> torch.Tensor:
h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC
for emb, as_token in cond_as_token:
if not as_token:
h = h + emb[:, None]
if self.use_pos_emb:
h = h + self.pos_emb
extra_tokens = [
(emb[:, None] if len(emb.shape) == 2 else emb)
for emb, as_token in cond_as_token
if as_token
]
if len(extra_tokens):
h = torch.cat(extra_tokens + [h], dim=1)
h = self.ln_pre(h)
h = self.backbone(h)
h = self.ln_post(h)
if len(extra_tokens):
h = h[:, sum(h.shape[1] for h in extra_tokens) :]
h = self.output_proj(h)
return h.permute(0, 2, 1)
class CLIPImagePointDiffusionTransformer(PointDiffusionTransformer):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int = 1024,
token_cond: bool = False,
cond_drop_prob: float = 0.0,
frozen_clip: bool = True,
**kwargs,
):
super().__init__(
device=device, dtype=dtype, n_ctx=n_ctx + int(token_cond), pos_emb_n_ctx=n_ctx, **kwargs
)
self.n_ctx = n_ctx
self.token_cond = token_cond
self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
self.clip_embed = nn.Linear(
self.clip.feature_dim, self.backbone.width, device=device, dtype=dtype
)
self.cond_drop_prob = cond_drop_prob
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
with torch.no_grad():
return dict(embeddings=self.clip(batch_size, **model_kwargs))
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
images: Optional[Iterable[Optional[ImageType]]] = None,
texts: Optional[Iterable[Optional[str]]] = None,
embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
):
"""
:param x: an [N x C x T] tensor.
:param t: an [N] tensor.
:param images: a batch of images to condition on.
:param texts: a batch of texts to condition on.
:param embeddings: a batch of CLIP embeddings to condition on.
:return: an [N x C' x T] tensor.
"""
assert x.shape[-1] == self.n_ctx
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
clip_out = self.clip(batch_size=len(x), images=images, texts=texts, embeddings=embeddings)
assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0]
if self.training:
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
clip_out = clip_out * mask[:, None].to(clip_out)
# Rescale the features to have unit variance
clip_out = math.sqrt(clip_out.shape[1]) * clip_out
clip_embed = self.clip_embed(clip_out)
cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)]
return self._forward_with_cond(x, cond)
class CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int = 1024,
cond_drop_prob: float = 0.0,
frozen_clip: bool = True,
**kwargs,
):
clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
super().__init__(
device=device,
dtype=dtype,
n_ctx=n_ctx + clip.grid_size**2,
pos_emb_n_ctx=n_ctx,
**kwargs,
)
self.n_ctx = n_ctx
self.clip = clip
self.clip_embed = nn.Sequential(
nn.LayerNorm(
normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype
),
nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),
)
self.cond_drop_prob = cond_drop_prob
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
_ = batch_size
with torch.no_grad():
return dict(embeddings=self.clip.embed_images_grid(model_kwargs["images"]))
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
images: Optional[Iterable[ImageType]] = None,
embeddings: Optional[Iterable[torch.Tensor]] = None,
):
"""
:param x: an [N x C x T] tensor.
:param t: an [N] tensor.
:param images: a batch of images to condition on.
:param embeddings: a batch of CLIP latent grids to condition on.
:return: an [N x C' x T] tensor.
"""
assert images is not None or embeddings is not None, "must specify images or embeddings"
assert images is None or embeddings is None, "cannot specify both images and embeddings"
assert x.shape[-1] == self.n_ctx
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
if images is not None:
clip_out = self.clip.embed_images_grid(images)
else:
clip_out = embeddings
if self.training:
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
clip_out = clip_out * mask[:, None, None].to(clip_out)
clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC
clip_embed = self.clip_embed(clip_out)
cond = [(t_embed, self.time_token_cond), (clip_embed, True)]
return self._forward_with_cond(x, cond)
class UpsamplePointDiffusionTransformer(PointDiffusionTransformer):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
cond_input_channels: Optional[int] = None,
cond_ctx: int = 1024,
n_ctx: int = 4096 - 1024,
channel_scales: Optional[Sequence[float]] = None,
channel_biases: Optional[Sequence[float]] = None,
**kwargs,
):
super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **kwargs)
self.n_ctx = n_ctx
self.cond_input_channels = cond_input_channels or self.input_channels
self.cond_point_proj = nn.Linear(
self.cond_input_channels, self.backbone.width, device=device, dtype=dtype
)
self.register_buffer(
"channel_scales",
torch.tensor(channel_scales, dtype=dtype, device=device)
if channel_scales is not None
else None,
)
self.register_buffer(
"channel_biases",
torch.tensor(channel_biases, dtype=dtype, device=device)
if channel_biases is not None
else None,
)
def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor):
"""
:param x: an [N x C1 x T] tensor.
:param t: an [N] tensor.
:param low_res: an [N x C2 x T'] tensor of conditioning points.
:return: an [N x C3 x T] tensor.
"""
assert x.shape[-1] == self.n_ctx
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
low_res_embed = self._embed_low_res(low_res)
cond = [(t_embed, self.time_token_cond), (low_res_embed, True)]
return self._forward_with_cond(x, cond)
def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor:
if self.channel_scales is not None:
x = x * self.channel_scales[None, :, None]
if self.channel_biases is not None:
x = x + self.channel_biases[None, :, None]
return self.cond_point_proj(x.permute(0, 2, 1))
class CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffusionTransformer):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int = 4096 - 1024,
cond_drop_prob: float = 0.0,
frozen_clip: bool = True,
**kwargs,
):
clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs)
self.n_ctx = n_ctx
self.clip = clip
self.clip_embed = nn.Sequential(
nn.LayerNorm(
normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype
),
nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),
)
self.cond_drop_prob = cond_drop_prob
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
_ = batch_size
with torch.no_grad():
return dict(
embeddings=self.clip.embed_images_grid(model_kwargs["images"]),
low_res=model_kwargs["low_res"],
)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
*,
low_res: torch.Tensor,
images: Optional[Iterable[ImageType]] = None,
embeddings: Optional[Iterable[torch.Tensor]] = None,
):
"""
:param x: an [N x C1 x T] tensor.
:param t: an [N] tensor.
:param low_res: an [N x C2 x T'] tensor of conditioning points.
:param images: a batch of images to condition on.
:param embeddings: a batch of CLIP latent grids to condition on.
:return: an [N x C3 x T] tensor.
"""
assert x.shape[-1] == self.n_ctx
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
low_res_embed = self._embed_low_res(low_res)
if images is not None:
clip_out = self.clip.embed_images_grid(images)
else:
clip_out = embeddings
if self.training:
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
clip_out = clip_out * mask[:, None, None].to(clip_out)
clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC
clip_embed = self.clip_embed(clip_out)
cond = [(t_embed, self.time_token_cond), (clip_embed, True), (low_res_embed, True)]
return self._forward_with_cond(x, cond)
================================================
FILE: shap_e/models/generation/util.py
================================================
import math
import torch
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
================================================
FILE: shap_e/models/nerf/__init__.py
================================================
================================================
FILE: shap_e/models/nerf/model.py
================================================
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from shap_e.models.nn.checkpoint import checkpoint
from shap_e.models.nn.encoding import encode_position, spherical_harmonics_basis
from shap_e.models.nn.meta import MetaModule, subdict
from shap_e.models.nn.ops import MLP, MetaMLP, get_act, mlp_init, zero_init
from shap_e.models.nn.utils import ArrayType
from shap_e.models.query import Query
from shap_e.util.collections import AttrDict
class NeRFModel(ABC):
"""
Parametric scene representation whose outputs are integrated by NeRFRenderer
"""
@abstractmethod
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
"""
:param query: the points in the field to query.
:param params: Meta parameters
:param options: Optional hyperparameters
:return: An AttrDict containing at least
- density: [batch_size x ... x 1]
- channels: [batch_size x ... x n_channels]
- aux_losses: [batch_size x ... x 1]
"""
class VoidNeRFModel(MetaModule, NeRFModel):
"""
Implements the default empty space model where all queries are rendered as
background.
"""
def __init__(
self,
background: ArrayType,
trainable: bool = False,
channel_scale: float = 255.0,
device: torch.device = torch.device("cuda"),
):
super().__init__()
background = nn.Parameter(
torch.from_numpy(np.array(background)).to(dtype=torch.float32, device=device)
/ channel_scale
)
if trainable:
self.register_parameter("background", background)
else:
self.register_buffer("background", background)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
_ = params
default_bg = self.background[None]
background = options.get("background", default_bg) if options is not None else default_bg
shape = query.position.shape[:-1]
ones = [1] * (len(shape) - 1)
n_channels = background.shape[-1]
background = torch.broadcast_to(
background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]
)
return background
class MLPNeRFModel(MetaModule, NeRFModel):
def __init__(
self,
# Positional encoding parameters
n_levels: int = 10,
# MLP parameters
d_hidden: int = 256,
n_density_layers: int = 4,
n_channel_layers: int = 1,
n_channels: int = 3,
sh_degree: int = 4,
activation: str = "relu",
density_activation: str = "exp",
init: Optional[str] = None,
init_scale: float = 1.0,
output_activation: str = "sigmoid",
meta_parameters: bool = False,
trainable_meta: bool = False,
zero_out: bool = True,
register_freqs: bool = True,
posenc_version: str = "v1",
device: torch.device = torch.device("cuda"),
):
super().__init__()
# Positional encoding
if register_freqs:
# not used anymore
self.register_buffer(
"freqs",
2.0 ** torch.arange(n_levels, device=device, dtype=torch.float).view(1, n_levels),
)
self.posenc_version = posenc_version
dummy = torch.eye(1, 3)
d_input = encode_position(posenc_version, position=dummy).shape[-1]
self.n_levels = n_levels
self.sh_degree = sh_degree
d_sh_coeffs = sh_degree**2
self.meta_parameters = meta_parameters
mlp_cls = (
partial(
MetaMLP,
meta_scale=False,
meta_shift=False,
meta_proj=True,
meta_bias=True,
trainable_meta=trainable_meta,
)
if meta_parameters
else MLP
)
self.density_mlp = mlp_cls(
d_input=d_input,
d_hidden=[d_hidden] * (n_density_layers - 1),
d_output=d_hidden,
act_name=activation,
init_scale=init_scale,
)
self.channel_mlp = mlp_cls(
d_input=d_hidden + d_sh_coeffs,
d_hidden=[d_hidden] * n_channel_layers,
d_output=n_channels,
act_name=activation,
init_scale=init_scale,
)
self.act = get_act(output_activation)
self.density_act = get_act(density_activation)
mlp_init(
list(self.density_mlp.affines) + list(self.channel_mlp.affines),
init=init,
init_scale=init_scale,
)
if zero_out:
zero_init(self.channel_mlp.affines[-1])
self.to(device)
def encode_position(self, query: Query):
h = encode_position(self.posenc_version, position=query.position)
return h
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
params = self.update(params)
options = AttrDict() if options is None else AttrDict(options)
query = query.copy()
h_position = self.encode_position(query)
if self.meta_parameters:
density_params = subdict(params, "density_mlp")
density_mlp = partial(
self.density_mlp, params=density_params, options=options, log_prefix="density_"
)
density_mlp_parameters = list(density_params.values())
else:
density_mlp = partial(self.density_mlp, options=options, log_prefix="density_")
density_mlp_parameters = self.density_mlp.parameters()
h_density = checkpoint(
density_mlp,
(h_position,),
density_mlp_parameters,
options.checkpoint_nerf_mlp,
)
h_direction = maybe_get_spherical_harmonics_basis(
sh_degree=self.sh_degree,
coords_shape=query.position.shape,
coords=query.direction,
device=query.position.device,
)
if self.meta_parameters:
channel_params = subdict(params, "channel_mlp")
channel_mlp = partial(
self.channel_mlp, params=channel_params, options=options, log_prefix="channel_"
)
channel_mlp_parameters = list(channel_params.values())
else:
channel_mlp = partial(self.channel_mlp, options=options, log_prefix="channel_")
channel_mlp_parameters = self.channel_mlp.parameters()
h_channel = checkpoint(
channel_mlp,
(torch.cat([h_density, h_direction], dim=-1),),
channel_mlp_parameters,
options.checkpoint_nerf_mlp,
)
density_logit = h_density[..., :1]
res = AttrDict(
density_logit=density_logit,
density=self.density_act(density_logit),
channels=self.act(h_channel),
aux_losses=AttrDict(),
no_weight_grad_aux_losses=AttrDict(),
)
if options.return_h_density:
res.h_density = h_density
return res
def maybe_get_spherical_harmonics_basis(
sh_degree: int,
coords_shape: Tuple[int],
coords: Optional[torch.Tensor] = None,
device: torch.device = torch.device("cuda"),
) -> torch.Tensor:
"""
:param sh_degree: Spherical harmonics degree
:param coords_shape: [*shape, 3]
:param coords: optional coordinate tensor of coords_shape
"""
if coords is None:
return torch.zeros(*coords_shape[:-1], sh_degree**2).to(device)
return spherical_harmonics_basis(coords, sh_degree)
================================================
FILE: shap_e/models/nerf/ray.py
================================================
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Tuple
import torch
from shap_e.models.nn.utils import sample_pmf
from shap_e.models.volume import Volume, VolumeRange
from shap_e.util.collections import AttrDict
from .model import NeRFModel, Query
def render_rays(
rays: torch.Tensor,
parts: List["RayVolumeIntegral"],
void_model: NeRFModel,
shared: bool = False,
prev_raw_outputs: Optional[List[AttrDict]] = None,
render_with_direction: bool = True,
importance_sampling_options: Optional[Dict[str, Any]] = None,
) -> Tuple["RayVolumeIntegralResults", List["RaySampler"], List[AttrDict]]:
"""
Perform volumetric rendering over a partition of possible t's in the union
of rendering volumes (written below with some abuse of notations)
C(r) := sum(
transmittance(t[i]) *
integrate(
lambda t: density(t) * channels(t) * transmittance(t),
[t[i], t[i + 1]],
)
for i in range(len(parts))
) + transmittance(t[-1]) * void_model(t[-1]).channels
where
1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the
probability of light passing through the volume specified by [t[0], s].
(transmittance of 1 means light can pass freely)
2) density and channels are obtained by evaluating the appropriate
part.model at time t.
3) [t[i], t[i + 1]] is defined as the range of t where the ray intersects
(parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface
of the shell (if bounded). If the ray does not intersect, the integral over
this segment is evaluated as 0 and transmittance(t[i + 1]) :=
transmittance(t[i]).
4) The last term is integration to infinity (e.g. [t[-1], math.inf]) that
is evaluated by the void_model (i.e. we consider this space to be empty).
:param rays: [batch_size x ... x 2 x 3] origin and direction.
:param parts: disjoint volume integrals.
:param void_model: use this model to integrate over the empty space
:param shared: All RayVolumeIntegrals are calculated with the same model.
:param prev_raw_outputs: Raw outputs from the previous rendering step
:return: A tuple of
- AttrDict containing the rendered `channels`, `distances`, and the `aux_losses`
- A list of importance samplers for additional fine-grained rendering
- A list of raw output for each interval
"""
if importance_sampling_options is None:
importance_sampling_options = {}
origin, direc = rays[..., 0, :], rays[..., 1, :]
if prev_raw_outputs is None:
prev_raw_outputs = [None] * len(parts)
samplers = []
raw_outputs = []
t0 = None
results = None
for part_i, prev_raw_i in zip(parts, prev_raw_outputs):
# Integrate over [t[i], t[i + 1]]
results_i = part_i.render_rays(
origin,
direc,
t0=t0,
prev_raw=prev_raw_i,
shared=shared,
render_with_direction=render_with_direction,
)
# Create an importance sampler for (optional) fine rendering
samplers.append(
ImportanceRaySampler(
results_i.volume_range, results_i.raw, **importance_sampling_options
)
)
raw_outputs.append(results_i.raw)
# Pass t[i + 1] as the start of integration for the next interval.
t0 = results_i.volume_range.next_t0()
# Combine the results from [t[0], t[i]] and [t[i], t[i+1]]
results = results_i if results is None else results.combine(results_i)
# While integrating out [t[-1], math.inf] is the correct thing to do, this
# erases a lot of useful information. Also, void_model is meant to predict
# the channels at t=math.inf.
# # Add the void background over [t[-1], math.inf] to complete integration.
# results = results.combine(
# RayVolumeIntegralResults(
# output=AttrDict(
# channels=void_model(origin, direc),
# distances=torch.zeros_like(t0),
# aux_losses=AttrDict(),
# ),
# volume_range=VolumeRange(
# t0=t0,
# t1=torch.full_like(t0, math.inf),
# intersected=torch.full_like(results.volume_range.intersected, True),
# ),
# # Void space extends to infinity. It is assumed that no light
# # passes beyond the void.
# transmittance=torch.zeros_like(results_i.transmittance),
# )
# )
results.output.channels = results.output.channels + results.transmittance * void_model(
Query(origin, direc)
)
return results, samplers, raw_outputs
@dataclass
class RayVolumeIntegralResults:
"""
Stores the relevant state and results of
integrate(
lambda t: density(t) * channels(t) * transmittance(t),
[t0, t1],
)
"""
# Rendered output and auxiliary losses
# output.channels has shape [batch_size, *inner_shape, n_channels]
output: AttrDict
"""
Optional values
"""
# Raw values contain the sampled `ts`, `density`, `channels`, etc.
raw: Optional[AttrDict] = None
# Integration
volume_range: Optional[VolumeRange] = None
# If a ray intersects, the transmittance from t0 to t1 (e.g. the
# probability that the ray passes through this volume).
# has shape [batch_size, *inner_shape, 1]
transmittance: Optional[torch.Tensor] = None
def combine(self, cur: "RayVolumeIntegralResults") -> "RayVolumeIntegralResults":
"""
Combines the integration results of `self` over [t0, t1] and
`cur` over [t1, t2] to produce a new set of results over [t0, t2] by
using a similar equation to (4) in NeRF++:
integrate(
lambda t: density(t) * channels(t) * transmittance(t),
[t0, t2]
)
= integrate(
lambda t: density(t) * channels(t) * transmittance(t),
[t0, t1]
) + transmittance(t1) * integrate(
lambda t: density(t) * channels(t) * transmittance(t),
[t1, t2]
)
"""
assert torch.allclose(self.volume_range.next_t0(), cur.volume_range.t0)
def _combine_fn(
prev_val: Optional[torch.Tensor],
cur_val: Optional[torch.Tensor],
*,
prev_transmittance: torch.Tensor,
):
assert prev_val is not None
if cur_val is None:
# cur_output.aux_losses are empty for the void_model.
return prev_val
return prev_val + prev_transmittance * cur_val
output = self.output.combine(
cur.output, combine_fn=partial(_combine_fn, prev_transmittance=self.transmittance)
)
combined = RayVolumeIntegralResults(
output=output,
volume_range=self.volume_range.extend(cur.volume_range),
transmittance=self.transmittance * cur.transmittance,
)
return combined
@dataclass
class RayVolumeIntegral:
model: NeRFModel
volume: Volume
sampler: "RaySampler"
n_samples: int
def render_rays(
self,
origin: torch.Tensor,
direction: torch.Tensor,
t0: Optional[torch.Tensor] = None,
prev_raw: Optional[AttrDict] = None,
shared: bool = False,
render_with_direction: bool = True,
) -> "RayVolumeIntegralResults":
"""
Perform volumetric rendering over the given volume.
:param position: [batch_size, *shape, 3]
:param direction: [batch_size, *shape, 3]
:param t0: Optional [batch_size, *shape, 1]
:param prev_raw: the raw outputs when using multiple levels with this model.
:param shared: means the same model is used for all RayVolumeIntegral's
:param render_with_direction: use the incoming ray direction when querying the model.
:return: RayVolumeIntegralResults
"""
# 1. Intersect the rays with the current volume and sample ts to
# integrate along.
vrange = self.volume.intersect(origin, direction, t0_lower=t0)
ts = self.sampler.sample(vrange.t0, vrange.t1, self.n_samples)
if prev_raw is not None and not shared:
# Append the previous ts now before fprop because previous
# rendering used a different model and we can't reuse the output.
ts = torch.sort(torch.cat([ts, prev_raw.ts], dim=-2), dim=-2).values
# Shape sanity checks
batch_size, *_shape, _t0_dim = vrange.t0.shape
_, *ts_shape, _ts_dim = ts.shape
# 2. Get the points along the ray and query the model
directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3])
positions = origin.unsqueeze(-2) + ts * directions
optional_directions = directions if render_with_direction else None
mids = (ts[..., 1:, :] + ts[..., :-1, :]) / 2
raw = self.model(
Query(
position=positions,
direction=optional_directions,
t_min=torch.cat([vrange.t0[..., None, :], mids], dim=-2),
t_max=torch.cat([mids, vrange.t1[..., None, :]], dim=-2),
)
)
raw.ts = ts
if prev_raw is not None and shared:
# We can append the additional queries to previous raw outputs
# before integration
copy = prev_raw.copy()
result = torch.sort(torch.cat([raw.pop("ts"), copy.pop("ts")], dim=-2), dim=-2)
merge_results = partial(self._merge_results, dim=-2, indices=result.indices)
raw = raw.combine(copy, merge_results)
raw.ts = result.values
# 3. Integrate the raw results
output, transmittance = self.integrate_samples(vrange, raw)
# 4. Clean up results that do not intersect with the volume.
transmittance = torch.where(
vrange.intersected, transmittance, torch.ones_like(transmittance)
)
def _mask_fn(_key: str, tensor: torch.Tensor):
return torch.where(vrange.intersected, tensor, torch.zeros_like(tensor))
def _is_tensor(_key: str, value: Any):
return isinstance(value, torch.Tensor)
output = output.map(map_fn=_mask_fn, should_map=_is_tensor)
return RayVolumeIntegralResults(
output=output,
raw=raw,
volume_range=vrange,
transmittance=transmittance,
)
def integrate_samples(
self,
volume_range: VolumeRange,
raw: AttrDict,
) -> Tuple[AttrDict, torch.Tensor]:
"""
Integrate the raw.channels along with other aux_losses and values to
produce the final output dictionary containing rendered `channels`,
estimated `distances` and `aux_losses`.
:param volume_range: Specifies the integral range [t0, t1]
:param raw: Contains a dict of function evaluations at ts. Should have
density: torch.Tensor [batch_size, *shape, n_samples, 1]
channels: torch.Tensor [batch_size, *shape, n_samples, n_channels]
aux_losses: {key: torch.Tensor [batch_size, *shape, n_samples, 1] for each key}
no_weight_grad_aux_losses: an optional set of losses for which the weights
should be detached before integration.
after the call, integrate_samples populates some intermediate calculations
for later use like
weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density *
transmittance)[i] weight for each rgb output at [..., i, :].
:returns: a tuple of (
a dictionary of rendered outputs and aux_losses,
transmittance of this volume,
)
"""
# 1. Calculate the weights
_, _, dt = volume_range.partition(raw.ts)
ddensity = raw.density * dt
mass = torch.cumsum(ddensity, dim=-2)
transmittance = torch.exp(-mass[..., -1, :])
alphas = 1.0 - torch.exp(-ddensity)
Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2))
# This is the probability of light hitting and reflecting off of
# something at depth [..., i, :].
weights = alphas * Ts
# 2. Integrate all results
def _integrate(key: str, samples: torch.Tensor, weights: torch.Tensor):
if key == "density":
# Omit integrating the density, because we don't need it
return None
return torch.sum(samples * weights, dim=-2)
def _is_tensor(_key: str, value: Any):
return isinstance(value, torch.Tensor)
if raw.no_weight_grad_aux_losses:
extra_aux_losses = raw.no_weight_grad_aux_losses.map(
partial(_integrate, weights=weights.detach()), should_map=_is_tensor
)
else:
extra_aux_losses = {}
output = raw.map(partial(_integrate, weights=weights), should_map=_is_tensor)
if "no_weight_grad_aux_losses" in output:
del output["no_weight_grad_aux_losses"]
output.aux_losses.update(extra_aux_losses)
# Integrating the ts yields the distance away from the origin; rename the variable.
output.distances = output.ts
del output["ts"]
del output["density"]
assert output.distances.shape == (*output.channels.shape[:-1], 1)
assert output.channels.shape[:-1] == raw.channels.shape[:-2]
assert output.channels.shape[-1] == raw.channels.shape[-1]
# 3. Reduce loss
def _reduce_loss(_key: str, loss: torch.Tensor):
return loss.view(loss.shape[0], -1).sum(dim=-1)
# 4. Store other useful calculations
raw.weights = weights
output.aux_losses = output.aux_losses.map(_reduce_loss)
return output, transmittance
def _merge_results(
self, a: Optional[torch.Tensor], b: torch.Tensor, dim: int, indices: torch.Tensor
):
"""
:param a: [..., n_a, ...]. The other dictionary containing the b's may
contain extra tensors from earlier calculations, so a can be None.
:param b: [..., n_b, ...]
:param dim: dimension to merge
:param indices: how the merged results should be sorted at the end
:return: a concatted and sorted tensor of size [..., n_a + n_b, ...]
"""
if a is None:
return None
merged = torch.cat([a, b], dim=dim)
return torch.gather(merged, dim=dim, index=torch.broadcast_to(indices, merged.shape))
class RaySampler(ABC):
@abstractmethod
def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
"""
:param t0: start time has shape [batch_size, *shape, 1]
:param t1: finish time has shape [batch_size, *shape, 1]
:param n_samples: number of ts to sample
:return: sampled ts of shape [batch_size, *shape, n_samples, 1]
"""
class StratifiedRaySampler(RaySampler):
"""
Instead of fixed intervals, a sample is drawn uniformly at random from each
interval.
"""
def __init__(self, depth_mode: str = "linear"):
"""
:param depth_mode: linear samples ts linearly in depth. harmonic ensures
closer points are sampled more densely.
"""
self.depth_mode = depth_mode
assert self.depth_mode in ("linear", "geometric", "harmonic")
def sample(
self,
t0: torch.Tensor,
t1: torch.Tensor,
n_samples: int,
epsilon: float = 1e-3,
) -> torch.Tensor:
"""
:param t0: start time has shape [batch_size, *shape, 1]
:param t1: finish time has shape [batch_size, *shape, 1]
:param n_samples: number of ts to sample
:return: sampled ts of shape [batch_size, *shape, n_samples, 1]
"""
ones = [1] * (len(t0.shape) - 1)
ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device)
if self.depth_mode == "linear":
ts = t0 * (1.0 - ts) + t1 * ts
elif self.depth_mode == "geometric":
ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp()
elif self.depth_mode == "harmonic":
# The original NeRF recommends this interpolation scheme for
# spherical scenes, but there could be some weird edge cases when
# the observer crosses from the inner to outer volume.
ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts)
mids = 0.5 * (ts[..., 1:] + ts[..., :-1])
upper = torch.cat([mids, t1], dim=-1)
lower = torch.cat([t0, mids], dim=-1)
t_rand = torch.rand_like(ts)
ts = lower + (upper - lower) * t_rand
return ts.unsqueeze(-1)
class ImportanceRaySampler(RaySampler):
"""
Given the initial estimate of densities, this samples more from
regions/bins expected to have objects.
"""
def __init__(
self, volume_range: VolumeRange, raw: AttrDict, blur_pool: bool = False, alpha: float = 1e-5
):
"""
:param volume_range: the range in which a ray intersects the given volume.
:param raw: dictionary of raw outputs from the NeRF models of shape
[batch_size, *shape, n_coarse_samples, 1]. Should at least contain
:param ts: earlier samples from the coarse rendering step
:param weights: discretized version of density * transmittance
:param blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF.
:param alpha: small value to add to weights.
"""
self.volume_range = volume_range
self.ts = raw.ts.clone().detach()
self.weights = raw.weights.clone().detach()
self.blur_pool = blur_pool
self.alpha = alpha
@torch.no_grad()
def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
"""
:param t0: start time has shape [batch_size, *shape, 1]
:param t1: finish time has shape [batch_size, *shape, 1]
:param n_samples: number of ts to sample
:return: sampled ts of shape [batch_size, *shape, n_samples, 1]
"""
lower, upper, _ = self.volume_range.partition(self.ts)
batch_size, *shape, n_coarse_samples, _ = self.ts.shape
weights = self.weights
if self.blur_pool:
padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2)
maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :])
weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :])
weights = weights + self.alpha
pmf = weights / weights.sum(dim=-2, keepdim=True)
inds = sample_pmf(pmf, n_samples)
assert inds.shape == (batch_size, *shape, n_samples, 1)
assert (inds >= 0).all() and (inds < n_coarse_samples).all()
t_rand = torch.rand(inds.shape, device=inds.device)
lower_ = torch.gather(lower, -2, inds)
upper_ = torch.gather(upper, -2, inds)
ts = lower_ + (upper_ - lower_) * t_rand
ts = torch.sort(ts, dim=-2).values
return ts
================================================
FILE: shap_e/models/nerf/renderer.py
================================================
from functools import partial
from typing import Any, Dict, Optional
import torch
from shap_e.models.nn.meta import subdict
from shap_e.models.renderer import RayRenderer
from shap_e.models.volume import Volume
from shap_e.util.collections import AttrDict
from .model import NeRFModel
from .ray import RayVolumeIntegral, StratifiedRaySampler, render_rays
class TwoStepNeRFRenderer(RayRenderer):
"""
Coarse and fine-grained rendering as proposed by NeRF. This class
additionally supports background rendering like NeRF++.
"""
def __init__(
self,
n_coarse_samples: int,
n_fine_samples: int,
void_model: NeRFModel,
fine_model: NeRFModel,
volume: Volume,
coarse_model: Optional[NeRFModel] = None,
coarse_background_model: Optional[NeRFModel] = None,
fine_background_model: Optional[NeRFModel] = None,
outer_volume: Optional[Volume] = None,
foreground_stratified_depth_sampling_mode: str = "linear",
background_stratified_depth_sampling_mode: str = "linear",
importance_sampling_options: Optional[Dict[str, Any]] = None,
channel_scale: float = 255,
device: torch.device = torch.device("cuda"),
**kwargs,
):
"""
:param outer_volume: is where distant objects are encoded.
"""
super().__init__(**kwargs)
if coarse_model is None:
assert (
fine_background_model is None or coarse_background_model is None
), "models should be shared for both fg and bg"
self.n_coarse_samples = n_coarse_samples
self.n_fine_samples = n_fine_samples
self.void_model = void_model
self.coarse_model = coarse_model
self.fine_model = fine_model
self.volume = volume
self.coarse_background_model = coarse_background_model
self.fine_background_model = fine_background_model
self.outer_volume = outer_volume
self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode
self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode
self.importance_sampling_options = AttrDict(importance_sampling_options or {})
self.channel_scale = channel_scale
self.device = device
self.to(device)
if self.coarse_background_model is not None:
assert self.fine_background_model is not None
assert self.outer_volume is not None
def render_rays(
self,
batch: Dict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
) -> AttrDict:
params = self.update(params)
batch = AttrDict(batch)
if options is None:
options = AttrDict()
options.setdefault("render_background", True)
options.setdefault("render_with_direction", True)
options.setdefault("n_coarse_samples", self.n_coarse_samples)
options.setdefault("n_fine_samples", self.n_fine_samples)
options.setdefault(
"foreground_stratified_depth_sampling_mode",
self.foreground_stratified_depth_sampling_mode,
)
options.setdefault(
"background_stratified_depth_sampling_mode",
self.background_stratified_depth_sampling_mode,
)
shared = self.coarse_model is None
# First, render rays using the coarse models with stratified ray samples.
coarse_model, coarse_key = (
(self.fine_model, "fine_model") if shared else (self.coarse_model, "coarse_model")
)
coarse_model = partial(
coarse_model,
params=subdict(params, coarse_key),
options=options,
)
parts = [
RayVolumeIntegral(
model=coarse_model,
volume=self.volume,
sampler=StratifiedRaySampler(
depth_mode=options.foreground_stratified_depth_sampling_mode,
),
n_samples=options.n_coarse_samples,
),
]
if options.render_background and self.outer_volume is not None:
coarse_background_model, coarse_background_key = (
(self.fine_background_model, "fine_background_model")
if shared
else (self.coarse_background_model, "coarse_background_model")
)
coarse_background_model = partial(
coarse_background_model,
params=subdict(params, coarse_background_key),
options=options,
)
parts.append(
RayVolumeIntegral(
model=coarse_background_model,
volume=self.outer_volume,
sampler=StratifiedRaySampler(
depth_mode=options.background_stratified_depth_sampling_mode,
),
n_samples=options.n_coarse_samples,
)
)
coarse_results, samplers, coarse_raw_outputs = render_rays(
batch.rays,
parts,
partial(self.void_model, options=options),
shared=shared,
render_with_direction=options.render_with_direction,
importance_sampling_options=AttrDict(self.importance_sampling_options),
)
# Then, render rays using the fine models with importance-weighted ray samples.
fine_model = partial(
self.fine_model,
params=subdict(params, "fine_model"),
options=options,
)
parts = [
RayVolumeIntegral(
model=fine_model,
volume=self.volume,
sampler=samplers[0],
n_samples=options.n_fine_samples,
),
]
if options.render_background and self.outer_volume is not None:
fine_background_model = partial(
self.fine_background_model,
params=subdict(params, "fine_background_model"),
options=options,
)
parts.append(
RayVolumeIntegral(
model=fine_background_model,
volume=self.outer_volume,
sampler=samplers[1],
n_samples=options.n_fine_samples,
)
)
fine_results, *_ = render_rays(
batch.rays,
parts,
partial(self.void_model, options=options),
shared=shared,
prev_raw_outputs=coarse_raw_outputs,
render_with_direction=options.render_with_direction,
)
# Combine results
aux_losses = fine_results.output.aux_losses.copy()
for key, val in coarse_results.output.aux_losses.items():
aux_losses[key + "_coarse"] = val
return AttrDict(
channels=fine_results.output.channels * self.channel_scale,
channels_coarse=coarse_results.output.channels * self.channel_scale,
distances=fine_results.output.distances,
transmittance=fine_results.transmittance,
transmittance_coarse=coarse_results.transmittance,
t0=fine_results.volume_range.t0,
t1=fine_results.volume_range.t1,
intersected=fine_results.volume_range.intersected,
aux_losses=aux_losses,
)
class OneStepNeRFRenderer(RayRenderer):
"""
Renders rays using stratified sampling only unlike vanilla NeRF.
The same setup as NeRF++.
"""
def __init__(
self,
n_samples: int,
void_model: NeRFModel,
foreground_model: NeRFModel,
volume: Volume,
background_model: Optional[NeRFModel] = None,
outer_volume: Optional[Volume] = None,
foreground_stratified_depth_sampling_mode: str = "linear",
background_stratified_depth_sampling_mode: str = "linear",
channel_scale: float = 255,
device: torch.device = torch.device("cuda"),
**kwargs,
):
super().__init__(**kwargs)
self.n_samples = n_samples
self.void_model = void_model
self.foreground_model = foreground_model
self.volume = volume
self.background_model = background_model
self.outer_volume = outer_volume
self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode
self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode
self.channel_scale = channel_scale
self.device = device
self.to(device)
def render_rays(
self,
batch: Dict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
) -> AttrDict:
params = self.update(params)
batch = AttrDict(batch)
if options is None:
options = AttrDict()
options.setdefault("render_background", True)
options.setdefault("render_with_direction", True)
options.setdefault("n_samples", self.n_samples)
options.setdefault(
"foreground_stratified_depth_sampling_mode",
self.foreground_stratified_depth_sampling_mode,
)
options.setdefault(
"background_stratified_depth_sampling_mode",
self.background_stratified_depth_sampling_mode,
)
foreground_model = partial(
self.foreground_model,
params=subdict(params, "foreground_model"),
options=options,
)
parts = [
RayVolumeIntegral(
model=foreground_model,
volume=self.volume,
sampler=StratifiedRaySampler(
depth_mode=options.foreground_stratified_depth_sampling_mode
),
n_samples=options.n_samples,
),
]
if options.render_background and self.outer_volume is not None:
background_model = partial(
self.background_model,
params=subdict(params, "background_model"),
options=options,
)
parts.append(
RayVolumeIntegral(
model=background_model,
volume=self.outer_volume,
sampler=StratifiedRaySampler(
depth_mode=options.background_stratified_depth_sampling_mode
),
n_samples=options.n_samples,
)
)
results, *_ = render_rays(
batch.rays,
parts,
self.void_model,
render_with_direction=options.render_with_direction,
)
return AttrDict(
channels=results.output.channels * self.channel_scale,
distances=results.output.distances,
transmittance=results.transmittance,
t0=results.volume_range.t0,
t1=results.volume_range.t1,
intersected=results.volume_range.intersected,
aux_losses=results.output.aux_losses,
)
================================================
FILE: shap_e/models/nerstf/mlp.py
================================================
from typing import Any, Dict, Optional, Tuple
import torch
from shap_e.models.nn.ops import get_act
from shap_e.models.query import Query
from shap_e.models.stf.mlp import MLPModel
from shap_e.util.collections import AttrDict
class MLPDensitySDFModel(MLPModel):
def __init__(
self,
initial_bias: float = -0.1,
sdf_activation="tanh",
density_activation="exp",
**kwargs,
):
super().__init__(
n_output=2,
output_activation="identity",
**kwargs,
)
self.mlp[-1].bias[0].data.fill_(initial_bias)
self.sdf_activation = get_act(sdf_activation)
self.density_activation = get_act(density_activation)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
# query.direction is None typically for SDF models and training
h, _h_directionless = self._mlp(
query.position, query.direction, params=params, options=options
)
h_sdf, h_density = h.split(1, dim=-1)
return AttrDict(
density=self.density_activation(h_density),
signed_distance=self.sdf_activation(h_sdf),
)
class MLPNeRSTFModel(MLPModel):
def __init__(
self,
sdf_activation="tanh",
density_activation="exp",
channel_activation="sigmoid",
direction_dependent_shape: bool = True, # To be able to load old models. Set this to be False in future models.
separate_nerf_channels: bool = False,
separate_coarse_channels: bool = False,
initial_density_bias: float = 0.0,
initial_sdf_bias: float = -0.1,
**kwargs,
):
h_map, h_directionless_map = indices_for_output_mode(
direction_dependent_shape=direction_dependent_shape,
separate_nerf_channels=separate_nerf_channels,
separate_coarse_channels=separate_coarse_channels,
)
n_output = index_mapping_max(h_map)
super().__init__(
n_output=n_output,
output_activation="identity",
**kwargs,
)
self.direction_dependent_shape = direction_dependent_shape
self.separate_nerf_channels = separate_nerf_channels
self.separate_coarse_channels = separate_coarse_channels
self.sdf_activation = get_act(sdf_activation)
self.density_activation = get_act(density_activation)
self.channel_activation = get_act(channel_activation)
self.h_map = h_map
self.h_directionless_map = h_directionless_map
self.mlp[-1].bias.data.zero_()
layer = -1 if self.direction_dependent_shape else self.insert_direction_at
self.mlp[layer].bias[0].data.fill_(initial_sdf_bias)
self.mlp[layer].bias[1].data.fill_(initial_density_bias)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
options = AttrDict() if options is None else AttrDict(options)
h, h_directionless = self._mlp(
query.position, query.direction, params=params, options=options
)
activations = map_indices_to_keys(self.h_map, h)
activations.update(map_indices_to_keys(self.h_directionless_map, h_directionless))
if options.nerf_level == "coarse":
h_density = activations.density_coarse
else:
h_density = activations.density_fine
if options.get("rendering_mode", "stf") == "nerf":
if options.nerf_level == "coarse":
h_channels = activations.nerf_coarse
else:
h_channels = activations.nerf_fine
else:
h_channels = activations.stf
return AttrDict(
density=self.density_activation(h_density),
signed_distance=self.sdf_activation(activations.sdf),
channels=self.channel_activation(h_channels),
)
IndexMapping = AttrDict[str, Tuple[int, int]]
def indices_for_output_mode(
direction_dependent_shape: bool,
separate_nerf_channels: bool,
separate_coarse_channels: bool,
) -> Tuple[IndexMapping, IndexMapping]:
"""
Get output mappings for (h, h_directionless).
"""
h_map = AttrDict()
h_directionless_map = AttrDict()
if direction_dependent_shape:
h_map.sdf = (0, 1)
if separate_coarse_channels:
assert separate_nerf_channels
h_map.density_coarse = (1, 2)
h_map.density_fine = (2, 3)
h_map.stf = (3, 6)
h_map.nerf_coarse = (6, 9)
h_map.nerf_fine = (9, 12)
else:
h_map.density_coarse = (1, 2)
h_map.density_fine = (1, 2)
if separate_nerf_channels:
h_map.stf = (2, 5)
h_map.nerf_coarse = (5, 8)
h_map.nerf_fine = (5, 8)
else:
h_map.stf = (2, 5)
h_map.nerf_coarse = (2, 5)
h_map.nerf_fine = (2, 5)
else:
h_directionless_map.sdf = (0, 1)
h_directionless_map.density_coarse = (1, 2)
if separate_coarse_channels:
h_directionless_map.density_fine = (2, 3)
else:
h_directionless_map.density_fine = h_directionless_map.density_coarse
h_map.stf = (0, 3)
if separate_coarse_channels:
assert separate_nerf_channels
h_map.nerf_coarse = (3, 6)
h_map.nerf_fine = (6, 9)
else:
if separate_nerf_channels:
h_map.nerf_coarse = (3, 6)
else:
h_map.nerf_coarse = (0, 3)
h_map.nerf_fine = h_map.nerf_coarse
return h_map, h_directionless_map
def map_indices_to_keys(mapping: IndexMapping, data: torch.Tensor) -> AttrDict[str, torch.Tensor]:
return AttrDict({k: data[..., start:end] for k, (start, end) in mapping.items()})
def index_mapping_max(mapping: IndexMapping) -> int:
return max(end for _, (_, end) in mapping.items())
================================================
FILE: shap_e/models/nerstf/renderer.py
================================================
from functools import partial
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch
from shap_e.models.nerf.model import NeRFModel
from shap_e.models.nerf.ray import RayVolumeIntegral, StratifiedRaySampler, render_rays
from shap_e.models.nn.meta import subdict
from shap_e.models.nn.utils import to_torch
from shap_e.models.query import Query
from shap_e.models.renderer import RayRenderer, render_views_from_rays
from shap_e.models.stf.base import Model
from shap_e.models.stf.renderer import STFRendererBase, render_views_from_stf
from shap_e.models.volume import BoundingBoxVolume, Volume
from shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR
from shap_e.util.collections import AttrDict
class NeRSTFRenderer(RayRenderer, STFRendererBase):
def __init__(
self,
sdf: Optional[Model],
tf: Optional[Model],
nerstf: Optional[Model],
void: NeRFModel,
volume: Volume,
grid_size: int,
n_coarse_samples: int,
n_fine_samples: int,
importance_sampling_options: Optional[Dict[str, Any]] = None,
separate_shared_samples: bool = False,
texture_channels: Sequence[str] = ("R", "G", "B"),
channel_scale: Sequence[float] = (255.0, 255.0, 255.0),
ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR,
diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR,
specular_color: Union[float, Tuple[float]] = 0.0,
output_srgb: bool = True,
device: torch.device = torch.device("cuda"),
**kwargs,
):
super().__init__(**kwargs)
assert isinstance(volume, BoundingBoxVolume), "cannot sample points in unknown volume"
assert (nerstf is not None) ^ (sdf is not None and tf is not None)
self.sdf = sdf
self.tf = tf
self.nerstf = nerstf
self.void = void
self.volume = volume
self.grid_size = grid_size
self.n_coarse_samples = n_coarse_samples
self.n_fine_samples = n_fine_samples
self.importance_sampling_options = AttrDict(importance_sampling_options or {})
self.separate_shared_samples = separate_shared_samples
self.texture_channels = texture_channels
self.channel_scale = to_torch(channel_scale).to(device)
self.ambient_color = ambient_color
self.diffuse_color = diffuse_color
self.specular_color = specular_color
self.output_srgb = output_srgb
self.device = device
self.to(device)
def _query(
self,
query: Query,
params: AttrDict[str, torch.Tensor],
options: AttrDict[str, Any],
) -> AttrDict:
no_dir_query = query.copy()
no_dir_query.direction = None
if options.get("rendering_mode", "stf") == "stf":
assert query.direction is None
if self.nerstf is not None:
sdf = tf = self.nerstf(
query,
params=subdict(params, "nerstf"),
options=options,
)
else:
sdf = self.sdf(no_dir_query, params=subdict(params, "sdf"), options=options)
tf = self.tf(query, params=subdict(params, "tf"), options=options)
return AttrDict(
density=sdf.density,
signed_distance=sdf.signed_distance,
channels=tf.channels,
aux_losses=dict(),
)
def render_rays(
self,
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[AttrDict] = None,
) -> AttrDict:
"""
:param batch: has
- rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray.
:param options: Optional[Dict]
"""
params = self.update(params)
options = AttrDict() if options is None else AttrDict(options)
# Necessary to tell the TF to use specific NeRF channels.
options.rendering_mode = "nerf"
model = partial(self._query, params=params, options=options)
# First, render rays with coarse, stratified samples.
options.nerf_level = "coarse"
parts = [
RayVolumeIntegral(
model=model,
volume=self.volume,
sampler=StratifiedRaySampler(),
n_samples=self.n_coarse_samples,
),
]
coarse_results, samplers, coarse_raw_outputs = render_rays(
batch.rays,
parts,
self.void,
shared=not self.separate_shared_samples,
render_with_direction=options.render_with_direction,
importance_sampling_options=self.importance_sampling_options,
)
# Then, render with additional importance-weighted ray samples.
options.nerf_level = "fine"
parts = [
RayVolumeIntegral(
model=model,
volume=self.volume,
sampler=samplers[0],
n_samples=self.n_fine_samples,
),
]
fine_results, _, raw_outputs = render_rays(
batch.rays,
parts,
self.void,
shared=not self.separate_shared_samples,
prev_raw_outputs=coarse_raw_outputs,
render_with_direction=options.render_with_direction,
)
raw = raw_outputs[0]
aux_losses = fine_results.output.aux_losses.copy()
if self.separate_shared_samples:
for key, val in coarse_results.output.aux_losses.items():
aux_losses[key + "_coarse"] = val
channels = fine_results.output.channels
shape = [1] * (channels.ndim - 1) + [len(self.texture_channels)]
channels = channels * self.channel_scale.view(*shape)
res = AttrDict(
channels=channels,
transmittance=fine_results.transmittance,
raw_signed_distance=raw.signed_distance,
raw_density=raw.density,
distances=fine_results.output.distances,
t0=fine_results.volume_range.t0,
t1=fine_results.volume_range.t1,
intersected=fine_results.volume_range.intersected,
aux_losses=aux_losses,
)
if self.separate_shared_samples:
res.update(
dict(
channels_coarse=(
coarse_results.output.channels * self.channel_scale.view(*shape)
),
distances_coarse=coarse_results.output.distances,
transmittance_coarse=coarse_results.transmittance,
)
)
return res
def render_views(
self,
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[AttrDict] = None,
) -> AttrDict:
"""
Returns a backproppable rendering of a view
:param batch: contains either ["poses", "camera"], or ["cameras"]. Can
optionally contain any of ["height", "width", "query_batch_size"]
:param params: Meta parameters
contains rendering_mode in ["stf", "nerf"]
:param options: controls checkpointing, caching, and rendering.
Can provide a `rendering_mode` in ["stf", "nerf"]
"""
params = self.update(params)
options = AttrDict() if options is None else AttrDict(options)
if options.cache is None:
created_cache = True
options.cache = AttrDict()
else:
created_cache = False
rendering_mode = options.get("rendering_mode", "stf")
if rendering_mode == "nerf":
output = render_views_from_rays(
self.render_rays,
batch,
params=params,
options=options,
device=self.device,
)
elif rendering_mode == "stf":
sdf_fn = tf_fn = nerstf_fn = None
if self.nerstf is not None:
nerstf_fn = partial(
self.nerstf.forward_batched,
params=subdict(params, "nerstf"),
options=options,
)
else:
sdf_fn = partial(
self.sdf.forward_batched,
params=subdict(params, "sdf"),
options=options,
)
tf_fn = partial(
self.tf.forward_batched,
params=subdict(params, "tf"),
options=options,
)
output = render_views_from_stf(
batch,
options,
sdf_fn=sdf_fn,
tf_fn=tf_fn,
nerstf_fn=nerstf_fn,
volume=self.volume,
grid_size=self.grid_size,
channel_scale=self.channel_scale,
texture_channels=self.texture_channels,
ambient_color=self.ambient_color,
diffuse_color=self.diffuse_color,
specular_color=self.specular_color,
output_srgb=self.output_srgb,
device=self.device,
)
else:
raise NotImplementedError
if created_cache:
del options["cache"]
return output
def get_signed_distance(
self,
query: Query,
params: Dict[str, torch.Tensor],
options: AttrDict[str, Any],
) -> torch.Tensor:
if self.sdf is not None:
return self.sdf(query, params=subdict(params, "sdf"), options=options).signed_distance
assert self.nerstf is not None
return self.nerstf(query, params=subdict(params, "nerstf"), options=options).signed_distance
def get_texture(
self,
query: Query,
params: Dict[str, torch.Tensor],
options: AttrDict[str, Any],
) -> torch.Tensor:
if self.tf is not None:
return self.tf(query, params=subdict(params, "tf"), options=options).channels
assert self.nerstf is not None
return self.nerstf(query, params=subdict(params, "nerstf"), options=options).channels
================================================
FILE: shap_e/models/nn/__init__.py
================================================
from .meta import *
from .ops import *
================================================
FILE: shap_e/models/nn/camera.py
================================================
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from shap_e.rendering.view_data import ProjectiveCamera
@dataclass
class DifferentiableCamera(ABC):
"""
An object describing how a camera corresponds to pixels in an image.
"""
@abstractmethod
def camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
"""
For every (x, y) coordinate in a rendered image, compute the ray of the
corresponding pixel.
:param coords: an [N x ... x 2] integer array of 2D image coordinates.
:return: an [N x ... x 2 x 3] array of [2 x 3] (origin, direction) tuples.
The direction should always be unit length.
"""
@abstractmethod
def resize_image(self, width: int, height: int) -> "DifferentiableCamera":
"""
Creates a new camera with the same intrinsics and direction as this one,
but with resized image dimensions.
"""
@dataclass
class DifferentiableProjectiveCamera(DifferentiableCamera):
"""
Implements a batch, differentiable, standard pinhole camera
"""
origin: torch.Tensor # [batch_size x 3]
x: torch.Tensor # [batch_size x 3]
y: torch.Tensor # [batch_size x 3]
z: torch.Tensor # [batch_size x 3]
width: int
height: int
x_fov: float
y_fov: float
def __post_init__(self):
assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0]
assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3
assert (
len(self.x.shape)
== len(self.y.shape)
== len(self.z.shape)
== len(self.origin.shape)
== 2
)
def resolution(self):
return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32))
def fov(self):
return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32))
def image_coords(self) -> torch.Tensor:
"""
:return: coords of shape (width * height, 2)
"""
pixel_indices = torch.arange(self.height * self.width)
coords = torch.stack(
[
pixel_indices % self.width,
torch.div(pixel_indices, self.width, rounding_mode="trunc"),
],
axis=1,
)
return coords
def camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
batch_size, *shape, n_coords = coords.shape
assert n_coords == 2
assert batch_size == self.origin.shape[0]
flat = coords.view(batch_size, -1, 2)
res = self.resolution().to(flat.device)
fov = self.fov().to(flat.device)
fracs = (flat.float() / (res - 1)) * 2 - 1
fracs = fracs * torch.tan(fov / 2)
fracs = fracs.view(batch_size, -1, 2)
directions = (
self.z.view(batch_size, 1, 3)
+ self.x.view(batch_size, 1, 3) * fracs[:, :, :1]
+ self.y.view(batch_size, 1, 3) * fracs[:, :, 1:]
)
directions = directions / directions.norm(dim=-1, keepdim=True)
rays = torch.stack(
[
torch.broadcast_to(
self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]
),
directions,
],
dim=2,
)
return rays.view(batch_size, *shape, 2, 3)
def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera":
"""
Creates a new camera for the resized view assuming the aspect ratio does not change.
"""
assert width * self.height == height * self.width, "The aspect ratio should not change."
return DifferentiableProjectiveCamera(
origin=self.origin,
x=self.x,
y=self.y,
z=self.z,
width=width,
height=height,
x_fov=self.x_fov,
y_fov=self.y_fov,
)
@dataclass
class DifferentiableCameraBatch(ABC):
"""
Annotate a differentiable camera with a multi-dimensional batch shape.
"""
shape: Tuple[int]
flat_camera: DifferentiableCamera
def normalize(vec: torch.Tensor) -> torch.Tensor:
return vec / vec.norm(dim=-1, keepdim=True)
def project_out(vec1: torch.Tensor, vec2: torch.Tensor) -> torch.Tensor:
"""
Removes the vec2 component from vec1
"""
vec2 = normalize(vec2)
proj = (vec1 * vec2).sum(dim=-1, keepdim=True)
return vec1 - proj * vec2
def camera_orientation(toward: torch.Tensor, up: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
:param toward: [batch_size x 3] unit vector from camera position to the object
:param up: Optional [batch_size x 3] specifying the physical up direction in the world frame.
:return: [batch_size x 3 x 3]
"""
if up is None:
up = torch.zeros_like(toward)
up[:, 2] = 1
assert len(toward.shape) == 2
assert toward.shape[1] == 3
assert len(up.shape) == 2
assert up.shape[1] == 3
z = toward / toward.norm(dim=-1, keepdim=True)
y = -normalize(project_out(up, toward))
x = torch.cross(y, z, dim=1)
return torch.stack([x, y, z], dim=1)
def projective_camera_frame(
origin: torch.Tensor,
toward: torch.Tensor,
camera_params: Union[ProjectiveCamera, DifferentiableProjectiveCamera],
) -> DifferentiableProjectiveCamera:
"""
Given the origin and the direction of a view, return a differentiable
projective camera with the given parameters.
TODO: We need to support the rotation of the camera frame about the
`toward` vector to fully implement 6 degrees of freedom.
"""
rot = camera_orientation(toward)
camera = DifferentiableProjectiveCamera(
origin=origin,
x=rot[:, 0],
y=rot[:, 1],
z=rot[:, 2],
width=camera_params.width,
height=camera_params.height,
x_fov=camera_params.x_fov,
y_fov=camera_params.y_fov,
)
return camera
@torch.no_grad()
def get_image_coords(width, height) -> torch.Tensor:
pixel_indices = torch.arange(height * width)
# torch throws warnings for pixel_indices // width
pixel_indices_div = torch.div(pixel_indices, width, rounding_mode="trunc")
coords = torch.stack([pixel_indices % width, pixel_indices_div], dim=1)
return coords
================================================
FILE: shap_e/models/nn/checkpoint.py
================================================
from typing import Callable, Iterable, Sequence, Union
import torch
from torch.cuda.amp import custom_bwd, custom_fwd
def checkpoint(
func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
inputs: Sequence[torch.Tensor],
params: Iterable[torch.Tensor],
flag: bool,
):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.length = length
input_tensors = list(args[:length])
input_params = list(args[length:])
ctx.save_for_backward(*input_tensors, *input_params)
with torch.no_grad():
output_tensors = ctx.run_function(*input_tensors)
return output_tensors
@staticmethod
@custom_bwd
def backward(ctx, *output_grads):
inputs = ctx.saved_tensors
input_tensors = inputs[: ctx.length]
input_params = inputs[ctx.length :]
res = CheckpointFunctionGradFunction.apply(
ctx.run_function,
len(input_tensors),
len(input_params),
*input_tensors,
*input_params,
*output_grads
)
return (None, None) + res
class CheckpointFunctionGradFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, run_function, length_1, length_2, *args):
ctx.run_function = run_function
ctx.length_1 = length_1
ctx.length_2 = length_2
input_tensors = [x.detach().requires_grad_(True) for x in args[:length_1]]
input_params = list(args[length_1 : length_1 + length_2])
output_grads = list(args[length_1 + length_2 :])
ctx.save_for_backward(*input_tensors, *input_params, *output_grads)
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
input_tensors + input_params,
output_grads,
allow_unused=True,
)
return input_grads
@staticmethod
@custom_bwd
def backward(ctx, *all_output_grads):
args = ctx.saved_tensors
input_tensors = [x.detach().requires_grad_(True) for x in args[: ctx.length_1]]
input_params = list(args[ctx.length_1 : ctx.length_1 + ctx.length_2])
output_grads = [
x.detach().requires_grad_(True) for x in args[ctx.length_1 + ctx.length_2 :]
]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
input_tensors + input_params,
output_grads,
allow_unused=True,
create_graph=True,
retain_graph=True,
)
input_grads_grads = torch.autograd.grad(
input_grads,
input_tensors + input_params + output_grads,
all_output_grads,
allow_unused=True,
)
del input_grads
return (None, None, None) + input_grads_grads
================================================
FILE: shap_e/models/nn/encoding.py
================================================
import math
from functools import lru_cache
from typing import Optional
import torch
import torch.nn as nn
def encode_position(version: str, *, position: torch.Tensor):
if version == "v1":
freqs = get_scales(0, 10, position.dtype, position.device).view(1, -1)
freqs = position.reshape(-1, 1) * freqs
return torch.cat([freqs.cos(), freqs.sin()], dim=1).reshape(*position.shape[:-1], -1)
elif version == "nerf":
return posenc_nerf(position, min_deg=0, max_deg=15)
else:
raise ValueError(version)
def encode_channels(version: str, *, channels: torch.Tensor):
if version == "v1":
freqs = get_scales(0, 10, channels.dtype, channels.device).view(1, -1)
freqs = channels.reshape(-1, 1) * freqs
return torch.cat([freqs.cos(), freqs.sin()], dim=1).reshape(*channels.shape[:-1], -1)
elif version == "nerf":
return posenc_nerf(channels, min_deg=0, max_deg=15)
else:
raise ValueError(version)
def position_encoding_channels(version: Optional[str] = None) -> int:
if version is None:
return 1
return encode_position(version, position=torch.zeros(1, 1)).shape[-1]
def channel_encoding_channels(version: Optional[str] = None) -> int:
if version is None:
return 1
return encode_channels(version, channels=torch.zeros(1, 1)).shape[-1]
class PosEmbLinear(nn.Linear):
def __init__(
self, posemb_version: Optional[str], in_features: int, out_features: int, **kwargs
):
super().__init__(
in_features * position_encoding_channels(posemb_version),
out_features,
**kwargs,
)
self.posemb_version = posemb_version
def forward(self, x: torch.Tensor):
if self.posemb_version is not None:
x = encode_position(self.posemb_version, position=x)
return super().forward(x)
class MultiviewPoseEmbedding(nn.Conv2d):
def __init__(
self,
posemb_version: Optional[str],
n_channels: int,
out_features: int,
stride: int = 1,
**kwargs,
):
in_features = (
n_channels * channel_encoding_channels(version=posemb_version)
+ 3 * position_encoding_channels(version=posemb_version)
+ 3 * position_encoding_channels(version=posemb_version)
)
super().__init__(
in_features,
out_features,
kernel_size=3,
stride=stride,
padding=1,
**kwargs,
)
self.posemb_version = posemb_version
def forward(
self, channels: torch.Tensor, position: torch.Tensor, direction: torch.Tensor
) -> torch.Tensor:
"""
:param channels: [batch_shape, inner_batch_shape, n_channels, height, width]
:param position: [batch_shape, inner_batch_shape, 3, height, width]
:param direction: [batch_shape, inner_batch_shape, 3, height, width]
:return: [*batch_shape, out_features, height, width]
"""
if self.posemb_version is not None:
channels = channels.permute(0, 1, 3, 4, 2)
position = position.permute(0, 1, 3, 4, 2)
direction = direction.permute(0, 1, 3, 4, 2)
channels = encode_channels(self.posemb_version, channels=channels).permute(
0, 1, 4, 2, 3
)
direction = maybe_encode_direction(
self.posemb_version, position=position, direction=direction
).permute(0, 1, 4, 2, 3)
position = encode_position(self.posemb_version, position=position).permute(
0, 1, 4, 2, 3
)
x = torch.cat([channels, position, direction], dim=-3)
*batch_shape, in_features, height, width = x.shape
return (
super()
.forward(x.view(-1, in_features, height, width))
.view(*batch_shape, -1, height, width)
)
class MultiviewPointCloudEmbedding(nn.Conv2d):
def __init__(
self,
posemb_version: Optional[str],
n_channels: int,
out_features: int,
stride: int = 1,
**kwargs,
):
in_features = (
n_channels * channel_encoding_channels(version=posemb_version)
+ 3 * position_encoding_channels(version=posemb_version)
+ 3 * position_encoding_channels(version=posemb_version)
)
super().__init__(
in_features,
out_features,
kernel_size=3,
stride=stride,
padding=1,
**kwargs,
)
self.posemb_version = posemb_version
self.register_parameter(
"unk_token", nn.Parameter(torch.randn(in_features, **kwargs) * 0.01)
)
self.unk_token: torch.Tensor
def forward(
self,
channels: torch.Tensor,
origin: torch.Tensor,
position: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
"""
:param channels: [batch_shape, inner_batch_shape, n_channels, height, width]
:param origin: [batch_shape, inner_batch_shape, 3, height, width]
:param position: [batch_shape, inner_batch_shape, 3, height, width]
:return: [*batch_shape, out_features, height, width]
"""
if self.posemb_version is not None:
channels = channels.permute(0, 1, 3, 4, 2)
origin = origin.permute(0, 1, 3, 4, 2)
position = position.permute(0, 1, 3, 4, 2)
channels = encode_channels(self.posemb_version, channels=channels).permute(
0, 1, 4, 2, 3
)
origin = encode_position(self.posemb_version, position=origin).permute(0, 1, 4, 2, 3)
position = encode_position(self.posemb_version, position=position).permute(
0, 1, 4, 2, 3
)
x = torch.cat([channels, origin, position], dim=-3)
unk_token = torch.broadcast_to(self.unk_token.view(1, 1, -1, 1, 1), x.shape)
x = torch.where(mask, x, unk_token)
*batch_shape, in_features, height, width = x.shape
return (
super()
.forward(x.view(-1, in_features, height, width))
.view(*batch_shape, -1, height, width)
)
def maybe_encode_direction(
version: str,
*,
position: torch.Tensor,
direction: Optional[torch.Tensor] = None,
):
if version == "v1":
sh_degree = 4
if direction is None:
return torch.zeros(*position.shape[:-1], sh_degree**2).to(position)
return spherical_harmonics_basis(direction, sh_degree=sh_degree)
elif version == "nerf":
if direction is None:
return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8))
return posenc_nerf(direction, min_deg=0, max_deg=8)
else:
raise ValueError(version)
def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor:
"""
Concatenate x and its positional encodings, following NeRF.
Reference: https://arxiv.org/pdf/2210.04628.pdf
"""
if min_deg == max_deg:
return x
scales = get_scales(min_deg, max_deg, x.dtype, x.device)
*shape, dim = x.shape
xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1)
assert xb.shape[-1] == dim * (max_deg - min_deg)
emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin()
return torch.cat([x, emb], dim=-1)
@lru_cache
def get_scales(
min_deg: int,
max_deg: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
return 2.0 ** torch.arange(min_deg, max_deg, device=device, dtype=dtype)
def spherical_harmonics_basis(
coords: torch.Tensor,
sh_degree: int,
) -> torch.Tensor:
"""
Calculate the spherical harmonics basis
:param coords: [batch_size, *shape, 3] of unit norm
:param sh_degree: Spherical harmonics degree
:return: [batch_size, *shape, sh_degree**2]
"""
if sh_degree > 8:
raise NotImplementedError
batch_size, *shape, _ = coords.shape
x, y, z = coords.reshape(-1, 3).split(1, dim=-1)
x = x.squeeze(dim=-1)
y = y.squeeze(dim=-1)
z = z.squeeze(dim=-1)
xy, xz, yz = x * y, x * z, y * z
x2, y2, z2 = x * x, y * y, z * z
x4, y4, z4 = x2 * x2, y2 * y2, z2 * z2
x6, y6, z6 = x4 * x2, y4 * y2, z4 * z2
xyz = xy * z
# https://github.com/NVlabs/tiny-cuda-nn/blob/8575542682cb67cddfc748cc3d3cfc12593799aa/include/tiny-cuda-nn/encodings/spherical_harmonics.h#L76
out = torch.zeros(x.shape[0], sh_degree**2, dtype=x.dtype, device=x.device)
def _sh():
out[:, 0] = 0.28209479177387814 # 1/(2*sqrt(pi))
if sh_degree <= 1:
return
out[:, 1] = -0.48860251190291987 * y # -sqrt(3)*y/(2*sqrt(pi))
out[:, 2] = 0.48860251190291987 * z # sqrt(3)*z/(2*sqrt(pi))
out[:, 3] = -0.48860251190291987 * x # -sqrt(3)*x/(2*sqrt(pi))
if sh_degree <= 2:
return
out[:, 4] = 1.0925484305920792 * xy # sqrt(15)*xy/(2*sqrt(pi))
out[:, 5] = -1.0925484305920792 * yz # -sqrt(15)*yz/(2*sqrt(pi))
out[:, 6] = (
0.94617469575755997 * z2 - 0.31539156525251999
) # sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))
out[:, 7] = -1.0925484305920792 * xz # -sqrt(15)*xz/(2*sqrt(pi))
out[:, 8] = (
0.54627421529603959 * x2 - 0.54627421529603959 * y2
) # sqrt(15)*(x2 - y2)/(4*sqrt(pi))
if sh_degree <= 3:
return
out[:, 9] = (
0.59004358992664352 * y * (-3.0 * x2 + y2)
) # sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
out[:, 10] = 2.8906114426405538 * xy * z # sqrt(105)*xy*z/(2*sqrt(pi))
out[:, 11] = (
0.45704579946446572 * y * (1.0 - 5.0 * z2)
) # sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))
out[:, 12] = 0.3731763325901154 * z * (5.0 * z2 - 3.0) # sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))
out[:, 13] = (
0.45704579946446572 * x * (1.0 - 5.0 * z2)
) # sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))
out[:, 14] = 1.4453057213202769 * z * (x2 - y2) # sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))
out[:, 15] = (
0.59004358992664352 * x * (-x2 + 3.0 * y2)
) # sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
if sh_degree <= 4:
return
out[:, 16] = 2.5033429417967046 * xy * (x2 - y2) # 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi))
out[:, 17] = (
1.7701307697799304 * yz * (-3.0 * x2 + y2)
) # 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi))
out[:, 18] = (
0.94617469575756008 * xy * (7.0 * z2 - 1.0)
) # 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi))
out[:, 19] = (
0.66904654355728921 * yz * (3.0 - 7.0 * z2)
) # 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi))
out[:, 20] = (
-3.1735664074561294 * z2 + 3.7024941420321507 * z4 + 0.31735664074561293
) # 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi))
out[:, 21] = (
0.66904654355728921 * xz * (3.0 - 7.0 * z2)
) # 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi))
out[:, 22] = (
0.47308734787878004 * (x2 - y2) * (7.0 * z2 - 1.0)
) # 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi))
out[:, 23] = (
1.7701307697799304 * xz * (-x2 + 3.0 * y2)
) # 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi))
out[:, 24] = (
-3.7550144126950569 * x2 * y2 + 0.62583573544917614 * x4 + 0.62583573544917614 * y4
) # 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
if sh_degree <= 5:
return
out[:, 25] = (
0.65638205684017015 * y * (10.0 * x2 * y2 - 5.0 * x4 - y4)
) # 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
out[:, 26] = (
8.3026492595241645 * xy * z * (x2 - y2)
) # 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi))
out[:, 27] = (
-0.48923829943525038 * y * (3.0 * x2 - y2) * (9.0 * z2 - 1.0)
) # -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
out[:, 28] = (
4.7935367849733241 * xy * z * (3.0 * z2 - 1.0)
) # sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi))
out[:, 29] = (
0.45294665119569694 * y * (14.0 * z2 - 21.0 * z4 - 1.0)
) # sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
out[:, 30] = (
0.1169503224534236 * z * (-70.0 * z2 + 63.0 * z4 + 15.0)
) # sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi))
out[:, 31] = (
0.45294665119569694 * x * (14.0 * z2 - 21.0 * z4 - 1.0)
) # sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
out[:, 32] = (
2.3967683924866621 * z * (x2 - y2) * (3.0 * z2 - 1.0)
) # sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi))
out[:, 33] = (
-0.48923829943525038 * x * (x2 - 3.0 * y2) * (9.0 * z2 - 1.0)
) # -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi))
out[:, 34] = (
2.0756623148810411 * z * (-6.0 * x2 * y2 + x4 + y4)
) # 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
out[:, 35] = (
0.65638205684017015 * x * (10.0 * x2 * y2 - x4 - 5.0 * y4)
) # 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
if sh_degree <= 6:
return
out[:, 36] = (
1.3663682103838286 * xy * (-10.0 * x2 * y2 + 3.0 * x4 + 3.0 * y4)
) # sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
out[:, 37] = (
2.3666191622317521 * yz * (10.0 * x2 * y2 - 5.0 * x4 - y4)
) # 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
out[:, 38] = (
2.0182596029148963 * xy * (x2 - y2) * (11.0 * z2 - 1.0)
) # 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
out[:, 39] = (
-0.92120525951492349 * yz * (3.0 * x2 - y2) * (11.0 * z2 - 3.0)
) # -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
out[:, 40] = (
0.92120525951492349 * xy * (-18.0 * z2 + 33.0 * z4 + 1.0)
) # sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
out[:, 41] = (
0.58262136251873131 * yz * (30.0 * z2 - 33.0 * z4 - 5.0)
) # sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
out[:, 42] = (
6.6747662381009842 * z2
- 20.024298714302954 * z4
+ 14.684485723822165 * z6
- 0.31784601133814211
) # sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi))
out[:, 43] = (
0.58262136251873131 * xz * (30.0 * z2 - 33.0 * z4 - 5.0)
) # sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
out[:, 44] = (
0.46060262975746175 * (x2 - y2) * (11.0 * z2 * (3.0 * z2 - 1.0) - 7.0 * z2 + 1.0)
) # sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi))
out[:, 45] = (
-0.92120525951492349 * xz * (x2 - 3.0 * y2) * (11.0 * z2 - 3.0)
) # -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi))
out[:, 46] = (
0.50456490072872406 * (11.0 * z2 - 1.0) * (-6.0 * x2 * y2 + x4 + y4)
) # 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
out[:, 47] = (
2.3666191622317521 * xz * (10.0 * x2 * y2 - x4 - 5.0 * y4)
) # 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
out[:, 48] = (
10.247761577878714 * x2 * y4
- 10.247761577878714 * x4 * y2
+ 0.6831841051919143 * x6
- 0.6831841051919143 * y6
) # sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
if sh_degree <= 7:
return
out[:, 49] = (
0.70716273252459627 * y * (-21.0 * x2 * y4 + 35.0 * x4 * y2 - 7.0 * x6 + y6)
) # 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi))
out[:, 50] = (
5.2919213236038001 * xy * z * (-10.0 * x2 * y2 + 3.0 * x4 + 3.0 * y4)
) # 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
out[:, 51] = (
-0.51891557872026028 * y * (13.0 * z2 - 1.0) * (-10.0 * x2 * y2 + 5.0 * x4 + y4)
) # -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi))
out[:, 52] = (
4.1513246297620823 * xy * z * (x2 - y2) * (13.0 * z2 - 3.0)
) # 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
out[:, 53] = (
-0.15645893386229404
* y
* (3.0 * x2 - y2)
* (13.0 * z2 * (11.0 * z2 - 3.0) - 27.0 * z2 + 3.0)
) # -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
out[:, 54] = (
0.44253269244498261 * xy * z * (-110.0 * z2 + 143.0 * z4 + 15.0)
) # 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
out[:, 55] = (
0.090331607582517306 * y * (-135.0 * z2 + 495.0 * z4 - 429.0 * z6 + 5.0)
) # sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
out[:, 56] = (
0.068284276912004949 * z * (315.0 * z2 - 693.0 * z4 + 429.0 * z6 - 35.0)
) # sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi))
out[:, 57] = (
0.090331607582517306 * x * (-135.0 * z2 + 495.0 * z4 - 429.0 * z6 + 5.0)
) # sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
out[:, 58] = (
0.07375544874083044
* z
* (x2 - y2)
* (143.0 * z2 * (3.0 * z2 - 1.0) - 187.0 * z2 + 45.0)
) # sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi))
out[:, 59] = (
-0.15645893386229404
* x
* (x2 - 3.0 * y2)
* (13.0 * z2 * (11.0 * z2 - 3.0) - 27.0 * z2 + 3.0)
) # -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
out[:, 60] = (
1.0378311574405206 * z * (13.0 * z2 - 3.0) * (-6.0 * x2 * y2 + x4 + y4)
) # 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
out[:, 61] = (
-0.51891557872026028 * x * (13.0 * z2 - 1.0) * (-10.0 * x2 * y2 + x4 + 5.0 * y4)
) # -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi))
out[:, 62] = (
2.6459606618019 * z * (15.0 * x2 * y4 - 15.0 * x4 * y2 + x6 - y6)
) # 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
out[:, 63] = (
0.70716273252459627 * x * (-35.0 * x2 * y4 + 21.0 * x4 * y2 - x6 + 7.0 * y6)
) # 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi))
_sh()
return out.view(batch_size, *shape, sh_degree**2)
================================================
FILE: shap_e/models/nn/meta.py
================================================
"""
Meta-learning modules based on: https://github.com/tristandeleu/pytorch-meta
MIT License
Copyright (c) 2019-2020 Tristan Deleu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import itertools
import re
from collections import OrderedDict
import torch.nn as nn
from shap_e.util.collections import AttrDict
__all__ = [
"MetaModule",
"subdict",
"superdict",
"leveldict",
"leveliter",
"batch_meta_parameters",
"batch_meta_state_dict",
]
def subdict(dictionary, key=None):
if dictionary is None:
return None
if (key is None) or (key == ""):
return dictionary
key_re = re.compile(r"^{0}\.(.+)".format(re.escape(key)))
return AttrDict(
OrderedDict(
(key_re.sub(r"\1", k), value)
for (k, value) in dictionary.items()
if key_re.match(k) is not None
)
)
def superdict(dictionary, key=None):
if dictionary is None:
return None
if (key is None) or (key == ""):
return dictionary
return AttrDict(OrderedDict((key + "." + k, value) for (k, value) in dictionary.items()))
def leveldict(dictionary, depth=0):
return AttrDict(leveliter(dictionary, depth=depth))
def leveliter(dictionary, depth=0):
"""
depth == 0 is root
"""
for key, value in dictionary.items():
if key.count(".") == depth:
yield key, value
class MetaModule(nn.Module):
"""
Base class for PyTorch meta-learning modules. These modules accept an
additional argument `params` in their `forward` method.
Notes
-----
Objects inherited from `MetaModule` are fully compatible with PyTorch
modules from `torch.nn.Module`. The argument `params` is a dictionary of
tensors, with full support of the computation graph (for differentiation).
Based on SIREN's torchmeta with some additional features/changes.
All meta weights must not have the batch dimension, as they are later tiled
to the given batch size after unsqueezing the first dimension (e.g. a
weight of dimension [d_out x d_in] is tiled to have the dimension [batch x
d_out x d_in]). Requiring all meta weights to have a batch dimension of 1
(e.g. [1 x d_out x d_in] from the earlier example) could be a more natural
choice, but this results in silent failures.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._meta_state_dict = set()
self._meta_params = set()
def register_meta_buffer(self, name: str, param: nn.Parameter):
"""
Registers a trainable or nontrainable parameter as a meta buffer. This
can be later retrieved by meta_state_dict
"""
self.register_buffer(name, param)
self._meta_state_dict.add(name)
def register_meta_parameter(self, name: str, parameter: nn.Parameter):
"""
Registers a meta parameter so it is included in named_meta_parameters
and meta_state_dict.
"""
self.register_parameter(name, parameter)
self._meta_params.add(name)
self._meta_state_dict.add(name)
def register_meta(self, name: str, parameter: nn.Parameter, trainable: bool = True):
if trainable:
self.register_meta_parameter(name, parameter)
else:
self.register_meta_buffer(name, parameter)
def register(self, name: str, parameter: nn.Parameter, meta: bool, trainable: bool = True):
if meta:
if trainable:
self.register_meta_parameter(name, parameter)
else:
self.register_meta_buffer(name, parameter)
else:
if trainable:
self.register_parameter(name, parameter)
else:
self.register_buffer(name, parameter)
def named_meta_parameters(self, prefix="", recurse=True):
"""
Returns an iterator over all the names and meta parameters
"""
def meta_iterator(module):
meta = module._meta_params if isinstance(module, MetaModule) else set()
for name, param in module._parameters.items():
if name in meta:
yield name, param
gen = self._named_members(
meta_iterator,
prefix=prefix,
recurse=recurse,
)
for name, param in gen:
yield name, param
def named_nonmeta_parameters(self, prefix="", recurse=True):
def _iterator(module):
meta = module._meta_params if isinstance(module, MetaModule) else set()
for name, param in module._parameters.items():
if name not in meta:
yield name, param
gen = self._named_members(
_iterator,
prefix=prefix,
recurse=recurse,
)
for name, param in gen:
yield name, param
def nonmeta_parameters(self, prefix="", recurse=True):
for _, param in self.named_nonmeta_parameters(prefix=prefix, recurse=recurse):
yield param
def meta_state_dict(self, prefix="", recurse=True):
"""
Returns an iterator over all the names and meta parameters/buffers.
One difference between module.state_dict() is that this preserves
requires_grad, because we may want to compute the gradient w.r.t. meta
buffers, but don't necessarily update them automatically.
"""
def meta_iterator(module):
meta = module._meta_state_dict if isinstance(module, MetaModule) else set()
for name, param in itertools.chain(module._buffers.items(), module._parameters.items()):
if name in meta:
yield name, param
gen = self._named_members(
meta_iterator,
prefix=prefix,
recurse=recurse,
)
return dict(gen)
def update(self, params=None):
"""
Updates the parameter list before the forward prop so that if `params`
is None or doesn't have a certain key, the module uses the default
parameter/buffer registered in the module.
"""
if params is None:
params = AttrDict()
params = AttrDict(params)
named_params = set([name for name, _ in self.named_parameters()])
for name, param in self.named_parameters():
params.setdefault(name, param)
for name, param in self.state_dict().items():
if name not in named_params:
params.setdefault(name, param)
return params
def batch_meta_parameters(net, batch_size):
params = AttrDict()
for name, param in net.named_meta_parameters():
params[name] = param.clone().unsqueeze(0).repeat(batch_size, *[1] * len(param.shape))
return params
def batch_meta_state_dict(net, batch_size):
state_dict = AttrDict()
meta_parameters = set([name for name, _ in net.named_meta_parameters()])
for name, param in net.meta_state_dict().items():
state_dict[name] = param.clone().unsqueeze(0).repeat(batch_size, *[1] * len(param.shape))
return state_dict
================================================
FILE: shap_e/models/nn/ops.py
================================================
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from shap_e.util.collections import AttrDict
from .meta import MetaModule, subdict
from .pointnet2_utils import sample_and_group, sample_and_group_all
def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def swish(x):
return x * torch.sigmoid(x)
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
def torch_gelu(x):
return torch.nn.functional.gelu(x)
def geglu(x):
v, gates = x.chunk(2, dim=-1)
return v * gelu(gates)
class SirenSin:
def __init__(self, w0=30.0):
self.w0 = w0
def __call__(self, x):
return torch.sin(self.w0 * x)
def get_act(name):
return {
"relu": torch.nn.functional.relu,
"leaky_relu": torch.nn.functional.leaky_relu,
"swish": swish,
"tanh": torch.tanh,
"gelu": gelu,
"quick_gelu": quick_gelu,
"torch_gelu": torch_gelu,
"gelu2": quick_gelu,
"geglu": geglu,
"sigmoid": torch.sigmoid,
"sin": torch.sin,
"sin30": SirenSin(w0=30.0),
"softplus": F.softplus,
"exp": torch.exp,
"identity": lambda x: x,
}[name]
def zero_init(affine):
nn.init.constant_(affine.weight, 0.0)
if affine.bias is not None:
nn.init.constant_(affine.bias, 0.0)
def siren_init_first_layer(affine, init_scale: float = 1.0):
n_input = affine.weight.shape[1]
u = init_scale / n_input
nn.init.uniform_(affine.weight, -u, u)
if affine.bias is not None:
nn.init.constant_(affine.bias, 0.0)
def siren_init(affine, coeff=1.0, init_scale: float = 1.0):
n_input = affine.weight.shape[1]
u = init_scale * np.sqrt(6.0 / n_input) / coeff
nn.init.uniform_(affine.weight, -u, u)
if affine.bias is not None:
nn.init.constant_(affine.bias, 0.0)
def siren_init_30(affine, init_scale: float = 1.0):
siren_init(affine, coeff=30.0, init_scale=init_scale)
def std_init(affine, init_scale: float = 1.0):
n_in = affine.weight.shape[1]
stddev = init_scale / math.sqrt(n_in)
nn.init.normal_(affine.weight, std=stddev)
if affine.bias is not None:
nn.init.constant_(affine.bias, 0.0)
def mlp_init(affines, init: Optional[str] = None, init_scale: float = 1.0):
if init == "siren30":
for idx, affine in enumerate(affines):
init = siren_init_first_layer if idx == 0 else siren_init_30
init(affine, init_scale=init_scale)
elif init == "siren":
for idx, affine in enumerate(affines):
init = siren_init_first_layer if idx == 0 else siren_init
init(affine, init_scale=init_scale)
elif init is None:
for affine in affines:
std_init(affine, init_scale=init_scale)
else:
raise NotImplementedError(init)
class MetaLinear(MetaModule):
def __init__(
self,
n_in,
n_out,
bias: bool = True,
meta_scale: bool = True,
meta_shift: bool = True,
meta_proj: bool = False,
meta_bias: bool = False,
trainable_meta: bool = False,
**kwargs,
):
super().__init__()
# n_in, n_out, bias=bias)
register_meta_fn = (
self.register_meta_parameter if trainable_meta else self.register_meta_buffer
)
if meta_scale:
register_meta_fn("scale", nn.Parameter(torch.ones(n_out, **kwargs)))
if meta_shift:
register_meta_fn("shift", nn.Parameter(torch.zeros(n_out, **kwargs)))
register_proj_fn = self.register_parameter if not meta_proj else register_meta_fn
register_proj_fn("weight", nn.Parameter(torch.empty((n_out, n_in), **kwargs)))
if not bias:
self.register_parameter("bias", None)
else:
register_bias_fn = self.register_parameter if not meta_bias else register_meta_fn
register_bias_fn("bias", nn.Parameter(torch.empty(n_out, **kwargs)))
self.reset_parameters()
def reset_parameters(self) -> None:
# from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
def _bcast(self, op, left, right):
if right.ndim == 2:
# Has dimension [batch x d_output]
right = right.unsqueeze(1)
return op(left, right)
def forward(self, x, params=None):
params = self.update(params)
batch_size, *shape, d_in = x.shape
x = x.view(batch_size, -1, d_in)
if params.weight.ndim == 2:
h = torch.einsum("bni,oi->bno", x, params.weight)
elif params.weight.ndim == 3:
h = torch.einsum("bni,boi->bno", x, params.weight)
if params.bias is not None:
h = self._bcast(torch.add, h, params.bias)
if params.scale is not None:
h = self._bcast(torch.mul, h, params.scale)
if params.shift is not None:
h = self._bcast(torch.add, h, params.shift)
h = h.view(batch_size, *shape, -1)
return h
def Conv(n_dim, d_in, d_out, kernel, stride=1, padding=0, dilation=1, **kwargs):
cls = {
1: nn.Conv1d,
2: nn.Conv2d,
3: nn.Conv3d,
}[n_dim]
return cls(d_in, d_out, kernel, stride=stride, padding=padding, dilation=dilation, **kwargs)
def flatten(x):
batch_size, *shape, n_channels = x.shape
n_ctx = np.prod(shape)
return x.view(batch_size, n_ctx, n_channels), AttrDict(
shape=shape, n_ctx=n_ctx, n_channels=n_channels
)
def unflatten(x, info):
batch_size = x.shape[0]
return x.view(batch_size, *info.shape, info.n_channels)
def torchify(x):
extent = list(range(1, x.ndim - 1))
return x.permute([0, x.ndim - 1, *extent])
def untorchify(x):
extent = list(range(2, x.ndim))
return x.permute([0, *extent, 1])
class MLP(nn.Module):
def __init__(
self,
d_input: int,
d_hidden: List[int],
d_output: int,
act_name: str = "quick_gelu",
bias: bool = True,
init: Optional[str] = None,
init_scale: float = 1.0,
zero_out: bool = False,
):
"""
Required: d_input, d_hidden, d_output
Optional: act_name, bias
"""
super().__init__()
ds = [d_input] + d_hidden + [d_output]
affines = [nn.Linear(d_in, d_out, bias=bias) for d_in, d_out in zip(ds[:-1], ds[1:])]
self.d = ds
self.affines = nn.ModuleList(affines)
self.act = get_act(act_name)
mlp_init(self.affines, init=init, init_scale=init_scale)
if zero_out:
zero_init(affines[-1])
def forward(self, h, options: Optional[AttrDict] = None, log_prefix: str = ""):
options = AttrDict() if options is None else AttrDict(options)
*hid, out = self.affines
for i, f in enumerate(hid):
h = self.act(f(h))
h = out(h)
return h
class MetaMLP(MetaModule):
def __init__(
self,
d_input: int,
d_hidden: List[int],
d_output: int,
act_name: str = "quick_gelu",
bias: bool = True,
meta_scale: bool = True,
meta_shift: bool = True,
meta_proj: bool = False,
meta_bias: bool = False,
trainable_meta: bool = False,
init: Optional[str] = None,
init_scale: float = 1.0,
zero_out: bool = False,
):
super().__init__()
ds = [d_input] + d_hidden + [d_output]
affines = [
MetaLinear(
d_in,
d_out,
bias=bias,
meta_scale=meta_scale,
meta_shift=meta_shift,
meta_proj=meta_proj,
meta_bias=meta_bias,
trainable_meta=trainable_meta,
)
for d_in, d_out in zip(ds[:-1], ds[1:])
]
self.d = ds
self.affines = nn.ModuleList(affines)
self.act = get_act(act_name)
mlp_init(affines, init=init, init_scale=init_scale)
if zero_out:
zero_init(affines[-1])
def forward(self, h, params=None, options: Optional[AttrDict] = None, log_prefix: str = ""):
options = AttrDict() if options is None else AttrDict(options)
params = self.update(params)
*hid, out = self.affines
for i, layer in enumerate(hid):
h = self.act(layer(h, params=subdict(params, f"{log_prefix}affines.{i}")))
last = len(self.affines) - 1
h = out(h, params=subdict(params, f"{log_prefix}affines.{last}"))
return h
class LayerNorm(nn.LayerNorm):
def __init__(
self, norm_shape: Union[int, Tuple[int]], eps: float = 1e-5, elementwise_affine: bool = True
):
super().__init__(norm_shape, eps=eps, elementwise_affine=elementwise_affine)
self.width = np.prod(norm_shape)
self.max_numel = 65535 * self.width
def forward(self, input):
if input.numel() > self.max_numel:
return F.layer_norm(
input.float(), self.normalized_shape, self.weight, self.bias, self.eps
).type_as(input)
else:
return super(LayerNorm, self).forward(input.float()).type_as(input)
class PointSetEmbedding(nn.Module):
def __init__(
self,
*,
radius: float,
n_point: int,
n_sample: int,
d_input: int,
d_hidden: List[int],
patch_size: int = 1,
stride: int = 1,
activation: str = "swish",
group_all: bool = False,
padding_mode: str = "zeros",
fps_method: str = "fps",
**kwargs,
):
super().__init__()
self.n_point = n_point
self.radius = radius
self.n_sample = n_sample
self.mlp_convs = nn.ModuleList()
self.act = get_act(activation)
self.patch_size = patch_size
self.stride = stride
last_channel = d_input + 3
for out_channel in d_hidden:
self.mlp_convs.append(
nn.Conv2d(
last_channel,
out_channel,
kernel_size=(patch_size, 1),
stride=(stride, 1),
padding=(patch_size // 2, 0),
padding_mode=padding_mode,
**kwargs,
)
)
last_channel = out_channel
self.group_all = group_all
self.fps_method = fps_method
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_points: sample points feature data, [B, d_hidden[-1], n_point]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(
self.n_point,
self.radius,
self.n_sample,
xyz,
points,
deterministic=not self.training,
fps_method=self.fps_method,
)
# new_xyz: sampled points position data, [B, n_point, C]
# new_points: sampled points data, [B, n_point, n_sample, C+D]
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, n_sample, n_point]
for i, conv in enumerate(self.mlp_convs):
new_points = self.act(self.apply_conv(new_points, conv))
new_points = new_points.mean(dim=2)
return new_points
def apply_conv(self, points: torch.Tensor, conv: nn.Module):
batch, channels, n_samples, _ = points.shape
# Shuffle the representations
if self.patch_size > 1:
# TODO shuffle deterministically when not self.training
_, indices = torch.rand(batch, channels, n_samples, 1, device=points.device).sort(dim=2)
points = torch.gather(points, 2, torch.broadcast_to(indices, points.shape))
return conv(points)
================================================
FILE: shap_e/models/nn/pointnet2_utils.py
================================================
"""
Based on https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet2_utils.py
MIT License
Copyright (c) 2019 benny
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from time import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def timeit(tag, t):
print("{}: {}s".format(tag, time() - t))
return time()
def pc_normalize(pc):
l = pc.shape[0]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = (
torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint, deterministic=False):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
if deterministic:
farthest = torch.arange(0, B, dtype=torch.long).to(device)
else:
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius**2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
def sample_and_group(
npoint,
radius,
nsample,
xyz,
points,
returnfps=False,
deterministic=False,
fps_method: str = "fps",
):
"""
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
B, N, C = xyz.shape
S = npoint
if fps_method == "fps":
fps_idx = farthest_point_sample(xyz, npoint, deterministic=deterministic) # [B, npoint, C]
elif fps_method == "first":
fps_idx = torch.arange(npoint)[None].repeat(B, 1)
else:
raise ValueError(f"Unknown FPS method: {fps_method}")
new_xyz = index_points(xyz, fps_idx)
idx = query_ball_point(radius, nsample, xyz, new_xyz)
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, idx)
new_points = torch.cat(
[grouped_xyz_norm, grouped_points], dim=-1
) # [B, npoint, nsample, C+D]
else:
new_points = grouped_xyz_norm
if returnfps:
return new_xyz, new_points, grouped_xyz, fps_idx
else:
return new_xyz, new_points
def sample_and_group_all(xyz, points):
"""
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
"""
device = xyz.device
B, N, C = xyz.shape
new_xyz = torch.zeros(B, 1, C).to(device)
grouped_xyz = xyz.view(B, 1, N, C)
if points is not None:
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
else:
new_points = grouped_xyz
return new_xyz, new_points
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(
self.npoint, self.radius, self.nsample, xyz, points, deterministic=not self.training
)
# new_xyz: sampled points position data, [B, npoint, C]
# new_points: sampled points data, [B, npoint, nsample, C+D]
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
new_points = torch.max(new_points, 2)[0]
new_xyz = new_xyz.permute(0, 2, 1)
return new_xyz, new_points
class PointNetSetAbstractionMsg(nn.Module):
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
super(PointNetSetAbstractionMsg, self).__init__()
self.npoint = npoint
self.radius_list = radius_list
self.nsample_list = nsample_list
self.conv_blocks = nn.ModuleList()
self.bn_blocks = nn.ModuleList()
for i in range(len(mlp_list)):
convs = nn.ModuleList()
bns = nn.ModuleList()
last_channel = in_channel + 3
for out_channel in mlp_list[i]:
convs.append(nn.Conv2d(last_channel, out_channel, 1))
bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.conv_blocks.append(convs)
self.bn_blocks.append(bns)
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
B, N, C = xyz.shape
S = self.npoint
new_xyz = index_points(xyz, farthest_point_sample(xyz, S, deterministic=not self.training))
new_points_list = []
for i, radius in enumerate(self.radius_list):
K = self.nsample_list[i]
group_idx = query_ball_point(radius, K, xyz, new_xyz)
grouped_xyz = index_points(xyz, group_idx)
grouped_xyz -= new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, group_idx)
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
else:
grouped_points = grouped_xyz
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
for j in range(len(self.conv_blocks[i])):
conv = self.conv_blocks[i][j]
bn = self.bn_blocks[i][j]
grouped_points = F.relu(bn(conv(grouped_points)))
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
new_points_list.append(new_points)
new_xyz = new_xyz.permute(0, 2, 1)
new_points_concat = torch.cat(new_points_list, dim=1)
return new_xyz, new_points_concat
class PointNetFeaturePropagation(nn.Module):
def __init__(self, in_channel, mlp):
super(PointNetFeaturePropagation, self).__init__()
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
last_channel = out_channel
def forward(self, xyz1, xyz2, points1, points2):
"""
Input:
xyz1: input points position data, [B, C, N]
xyz2: sampled input points position data, [B, C, S]
points1: input points data, [B, D, N]
points2: input points data, [B, D, S]
Return:
new_points: upsampled points data, [B, D', N]
"""
xyz1 = xyz1.permute(0, 2, 1)
xyz2 = xyz2.permute(0, 2, 1)
points2 = points2.permute(0, 2, 1)
B, N, C = xyz1.shape
_, S, _ = xyz2.shape
if S == 1:
interpolated_points = points2.repeat(1, N, 1)
else:
dists = square_distance(xyz1, xyz2)
dists, idx = dists.sort(dim=-1)
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_points = torch.sum(
index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2
)
if points1 is not None:
points1 = points1.permute(0, 2, 1)
new_points = torch.cat([points1, interpolated_points], dim=-1)
else:
new_points = interpolated_points
new_points = new_points.permute(0, 2, 1)
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
return new_points
================================================
FILE: shap_e/models/nn/utils.py
================================================
from typing import Iterable, Union
import numpy as np
import torch
ArrayType = Union[np.ndarray, Iterable[int], torch.Tensor]
def to_torch(arr: ArrayType, dtype=torch.float):
if isinstance(arr, torch.Tensor):
return arr
return torch.from_numpy(np.array(arr)).to(dtype)
def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor:
"""
Sample from the given discrete probability distribution with replacement.
The i-th bin is assumed to have mass pmf[i].
:param pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all()
:param n_samples: number of samples
:return: indices sampled with replacement
"""
*shape, support_size, last_dim = pmf.shape
assert last_dim == 1
cdf = torch.cumsum(pmf.view(-1, support_size), dim=1)
inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device))
return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1)
def safe_divide(a, b, epsilon=1e-6):
return a / torch.where(b < 0, b - epsilon, b + epsilon)
================================================
FILE: shap_e/models/query.py
================================================
from dataclasses import dataclass
from typing import Callable, Optional
import torch
@dataclass
class Query:
# Both of these are of shape [batch_size x ... x 3]
position: torch.Tensor
direction: Optional[torch.Tensor] = None
t_min: Optional[torch.Tensor] = None
t_max: Optional[torch.Tensor] = None
def copy(self) -> "Query":
return Query(
position=self.position,
direction=self.direction,
t_min=self.t_min,
t_max=self.t_max,
)
def map_tensors(self, f: Callable[[torch.Tensor], torch.Tensor]) -> "Query":
return Query(
position=f(self.position),
direction=f(self.direction) if self.direction is not None else None,
t_min=f(self.t_min) if self.t_min is not None else None,
t_max=f(self.t_max) if self.t_max is not None else None,
)
================================================
FILE: shap_e/models/renderer.py
================================================
from abc import abstractmethod
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from shap_e.models.nn.camera import (
DifferentiableCamera,
DifferentiableProjectiveCamera,
get_image_coords,
projective_camera_frame,
)
from shap_e.models.nn.meta import MetaModule
from shap_e.util.collections import AttrDict
class Renderer(MetaModule):
"""
A rendering abstraction that can render rays and views by calling the
appropriate models. The models are instantiated outside but registered in
this module.
"""
@abstractmethod
def render_views(
self,
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
) -> AttrDict:
"""
Returns a backproppable rendering of a view
:param batch: contains
- height: Optional[int]
- width: Optional[int]
- inner_batch_size or ray_batch_size: Optional[int] defaults to 4096 rays
And additionally, to specify poses with a default up direction:
- poses: [batch_size x *shape x 2 x 3] where poses[:, ..., 0, :] are the camera
positions, and poses[:, ..., 1, :] are the z-axis (toward the object) of
the camera frame.
- camera: DifferentiableCamera. Assumes the same camera position
across batch for simplicity. Could eventually support
batched cameras.
or to specify a batch of arbitrary poses:
- cameras: DifferentiableCameraBatch of shape [batch_size x *shape].
:param params: Meta parameters
:param options: Optional[Dict]
"""
class RayRenderer(Renderer):
"""
A rendering abstraction that can render rays and views by calling the
appropriate models. The models are instantiated outside but registered in
this module.
"""
@abstractmethod
def render_rays(
self,
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
) -> AttrDict:
"""
:param batch: has
- rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray.
- radii (optional): [batch_size x ... x 1] the "thickness" of each ray.
:param options: Optional[Dict]
"""
def render_views(
self,
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
device: torch.device = torch.device("cuda"),
) -> AttrDict:
output = render_views_from_rays(
self.render_rays,
batch,
params=params,
options=options,
device=self.device,
)
return output
def forward(
self,
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
) -> AttrDict:
"""
:param batch: must contain either
- rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray.
or
- poses: [batch_size x 2 x 3] where poses[:, 0] are the camera
positions, and poses[:, 1] are the z-axis (toward the object) of
the camera frame.
- camera: an instance of Camera that implements camera_rays
or
- cameras: DifferentiableCameraBatch of shape [batch_size x *shape].
For both of the above two options, these may be specified.
- height: Optional[int]
- width: Optional[int]
- ray_batch_size or inner_batch_size: Optional[int] defaults to 4096 rays
:param params: a dictionary of optional meta parameters.
:param options: A Dict of other hyperparameters that could be
related to rendering or debugging
:return: a dictionary containing
- channels: [batch_size, *shape, n_channels]
- distances: [batch_size, *shape, 1]
- transmittance: [batch_size, *shape, 1]
- aux_losses: Dict[str, torch.Tensor]
"""
if "rays" in batch:
for key in ["poses", "camera", "height", "width"]:
assert key not in batch
return self.render_rays(batch, params=params, options=options)
elif "poses" in batch or "cameras" in batch:
assert "rays" not in batch
if "poses" in batch:
assert "camera" in batch
else:
assert "camera" not in batch
return self.render_views(batch, params=params, options=options)
raise NotImplementedError
def get_camera_from_batch(batch: AttrDict) -> Tuple[DifferentiableCamera, int, Tuple[int]]:
if "poses" in batch:
assert not "cameras" in batch
batch_size, *inner_shape, n_vecs, spatial_dim = batch.poses.shape
assert n_vecs == 2 and spatial_dim == 3
inner_batch_size = int(np.prod(inner_shape))
poses = batch.poses.view(batch_size * inner_batch_size, 2, 3)
position, direction = poses[:, 0], poses[:, 1]
camera = projective_camera_frame(position, direction, batch.camera)
elif "cameras" in batch:
assert not "camera" in batch
batch_size, *inner_shape = batch.cameras.shape
camera = batch.cameras.flat_camera
else:
raise ValueError(f'neither "poses" nor "cameras" found in keys: {batch.keys()}')
if "height" in batch and "width" in batch:
camera = camera.resize_image(batch.width, batch.height)
return camera, batch_size, inner_shape
def append_tensor(val_list: Optional[List[torch.Tensor]], output: Optional[torch.Tensor]):
if val_list is None:
return [output]
return val_list + [output]
def render_views_from_rays(
render_rays: Callable[[AttrDict, AttrDict, AttrDict], AttrDict],
batch: AttrDict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
device: torch.device = torch.device("cuda"),
) -> AttrDict:
camera, batch_size, inner_shape = get_camera_from_batch(batch)
inner_batch_size = int(np.prod(inner_shape))
coords = get_image_coords(camera.width, camera.height).to(device)
coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape])
rays = camera.camera_rays(coords)
# mip-NeRF radii calculation from: https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/datasets.py#L193-L200
directions = rays.view(batch_size, inner_batch_size, camera.height, camera.width, 2, 3)[
..., 1, :
]
neighbor_dists = torch.linalg.norm(directions[:, :, :, 1:] - directions[:, :, :, :-1], dim=-1)
neighbor_dists = torch.cat([neighbor_dists, neighbor_dists[:, :, :, -2:-1]], dim=3)
radii = (neighbor_dists * 2 / np.sqrt(12)).view(batch_size, -1, 1)
rays = rays.view(batch_size, inner_batch_size * camera.height * camera.width, 2, 3)
if isinstance(camera, DifferentiableProjectiveCamera):
# Compute the camera z direction corresponding to every ray's pixel.
# Used for depth computations below.
z_directions = (
(camera.z / torch.linalg.norm(camera.z, dim=-1, keepdim=True))
.reshape([batch_size, inner_batch_size, 1, 3])
.repeat(1, 1, camera.width * camera.height, 1)
.reshape(1, inner_batch_size * camera.height * camera.width, 3)
)
ray_batch_size = batch.get("ray_batch_size", batch.get("inner_batch_size", 4096))
assert rays.shape[1] % ray_batch_size == 0
n_batches = rays.shape[1] // ray_batch_size
output_list = AttrDict(aux_losses=dict())
for idx in range(n_batches):
rays_batch = AttrDict(
rays=rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size],
radii=radii[:, idx * ray_batch_size : (idx + 1) * ray_batch_size],
)
output = render_rays(rays_batch, params=params, options=options)
if isinstance(camera, DifferentiableProjectiveCamera):
z_batch = z_directions[:, idx * ray_batch_size : (idx + 1) * ray_batch_size]
ray_directions = rays_batch.rays[:, :, 1]
z_dots = (ray_directions * z_batch).sum(-1, keepdim=True)
output.depth = output.distances * z_dots
output_list = output_list.combine(output, append_tensor)
def _resize(val_list: List[torch.Tensor]):
val = torch.cat(val_list, dim=1)
assert val.shape[1] == inner_batch_size * camera.height * camera.width
return val.view(batch_size, *inner_shape, camera.height, camera.width, -1)
def _avg(_key: str, loss_list: List[torch.Tensor]):
return sum(loss_list) / n_batches
output = AttrDict(
{name: _resize(val_list) for name, val_list in output_list.items() if name != "aux_losses"}
)
output.aux_losses = output_list.aux_losses.map(_avg)
return output
================================================
FILE: shap_e/models/stf/__init__.py
================================================
================================================
FILE: shap_e/models/stf/base.py
================================================
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import torch
from shap_e.models.query import Query
from shap_e.models.renderer import append_tensor
from shap_e.util.collections import AttrDict
class Model(ABC):
@abstractmethod
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
"""
Predict an attribute given position
"""
def forward_batched(
self,
query: Query,
query_batch_size: int = 4096,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
if not query.position.numel():
# Avoid torch.cat() of zero tensors.
return self(query, params=params, options=options)
if options.cache is None:
created_cache = True
options.cache = AttrDict()
else:
created_cache = False
results_list = AttrDict()
for i in range(0, query.position.shape[1], query_batch_size):
out = self(
query=query.map_tensors(lambda x, i=i: x[:, i : i + query_batch_size]),
params=params,
options=options,
)
results_list = results_list.combine(out, append_tensor)
if created_cache:
del options["cache"]
return results_list.map(lambda key, tensor_list: torch.cat(tensor_list, dim=1))
================================================
FILE: shap_e/models/stf/mlp.py
================================================
from functools import partial
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from shap_e.models.nn.checkpoint import checkpoint
from shap_e.models.nn.encoding import encode_position, maybe_encode_direction
from shap_e.models.nn.meta import MetaModule, subdict
from shap_e.models.nn.ops import MetaLinear, get_act, mlp_init
from shap_e.models.query import Query
from shap_e.util.collections import AttrDict
from .base import Model
class MLPModel(MetaModule, Model):
def __init__(
self,
n_output: int,
output_activation: str,
# Positional encoding parameters
posenc_version: str = "v1",
# Direction related channel prediction
insert_direction_at: Optional[int] = None,
# MLP parameters
d_hidden: int = 256,
n_hidden_layers: int = 4,
activation: str = "relu",
init: Optional[str] = None,
init_scale: float = 1.0,
meta_parameters: bool = False,
trainable_meta: bool = False,
meta_proj: bool = True,
meta_bias: bool = True,
meta_start: int = 0,
meta_stop: Optional[int] = None,
n_meta_layers: Optional[int] = None,
register_freqs: bool = False,
device: torch.device = torch.device("cuda"),
):
super().__init__()
if register_freqs:
self.register_buffer("freqs", 2.0 ** torch.arange(10, device=device).view(1, 10))
# Positional encoding
self.posenc_version = posenc_version
dummy = torch.eye(1, 3)
d_posenc_pos = encode_position(posenc_version, position=dummy).shape[-1]
d_posenc_dir = maybe_encode_direction(posenc_version, position=dummy).shape[-1]
# Instantiate the MLP
mlp_widths = [d_hidden] * n_hidden_layers
input_widths = [d_posenc_pos, *mlp_widths]
output_widths = mlp_widths + [n_output]
self.meta_parameters = meta_parameters
# When this model is used jointly to express NeRF, it may have to
# process directions as well in which case we simply concatenate
# the direction representation at the specified layer.
self.insert_direction_at = insert_direction_at
if insert_direction_at is not None:
input_widths[self.insert_direction_at] += d_posenc_dir
linear_cls = lambda meta: (
partial(
MetaLinear,
meta_scale=False,
meta_shift=False,
meta_proj=meta_proj,
meta_bias=meta_bias,
trainable_meta=trainable_meta,
)
if meta
else nn.Linear
)
if meta_stop is None:
if n_meta_layers is not None:
assert n_meta_layers > 0
meta_stop = meta_start + n_meta_layers - 1
else:
meta_stop = n_hidden_layers
if meta_parameters:
metas = [meta_start <= layer <= meta_stop for layer in range(n_hidden_layers + 1)]
else:
metas = [False] * (n_hidden_layers + 1)
self.mlp = nn.ModuleList(
[
linear_cls(meta)(d_in, d_out, device=device)
for meta, d_in, d_out in zip(metas, input_widths, output_widths)
]
)
mlp_init(self.mlp, init=init, init_scale=init_scale)
self.activation = get_act(activation)
self.output_activation = get_act(output_activation)
self.device = device
self.to(device)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
"""
:param position: [batch_size x ... x 3]
:param params: Meta parameters
:param options: Optional hyperparameters
"""
# query.direction is None typically for SDF models and training
h_final, _h_directionless = self._mlp(
query.position, query.direction, params=params, options=options
)
return self.output_activation(h_final)
def _run_mlp(
self, position: torch.Tensor, direction: torch.Tensor, params: AttrDict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:return: the final and directionless activations at the given query
"""
h_preact = h = encode_position(self.posenc_version, position=position)
h_directionless = None
for i, layer in enumerate(self.mlp):
if i == self.insert_direction_at:
h_directionless = h_preact
h_direction = maybe_encode_direction(
self.posenc_version, position=position, direction=direction
)
h = torch.cat([h, h_direction], dim=-1)
if isinstance(layer, MetaLinear):
h = layer(h, params=subdict(params, f"mlp.{i}"))
else:
h = layer(h)
h_preact = h
if i < len(self.mlp) - 1:
h = self.activation(h)
h_final = h
if h_directionless is None:
h_directionless = h_preact
return h_final, h_directionless
def _mlp(
self,
position: torch.Tensor,
direction: Optional[torch.Tensor] = None,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param position: [batch_size x ... x 3]
:param params: Meta parameters
:param options: Optional hyperparameters
:return: the final and directionless activations at the given query
"""
params = self.update(params)
options = AttrDict() if options is None else AttrDict(options)
mlp = partial(self._run_mlp, direction=direction, params=params)
parameters = []
for i, layer in enumerate(self.mlp):
if isinstance(layer, MetaLinear):
parameters.extend(list(subdict(params, f"mlp.{i}").values()))
else:
parameters.extend(layer.parameters())
h_final, h_directionless = checkpoint(
mlp, (position,), parameters, options.checkpoint_stf_model
)
return h_final, h_directionless
class MLPSDFModel(MLPModel):
def __init__(self, initial_bias: float = -0.1, **kwargs):
super().__init__(n_output=1, output_activation="identity", **kwargs)
self.mlp[-1].bias.data.fill_(initial_bias)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
signed_distance = super().forward(query=query, params=params, options=options)
return AttrDict(signed_distance=signed_distance)
class MLPTextureFieldModel(MLPModel):
def __init__(
self,
n_channels: int = 3,
**kwargs,
):
super().__init__(n_output=n_channels, output_activation="sigmoid", **kwargs)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
channels = super().forward(query=query, params=params, options=options)
return AttrDict(channels=channels)
================================================
FILE: shap_e/models/stf/renderer.py
================================================
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from shap_e.models.nn.camera import DifferentiableCamera, DifferentiableProjectiveCamera
from shap_e.models.nn.meta import subdict
from shap_e.models.nn.utils import to_torch
from shap_e.models.query import Query
from shap_e.models.renderer import Renderer, get_camera_from_batch
from shap_e.models.volume import BoundingBoxVolume, Volume
from shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR
from shap_e.rendering.mc import marching_cubes
from shap_e.rendering.torch_mesh import TorchMesh
from shap_e.rendering.view_data import ProjectiveCamera
from shap_e.util.collections import AttrDict
from .base import Model
class STFRendererBase(ABC):
@abstractmethod
def get_signed_distance(
self,
position: torch.Tensor,
params: Dict[str, torch.Tensor],
options: AttrDict[str, Any],
) -> torch.Tensor:
pass
@abstractmethod
def get_texture(
self,
position: torch.Tensor,
params: Dict[str, torch.Tensor],
options: AttrDict[str, Any],
) -> torch.Tensor:
pass
class STFRenderer(Renderer, STFRendererBase):
def __init__(
self,
sdf: Model,
tf: Model,
volume: Volume,
grid_size: int,
texture_channels: Sequence[str] = ("R", "G", "B"),
channel_scale: Sequence[float] = (255.0, 255.0, 255.0),
ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR,
diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR,
specular_color: Union[float, Tuple[float]] = 0.0,
output_srgb: bool = True,
device: torch.device = torch.device("cuda"),
**kwargs,
):
super().__init__(**kwargs)
assert isinstance(volume, BoundingBoxVolume), "cannot sample points in unknown volume"
self.sdf = sdf
self.tf = tf
self.volume = volume
self.grid_size = grid_size
self.texture_channels = texture_channels
self.channel_scale = to_torch(channel_scale).to(device)
self.ambient_color = ambient_color
self.diffuse_color = diffuse_color
self.specular_color = specular_color
self.output_srgb = output_srgb
self.device = device
self.to(device)
def render_views(
self,
batch: Dict,
params: Optional[Dict] = None,
options: Optional[Dict] = None,
) -> AttrDict:
params = self.update(params)
options = AttrDict() if not options else AttrDict(options)
sdf_fn = partial(self.sdf.forward_batched, params=subdict(params, "sdf"))
tf_fn = partial(self.tf.forward_batched, params=subdict(params, "tf"))
nerstf_fn = None
return render_views_from_stf(
batch,
options,
sdf_fn=sdf_fn,
tf_fn=tf_fn,
nerstf_fn=nerstf_fn,
volume=self.volume,
grid_size=self.grid_size,
channel_scale=self.channel_scale,
texture_channels=self.texture_channels,
ambient_color=self.ambient_color,
diffuse_color=self.diffuse_color,
specular_color=self.specular_color,
output_srgb=self.output_srgb,
device=self.device,
)
def get_signed_distance(
self,
query: Query,
params: Dict[str, torch.Tensor],
options: AttrDict[str, Any],
) -> torch.Tensor:
return self.sdf(
query,
params=subdict(params, "sdf"),
options=options,
).signed_distance
def get_texture(
self,
query: Query,
params: Dict[str, torch.Tensor],
options: AttrDict[str, Any],
) -> torch.Tensor:
return self.tf(
query,
params=subdict(params, "tf"),
options=options,
).channels
def render_views_from_stf(
batch: Dict,
options: AttrDict[str, Any],
*,
sdf_fn: Optional[Callable],
tf_fn: Optional[Callable],
nerstf_fn: Optional[Callable],
volume: BoundingBoxVolume,
grid_size: int,
channel_scale: torch.Tensor,
texture_channels: Sequence[str] = ("R", "G", "B"),
ambient_color: Union[float, Tuple[float]] = 0.0,
diffuse_color: Union[float, Tuple[float]] = 1.0,
specular_color: Union[float, Tuple[float]] = 0.2,
output_srgb: bool = False,
device: torch.device = torch.device("cuda"),
) -> AttrDict:
"""
:param batch: contains either ["poses", "camera"], or ["cameras"]. Can
optionally contain any of ["height", "width", "query_batch_size"]
:param options: controls checkpointing, caching, and rendering
:param sdf_fn: returns [batch_size, query_batch_size, n_output] where
n_output >= 1.
:param tf_fn: returns [batch_size, query_batch_size, n_channels]
:param volume: AABB volume
:param grid_size: SDF sampling resolution
:param texture_channels: what texture to predict
:param channel_scale: how each channel is scaled
:return: at least
channels: [batch_size, len(cameras), height, width, 3]
transmittance: [batch_size, len(cameras), height, width, 1]
aux_losses: AttrDict[str, torch.Tensor]
"""
camera, batch_size, inner_shape = get_camera_from_batch(batch)
inner_batch_size = int(np.prod(inner_shape))
assert camera.width == camera.height, "only square views are supported"
assert camera.x_fov == camera.y_fov, "only square views are supported"
assert isinstance(camera, DifferentiableProjectiveCamera)
device = camera.origin.device
device_type = device.type
TO_CACHE = ["fields", "raw_meshes", "raw_signed_distance", "raw_density", "mesh_mask", "meshes"]
if options.cache is not None and all(key in options.cache for key in TO_CACHE):
fields = options.cache.fields
raw_meshes = options.cache.raw_meshes
raw_signed_distance = options.cache.raw_signed_distance
raw_density = options.cache.raw_density
mesh_mask = options.cache.mesh_mask
else:
query_batch_size = batch.get("query_batch_size", batch.get("ray_batch_size", 4096))
query_points = volume_query_points(volume, grid_size)
fn = nerstf_fn if sdf_fn is None else sdf_fn
sdf_out = fn(
query=Query(position=query_points[None].repeat(batch_size, 1, 1)),
query_batch_size=query_batch_size,
options=options,
)
raw_signed_distance = sdf_out.signed_distance
raw_density = None
if "density" in sdf_out:
raw_density = sdf_out.density
with torch.autocast(device_type, enabled=False):
fields = sdf_out.signed_distance.float()
raw_signed_distance = sdf_out.signed_distance
assert (
len(fields.shape) == 3 and fields.shape[-1] == 1
), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
fields = fields.reshape(batch_size, *([grid_size] * 3))
# Force a negative border around the SDFs to close off all the models.
full_grid = torch.zeros(
batch_size,
grid_size + 2,
grid_size + 2,
grid_size + 2,
device=fields.device,
dtype=fields.dtype,
)
full_grid.fill_(-1.0)
full_grid[:, 1:-1, 1:-1, 1:-1] = fields
fields = full_grid
raw_meshes = []
mesh_mask = []
for field in fields:
raw_mesh = marching_cubes(field, volume.bbox_min, volume.bbox_max - volume.bbox_min)
if len(raw_mesh.faces) == 0:
# DDP deadlocks when there are unused parameters on some ranks
# and not others, so we make sure the field is a dependency in
# the graph regardless of empty meshes.
vertex_dependency = field.mean()
raw_mesh = TorchMesh(
verts=torch.zeros(3, 3, device=device) + vertex_dependency,
faces=torch.tensor([[0, 1, 2]], dtype=torch.long, device=device),
)
# Make sure we only feed back zero gradients to the field
# by masking out the final renderings of this mesh.
mesh_mask.append(False)
else:
mesh_mask.append(True)
raw_meshes.append(raw_mesh)
mesh_mask = torch.tensor(mesh_mask, device=device)
max_vertices = max(len(m.verts) for m in raw_meshes)
fn = nerstf_fn if tf_fn is None else tf_fn
tf_out = fn(
query=Query(
position=torch.stack(
[m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes],
dim=0,
)
),
query_batch_size=query_batch_size,
options=options,
)
if "cache" in options:
options.cache.fields = fields
options.cache.raw_meshes = raw_meshes
options.cache.raw_signed_distance = raw_signed_distance
options.cache.raw_density = raw_density
options.cache.mesh_mask = mesh_mask
if output_srgb:
tf_out.channels = _convert_srgb_to_linear(tf_out.channels)
# Make sure the raw meshes have colors.
with torch.autocast(device_type, enabled=False):
textures = tf_out.channels.float()
assert len(textures.shape) == 3 and textures.shape[-1] == len(
texture_channels
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
for m, texture in zip(raw_meshes, textures):
texture = texture[: len(m.verts)]
m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))}
args = dict(
options=options,
texture_channels=texture_channels,
ambient_color=ambient_color,
diffuse_color=diffuse_color,
specular_color=specular_color,
camera=camera,
batch_size=batch_size,
inner_batch_size=inner_batch_size,
inner_shape=inner_shape,
raw_meshes=raw_meshes,
tf_out=tf_out,
)
try:
out = _render_with_pytorch3d(**args)
except ModuleNotFoundError as exc:
warnings.warn(f"exception rendering with PyTorch3D: {exc}")
warnings.warn(
"falling back on native PyTorch renderer, which does not support full gradients"
)
out = _render_with_raycast(**args)
# Apply mask to prevent gradients for empty meshes.
reshaped_mask = mesh_mask.view([-1] + [1] * (len(out.channels.shape) - 1))
out.channels = torch.where(reshaped_mask, out.channels, torch.zeros_like(out.channels))
out.transmittance = torch.where(
reshaped_mask, out.transmittance, torch.ones_like(out.transmittance)
)
if output_srgb:
out.channels = _convert_linear_to_srgb(out.channels)
out.channels = out.channels * (1 - out.transmittance) * channel_scale.view(-1)
# This might be useful information to have downstream
out.raw_meshes = raw_meshes
out.fields = fields
out.mesh_mask = mesh_mask
out.raw_signed_distance = raw_signed_distance
out.aux_losses = AttrDict(cross_entropy=cross_entropy_sdf_loss(fields))
if raw_density is not None:
out.raw_density = raw_density
return out
def _render_with_pytorch3d(
options: AttrDict,
texture_channels: Sequence[str],
ambient_color: Union[float, Tuple[float]],
diffuse_color: Union[float, Tuple[float]],
specular_color: Union[float, Tuple[float]],
camera: DifferentiableCamera,
batch_size: int,
inner_shape: Sequence[int],
inner_batch_size: int,
raw_meshes: List[TorchMesh],
tf_out: AttrDict,
):
_ = tf_out
# Lazy import because pytorch3d is installed lazily.
from shap_e.rendering.pytorch3d_util import (
blender_uniform_lights,
convert_cameras_torch,
convert_meshes,
render_images,
)
n_channels = len(texture_channels)
device = camera.origin.device
device_type = device.type
with torch.autocast(device_type, enabled=False):
meshes = convert_meshes(raw_meshes)
lights = blender_uniform_lights(
batch_size,
device,
ambient_color=ambient_color,
diffuse_color=diffuse_color,
specular_color=specular_color,
)
# Separate camera intrinsics for each view, so that we can
# create a new camera for each batch of views.
cam_shape = [batch_size, inner_batch_size, -1]
position = camera.origin.reshape(cam_shape)
x = camera.x.reshape(cam_shape)
y = camera.y.reshape(cam_shape)
z = camera.z.reshape(cam_shape)
results = []
for i in range(inner_batch_size):
sub_cams = convert_cameras_torch(
position[:, i], x[:, i], y[:, i], z[:, i], fov=camera.x_fov
)
imgs = render_images(
camera.width,
meshes,
sub_cams,
lights,
use_checkpoint=options.checkpoint_render,
**options.get("render_options", {}),
)
results.append(imgs)
views = torch.stack(results, dim=1)
views = views.view(batch_size, *inner_shape, camera.height, camera.width, n_channels + 1)
out = AttrDict(
channels=views[..., :-1], # [batch_size, *inner_shape, height, width, n_channels]
transmittance=1 - views[..., -1:], # [batch_size, *inner_shape, height, width, 1]
meshes=meshes,
)
return out
def _render_with_raycast(
options: AttrDict,
texture_channels: Sequence[str],
ambient_color: Union[float, Tuple[float]],
diffuse_color: Union[float, Tuple[float]],
specular_color: Union[float, Tuple[float]],
camera: DifferentiableCamera,
batch_size: int,
inner_shape: Sequence[int],
inner_batch_size: int,
raw_meshes: List[TorchMesh],
tf_out: AttrDict,
):
assert np.mean(np.array(specular_color)) == 0
from shap_e.rendering.raycast.render import render_diffuse_mesh
from shap_e.rendering.raycast.types import TriMesh as TorchTriMesh
device = camera.origin.device
device_type = device.type
cam_shape = [batch_size, inner_batch_size, -1]
origin = camera.origin.reshape(cam_shape)
x = camera.x.reshape(cam_shape)
y = camera.y.reshape(cam_shape)
z = camera.z.reshape(cam_shape)
with torch.autocast(device_type, enabled=False):
all_meshes = []
for i, mesh in enumerate(raw_meshes):
all_meshes.append(
TorchTriMesh(
faces=mesh.faces.long(),
vertices=mesh.verts.float(),
vertex_colors=tf_out.channels[i, : len(mesh.verts)].float(),
)
)
all_images = []
for i, mesh in enumerate(all_meshes):
for j in range(inner_batch_size):
all_images.append(
render_diffuse_mesh(
camera=ProjectiveCamera(
origin=origin[i, j].detach().cpu().numpy(),
x=x[i, j].detach().cpu().numpy(),
y=y[i, j].detach().cpu().numpy(),
z=z[i, j].detach().cpu().numpy(),
width=camera.width,
height=camera.height,
x_fov=camera.x_fov,
y_fov=camera.y_fov,
),
mesh=mesh,
diffuse=float(np.array(diffuse_color).mean()),
ambient=float(np.array(ambient_color).mean()),
ray_batch_size=16, # low memory usage
checkpoint=options.checkpoint_render,
)
)
n_channels = len(texture_channels)
views = torch.stack(all_images).view(
batch_size, *inner_shape, camera.height, camera.width, n_channels + 1
)
return AttrDict(
channels=views[..., :-1], # [batch_size, *inner_shape, height, width, n_channels]
transmittance=1 - views[..., -1:], # [batch_size, *inner_shape, height, width, 1]
meshes=all_meshes,
)
def _convert_srgb_to_linear(u: torch.Tensor) -> torch.Tensor:
return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4)
def _convert_linear_to_srgb(u: torch.Tensor) -> torch.Tensor:
return torch.where(u <= 0.0031308, 12.92 * u, 1.055 * (u ** (1 / 2.4)) - 0.055)
def cross_entropy_sdf_loss(fields: torch.Tensor):
logits = F.logsigmoid(fields)
signs = (fields > 0).float()
losses = []
for dim in range(1, 4):
n = logits.shape[dim]
for (t_start, t_end, p_start, p_end) in [(0, -1, 1, n), (1, n, 0, -1)]:
targets = slice_fields(signs, dim, t_start, t_end)
preds = slice_fields(logits, dim, p_start, p_end)
losses.append(
F.binary_cross_entropy_with_logits(preds, targets, reduction="none")
.flatten(1)
.mean()
)
return torch.stack(losses, dim=-1).sum()
def slice_fields(fields: torch.Tensor, dim: int, start: int, end: int):
if dim == 1:
return fields[:, start:end]
elif dim == 2:
return fields[:, :, start:end]
elif dim == 3:
return fields[:, :, :, start:end]
else:
raise ValueError(f"cannot slice dimension {dim}")
def volume_query_points(
volume: Volume,
grid_size: int,
):
assert isinstance(volume, BoundingBoxVolume)
indices = torch.arange(grid_size**3, device=volume.bbox_min.device)
zs = indices % grid_size
ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size
xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size
combined = torch.stack([xs, ys, zs], dim=1)
return (combined.float() / (grid_size - 1)) * (
volume.bbox_max - volume.bbox_min
) + volume.bbox_min
================================================
FILE: shap_e/models/transmitter/__init__.py
================================================
================================================
FILE: shap_e/models/transmitter/base.py
================================================
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple
import torch.nn as nn
from torch import torch
from shap_e.models.renderer import Renderer
from shap_e.util.collections import AttrDict
from .bottleneck import latent_bottleneck_from_config, latent_warp_from_config
from .params_proj import flatten_param_shapes, params_proj_from_config
class Encoder(nn.Module, ABC):
def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]]):
"""
Instantiate the encoder with information about the renderer's input
parameters. This information can be used to create output layers to
generate the necessary latents.
"""
super().__init__()
self.param_shapes = param_shapes
self.device = device
@abstractmethod
def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict:
"""
Encode a batch of data into a batch of latent information.
"""
class VectorEncoder(Encoder):
def __init__(
self,
*,
device: torch.device,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
d_latent: int,
latent_bottleneck: Optional[Dict[str, Any]] = None,
latent_warp: Optional[Dict[str, Any]] = None,
):
super().__init__(device=device, param_shapes=param_shapes)
if latent_bottleneck is None:
latent_bottleneck = dict(name="identity")
if latent_warp is None:
latent_warp = dict(name="identity")
self.d_latent = d_latent
self.params_proj = params_proj_from_config(
params_proj, device=device, param_shapes=param_shapes, d_latent=d_latent
)
self.latent_bottleneck = latent_bottleneck_from_config(
latent_bottleneck, device=device, d_latent=d_latent
)
self.latent_warp = latent_warp_from_config(latent_warp, device=device)
def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict:
h = self.encode_to_bottleneck(batch, options=options)
return self.bottleneck_to_params(h, options=options)
def encode_to_bottleneck(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> torch.Tensor:
return self.latent_warp.warp(
self.latent_bottleneck(self.encode_to_vector(batch, options=options), options=options),
options=options,
)
@abstractmethod
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
"""
Encode the batch into a single latent vector.
"""
def bottleneck_to_params(
self, vector: torch.Tensor, options: Optional[AttrDict] = None
) -> AttrDict:
_ = options
return self.params_proj(self.latent_warp.unwarp(vector, options=options), options=options)
class ChannelsEncoder(VectorEncoder):
def __init__(
self,
*,
device: torch.device,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
d_latent: int,
latent_bottleneck: Optional[Dict[str, Any]] = None,
latent_warp: Optional[Dict[str, Any]] = None,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
d_latent=d_latent,
latent_bottleneck=latent_bottleneck,
latent_warp=latent_warp,
)
self.flat_shapes = flatten_param_shapes(param_shapes)
self.latent_ctx = sum(flat[0] for flat in self.flat_shapes.values())
@abstractmethod
def encode_to_channels(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> torch.Tensor:
"""
Encode the batch into a per-data-point set of latents.
:return: [batch_size, latent_ctx, latent_width]
"""
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
return self.encode_to_channels(batch, options=options).flatten(1)
def bottleneck_to_channels(
self, vector: torch.Tensor, options: Optional[AttrDict] = None
) -> torch.Tensor:
_ = options
return vector.view(vector.shape[0], self.latent_ctx, -1)
def bottleneck_to_params(
self, vector: torch.Tensor, options: Optional[AttrDict] = None
) -> AttrDict:
_ = options
return self.params_proj(
self.bottleneck_to_channels(self.latent_warp.unwarp(vector)), options=options
)
class Transmitter(nn.Module):
def __init__(self, encoder: Encoder, renderer: Renderer):
super().__init__()
self.encoder = encoder
self.renderer = renderer
def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict:
"""
Transmit the batch through the encoder and then the renderer.
"""
params = self.encoder(batch, options=options)
return self.renderer(batch, params=params, options=options)
class VectorDecoder(nn.Module):
def __init__(
self,
*,
device: torch.device,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
d_latent: int,
latent_warp: Optional[Dict[str, Any]] = None,
renderer: Renderer,
):
super().__init__()
self.device = device
self.param_shapes = param_shapes
if latent_warp is None:
latent_warp = dict(name="identity")
self.d_latent = d_latent
self.params_proj = params_proj_from_config(
params_proj, device=device, param_shapes=param_shapes, d_latent=d_latent
)
self.latent_warp = latent_warp_from_config(latent_warp, device=device)
self.renderer = renderer
def bottleneck_to_params(
self, vector: torch.Tensor, options: Optional[AttrDict] = None
) -> AttrDict:
_ = options
return self.params_proj(self.latent_warp.unwarp(vector, options=options), options=options)
class ChannelsDecoder(VectorDecoder):
def __init__(
self,
*,
latent_ctx: int,
**kwargs,
):
super().__init__(**kwargs)
self.latent_ctx = latent_ctx
def bottleneck_to_channels(
self, vector: torch.Tensor, options: Optional[AttrDict] = None
) -> torch.Tensor:
_ = options
return vector.view(vector.shape[0], self.latent_ctx, -1)
def bottleneck_to_params(
self, vector: torch.Tensor, options: Optional[AttrDict] = None
) -> AttrDict:
_ = options
return self.params_proj(
self.bottleneck_to_channels(self.latent_warp.unwarp(vector)), options=options
)
================================================
FILE: shap_e/models/transmitter/bottleneck.py
================================================
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import numpy as np
import torch.nn as nn
from torch import torch
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.util.collections import AttrDict
class LatentBottleneck(nn.Module, ABC):
def __init__(self, *, device: torch.device, d_latent: int):
super().__init__()
self.device = device
self.d_latent = d_latent
@abstractmethod
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
pass
class LatentWarp(nn.Module, ABC):
def __init__(self, *, device: torch.device):
super().__init__()
self.device = device
@abstractmethod
def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
pass
@abstractmethod
def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
pass
class IdentityLatentWarp(LatentWarp):
def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
return x
def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
return x
class Tan2LatentWarp(LatentWarp):
def __init__(self, *, coeff1: float = 1.0, device: torch.device):
super().__init__(device=device)
self.coeff1 = coeff1
self.scale = np.tan(np.tan(1.0) * coeff1)
def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
return ((x.float().tan() * self.coeff1).tan() / self.scale).to(x.dtype)
def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
return ((x.float() * self.scale).arctan() / self.coeff1).arctan().to(x.dtype)
class IdentityLatentBottleneck(LatentBottleneck):
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
return x
class ClampNoiseBottleneck(LatentBottleneck):
def __init__(self, *, device: torch.device, d_latent: int, noise_scale: float):
super().__init__(device=device, d_latent=d_latent)
self.noise_scale = noise_scale
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
x = x.tanh()
if not self.training:
return x
return x + torch.randn_like(x) * self.noise_scale
class ClampDiffusionNoiseBottleneck(LatentBottleneck):
def __init__(
self,
*,
device: torch.device,
d_latent: int,
diffusion: Dict[str, Any],
diffusion_prob: float = 1.0,
):
super().__init__(device=device, d_latent=d_latent)
self.diffusion = diffusion_from_config(diffusion)
self.diffusion_prob = diffusion_prob
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
_ = options
x = x.tanh()
if not self.training:
return x
t = torch.randint(low=0, high=self.diffusion.num_timesteps, size=(len(x),), device=x.device)
t = torch.where(
torch.rand(len(x), device=x.device) < self.diffusion_prob, t, torch.zeros_like(t)
)
return self.diffusion.q_sample(x, t)
def latent_bottleneck_from_config(config: Dict[str, Any], device: torch.device, d_latent: int):
name = config.pop("name")
if name == "clamp_noise":
return ClampNoiseBottleneck(**config, device=device, d_latent=d_latent)
elif name == "identity":
return IdentityLatentBottleneck(**config, device=device, d_latent=d_latent)
elif name == "clamp_diffusion_noise":
return ClampDiffusionNoiseBottleneck(**config, device=device, d_latent=d_latent)
else:
raise ValueError(f"unknown latent bottleneck: {name}")
def latent_warp_from_config(config: Dict[str, Any], device: torch.device):
name = config.pop("name")
if name == "identity":
return IdentityLatentWarp(**config, device=device)
elif name == "tan2":
return Tan2LatentWarp(**config, device=device)
else:
raise ValueError(f"unknown latent warping function: {name}")
================================================
FILE: shap_e/models/transmitter/channels_encoder.py
================================================
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch import torch
from shap_e.models.generation.perceiver import SimplePerceiver
from shap_e.models.generation.transformer import Transformer
from shap_e.models.nn.camera import DifferentiableProjectiveCamera
from shap_e.models.nn.encoding import (
MultiviewPointCloudEmbedding,
MultiviewPoseEmbedding,
PosEmbLinear,
)
from shap_e.models.nn.ops import PointSetEmbedding
from shap_e.rendering.point_cloud import PointCloud
from shap_e.rendering.view_data import ProjectiveCamera
from shap_e.util.collections import AttrDict
from .base import ChannelsEncoder
class TransformerChannelsEncoder(ChannelsEncoder, ABC):
"""
Encode point clouds using a transformer model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
d_latent: int = 512,
latent_bottleneck: Optional[Dict[str, Any]] = None,
latent_warp: Optional[Dict[str, Any]] = None,
n_ctx: int = 1024,
width: int = 512,
layers: int = 12,
heads: int = 8,
init_scale: float = 0.25,
latent_scale: float = 1.0,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
d_latent=d_latent,
latent_bottleneck=latent_bottleneck,
latent_warp=latent_warp,
)
self.width = width
self.device = device
self.dtype = dtype
self.n_ctx = n_ctx
self.backbone = Transformer(
device=device,
dtype=dtype,
n_ctx=n_ctx + self.latent_ctx,
width=width,
layers=layers,
heads=heads,
init_scale=init_scale,
)
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.register_parameter(
"output_tokens",
nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)),
)
self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype)
self.latent_scale = latent_scale
@abstractmethod
def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
pass
def encode_to_channels(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> torch.Tensor:
h = self.encode_input(batch, options=options)
h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1)
h = self.ln_pre(h)
h = self.backbone(h)
h = h[:, -self.latent_ctx :]
h = self.ln_post(h)
h = self.output_proj(h)
return h
class PerceiverChannelsEncoder(ChannelsEncoder, ABC):
"""
Encode point clouds using a perceiver model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
min_unrolls: int,
max_unrolls: int,
d_latent: int = 512,
latent_bottleneck: Optional[Dict[str, Any]] = None,
latent_warp: Optional[Dict[str, Any]] = None,
width: int = 512,
layers: int = 12,
xattn_layers: int = 1,
heads: int = 8,
init_scale: float = 0.25,
# Training hparams
inner_batch_size: Union[int, List[int]] = 1,
data_ctx: int = 1,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
d_latent=d_latent,
latent_bottleneck=latent_bottleneck,
latent_warp=latent_warp,
)
self.width = width
self.device = device
self.dtype = dtype
if isinstance(inner_batch_size, int):
inner_batch_size = [inner_batch_size]
self.inner_batch_size = inner_batch_size
self.data_ctx = data_ctx
self.min_unrolls = min_unrolls
self.max_unrolls = max_unrolls
encoder_fn = lambda inner_batch_size: SimplePerceiver(
device=device,
dtype=dtype,
n_ctx=self.data_ctx + self.latent_ctx,
n_data=inner_batch_size,
width=width,
layers=xattn_layers,
heads=heads,
init_scale=init_scale,
)
self.encoder = (
encoder_fn(self.inner_batch_size[0])
if len(self.inner_batch_size) == 1
else nn.ModuleList([encoder_fn(inner_bsz) for inner_bsz in self.inner_batch_size])
)
self.processor = Transformer(
device=device,
dtype=dtype,
n_ctx=self.data_ctx + self.latent_ctx,
layers=layers - xattn_layers,
width=width,
heads=heads,
init_scale=init_scale,
)
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.register_parameter(
"output_tokens",
nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)),
)
self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype)
@abstractmethod
def get_h_and_iterator(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> Tuple[torch.Tensor, Iterable[Union[torch.Tensor, Tuple]]]:
"""
:return: a tuple of (
the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],
an iterator over the given data
)
"""
def encode_to_channels(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> torch.Tensor:
h, it = self.get_h_and_iterator(batch, options=options)
n_unrolls = self.get_n_unrolls()
for _ in range(n_unrolls):
data = next(it)
if isinstance(data, tuple):
for data_i, encoder_i in zip(data, self.encoder):
h = encoder_i(h, data_i)
else:
h = self.encoder(h, data)
h = self.processor(h)
h = self.output_proj(self.ln_post(h[:, -self.latent_ctx :]))
return h
def get_n_unrolls(self):
if self.training:
n_unrolls = torch.randint(
self.min_unrolls, self.max_unrolls + 1, size=(), device=self.device
)
dist.broadcast(n_unrolls, 0)
n_unrolls = n_unrolls.item()
else:
n_unrolls = self.max_unrolls
return n_unrolls
@dataclass
class DatasetIterator:
embs: torch.Tensor # [batch_size, dataset_size, *shape]
batch_size: int
def __iter__(self):
self._reset()
return self
def __next__(self):
_outer_batch_size, dataset_size, *_shape = self.embs.shape
while True:
start = self.idx
self.idx += self.batch_size
end = self.idx
if end <= dataset_size:
break
self._reset()
return self.embs[:, start:end]
def _reset(self):
self._shuffle()
self.idx = 0 # pylint: disable=attribute-defined-outside-init
def _shuffle(self):
outer_batch_size, dataset_size, *shape = self.embs.shape
idx = torch.stack(
[
torch.randperm(dataset_size, device=self.embs.device)
for _ in range(outer_batch_size)
],
dim=0,
)
idx = idx.view(outer_batch_size, dataset_size, *([1] * len(shape)))
idx = torch.broadcast_to(idx, self.embs.shape)
self.embs = torch.gather(self.embs, 1, idx)
class PointCloudTransformerChannelsEncoder(TransformerChannelsEncoder):
"""
Encode point clouds using a transformer model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
input_channels: int = 6,
**kwargs,
):
super().__init__(**kwargs)
self.input_channels = input_channels
self.input_proj = nn.Linear(
input_channels, self.width, device=self.device, dtype=self.dtype
)
def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
_ = options
points = batch.points
h = self.input_proj(points.permute(0, 2, 1)) # NCL -> NLC
return h
class PointCloudPerceiverChannelsEncoder(PerceiverChannelsEncoder):
"""
Encode point clouds using a transformer model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
cross_attention_dataset: str = "pcl",
fps_method: str = "fps",
# point cloud hyperparameters
input_channels: int = 6,
pos_emb: Optional[str] = None,
# multiview hyperparameters
image_size: int = 256,
patch_size: int = 32,
pose_dropout: float = 0.0,
use_depth: bool = False,
max_depth: float = 5.0,
# point conv hyperparameters
pointconv_radius: float = 0.5,
pointconv_samples: int = 32,
pointconv_hidden: Optional[List[int]] = None,
pointconv_patch_size: int = 1,
pointconv_stride: int = 1,
pointconv_padding_mode: str = "zeros",
use_pointconv: bool = False,
# other hyperparameters
**kwargs,
):
super().__init__(**kwargs)
assert cross_attention_dataset in (
"pcl",
"multiview",
"dense_pose_multiview",
"multiview_pcl",
"pcl_and_multiview_pcl",
"incorrect_multiview_pcl",
"pcl_and_incorrect_multiview_pcl",
)
assert fps_method in ("fps", "first")
self.cross_attention_dataset = cross_attention_dataset
self.fps_method = fps_method
self.input_channels = input_channels
self.input_proj = PosEmbLinear(
pos_emb,
input_channels,
self.width,
device=self.device,
dtype=self.dtype,
)
self.use_pointconv = use_pointconv
if use_pointconv:
if pointconv_hidden is None:
pointconv_hidden = [self.width]
self.point_conv = PointSetEmbedding(
n_point=self.data_ctx,
radius=pointconv_radius,
n_sample=pointconv_samples,
d_input=self.input_proj.weight.shape[0],
d_hidden=pointconv_hidden,
patch_size=pointconv_patch_size,
stride=pointconv_stride,
padding_mode=pointconv_padding_mode,
fps_method=fps_method,
device=self.device,
dtype=self.dtype,
)
if self.cross_attention_dataset == "multiview":
self.image_size = image_size
self.patch_size = patch_size
self.pose_dropout = pose_dropout
self.use_depth = use_depth
self.max_depth = max_depth
pos_ctx = (image_size // patch_size) ** 2
self.register_parameter(
"pos_emb",
nn.Parameter(
torch.randn(
pos_ctx * self.inner_batch_size,
self.width,
device=self.device,
dtype=self.dtype,
)
),
)
self.patch_emb = nn.Conv2d(
in_channels=3 if not use_depth else 4,
out_channels=self.width,
kernel_size=patch_size,
stride=patch_size,
device=self.device,
dtype=self.dtype,
)
self.camera_emb = nn.Sequential(
nn.Linear(
3 * 4 + 1, self.width, device=self.device, dtype=self.dtype
), # input size is for origin+x+y+z+fov
nn.GELU(),
nn.Linear(self.width, 2 * self.width, device=self.device, dtype=self.dtype),
)
elif self.cross_attention_dataset == "dense_pose_multiview":
# The number of output features is halved, because a patch_size of
# 32 ends up with a large patch_emb weight.
self.view_pose_width = self.width // 2
self.image_size = image_size
self.patch_size = patch_size
self.use_depth = use_depth
self.max_depth = max_depth
self.mv_pose_embed = MultiviewPoseEmbedding(
posemb_version="nerf",
n_channels=4 if self.use_depth else 3,
out_features=self.view_pose_width,
device=self.device,
dtype=self.dtype,
)
pos_ctx = (image_size // patch_size) ** 2
# Positional embedding is unnecessary because pose information is baked into each pixel
self.patch_emb = nn.Conv2d(
in_channels=self.view_pose_width,
out_channels=self.width,
kernel_size=patch_size,
stride=patch_size,
device=self.device,
dtype=self.dtype,
)
elif (
self.cross_attention_dataset == "multiview_pcl"
or self.cross_attention_dataset == "incorrect_multiview_pcl"
):
self.view_pose_width = self.width // 2
self.image_size = image_size
self.patch_size = patch_size
self.max_depth = max_depth
assert use_depth
self.mv_pcl_embed = MultiviewPointCloudEmbedding(
posemb_version="nerf",
n_channels=3,
out_features=self.view_pose_width,
device=self.device,
dtype=self.dtype,
)
self.patch_emb = nn.Conv2d(
in_channels=self.view_pose_width,
out_channels=self.width,
kernel_size=patch_size,
stride=patch_size,
device=self.device,
dtype=self.dtype,
)
elif (
self.cross_attention_dataset == "pcl_and_multiview_pcl"
or self.cross_attention_dataset == "pcl_and_incorrect_multiview_pcl"
):
self.view_pose_width = self.width // 2
self.image_size = image_size
self.patch_size = patch_size
self.max_depth = max_depth
assert use_depth
self.mv_pcl_embed = MultiviewPointCloudEmbedding(
posemb_version="nerf",
n_channels=3,
out_features=self.view_pose_width,
device=self.device,
dtype=self.dtype,
)
self.patch_emb = nn.Conv2d(
in_channels=self.view_pose_width,
out_channels=self.width,
kernel_size=patch_size,
stride=patch_size,
device=self.device,
dtype=self.dtype,
)
def get_h_and_iterator(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> Tuple[torch.Tensor, Iterable]:
"""
:return: a tuple of (
the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],
an iterator over the given data
)
"""
options = AttrDict() if options is None else options
# Build the initial query embeddings
points = batch.points.permute(0, 2, 1) # NCL -> NLC
if self.use_pointconv:
points = self.input_proj(points).permute(0, 2, 1) # NLC -> NCL
xyz = batch.points[:, :3]
data_tokens = self.point_conv(xyz, points).permute(0, 2, 1) # NCL -> NLC
else:
fps_samples = self.sample_pcl_fps(points)
data_tokens = self.input_proj(fps_samples)
batch_size = points.shape[0]
latent_tokens = self.output_tokens.unsqueeze(0).repeat(batch_size, 1, 1)
h = self.ln_pre(torch.cat([data_tokens, latent_tokens], dim=1))
assert h.shape == (batch_size, self.data_ctx + self.latent_ctx, self.width)
# Build the dataset embedding iterator
dataset_fn = {
"pcl": self.get_pcl_dataset,
"multiview": self.get_multiview_dataset,
"dense_pose_multiview": self.get_dense_pose_multiview_dataset,
"pcl_and_multiview_pcl": self.get_pcl_and_multiview_pcl_dataset,
"multiview_pcl": self.get_multiview_pcl_dataset,
}[self.cross_attention_dataset]
it = dataset_fn(batch, options=options)
return h, it
def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor:
return sample_pcl_fps(points, data_ctx=self.data_ctx, method=self.fps_method)
def get_pcl_dataset(
self,
batch: AttrDict,
options: Optional[AttrDict[str, Any]] = None,
inner_batch_size: Optional[int] = None,
) -> Iterable:
_ = options
if inner_batch_size is None:
inner_batch_size = self.inner_batch_size[0]
points = batch.points.permute(0, 2, 1) # NCL -> NLC
dataset_emb = self.input_proj(points)
assert dataset_emb.shape[1] >= inner_batch_size
return iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size))
def get_multiview_dataset(
self,
batch: AttrDict,
options: Optional[AttrDict] = None,
inner_batch_size: Optional[int] = None,
) -> Iterable:
_ = options
if inner_batch_size is None:
inner_batch_size = self.inner_batch_size[0]
dataset_emb = self.encode_views(batch)
batch_size, num_views, n_patches, width = dataset_emb.shape
assert num_views >= inner_batch_size
it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size))
def gen():
while True:
examples = next(it)
assert examples.shape == (batch_size, self.inner_batch_size, n_patches, self.width)
views = examples.reshape(batch_size, -1, width) + self.pos_emb
yield views
return gen()
def get_dense_pose_multiview_dataset(
self,
batch: AttrDict,
options: Optional[AttrDict] = None,
inner_batch_size: Optional[int] = None,
) -> Iterable:
_ = options
if inner_batch_size is None:
inner_batch_size = self.inner_batch_size[0]
dataset_emb = self.encode_dense_pose_views(batch)
batch_size, num_views, n_patches, width = dataset_emb.shape
assert num_views >= inner_batch_size
it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size))
def gen():
while True:
examples = next(it)
assert examples.shape == (batch_size, inner_batch_size, n_patches, self.width)
views = examples.reshape(batch_size, -1, width)
yield views
return gen()
def get_pcl_and_multiview_pcl_dataset(
self,
batch: AttrDict,
options: Optional[AttrDict] = None,
use_distance: bool = True,
) -> Iterable:
_ = options
pcl_it = self.get_pcl_dataset(
batch, options=options, inner_batch_size=self.inner_batch_size[0]
)
multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance)
batch_size, num_views, n_patches, width = multiview_pcl_emb.shape
assert num_views >= self.inner_batch_size[1]
multiview_pcl_it = iter(
DatasetIterator(multiview_pcl_emb, batch_size=self.inner_batch_size[1])
)
def gen():
while True:
pcl = next(pcl_it)
multiview_pcl = next(multiview_pcl_it)
assert multiview_pcl.shape == (
batch_size,
self.inner_batch_size[1],
n_patches,
self.width,
)
yield pcl, multiview_pcl.reshape(batch_size, -1, width)
return gen()
def get_multiview_pcl_dataset(
self,
batch: AttrDict,
options: Optional[AttrDict] = None,
inner_batch_size: Optional[int] = None,
use_distance: bool = True,
) -> Iterable:
_ = options
if inner_batch_size is None:
inner_batch_size = self.inner_batch_size[0]
multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance)
batch_size, num_views, n_patches, width = multiview_pcl_emb.shape
assert num_views >= inner_batch_size
multiview_pcl_it = iter(DatasetIterator(multiview_pcl_emb, batch_size=inner_batch_size))
def gen():
while True:
multiview_pcl = next(multiview_pcl_it)
assert multiview_pcl.shape == (
batch_size,
inner_batch_size,
n_patches,
self.width,
)
yield multiview_pcl.reshape(batch_size, -1, width)
return gen()
def encode_views(self, batch: AttrDict) -> torch.Tensor:
"""
:return: [batch_size, num_views, n_patches, width]
"""
all_views = self.views_to_tensor(batch.views).to(self.device)
if self.use_depth:
all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)
all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)
batch_size, num_views, _, _, _ = all_views.shape
views_proj = self.patch_emb(
all_views.reshape([batch_size * num_views, *all_views.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
# [batch_size, num_views, 1, 2 * width]
camera_proj = self.camera_emb(all_cameras).reshape(
[batch_size, num_views, 1, self.width * 2]
)
pose_dropout = self.pose_dropout if self.training else 0.0
mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout
camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj))
scale, shift = camera_proj.chunk(2, dim=3)
views_proj = views_proj * (scale + 1.0) + shift
return views_proj
def encode_dense_pose_views(self, batch: AttrDict) -> torch.Tensor:
"""
:return: [batch_size, num_views, n_patches, width]
"""
all_views = self.views_to_tensor(batch.views).to(self.device)
if self.use_depth:
depths = self.depths_to_tensor(batch.depths)
all_views = torch.cat([all_views, depths], dim=2)
dense_poses, _ = self.dense_pose_cameras_to_tensor(batch.cameras)
dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3)
position, direction = dense_poses[:, :, 0], dense_poses[:, :, 1]
all_view_poses = self.mv_pose_embed(all_views, position, direction)
batch_size, num_views, _, _, _ = all_view_poses.shape
views_proj = self.patch_emb(
all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
return views_proj
def encode_multiview_pcl(self, batch: AttrDict, use_distance: bool = True) -> torch.Tensor:
"""
:return: [batch_size, num_views, n_patches, width]
"""
all_views = self.views_to_tensor(batch.views).to(self.device)
depths = self.raw_depths_to_tensor(batch.depths)
all_view_alphas = self.view_alphas_to_tensor(batch.view_alphas).to(self.device)
mask = all_view_alphas >= 0.999
dense_poses, camera_z = self.dense_pose_cameras_to_tensor(batch.cameras)
dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3)
origin, direction = dense_poses[:, :, 0], dense_poses[:, :, 1]
if use_distance:
ray_depth_factor = torch.sum(direction * camera_z[..., None, None], dim=2, keepdim=True)
depths = depths / ray_depth_factor
position = origin + depths * direction
all_view_poses = self.mv_pcl_embed(all_views, origin, position, mask)
batch_size, num_views, _, _, _ = all_view_poses.shape
views_proj = self.patch_emb(
all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
return views_proj
def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor:
"""
Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1].
"""
if isinstance(views, torch.Tensor):
return views
tensor_batch = []
num_views = len(views[0])
for inner_list in views:
assert len(inner_list) == num_views
inner_batch = []
for img in inner_list:
img = img.resize((self.image_size,) * 2).convert("RGB")
inner_batch.append(
torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)
/ 127.5
- 1
)
tensor_batch.append(torch.stack(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3)
def depths_to_tensor(
self, depths: Union[torch.Tensor, List[List[Image.Image]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1].
"""
if isinstance(depths, torch.Tensor):
return depths
tensor_batch = []
num_views = len(depths[0])
for inner_list in depths:
assert len(inner_list) == num_views
inner_batch = []
for arr in inner_list:
tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth
tensor = tensor * 2 - 1
tensor = F.interpolate(
tensor[None, None],
(self.image_size,) * 2,
mode="nearest",
)
inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))
tensor_batch.append(torch.cat(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0)
def view_alphas_to_tensor(
self, view_alphas: Union[torch.Tensor, List[List[Image.Image]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 1 x size x size] tensor in the range [0, 1].
"""
if isinstance(view_alphas, torch.Tensor):
return view_alphas
tensor_batch = []
num_views = len(view_alphas[0])
for inner_list in view_alphas:
assert len(inner_list) == num_views
inner_batch = []
for img in inner_list:
tensor = (
torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)
/ 255.0
)
tensor = F.interpolate(
tensor[None, None],
(self.image_size,) * 2,
mode="nearest",
)
inner_batch.append(tensor)
tensor_batch.append(torch.cat(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0)
def raw_depths_to_tensor(
self, depths: Union[torch.Tensor, List[List[Image.Image]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 1 x size x size] tensor
"""
if isinstance(depths, torch.Tensor):
return depths
tensor_batch = []
num_views = len(depths[0])
for inner_list in depths:
assert len(inner_list) == num_views
inner_batch = []
for arr in inner_list:
tensor = torch.from_numpy(arr).clamp(max=self.max_depth)
tensor = F.interpolate(
tensor[None, None],
(self.image_size,) * 2,
mode="nearest",
)
inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))
tensor_batch.append(torch.cat(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0)
def cameras_to_tensor(
self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 3*4+1] tensor of camera information.
"""
if isinstance(cameras, torch.Tensor):
return cameras
outer_batch = []
for inner_list in cameras:
inner_batch = []
for camera in inner_list:
inner_batch.append(
np.array(
[
*camera.x,
*camera.y,
*camera.z,
*camera.origin,
camera.x_fov,
]
)
)
outer_batch.append(np.stack(inner_batch, axis=0))
return torch.from_numpy(np.stack(outer_batch, axis=0)).float()
def dense_pose_cameras_to_tensor(
self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a tuple of (rays, z_directions) where
- rays: [batch, num_views, height, width, 2, 3] tensor of camera information.
- z_directions: [batch, num_views, 3] tensor of camera z directions.
"""
if isinstance(cameras, torch.Tensor):
raise NotImplementedError
for inner_list in cameras:
assert len(inner_list) == len(cameras[0])
camera = cameras[0][0]
flat_camera = DifferentiableProjectiveCamera(
origin=torch.from_numpy(
np.stack(
[cam.origin for inner_list in cameras for cam in inner_list],
axis=0,
)
).to(self.device),
x=torch.from_numpy(
np.stack(
[cam.x for inner_list in cameras for cam in inner_list],
axis=0,
)
).to(self.device),
y=torch.from_numpy(
np.stack(
[cam.y for inner_list in cameras for cam in inner_list],
axis=0,
)
).to(self.device),
z=torch.from_numpy(
np.stack(
[cam.z for inner_list in cameras for cam in inner_list],
axis=0,
)
).to(self.device),
width=camera.width,
height=camera.height,
x_fov=camera.x_fov,
y_fov=camera.y_fov,
)
batch_size = len(cameras) * len(cameras[0])
coords = (
flat_camera.image_coords()
.to(flat_camera.origin.device)
.unsqueeze(0)
.repeat(batch_size, 1, 1)
)
rays = flat_camera.camera_rays(coords)
return (
rays.view(len(cameras), len(cameras[0]), camera.height, camera.width, 2, 3).to(
self.device
),
flat_camera.z.view(len(cameras), len(cameras[0]), 3).to(self.device),
)
def sample_pcl_fps(points: torch.Tensor, data_ctx: int, method: str = "fps") -> torch.Tensor:
"""
Run farthest-point sampling on a batch of point clouds.
:param points: batch of shape [N x num_points].
:param data_ctx: subsample count.
:param method: either 'fps' or 'first'. Using 'first' assumes that the
points are already sorted according to FPS sampling.
:return: batch of shape [N x min(num_points, data_ctx)].
"""
n_points = points.shape[1]
if n_points == data_ctx:
return points
if method == "first":
return points[:, :data_ctx]
elif method == "fps":
batch = points.cpu().split(1, dim=0)
fps = [sample_fps(x, n_samples=data_ctx) for x in batch]
return torch.cat(fps, dim=0).to(points.device)
else:
raise ValueError(f"unsupported farthest-point sampling method: {method}")
def sample_fps(example: torch.Tensor, n_samples: int) -> torch.Tensor:
"""
:param example: [1, n_points, 3 + n_channels]
:return: [1, n_samples, 3 + n_channels]
"""
points = example.cpu().squeeze(0).numpy()
coords, raw_channels = points[:, :3], points[:, 3:]
n_points, n_channels = raw_channels.shape
assert n_samples <= n_points
channels = {str(idx): raw_channels[:, idx] for idx in range(n_channels)}
max_points = min(32768, n_points)
fps_pcl = (
PointCloud(coords=coords, channels=channels)
.random_sample(max_points)
.farthest_point_sample(n_samples)
)
fps_channels = np.stack([fps_pcl.channels[str(idx)] for idx in range(n_channels)], axis=1)
fps = np.concatenate([fps_pcl.coords, fps_channels], axis=1)
fps = torch.from_numpy(fps).unsqueeze(0)
assert fps.shape == (1, n_samples, 3 + n_channels)
return fps
================================================
FILE: shap_e/models/transmitter/multiview_encoder.py
================================================
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from shap_e.models.generation.transformer import Transformer
from shap_e.rendering.view_data import ProjectiveCamera
from shap_e.util.collections import AttrDict
from .base import VectorEncoder
class MultiviewTransformerEncoder(VectorEncoder):
"""
Encode cameras and views using a transformer model with extra output
token(s) used to extract a latent vector.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
latent_bottleneck: Optional[Dict[str, Any]] = None,
d_latent: int = 512,
latent_ctx: int = 1,
num_views: int = 20,
image_size: int = 256,
patch_size: int = 32,
use_depth: bool = False,
max_depth: float = 5.0,
width: int = 512,
layers: int = 12,
heads: int = 8,
init_scale: float = 0.25,
pos_emb_init_scale: float = 1.0,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
latent_bottleneck=latent_bottleneck,
d_latent=d_latent,
)
self.num_views = num_views
self.image_size = image_size
self.patch_size = patch_size
self.use_depth = use_depth
self.max_depth = max_depth
self.n_ctx = num_views * (1 + (image_size // patch_size) ** 2)
self.latent_ctx = latent_ctx
self.width = width
assert d_latent % latent_ctx == 0
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.backbone = Transformer(
device=device,
dtype=dtype,
n_ctx=self.n_ctx + latent_ctx,
width=width,
layers=layers,
heads=heads,
init_scale=init_scale,
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.register_parameter(
"output_tokens",
nn.Parameter(torch.randn(latent_ctx, width, device=device, dtype=dtype)),
)
self.register_parameter(
"pos_emb",
nn.Parameter(
pos_emb_init_scale * torch.randn(self.n_ctx, width, device=device, dtype=dtype)
),
)
self.patch_emb = nn.Conv2d(
in_channels=3 if not use_depth else 4,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
device=device,
dtype=dtype,
)
self.camera_emb = nn.Sequential(
nn.Linear(
3 * 4 + 1, width, device=device, dtype=dtype
), # input size is for origin+x+y+z+fov
nn.GELU(),
nn.Linear(width, width, device=device, dtype=dtype),
)
self.output_proj = nn.Linear(width, d_latent // latent_ctx, device=device, dtype=dtype)
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
_ = options
all_views = self.views_to_tensor(batch.views).to(self.device)
if self.use_depth:
all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)
all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)
batch_size, num_views, _, _, _ = all_views.shape
views_proj = self.patch_emb(
all_views.reshape([batch_size * num_views, *all_views.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
cameras_proj = self.camera_emb(all_cameras).reshape([batch_size, num_views, 1, self.width])
h = torch.cat([views_proj, cameras_proj], dim=2).reshape([batch_size, -1, self.width])
h = h + self.pos_emb
h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1)
h = self.ln_pre(h)
h = self.backbone(h)
h = self.ln_post(h)
h = h[:, self.n_ctx :]
h = self.output_proj(h).flatten(1)
return h
def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor:
"""
Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1].
"""
if isinstance(views, torch.Tensor):
return views
tensor_batch = []
for inner_list in views:
assert len(inner_list) == self.num_views
inner_batch = []
for img in inner_list:
img = img.resize((self.image_size,) * 2).convert("RGB")
inner_batch.append(
torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)
/ 127.5
- 1
)
tensor_batch.append(torch.stack(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3)
def depths_to_tensor(
self, depths: Union[torch.Tensor, List[List[Image.Image]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1].
"""
if isinstance(depths, torch.Tensor):
return depths
tensor_batch = []
for inner_list in depths:
assert len(inner_list) == self.num_views
inner_batch = []
for arr in inner_list:
tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth
tensor = tensor * 2 - 1
tensor = F.interpolate(
tensor[None, None],
(self.image_size,) * 2,
mode="nearest",
)
inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))
tensor_batch.append(torch.cat(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0)
def cameras_to_tensor(
self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 3*4+1] tensor of camera information.
"""
if isinstance(cameras, torch.Tensor):
return cameras
outer_batch = []
for inner_list in cameras:
inner_batch = []
for camera in inner_list:
inner_batch.append(
np.array(
[
*camera.x,
*camera.y,
*camera.z,
*camera.origin,
camera.x_fov,
]
)
)
outer_batch.append(np.stack(inner_batch, axis=0))
return torch.from_numpy(np.stack(outer_batch, axis=0)).float()
================================================
FILE: shap_e/models/transmitter/params_proj.py
================================================
import math
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch.nn as nn
from torch import torch
from shap_e.util.collections import AttrDict
def flatten_param_shapes(param_shapes: Dict[str, Tuple[int]]):
flat_shapes = OrderedDict(
(name, (int(np.prod(shape)) // shape[-1], shape[-1]))
for name, shape in param_shapes.items()
)
return flat_shapes
class ParamsProj(nn.Module, ABC):
def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int):
super().__init__()
self.device = device
self.param_shapes = param_shapes
self.d_latent = d_latent
@abstractmethod
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
pass
class LinearParamsProj(ParamsProj):
def __init__(
self,
*,
device: torch.device,
param_shapes: Dict[str, Tuple[int]],
d_latent: int,
init_scale: Optional[float] = None,
):
super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent)
self.param_shapes = param_shapes
self.projections = nn.ModuleDict({})
for k, v in param_shapes.items():
self.projections[_sanitize_name(k)] = nn.Linear(
d_latent, int(np.prod(v)), device=device
)
if init_scale is not None:
scale = init_scale / math.sqrt(d_latent)
mod = self.projections[_sanitize_name(k)]
nn.init.normal_(mod.weight, std=scale)
nn.init.zeros_(mod.bias)
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
out = AttrDict()
for k in self.param_shapes.keys():
proj = self.projections[_sanitize_name(k)]
out[k] = proj(x).reshape([len(x), *self.param_shapes[k]])
return out
class MLPParamsProj(ParamsProj):
def __init__(
self,
*,
device: torch.device,
param_shapes: Dict[str, Tuple[int]],
d_latent: int,
hidden_size: Optional[int] = None,
):
super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent)
if hidden_size is None:
hidden_size = d_latent
self.param_shapes = param_shapes
self.projections = nn.ModuleDict({})
for k, v in param_shapes.items():
self.projections[_sanitize_name(k)] = nn.Sequential(
nn.Linear(d_latent, hidden_size, device=device),
nn.GELU(),
nn.Linear(hidden_size, int(np.prod(v)), device=device),
)
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
out = AttrDict()
for k in self.param_shapes.keys():
proj = self.projections[_sanitize_name(k)]
out[k] = proj(x).reshape([len(x), *self.param_shapes[k]])
return out
class ChannelsProj(nn.Module):
def __init__(
self,
*,
device: torch.device,
vectors: int,
channels: int,
d_latent: int,
init_scale: float = 1.0,
learned_scale: Optional[float] = None,
use_ln: bool = False,
):
super().__init__()
self.proj = nn.Linear(d_latent, vectors * channels, device=device)
self.use_ln = use_ln
self.learned_scale = learned_scale
if use_ln:
self.norm = nn.LayerNorm(normalized_shape=(channels,), device=device)
if learned_scale is not None:
self.norm.weight.data.fill_(learned_scale)
scale = init_scale / math.sqrt(d_latent)
elif learned_scale is not None:
gain = torch.ones((channels,), device=device) * learned_scale
self.register_parameter("gain", nn.Parameter(gain))
scale = init_scale / math.sqrt(d_latent)
else:
scale = init_scale / math.sqrt(d_latent * channels)
nn.init.normal_(self.proj.weight, std=scale)
nn.init.zeros_(self.proj.bias)
self.d_latent = d_latent
self.vectors = vectors
self.channels = channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_bvd = x
w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent)
b_vc = self.proj.bias.view(1, self.vectors, self.channels)
h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd)
if self.use_ln:
h = self.norm(h)
elif self.learned_scale is not None:
h = h * self.gain.view(1, 1, -1)
h = h + b_vc
return h
class ChannelsParamsProj(ParamsProj):
def __init__(
self,
*,
device: torch.device,
param_shapes: Dict[str, Tuple[int]],
d_latent: int,
init_scale: float = 1.0,
learned_scale: Optional[float] = None,
use_ln: bool = False,
):
super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent)
self.param_shapes = param_shapes
self.projections = nn.ModuleDict({})
self.flat_shapes = flatten_param_shapes(param_shapes)
self.learned_scale = learned_scale
self.use_ln = use_ln
for k, (vectors, channels) in self.flat_shapes.items():
self.projections[_sanitize_name(k)] = ChannelsProj(
device=device,
vectors=vectors,
channels=channels,
d_latent=d_latent,
init_scale=init_scale,
learned_scale=learned_scale,
use_ln=use_ln,
)
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
out = AttrDict()
start = 0
for k, shape in self.param_shapes.items():
vectors, _ = self.flat_shapes[k]
end = start + vectors
x_bvd = x[:, start:end]
out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape)
start = end
return out
def params_proj_from_config(
config: Dict[str, Any], device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int
):
name = config.pop("name")
if name == "linear":
return LinearParamsProj(
**config, device=device, param_shapes=param_shapes, d_latent=d_latent
)
elif name == "mlp":
return MLPParamsProj(**config, device=device, param_shapes=param_shapes, d_latent=d_latent)
elif name == "channels":
return ChannelsParamsProj(
**config, device=device, param_shapes=param_shapes, d_latent=d_latent
)
else:
raise ValueError(f"unknown params proj: {name}")
def _sanitize_name(x: str) -> str:
return x.replace(".", "__")
================================================
FILE: shap_e/models/transmitter/pc_encoder.py
================================================
from abc import abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch import torch
from shap_e.models.generation.perceiver import SimplePerceiver
from shap_e.models.generation.transformer import Transformer
from shap_e.models.nn.encoding import PosEmbLinear
from shap_e.rendering.view_data import ProjectiveCamera
from shap_e.util.collections import AttrDict
from .base import VectorEncoder
from .channels_encoder import DatasetIterator, sample_pcl_fps
class PointCloudTransformerEncoder(VectorEncoder):
"""
Encode point clouds using a transformer model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
latent_bottleneck: Optional[Dict[str, Any]] = None,
d_latent: int = 512,
latent_ctx: int = 1,
input_channels: int = 6,
n_ctx: int = 1024,
width: int = 512,
layers: int = 12,
heads: int = 8,
init_scale: float = 0.25,
pos_emb: Optional[str] = None,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
latent_bottleneck=latent_bottleneck,
d_latent=d_latent,
)
self.input_channels = input_channels
self.n_ctx = n_ctx
self.latent_ctx = latent_ctx
assert d_latent % latent_ctx == 0
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.backbone = Transformer(
device=device,
dtype=dtype,
n_ctx=n_ctx + latent_ctx,
width=width,
layers=layers,
heads=heads,
init_scale=init_scale,
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.register_parameter(
"output_tokens",
nn.Parameter(torch.randn(latent_ctx, width, device=device, dtype=dtype)),
)
self.input_proj = PosEmbLinear(pos_emb, input_channels, width, device=device, dtype=dtype)
self.output_proj = nn.Linear(width, d_latent // latent_ctx, device=device, dtype=dtype)
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
_ = options
points = batch.points.permute(0, 2, 1) # NCL -> NLC
h = self.input_proj(points)
h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1)
h = self.ln_pre(h)
h = self.backbone(h)
h = self.ln_post(h)
h = h[:, self.n_ctx :]
h = self.output_proj(h).flatten(1)
return h
class PerceiverEncoder(VectorEncoder):
"""
Encode point clouds using a perceiver model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
latent_bottleneck: Optional[Dict[str, Any]] = None,
d_latent: int = 512,
latent_ctx: int = 1,
width: int = 512,
layers: int = 12,
xattn_layers: int = 1,
heads: int = 8,
init_scale: float = 0.25,
# Training hparams
inner_batch_size: int = 1,
data_ctx: int = 1,
min_unrolls: int,
max_unrolls: int,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
latent_bottleneck=latent_bottleneck,
d_latent=d_latent,
)
self.width = width
self.device = device
self.dtype = dtype
self.latent_ctx = latent_ctx
self.inner_batch_size = inner_batch_size
self.data_ctx = data_ctx
self.min_unrolls = min_unrolls
self.max_unrolls = max_unrolls
self.encoder = SimplePerceiver(
device=device,
dtype=dtype,
n_ctx=self.data_ctx + self.latent_ctx,
n_data=self.inner_batch_size,
width=width,
layers=xattn_layers,
heads=heads,
init_scale=init_scale,
)
self.processor = Transformer(
device=device,
dtype=dtype,
n_ctx=self.data_ctx + self.latent_ctx,
layers=layers - xattn_layers,
width=width,
heads=heads,
init_scale=init_scale,
)
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.register_parameter(
"output_tokens",
nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)),
)
self.output_proj = nn.Linear(width, d_latent // self.latent_ctx, device=device, dtype=dtype)
@abstractmethod
def get_h_and_iterator(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> Tuple[torch.Tensor, Iterable]:
"""
:return: a tuple of (
the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],
an iterator over the given data
)
"""
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
h, it = self.get_h_and_iterator(batch, options=options)
n_unrolls = self.get_n_unrolls()
for _ in range(n_unrolls):
data = next(it)
h = self.encoder(h, data)
h = self.processor(h)
h = self.output_proj(self.ln_post(h[:, -self.latent_ctx :]))
return h.flatten(1)
def get_n_unrolls(self):
if self.training:
n_unrolls = torch.randint(
self.min_unrolls, self.max_unrolls + 1, size=(), device=self.device
)
dist.broadcast(n_unrolls, 0)
n_unrolls = n_unrolls.item()
else:
n_unrolls = self.max_unrolls
return n_unrolls
class PointCloudPerceiverEncoder(PerceiverEncoder):
"""
Encode point clouds using a transformer model with an extra output
token used to extract a latent vector.
"""
def __init__(
self,
*,
cross_attention_dataset: str = "pcl",
fps_method: str = "fps",
# point cloud hyperparameters
input_channels: int = 6,
pos_emb: Optional[str] = None,
# multiview hyperparameters
image_size: int = 256,
patch_size: int = 32,
pose_dropout: float = 0.0,
use_depth: bool = False,
max_depth: float = 5.0,
# other hyperparameters
**kwargs,
):
super().__init__(**kwargs)
assert cross_attention_dataset in ("pcl", "multiview")
assert fps_method in ("fps", "first")
self.cross_attention_dataset = cross_attention_dataset
self.fps_method = fps_method
self.input_channels = input_channels
self.input_proj = PosEmbLinear(
pos_emb, input_channels, self.width, device=self.device, dtype=self.dtype
)
if self.cross_attention_dataset == "multiview":
self.image_size = image_size
self.patch_size = patch_size
self.pose_dropout = pose_dropout
self.use_depth = use_depth
self.max_depth = max_depth
pos_ctx = (image_size // patch_size) ** 2
self.register_parameter(
"pos_emb",
nn.Parameter(
torch.randn(
pos_ctx * self.inner_batch_size,
self.width,
device=self.device,
dtype=self.dtype,
)
),
)
self.patch_emb = nn.Conv2d(
in_channels=3 if not use_depth else 4,
out_channels=self.width,
kernel_size=patch_size,
stride=patch_size,
device=self.device,
dtype=self.dtype,
)
self.camera_emb = nn.Sequential(
nn.Linear(
3 * 4 + 1, self.width, device=self.device, dtype=self.dtype
), # input size is for origin+x+y+z+fov
nn.GELU(),
nn.Linear(self.width, 2 * self.width, device=self.device, dtype=self.dtype),
)
def get_h_and_iterator(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> Tuple[torch.Tensor, Iterable]:
"""
:return: a tuple of (
the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],
an iterator over the given data
)
"""
options = AttrDict() if options is None else options
# Build the initial query embeddings
points = batch.points.permute(0, 2, 1) # NCL -> NLC
fps_samples = self.sample_pcl_fps(points)
batch_size = points.shape[0]
data_tokens = self.input_proj(fps_samples)
latent_tokens = self.output_tokens.unsqueeze(0).repeat(batch_size, 1, 1)
h = self.ln_pre(torch.cat([data_tokens, latent_tokens], dim=1))
assert h.shape == (batch_size, self.data_ctx + self.latent_ctx, self.width)
# Build the dataset embedding iterator
dataset_fn = {
"pcl": self.get_pcl_dataset,
"multiview": self.get_multiview_dataset,
}[self.cross_attention_dataset]
it = dataset_fn(batch, options=options)
return h, it
def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor:
return sample_pcl_fps(points, data_ctx=self.data_ctx, method=self.fps_method)
def get_pcl_dataset(
self, batch: AttrDict, options: Optional[AttrDict[str, Any]] = None
) -> Iterable:
_ = options
dataset_emb = self.input_proj(batch.points.permute(0, 2, 1)) # NCL -> NLC
assert dataset_emb.shape[1] >= self.inner_batch_size
return iter(DatasetIterator(dataset_emb, batch_size=self.inner_batch_size))
def get_multiview_dataset(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> Iterable:
_ = options
dataset_emb = self.encode_views(batch)
batch_size, num_views, n_patches, width = dataset_emb.shape
assert num_views >= self.inner_batch_size
it = iter(DatasetIterator(dataset_emb, batch_size=self.inner_batch_size))
def gen():
while True:
examples = next(it)
assert examples.shape == (batch_size, self.inner_batch_size, n_patches, self.width)
views = examples.reshape(batch_size, -1, width) + self.pos_emb
yield views
return gen()
def encode_views(self, batch: AttrDict) -> torch.Tensor:
"""
:return: [batch_size, num_views, n_patches, width]
"""
all_views = self.views_to_tensor(batch.views).to(self.device)
if self.use_depth:
all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)
all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)
batch_size, num_views, _, _, _ = all_views.shape
views_proj = self.patch_emb(
all_views.reshape([batch_size * num_views, *all_views.shape[2:]])
)
views_proj = (
views_proj.reshape([batch_size, num_views, self.width, -1])
.permute(0, 1, 3, 2)
.contiguous()
) # [batch_size x num_views x n_patches x width]
# [batch_size, num_views, 1, 2 * width]
camera_proj = self.camera_emb(all_cameras).reshape(
[batch_size, num_views, 1, self.width * 2]
)
pose_dropout = self.pose_dropout if self.training else 0.0
mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout
camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj))
scale, shift = camera_proj.chunk(2, dim=3)
views_proj = views_proj * (scale + 1.0) + shift
return views_proj
def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor:
"""
Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1].
"""
if isinstance(views, torch.Tensor):
return views
tensor_batch = []
num_views = len(views[0])
for inner_list in views:
assert len(inner_list) == num_views
inner_batch = []
for img in inner_list:
img = img.resize((self.image_size,) * 2).convert("RGB")
inner_batch.append(
torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)
/ 127.5
- 1
)
tensor_batch.append(torch.stack(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3)
def depths_to_tensor(
self, depths: Union[torch.Tensor, List[List[Image.Image]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1].
"""
if isinstance(depths, torch.Tensor):
return depths
tensor_batch = []
num_views = len(depths[0])
for inner_list in depths:
assert len(inner_list) == num_views
inner_batch = []
for arr in inner_list:
tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth
tensor = tensor * 2 - 1
tensor = F.interpolate(
tensor[None, None],
(self.image_size,) * 2,
mode="nearest",
)
inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))
tensor_batch.append(torch.cat(inner_batch, dim=0))
return torch.stack(tensor_batch, dim=0)
def cameras_to_tensor(
self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]
) -> torch.Tensor:
"""
Returns a [batch x num_views x 3*4+1] tensor of camera information.
"""
if isinstance(cameras, torch.Tensor):
return cameras
outer_batch = []
for inner_list in cameras:
inner_batch = []
for camera in inner_list:
inner_batch.append(
np.array(
[
*camera.x,
*camera.y,
*camera.z,
*camera.origin,
camera.x_fov,
]
)
)
outer_batch.append(np.stack(inner_batch, axis=0))
return torch.from_numpy(np.stack(outer_batch, axis=0)).float()
================================================
FILE: shap_e/models/volume.py
================================================
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import torch
from shap_e.models.nn.meta import MetaModule
from shap_e.models.nn.utils import ArrayType, safe_divide, to_torch
@dataclass
class VolumeRange:
t0: torch.Tensor
t1: torch.Tensor
intersected: torch.Tensor
def __post_init__(self):
assert self.t0.shape == self.t1.shape == self.intersected.shape
def next_t0(self):
"""
Given convex volume1 and volume2, where volume1 is contained in
volume2, this function returns the t0 at which rays leave volume1 and
intersect with volume2 \\ volume1.
"""
return self.t1 * self.intersected.float()
def extend(self, another: "VolumeRange") -> "VolumeRange":
"""
The ranges at which rays intersect with either one, or both, or none of
the self and another are merged together.
"""
return VolumeRange(
t0=torch.where(self.intersected, self.t0, another.t0),
t1=torch.where(another.intersected, another.t1, self.t1),
intersected=torch.logical_or(self.intersected, another.intersected),
)
def partition(self, ts) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Partitions t0 and t1 into n_samples intervals.
:param ts: [batch_size, *shape, n_samples, 1]
:return: a tuple of (
lower: [batch_size, *shape, n_samples, 1]
upper: [batch_size, *shape, n_samples, 1]
delta: [batch_size, *shape, n_samples, 1]
) where
ts \\in [lower, upper]
deltas = upper - lower
"""
mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5
lower = torch.cat([self.t0[..., None, :], mids], dim=-2)
upper = torch.cat([mids, self.t1[..., None, :]], dim=-2)
delta = upper - lower
assert lower.shape == upper.shape == delta.shape == ts.shape
return lower, upper, delta
class Volume(ABC):
"""
An abstraction of rendering volume.
"""
@abstractmethod
def intersect(
self,
origin: torch.Tensor,
direction: torch.Tensor,
t0_lower: Optional[torch.Tensor] = None,
params: Optional[Dict] = None,
epsilon: float = 1e-6,
) -> VolumeRange:
"""
:param origin: [batch_size, *shape, 3]
:param direction: [batch_size, *shape, 3]
:param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
:param params: Optional meta parameters in case Volume is parametric
:param epsilon: to stabilize calculations
:return: A tuple of (t0, t1, intersected) where each has a shape
[batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is
in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed
to be on the boundary of the volume.
"""
class BoundingBoxVolume(MetaModule, Volume):
"""
Axis-aligned bounding box defined by the two opposite corners.
"""
def __init__(
self,
*,
bbox_min: ArrayType,
bbox_max: ArrayType,
min_dist: float = 0.0,
min_t_range: float = 1e-3,
device: torch.device = torch.device("cuda"),
):
"""
:param bbox_min: the left/bottommost corner of the bounding box
:param bbox_max: the other corner of the bounding box
:param min_dist: all rays should start at least this distance away from the origin.
"""
super().__init__()
self.bbox_min = to_torch(bbox_min).to(device)
self.bbox_max = to_torch(bbox_max).to(device)
self.min_dist = min_dist
self.min_t_range = min_t_range
self.bbox = torch.stack([self.bbox_min, self.bbox_max])
assert self.bbox.shape == (2, 3)
assert self.min_dist >= 0.0
assert self.min_t_range > 0.0
self.device = device
def intersect(
self,
origin: torch.Tensor,
direction: torch.Tensor,
t0_lower: Optional[torch.Tensor] = None,
params: Optional[Dict] = None,
epsilon=1e-6,
) -> VolumeRange:
"""
:param origin: [batch_size, *shape, 3]
:param direction: [batch_size, *shape, 3]
:param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
:param params: Optional meta parameters in case Volume is parametric
:param epsilon: to stabilize calculations
:return: A tuple of (t0, t1, intersected) where each has a shape
[batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is
in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed
to be on the boundary of the volume.
"""
batch_size, *shape, _ = origin.shape
ones = [1] * len(shape)
bbox = self.bbox.view(1, *ones, 2, 3)
ts = safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon)
# Cases to think about:
#
# 1. t1 <= t0: the ray does not pass through the AABB.
# 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin.
# 3. t0 <= 0 <= t1: the ray starts from inside the BB
# 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice.
#
# 1 and 4 are clearly handled from t0 < t1 below.
# Making t0 at least min_dist (>= 0) takes care of 2 and 3.
t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist)
t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values
assert t0.shape == t1.shape == (batch_size, *shape, 1)
if t0_lower is not None:
assert t0.shape == t0_lower.shape
t0 = torch.maximum(t0, t0_lower)
intersected = t0 + self.min_t_range < t1
t0 = torch.where(intersected, t0, torch.zeros_like(t0))
t1 = torch.where(intersected, t1, torch.ones_like(t1))
return VolumeRange(t0=t0, t1=t1, intersected=intersected)
class UnboundedVolume(MetaModule, Volume):
"""
Originally used in NeRF. Unbounded volume but with a limited visibility
when rendering (e.g. objects that are farther away than the max_dist from
the ray origin are not considered)
"""
def __init__(
self,
*,
max_dist: float,
min_dist: float = 0.0,
min_t_range: float = 1e-3,
device: torch.device = torch.device("cuda"),
):
super().__init__()
self.max_dist = max_dist
self.min_dist = min_dist
self.min_t_range = min_t_range
assert self.min_dist >= 0.0
assert self.min_t_range > 0.0
self.device = device
def intersect(
self,
origin: torch.Tensor,
direction: torch.Tensor,
t0_lower: Optional[torch.Tensor] = None,
params: Optional[Dict] = None,
) -> VolumeRange:
"""
:param origin: [batch_size, *shape, 3]
:param direction: [batch_size, *shape, 3]
:param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
:param params: Optional meta parameters in case Volume is parametric
:param epsilon: to stabilize calculations
:return: A tuple of (t0, t1, intersected) where each has a shape
[batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is
in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed
to be on the boundary of the volume.
"""
batch_size, *shape, _ = origin.shape
t0 = torch.zeros(batch_size, *shape, 1, dtype=origin.dtype, device=origin.device)
if t0_lower is not None:
t0 = torch.maximum(t0, t0_lower)
t1 = t0 + self.max_dist
t0 = t0.clamp(self.min_dist)
return VolumeRange(t0=t0, t1=t1, intersected=t0 + self.min_t_range < t1)
class SphericalVolume(MetaModule, Volume):
"""
Used in NeRF++ but will not be used probably unless we want to reproduce
their results.
"""
def __init__(
self,
*,
radius: float,
center: ArrayType = (0.0, 0.0, 0.0),
min_dist: float = 0.0,
min_t_range: float = 1e-3,
device: torch.device = torch.device("cuda"),
):
super().__init__()
self.radius = radius
self.center = to_torch(center).to(device)
self.min_dist = min_dist
self.min_t_range = min_t_range
assert self.min_dist >= 0.0
assert self.min_t_range > 0.0
self.device = device
def intersect(
self,
origin: torch.Tensor,
direction: torch.Tensor,
t0_lower: Optional[torch.Tensor] = None,
params: Optional[Dict] = None,
epsilon=1e-6,
) -> VolumeRange:
raise NotImplementedError
================================================
FILE: shap_e/rendering/__init__.py
================================================
================================================
FILE: shap_e/rendering/_mc_table.py
================================================
# Treat a cube as a bitmap, and create the index into this array in order of
# ZYX (note Z is the most significant digit).
# The resulting object is an array of triangles, where each triangle is 6
# indices. Each consecutive pair of indices within this triangle represents an
# edge spanning two corners (identified by the indices).
#
# The corners of a cube are indexed as follows
#
# (0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0),
# (0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1)
#
# Here is a visualization of the cube indices:
#
# 6 + -----------------------+ 7
# /| /|
# / | / |
# / | / |
# 4 +------------------------+ 5 |
# | | | |
# | | | |
# | | | |
# | | 2 | | 3
# | +--------------------|---+
# | / | /
# | / | /
# |/ |/
# +------------------------+
# 0 1
#
# Derived using model3d, in particular this function:
# https://github.com/unixpickle/model3d/blob/7a3adb982c154c80c1a22032b5a0695160a7f96d/model3d/mc.go#L434
#
MC_TABLE = [
[],
[[0, 1, 0, 2, 0, 4]],
[[1, 0, 1, 5, 1, 3]],
[[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2]],
[[2, 0, 2, 3, 2, 6]],
[[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4]],
[[1, 0, 1, 5, 1, 3], [2, 6, 0, 2, 3, 2]],
[[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4]],
[[3, 1, 3, 7, 3, 2]],
[[0, 2, 0, 4, 0, 1], [3, 7, 2, 3, 1, 3]],
[[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0]],
[[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5]],
[[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6]],
[[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6]],
[[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7]],
[[0, 4, 1, 5, 3, 7], [0, 4, 3, 7, 2, 6]],
[[4, 0, 4, 6, 4, 5]],
[[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1]],
[[1, 5, 1, 3, 1, 0], [4, 6, 5, 4, 0, 4]],
[[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2]],
[[2, 0, 2, 3, 2, 6], [4, 5, 0, 4, 6, 4]],
[[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1]],
[[2, 6, 2, 0, 3, 2], [1, 0, 1, 5, 3, 1], [6, 4, 5, 4, 0, 4]],
[[1, 3, 5, 4, 1, 5], [1, 3, 4, 6, 5, 4], [1, 3, 3, 2, 4, 6], [3, 2, 2, 6, 4, 6]],
[[3, 1, 3, 7, 3, 2], [6, 4, 5, 4, 0, 4]],
[[4, 5, 0, 1, 4, 6], [0, 1, 0, 2, 4, 6], [7, 3, 2, 3, 1, 3]],
[[3, 2, 1, 0, 3, 7], [1, 0, 1, 5, 3, 7], [6, 4, 5, 4, 0, 4]],
[[3, 7, 3, 2, 1, 5], [3, 2, 6, 4, 1, 5], [1, 5, 6, 4, 5, 4], [3, 2, 2, 0, 6, 4]],
[[3, 7, 2, 6, 3, 1], [2, 6, 2, 0, 3, 1], [5, 4, 0, 4, 6, 4]],
[[1, 0, 1, 3, 5, 4], [1, 3, 2, 6, 5, 4], [1, 3, 3, 7, 2, 6], [5, 4, 2, 6, 4, 6]],
[[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7], [4, 5, 0, 4, 4, 6]],
[[6, 2, 4, 6, 4, 5], [4, 5, 5, 1, 6, 2], [6, 2, 5, 1, 7, 3]],
[[5, 1, 5, 4, 5, 7]],
[[0, 1, 0, 2, 0, 4], [5, 7, 1, 5, 4, 5]],
[[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3]],
[[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3]],
[[2, 0, 2, 3, 2, 6], [7, 5, 1, 5, 4, 5]],
[[2, 6, 0, 4, 2, 3], [0, 4, 0, 1, 2, 3], [7, 5, 1, 5, 4, 5]],
[[5, 7, 1, 3, 5, 4], [1, 3, 1, 0, 5, 4], [6, 2, 0, 2, 3, 2]],
[[3, 1, 3, 2, 7, 5], [3, 2, 0, 4, 7, 5], [3, 2, 2, 6, 0, 4], [7, 5, 0, 4, 5, 4]],
[[3, 7, 3, 2, 3, 1], [5, 4, 7, 5, 1, 5]],
[[0, 4, 0, 1, 2, 0], [3, 1, 3, 7, 2, 3], [4, 5, 7, 5, 1, 5]],
[[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0]],
[[0, 4, 2, 3, 0, 2], [0, 4, 3, 7, 2, 3], [0, 4, 4, 5, 3, 7], [4, 5, 5, 7, 3, 7]],
[[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6], [4, 5, 7, 5, 1, 5]],
[[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6], [5, 7, 1, 5, 5, 4]],
[[2, 6, 2, 0, 3, 7], [2, 0, 4, 5, 3, 7], [3, 7, 4, 5, 7, 5], [2, 0, 0, 1, 4, 5]],
[[4, 0, 5, 4, 5, 7], [5, 7, 7, 3, 4, 0], [4, 0, 7, 3, 6, 2]],
[[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0]],
[[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6]],
[[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7]],
[[0, 2, 4, 6, 5, 7], [0, 2, 5, 7, 1, 3]],
[[5, 1, 4, 0, 5, 7], [4, 0, 4, 6, 5, 7], [3, 2, 6, 2, 0, 2]],
[[2, 3, 2, 6, 0, 1], [2, 6, 7, 5, 0, 1], [0, 1, 7, 5, 1, 5], [2, 6, 6, 4, 7, 5]],
[[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7], [2, 6, 0, 2, 2, 3]],
[[3, 1, 2, 3, 2, 6], [2, 6, 6, 4, 3, 1], [3, 1, 6, 4, 7, 5]],
[[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0], [2, 3, 1, 3, 7, 3]],
[[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6], [3, 2, 1, 3, 3, 7]],
[[0, 1, 0, 4, 2, 3], [0, 4, 5, 7, 2, 3], [0, 4, 4, 6, 5, 7], [2, 3, 5, 7, 3, 7]],
[[7, 5, 3, 7, 3, 2], [3, 2, 2, 0, 7, 5], [7, 5, 2, 0, 6, 4]],
[[0, 4, 4, 6, 5, 7], [0, 4, 5, 7, 1, 5], [0, 2, 1, 3, 3, 7], [3, 7, 2, 6, 0, 2]],
[
[3, 1, 7, 3, 6, 2],
[6, 2, 0, 1, 3, 1],
[6, 4, 0, 1, 6, 2],
[6, 4, 5, 1, 0, 1],
[6, 4, 7, 5, 5, 1],
],
[
[4, 0, 6, 4, 7, 5],
[7, 5, 1, 0, 4, 0],
[7, 3, 1, 0, 7, 5],
[7, 3, 2, 0, 1, 0],
[7, 3, 6, 2, 2, 0],
],
[[7, 3, 6, 2, 6, 4], [7, 5, 7, 3, 6, 4]],
[[6, 2, 6, 7, 6, 4]],
[[0, 4, 0, 1, 0, 2], [6, 7, 4, 6, 2, 6]],
[[1, 0, 1, 5, 1, 3], [7, 6, 4, 6, 2, 6]],
[[1, 3, 0, 2, 1, 5], [0, 2, 0, 4, 1, 5], [7, 6, 4, 6, 2, 6]],
[[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0]],
[[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3]],
[[6, 4, 2, 0, 6, 7], [2, 0, 2, 3, 6, 7], [5, 1, 3, 1, 0, 1]],
[[1, 5, 1, 3, 0, 4], [1, 3, 7, 6, 0, 4], [0, 4, 7, 6, 4, 6], [1, 3, 3, 2, 7, 6]],
[[3, 2, 3, 1, 3, 7], [6, 4, 2, 6, 7, 6]],
[[3, 7, 3, 2, 1, 3], [0, 2, 0, 4, 1, 0], [7, 6, 4, 6, 2, 6]],
[[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0], [4, 6, 2, 6, 7, 6]],
[[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5], [6, 4, 2, 6, 6, 7]],
[[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0]],
[[0, 1, 4, 6, 0, 4], [0, 1, 6, 7, 4, 6], [0, 1, 1, 3, 6, 7], [1, 3, 3, 7, 6, 7]],
[[0, 2, 0, 1, 4, 6], [0, 1, 3, 7, 4, 6], [0, 1, 1, 5, 3, 7], [4, 6, 3, 7, 6, 7]],
[[7, 3, 6, 7, 6, 4], [6, 4, 4, 0, 7, 3], [7, 3, 4, 0, 5, 1]],
[[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5]],
[[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5]],
[[6, 7, 4, 5, 6, 2], [4, 5, 4, 0, 6, 2], [3, 1, 0, 1, 5, 1]],
[[2, 0, 2, 6, 3, 1], [2, 6, 4, 5, 3, 1], [2, 6, 6, 7, 4, 5], [3, 1, 4, 5, 1, 5]],
[[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7]],
[[0, 1, 2, 3, 6, 7], [0, 1, 6, 7, 4, 5]],
[[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7], [1, 3, 0, 1, 1, 5]],
[[5, 4, 1, 5, 1, 3], [1, 3, 3, 2, 5, 4], [5, 4, 3, 2, 7, 6]],
[[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5], [1, 3, 7, 3, 2, 3]],
[[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5], [3, 7, 2, 3, 3, 1]],
[[0, 1, 1, 5, 3, 7], [0, 1, 3, 7, 2, 3], [0, 4, 2, 6, 6, 7], [6, 7, 4, 5, 0, 4]],
[
[6, 2, 7, 6, 5, 4],
[5, 4, 0, 2, 6, 2],
[5, 1, 0, 2, 5, 4],
[5, 1, 3, 2, 0, 2],
[5, 1, 7, 3, 3, 2],
],
[[3, 1, 3, 7, 2, 0], [3, 7, 5, 4, 2, 0], [2, 0, 5, 4, 0, 4], [3, 7, 7, 6, 5, 4]],
[[1, 0, 3, 1, 3, 7], [3, 7, 7, 6, 1, 0], [1, 0, 7, 6, 5, 4]],
[
[1, 0, 5, 1, 7, 3],
[7, 3, 2, 0, 1, 0],
[7, 6, 2, 0, 7, 3],
[7, 6, 4, 0, 2, 0],
[7, 6, 5, 4, 4, 0],
],
[[7, 6, 5, 4, 5, 1], [7, 3, 7, 6, 5, 1]],
[[5, 7, 5, 1, 5, 4], [6, 2, 7, 6, 4, 6]],
[[0, 2, 0, 4, 1, 0], [5, 4, 5, 7, 1, 5], [2, 6, 7, 6, 4, 6]],
[[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3], [2, 6, 7, 6, 4, 6]],
[[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3], [6, 7, 4, 6, 6, 2]],
[[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0], [1, 5, 4, 5, 7, 5]],
[[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3], [5, 1, 4, 5, 5, 7]],
[[0, 2, 2, 3, 6, 7], [0, 2, 6, 7, 4, 6], [0, 1, 4, 5, 5, 7], [5, 7, 1, 3, 0, 1]],
[
[5, 4, 7, 5, 3, 1],
[3, 1, 0, 4, 5, 4],
[3, 2, 0, 4, 3, 1],
[3, 2, 6, 4, 0, 4],
[3, 2, 7, 6, 6, 4],
],
[[5, 4, 5, 7, 1, 5], [3, 7, 3, 2, 1, 3], [4, 6, 2, 6, 7, 6]],
[[1, 0, 0, 2, 0, 4], [1, 5, 5, 4, 5, 7], [3, 2, 1, 3, 3, 7], [2, 6, 7, 6, 4, 6]],
[[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0], [6, 2, 7, 6, 6, 4]],
[
[0, 4, 2, 3, 0, 2],
[0, 4, 3, 7, 2, 3],
[0, 4, 4, 5, 3, 7],
[4, 5, 5, 7, 3, 7],
[6, 7, 4, 6, 2, 6],
],
[[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0], [5, 4, 7, 5, 5, 1]],
[
[0, 1, 4, 6, 0, 4],
[0, 1, 6, 7, 4, 6],
[0, 1, 1, 3, 6, 7],
[1, 3, 3, 7, 6, 7],
[5, 7, 1, 5, 4, 5],
],
[
[6, 7, 4, 6, 0, 2],
[0, 2, 3, 7, 6, 7],
[0, 1, 3, 7, 0, 2],
[0, 1, 5, 7, 3, 7],
[0, 1, 4, 5, 5, 7],
],
[[4, 0, 6, 7, 4, 6], [4, 0, 7, 3, 6, 7], [4, 0, 5, 7, 7, 3], [4, 5, 5, 7, 4, 0]],
[[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0]],
[[0, 2, 1, 5, 0, 1], [0, 2, 5, 7, 1, 5], [0, 2, 2, 6, 5, 7], [2, 6, 6, 7, 5, 7]],
[[1, 3, 1, 0, 5, 7], [1, 0, 2, 6, 5, 7], [5, 7, 2, 6, 7, 6], [1, 0, 0, 4, 2, 6]],
[[2, 0, 6, 2, 6, 7], [6, 7, 7, 5, 2, 0], [2, 0, 7, 5, 3, 1]],
[[0, 4, 0, 2, 1, 5], [0, 2, 6, 7, 1, 5], [0, 2, 2, 3, 6, 7], [1, 5, 6, 7, 5, 7]],
[[7, 6, 5, 7, 5, 1], [5, 1, 1, 0, 7, 6], [7, 6, 1, 0, 3, 2]],
[
[2, 0, 3, 2, 7, 6],
[7, 6, 4, 0, 2, 0],
[7, 5, 4, 0, 7, 6],
[7, 5, 1, 0, 4, 0],
[7, 5, 3, 1, 1, 0],
],
[[7, 5, 3, 1, 3, 2], [7, 6, 7, 5, 3, 2]],
[[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0], [3, 1, 7, 3, 3, 2]],
[
[0, 2, 1, 5, 0, 1],
[0, 2, 5, 7, 1, 5],
[0, 2, 2, 6, 5, 7],
[2, 6, 6, 7, 5, 7],
[3, 7, 2, 3, 1, 3],
],
[
[3, 7, 2, 3, 0, 1],
[0, 1, 5, 7, 3, 7],
[0, 4, 5, 7, 0, 1],
[0, 4, 6, 7, 5, 7],
[0, 4, 2, 6, 6, 7],
],
[[2, 0, 3, 7, 2, 3], [2, 0, 7, 5, 3, 7], [2, 0, 6, 7, 7, 5], [2, 6, 6, 7, 2, 0]],
[
[5, 7, 1, 5, 0, 4],
[0, 4, 6, 7, 5, 7],
[0, 2, 6, 7, 0, 4],
[0, 2, 3, 7, 6, 7],
[0, 2, 1, 3, 3, 7],
],
[[1, 0, 5, 7, 1, 5], [1, 0, 7, 6, 5, 7], [1, 0, 3, 7, 7, 6], [1, 3, 3, 7, 1, 0]],
[[0, 2, 0, 1, 0, 4], [3, 7, 6, 7, 5, 7]],
[[7, 5, 7, 3, 7, 6]],
[[7, 3, 7, 5, 7, 6]],
[[0, 1, 0, 2, 0, 4], [6, 7, 3, 7, 5, 7]],
[[1, 3, 1, 0, 1, 5], [7, 6, 3, 7, 5, 7]],
[[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2], [6, 7, 3, 7, 5, 7]],
[[2, 6, 2, 0, 2, 3], [7, 5, 6, 7, 3, 7]],
[[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4], [5, 7, 6, 7, 3, 7]],
[[1, 5, 1, 3, 0, 1], [2, 3, 2, 6, 0, 2], [5, 7, 6, 7, 3, 7]],
[[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4], [7, 6, 3, 7, 7, 5]],
[[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2]],
[[7, 6, 3, 2, 7, 5], [3, 2, 3, 1, 7, 5], [4, 0, 1, 0, 2, 0]],
[[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2]],
[[2, 3, 2, 0, 6, 7], [2, 0, 1, 5, 6, 7], [2, 0, 0, 4, 1, 5], [6, 7, 1, 5, 7, 5]],
[[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1]],
[[0, 4, 0, 1, 2, 6], [0, 1, 5, 7, 2, 6], [2, 6, 5, 7, 6, 7], [0, 1, 1, 3, 5, 7]],
[[1, 5, 0, 2, 1, 0], [1, 5, 2, 6, 0, 2], [1, 5, 5, 7, 2, 6], [5, 7, 7, 6, 2, 6]],
[[5, 1, 7, 5, 7, 6], [7, 6, 6, 2, 5, 1], [5, 1, 6, 2, 4, 0]],
[[4, 5, 4, 0, 4, 6], [7, 3, 5, 7, 6, 7]],
[[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1], [3, 7, 5, 7, 6, 7]],
[[4, 6, 4, 5, 0, 4], [1, 5, 1, 3, 0, 1], [6, 7, 3, 7, 5, 7]],
[[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2], [7, 3, 5, 7, 7, 6]],
[[2, 3, 2, 6, 0, 2], [4, 6, 4, 5, 0, 4], [3, 7, 5, 7, 6, 7]],
[[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1], [7, 5, 6, 7, 7, 3]],
[[0, 1, 1, 5, 1, 3], [0, 2, 2, 3, 2, 6], [4, 5, 0, 4, 4, 6], [5, 7, 6, 7, 3, 7]],
[
[1, 3, 5, 4, 1, 5],
[1, 3, 4, 6, 5, 4],
[1, 3, 3, 2, 4, 6],
[3, 2, 2, 6, 4, 6],
[7, 6, 3, 7, 5, 7],
],
[[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2], [0, 4, 6, 4, 5, 4]],
[[1, 0, 0, 2, 4, 6], [1, 0, 4, 6, 5, 4], [1, 3, 5, 7, 7, 6], [7, 6, 3, 2, 1, 3]],
[[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2], [4, 6, 5, 4, 4, 0]],
[
[7, 5, 6, 7, 2, 3],
[2, 3, 1, 5, 7, 5],
[2, 0, 1, 5, 2, 3],
[2, 0, 4, 5, 1, 5],
[2, 0, 6, 4, 4, 5],
],
[[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1], [4, 0, 6, 4, 4, 5]],
[
[4, 6, 5, 4, 1, 0],
[1, 0, 2, 6, 4, 6],
[1, 3, 2, 6, 1, 0],
[1, 3, 7, 6, 2, 6],
[1, 3, 5, 7, 7, 6],
],
[
[1, 5, 0, 2, 1, 0],
[1, 5, 2, 6, 0, 2],
[1, 5, 5, 7, 2, 6],
[5, 7, 7, 6, 2, 6],
[4, 6, 5, 4, 0, 4],
],
[[5, 1, 4, 6, 5, 4], [5, 1, 6, 2, 4, 6], [5, 1, 7, 6, 6, 2], [5, 7, 7, 6, 5, 1]],
[[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1]],
[[7, 3, 5, 1, 7, 6], [5, 1, 5, 4, 7, 6], [2, 0, 4, 0, 1, 0]],
[[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4]],
[[0, 2, 0, 4, 1, 3], [0, 4, 6, 7, 1, 3], [1, 3, 6, 7, 3, 7], [0, 4, 4, 5, 6, 7]],
[[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1], [0, 2, 3, 2, 6, 2]],
[[1, 5, 5, 4, 7, 6], [1, 5, 7, 6, 3, 7], [1, 0, 3, 2, 2, 6], [2, 6, 0, 4, 1, 0]],
[[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4], [2, 0, 3, 2, 2, 6]],
[
[2, 3, 6, 2, 4, 0],
[4, 0, 1, 3, 2, 3],
[4, 5, 1, 3, 4, 0],
[4, 5, 7, 3, 1, 3],
[4, 5, 6, 7, 7, 3],
],
[[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6]],
[[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6], [0, 4, 1, 0, 0, 2]],
[[1, 0, 5, 4, 7, 6], [1, 0, 7, 6, 3, 2]],
[[2, 3, 0, 2, 0, 4], [0, 4, 4, 5, 2, 3], [2, 3, 4, 5, 6, 7]],
[[1, 3, 1, 5, 0, 2], [1, 5, 7, 6, 0, 2], [1, 5, 5, 4, 7, 6], [0, 2, 7, 6, 2, 6]],
[
[5, 1, 4, 5, 6, 7],
[6, 7, 3, 1, 5, 1],
[6, 2, 3, 1, 6, 7],
[6, 2, 0, 1, 3, 1],
[6, 2, 4, 0, 0, 1],
],
[[6, 7, 2, 6, 2, 0], [2, 0, 0, 1, 6, 7], [6, 7, 0, 1, 4, 5]],
[[6, 2, 4, 0, 4, 5], [6, 7, 6, 2, 4, 5]],
[[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1]],
[[1, 5, 1, 0, 3, 7], [1, 0, 4, 6, 3, 7], [1, 0, 0, 2, 4, 6], [3, 7, 4, 6, 7, 6]],
[[1, 0, 3, 7, 1, 3], [1, 0, 7, 6, 3, 7], [1, 0, 0, 4, 7, 6], [0, 4, 4, 6, 7, 6]],
[[6, 4, 7, 6, 7, 3], [7, 3, 3, 1, 6, 4], [6, 4, 3, 1, 2, 0]],
[[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1], [2, 3, 6, 2, 2, 0]],
[
[7, 6, 3, 7, 1, 5],
[1, 5, 4, 6, 7, 6],
[1, 0, 4, 6, 1, 5],
[1, 0, 2, 6, 4, 6],
[1, 0, 3, 2, 2, 6],
],
[
[1, 0, 3, 7, 1, 3],
[1, 0, 7, 6, 3, 7],
[1, 0, 0, 4, 7, 6],
[0, 4, 4, 6, 7, 6],
[2, 6, 0, 2, 3, 2],
],
[[3, 1, 7, 6, 3, 7], [3, 1, 6, 4, 7, 6], [3, 1, 2, 6, 6, 4], [3, 2, 2, 6, 3, 1]],
[[3, 2, 3, 1, 7, 6], [3, 1, 0, 4, 7, 6], [7, 6, 0, 4, 6, 4], [3, 1, 1, 5, 0, 4]],
[
[0, 1, 2, 0, 6, 4],
[6, 4, 5, 1, 0, 1],
[6, 7, 5, 1, 6, 4],
[6, 7, 3, 1, 5, 1],
[6, 7, 2, 3, 3, 1],
],
[[0, 1, 4, 0, 4, 6], [4, 6, 6, 7, 0, 1], [0, 1, 6, 7, 2, 3]],
[[6, 7, 2, 3, 2, 0], [6, 4, 6, 7, 2, 0]],
[
[2, 6, 0, 2, 1, 3],
[1, 3, 7, 6, 2, 6],
[1, 5, 7, 6, 1, 3],
[1, 5, 4, 6, 7, 6],
[1, 5, 0, 4, 4, 6],
],
[[1, 5, 1, 0, 1, 3], [4, 6, 7, 6, 2, 6]],
[[0, 1, 2, 6, 0, 2], [0, 1, 6, 7, 2, 6], [0, 1, 4, 6, 6, 7], [0, 4, 4, 6, 0, 1]],
[[6, 7, 6, 2, 6, 4]],
[[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4]],
[[7, 5, 6, 4, 7, 3], [6, 4, 6, 2, 7, 3], [1, 0, 2, 0, 4, 0]],
[[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4], [0, 1, 5, 1, 3, 1]],
[[2, 0, 0, 4, 1, 5], [2, 0, 1, 5, 3, 1], [2, 6, 3, 7, 7, 5], [7, 5, 6, 4, 2, 6]],
[[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4]],
[[3, 2, 3, 7, 1, 0], [3, 7, 6, 4, 1, 0], [3, 7, 7, 5, 6, 4], [1, 0, 6, 4, 0, 4]],
[[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4], [1, 5, 3, 1, 1, 0]],
[
[7, 3, 5, 7, 4, 6],
[4, 6, 2, 3, 7, 3],
[4, 0, 2, 3, 4, 6],
[4, 0, 1, 3, 2, 3],
[4, 0, 5, 1, 1, 3],
],
[[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5]],
[[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5], [0, 1, 2, 0, 0, 4]],
[[1, 0, 1, 5, 3, 2], [1, 5, 4, 6, 3, 2], [3, 2, 4, 6, 2, 6], [1, 5, 5, 7, 4, 6]],
[
[0, 2, 4, 0, 5, 1],
[5, 1, 3, 2, 0, 2],
[5, 7, 3, 2, 5, 1],
[5, 7, 6, 2, 3, 2],
[5, 7, 4, 6, 6, 2],
],
[[2, 0, 3, 1, 7, 5], [2, 0, 7, 5, 6, 4]],
[[4, 6, 0, 4, 0, 1], [0, 1, 1, 3, 4, 6], [4, 6, 1, 3, 5, 7]],
[[0, 2, 1, 0, 1, 5], [1, 5, 5, 7, 0, 2], [0, 2, 5, 7, 4, 6]],
[[5, 7, 4, 6, 4, 0], [5, 1, 5, 7, 4, 0]],
[[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2]],
[[0, 1, 0, 2, 4, 5], [0, 2, 3, 7, 4, 5], [4, 5, 3, 7, 5, 7], [0, 2, 2, 6, 3, 7]],
[[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2], [1, 0, 5, 1, 1, 3]],
[
[1, 5, 3, 1, 2, 0],
[2, 0, 4, 5, 1, 5],
[2, 6, 4, 5, 2, 0],
[2, 6, 7, 5, 4, 5],
[2, 6, 3, 7, 7, 5],
],
[[2, 3, 0, 4, 2, 0], [2, 3, 4, 5, 0, 4], [2, 3, 3, 7, 4, 5], [3, 7, 7, 5, 4, 5]],
[[3, 2, 7, 3, 7, 5], [7, 5, 5, 4, 3, 2], [3, 2, 5, 4, 1, 0]],
[
[2, 3, 0, 4, 2, 0],
[2, 3, 4, 5, 0, 4],
[2, 3, 3, 7, 4, 5],
[3, 7, 7, 5, 4, 5],
[1, 5, 3, 1, 0, 1],
],
[[3, 2, 1, 5, 3, 1], [3, 2, 5, 4, 1, 5], [3, 2, 7, 5, 5, 4], [3, 7, 7, 5, 3, 2]],
[[2, 6, 2, 3, 0, 4], [2, 3, 7, 5, 0, 4], [2, 3, 3, 1, 7, 5], [0, 4, 7, 5, 4, 5]],
[
[3, 2, 1, 3, 5, 7],
[5, 7, 6, 2, 3, 2],
[5, 4, 6, 2, 5, 7],
[5, 4, 0, 2, 6, 2],
[5, 4, 1, 0, 0, 2],
],
[
[4, 5, 0, 4, 2, 6],
[2, 6, 7, 5, 4, 5],
[2, 3, 7, 5, 2, 6],
[2, 3, 1, 5, 7, 5],
[2, 3, 0, 1, 1, 5],
],
[[2, 3, 2, 0, 2, 6], [1, 5, 7, 5, 4, 5]],
[[5, 7, 4, 5, 4, 0], [4, 0, 0, 2, 5, 7], [5, 7, 0, 2, 1, 3]],
[[5, 4, 1, 0, 1, 3], [5, 7, 5, 4, 1, 3]],
[[0, 2, 4, 5, 0, 4], [0, 2, 5, 7, 4, 5], [0, 2, 1, 5, 5, 7], [0, 1, 1, 5, 0, 2]],
[[5, 4, 5, 1, 5, 7]],
[[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3]],
[[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3], [0, 2, 4, 0, 0, 1]],
[[3, 7, 3, 1, 2, 6], [3, 1, 5, 4, 2, 6], [3, 1, 1, 0, 5, 4], [2, 6, 5, 4, 6, 4]],
[
[6, 4, 2, 6, 3, 7],
[3, 7, 5, 4, 6, 4],
[3, 1, 5, 4, 3, 7],
[3, 1, 0, 4, 5, 4],
[3, 1, 2, 0, 0, 4],
],
[[2, 0, 2, 3, 6, 4], [2, 3, 1, 5, 6, 4], [6, 4, 1, 5, 4, 5], [2, 3, 3, 7, 1, 5]],
[
[0, 4, 1, 0, 3, 2],
[3, 2, 6, 4, 0, 4],
[3, 7, 6, 4, 3, 2],
[3, 7, 5, 4, 6, 4],
[3, 7, 1, 5, 5, 4],
],
[
[1, 3, 0, 1, 4, 5],
[4, 5, 7, 3, 1, 3],
[4, 6, 7, 3, 4, 5],
[4, 6, 2, 3, 7, 3],
[4, 6, 0, 2, 2, 3],
],
[[3, 7, 3, 1, 3, 2], [5, 4, 6, 4, 0, 4]],
[[3, 1, 2, 6, 3, 2], [3, 1, 6, 4, 2, 6], [3, 1, 1, 5, 6, 4], [1, 5, 5, 4, 6, 4]],
[
[3, 1, 2, 6, 3, 2],
[3, 1, 6, 4, 2, 6],
[3, 1, 1, 5, 6, 4],
[1, 5, 5, 4, 6, 4],
[0, 4, 1, 0, 2, 0],
],
[[4, 5, 6, 4, 6, 2], [6, 2, 2, 3, 4, 5], [4, 5, 2, 3, 0, 1]],
[[2, 3, 6, 4, 2, 6], [2, 3, 4, 5, 6, 4], [2, 3, 0, 4, 4, 5], [2, 0, 0, 4, 2, 3]],
[[1, 3, 5, 1, 5, 4], [5, 4, 4, 6, 1, 3], [1, 3, 4, 6, 0, 2]],
[[1, 3, 0, 4, 1, 0], [1, 3, 4, 6, 0, 4], [1, 3, 5, 4, 4, 6], [1, 5, 5, 4, 1, 3]],
[[4, 6, 0, 2, 0, 1], [4, 5, 4, 6, 0, 1]],
[[4, 6, 4, 0, 4, 5]],
[[4, 0, 6, 2, 7, 3], [4, 0, 7, 3, 5, 1]],
[[1, 5, 0, 1, 0, 2], [0, 2, 2, 6, 1, 5], [1, 5, 2, 6, 3, 7]],
[[3, 7, 1, 3, 1, 0], [1, 0, 0, 4, 3, 7], [3, 7, 0, 4, 2, 6]],
[[3, 1, 2, 0, 2, 6], [3, 7, 3, 1, 2, 6]],
[[0, 4, 2, 0, 2, 3], [2, 3, 3, 7, 0, 4], [0, 4, 3, 7, 1, 5]],
[[3, 7, 1, 5, 1, 0], [3, 2, 3, 7, 1, 0]],
[[0, 4, 1, 3, 0, 1], [0, 4, 3, 7, 1, 3], [0, 4, 2, 3, 3, 7], [0, 2, 2, 3, 0, 4]],
[[3, 7, 3, 1, 3, 2]],
[[2, 6, 3, 2, 3, 1], [3, 1, 1, 5, 2, 6], [2, 6, 1, 5, 0, 4]],
[[1, 5, 3, 2, 1, 3], [1, 5, 2, 6, 3, 2], [1, 5, 0, 2, 2, 6], [1, 0, 0, 2, 1, 5]],
[[2, 3, 0, 1, 0, 4], [2, 6, 2, 3, 0, 4]],
[[2, 3, 2, 0, 2, 6]],
[[1, 5, 0, 4, 0, 2], [1, 3, 1, 5, 0, 2]],
[[1, 5, 1, 0, 1, 3]],
[[0, 2, 0, 1, 0, 4]],
[],
]
================================================
FILE: shap_e/rendering/blender/__init__.py
================================================
from .render import render_mesh, render_model
from .view_data import BlenderViewData
__all__ = ["BlenderViewData", "render_model"]
================================================
FILE: shap_e/rendering/blender/blender_script.py
================================================
"""
Script to run within blender.
Provide arguments after `--`.
For example: `blender -b -P blender_script.py -- --help`
"""
import argparse
import json
import math
import os
import random
import sys
import bpy
from mathutils import Vector
from mathutils.noise import random_unit_vector
MAX_DEPTH = 5.0
FORMAT_VERSION = 6
# Set by main(), these constants are passed to the script to avoid
# duplicating them across multiple files.
UNIFORM_LIGHT_DIRECTION = None
BASIC_AMBIENT_COLOR = None
BASIC_DIFFUSE_COLOR = None
def clear_scene():
bpy.ops.object.select_all(action="SELECT")
bpy.ops.object.delete()
def clear_lights():
bpy.ops.object.select_all(action="DESELECT")
for obj in bpy.context.scene.objects.values():
if isinstance(obj.data, bpy.types.Light):
obj.select_set(True)
bpy.ops.object.delete()
def import_model(path):
clear_scene()
_, ext = os.path.splitext(path)
ext = ext.lower()
if ext == ".obj":
bpy.ops.import_scene.obj(filepath=path)
elif ext in [".glb", ".gltf"]:
bpy.ops.import_scene.gltf(filepath=path)
elif ext == ".stl":
bpy.ops.import_mesh.stl(filepath=path)
elif ext == ".fbx":
bpy.ops.import_scene.fbx(filepath=path)
elif ext == ".dae":
bpy.ops.wm.collada_import(filepath=path)
elif ext == ".ply":
bpy.ops.import_mesh.ply(filepath=path)
else:
raise RuntimeError(f"unexpected extension: {ext}")
def scene_root_objects():
for obj in bpy.context.scene.objects.values():
if not obj.parent:
yield obj
def scene_bbox(single_obj=None, ignore_matrix=False):
bbox_min = (math.inf,) * 3
bbox_max = (-math.inf,) * 3
found = False
for obj in scene_meshes() if single_obj is None else [single_obj]:
found = True
for coord in obj.bound_box:
coord = Vector(coord)
if not ignore_matrix:
coord = obj.matrix_world @ coord
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
if not found:
raise RuntimeError("no objects in scene to compute bounding box for")
return Vector(bbox_min), Vector(bbox_max)
def scene_meshes():
for obj in bpy.context.scene.objects.values():
if isinstance(obj.data, (bpy.types.Mesh)):
yield obj
def normalize_scene():
if len(list(scene_root_objects())) > 1:
# Create an empty object to be used as a parent for all root objects
parent_empty = bpy.data.objects.new("ParentEmpty", None)
bpy.context.scene.collection.objects.link(parent_empty)
# Parent all root objects to the empty object
for obj in scene_root_objects():
if obj != parent_empty:
obj.parent = parent_empty
bbox_min, bbox_max = scene_bbox()
scale = 1 / max(bbox_max - bbox_min)
for obj in scene_root_objects():
obj.scale = obj.scale * scale
# Apply scale to matrix_world.
bpy.context.view_layer.update()
bbox_min, bbox_max = scene_bbox()
offset = -(bbox_min + bbox_max) / 2
for obj in scene_root_objects():
obj.matrix_world.translation += offset
bpy.ops.object.select_all(action="DESELECT")
def create_camera():
# https://b3d.interplanety.org/en/how-to-create-camera-through-the-blender-python-api/
camera_data = bpy.data.cameras.new(name="Camera")
camera_object = bpy.data.objects.new("Camera", camera_data)
bpy.context.scene.collection.objects.link(camera_object)
bpy.context.scene.camera = camera_object
def set_camera(direction, camera_dist=2.0):
camera_pos = -camera_dist * direction
bpy.context.scene.camera.location = camera_pos
# https://blender.stackexchange.com/questions/5210/pointing-the-camera-in-a-particular-direction-programmatically
rot_quat = direction.to_track_quat("-Z", "Y")
bpy.context.scene.camera.rotation_euler = rot_quat.to_euler()
bpy.context.view_layer.update()
def randomize_camera(camera_dist=2.0):
direction = random_unit_vector()
set_camera(direction, camera_dist=camera_dist)
def pan_camera(time, axis="Z", camera_dist=2.0, elevation=0.1):
angle = time * math.pi * 2
direction = [-math.cos(angle), -math.sin(angle), elevation]
assert axis in ["X", "Y", "Z"]
if axis == "X":
direction = [direction[2], *direction[:2]]
elif axis == "Y":
direction = [direction[0], elevation, direction[1]]
direction = Vector(direction).normalized()
set_camera(direction, camera_dist=camera_dist)
def place_camera(time, camera_pose_mode="random", camera_dist_min=2.0, camera_dist_max=2.0):
camera_dist = random.uniform(camera_dist_min, camera_dist_max)
if camera_pose_mode == "random":
randomize_camera(camera_dist=camera_dist)
elif camera_pose_mode == "z-circular":
pan_camera(time, axis="Z", camera_dist=camera_dist)
elif camera_pose_mode == "z-circular-elevated":
pan_camera(time, axis="Z", camera_dist=camera_dist, elevation=-0.2617993878)
else:
raise ValueError(f"Unknown camera pose mode: {camera_pose_mode}")
def create_light(location, energy=1.0, angle=0.5 * math.pi / 180):
# https://blender.stackexchange.com/questions/215624/how-to-create-a-light-with-the-python-api-in-blender-2-92
light_data = bpy.data.lights.new(name="Light", type="SUN")
light_data.energy = energy
light_data.angle = angle
light_object = bpy.data.objects.new(name="Light", object_data=light_data)
direction = -location
rot_quat = direction.to_track_quat("-Z", "Y")
light_object.rotation_euler = rot_quat.to_euler()
bpy.context.view_layer.update()
bpy.context.collection.objects.link(light_object)
light_object.location = location
def create_random_lights(count=4, distance=2.0, energy=1.5):
clear_lights()
for _ in range(count):
create_light(random_unit_vector() * distance, energy=energy)
def create_camera_light():
clear_lights()
create_light(bpy.context.scene.camera.location, energy=5.0)
def create_uniform_light(backend):
clear_lights()
# Random direction to decorrelate axis-aligned sides.
pos = Vector(UNIFORM_LIGHT_DIRECTION)
angle = 0.0092 if backend == "CYCLES" else math.pi
create_light(pos, energy=5.0, angle=angle)
create_light(-pos, energy=5.0, angle=angle)
def create_vertex_color_shaders():
# By default, Blender will ignore vertex colors in both the
# Eevee and Cycles backends, since these colors aren't
# associated with a material.
#
# What we do here is create a simple material shader and link
# the vertex color to the material color.
for obj in bpy.context.scene.objects.values():
if not isinstance(obj.data, (bpy.types.Mesh)):
continue
if len(obj.data.materials):
# We don't want to override any existing materials.
continue
color_keys = (obj.data.vertex_colors or {}).keys()
if not len(color_keys):
# Many objects will have no materials *or* vertex colors.
continue
mat = bpy.data.materials.new(name="VertexColored")
mat.use_nodes = True
# There should be a Principled BSDF by default.
bsdf_node = None
for node in mat.node_tree.nodes:
if node.type == "BSDF_PRINCIPLED":
bsdf_node = node
assert bsdf_node is not None, "material has no Principled BSDF node to modify"
socket_map = {}
for input in bsdf_node.inputs:
socket_map[input.name] = input
# Make sure nothing lights the object except for the diffuse color.
socket_map["Specular"].default_value = 0.0
socket_map["Roughness"].default_value = 1.0
v_color = mat.node_tree.nodes.new("ShaderNodeVertexColor")
v_color.layer_name = color_keys[0]
mat.node_tree.links.new(v_color.outputs[0], socket_map["Base Color"])
obj.data.materials.append(mat)
def create_default_materials():
for obj in bpy.context.scene.objects.values():
if isinstance(obj.data, (bpy.types.Mesh)):
if not len(obj.data.materials):
mat = bpy.data.materials.new(name="DefaultMaterial")
mat.use_nodes = True
obj.data.materials.append(mat)
def find_materials():
all_materials = set()
for obj in bpy.context.scene.objects.values():
if not isinstance(obj.data, bpy.types.Mesh):
continue
for mat in obj.data.materials:
all_materials.add(mat)
return all_materials
def delete_all_materials():
for obj in bpy.context.scene.objects.values():
if isinstance(obj.data, bpy.types.Mesh):
# https://blender.stackexchange.com/questions/146714/removing-all-material-slots-in-one-go
obj.data.materials.clear()
def setup_material_extraction_shaders(capturing_material_alpha: bool):
"""
Change every material to emit texture colors (or alpha) rather than having
an actual reflective color. Returns a function to undo the changes to the
materials.
"""
# Objects can share materials, so we first find all of the
# materials in the project, and then modify them each once.
undo_fns = []
for mat in find_materials():
undo_fn = setup_material_extraction_shader_for_material(mat, capturing_material_alpha)
if undo_fn is not None:
undo_fns.append(undo_fn)
return lambda: [undo_fn() for undo_fn in undo_fns]
def setup_material_extraction_shader_for_material(mat, capturing_material_alpha: bool):
mat.use_nodes = True
# By default, most imported models should use the regular
# "Principled BSDF" material, so we should always find this.
# If not, this shader manipulation logic won't work.
bsdf_node = None
for node in mat.node_tree.nodes:
if node.type == "BSDF_PRINCIPLED":
bsdf_node = node
assert bsdf_node is not None, "material has no Principled BSDF node to modify"
socket_map = {}
for input in bsdf_node.inputs:
socket_map[input.name] = input
for name in ["Base Color", "Emission", "Emission Strength", "Alpha", "Specular"]:
assert name in socket_map.keys(), f"{name} not in {list(socket_map.keys())}"
old_base_color = get_socket_value(mat.node_tree, socket_map["Base Color"])
old_alpha = get_socket_value(mat.node_tree, socket_map["Alpha"])
old_emission = get_socket_value(mat.node_tree, socket_map["Emission"])
old_emission_strength = get_socket_value(mat.node_tree, socket_map["Emission Strength"])
old_specular = get_socket_value(mat.node_tree, socket_map["Specular"])
# Make sure the base color of all objects is black and the opacity
# is 1, so that we are effectively just telling the shader what color
# to make the pixels.
clear_socket_input(mat.node_tree, socket_map["Base Color"])
socket_map["Base Color"].default_value = [0, 0, 0, 1]
clear_socket_input(mat.node_tree, socket_map["Alpha"])
socket_map["Alpha"].default_value = 1
clear_socket_input(mat.node_tree, socket_map["Specular"])
socket_map["Specular"].default_value = 0.0
old_blend_method = mat.blend_method
mat.blend_method = "OPAQUE"
if capturing_material_alpha:
set_socket_value(mat.node_tree, socket_map["Emission"], old_alpha)
else:
set_socket_value(mat.node_tree, socket_map["Emission"], old_base_color)
clear_socket_input(mat.node_tree, socket_map["Emission Strength"])
socket_map["Emission Strength"].default_value = 1.0
def undo_fn():
mat.blend_method = old_blend_method
set_socket_value(mat.node_tree, socket_map["Base Color"], old_base_color)
set_socket_value(mat.node_tree, socket_map["Alpha"], old_alpha)
set_socket_value(mat.node_tree, socket_map["Emission"], old_emission)
set_socket_value(mat.node_tree, socket_map["Emission Strength"], old_emission_strength)
set_socket_value(mat.node_tree, socket_map["Specular"], old_specular)
return undo_fn
def get_socket_value(tree, socket):
default = socket.default_value
if not isinstance(default, float):
default = list(default)
for link in tree.links:
if link.to_socket == socket:
return (link.from_socket, default)
return (None, default)
def clear_socket_input(tree, socket):
for link in list(tree.links):
if link.to_socket == socket:
tree.links.remove(link)
def set_socket_value(tree, socket, socket_and_default):
clear_socket_input(tree, socket)
old_source_socket, default = socket_and_default
if isinstance(default, float) and not isinstance(socket.default_value, float):
# Codepath for setting Emission to a previous alpha value.
socket.default_value = [default] * 3 + [1.0]
else:
socket.default_value = default
if old_source_socket is not None:
tree.links.new(old_source_socket, socket)
def setup_nodes(output_path, capturing_material_alpha: bool = False, basic_lighting: bool = False):
tree = bpy.context.scene.node_tree
links = tree.links
for node in tree.nodes:
tree.nodes.remove(node)
# Helpers to perform math on links and constants.
def node_op(op: str, *args, clamp=False):
node = tree.nodes.new(type="CompositorNodeMath")
node.operation = op
if clamp:
node.use_clamp = True
for i, arg in enumerate(args):
if isinstance(arg, (int, float)):
node.inputs[i].default_value = arg
else:
links.new(arg, node.inputs[i])
return node.outputs[0]
def node_clamp(x, maximum=1.0):
return node_op("MINIMUM", x, maximum)
def node_mul(x, y, **kwargs):
return node_op("MULTIPLY", x, y, **kwargs)
def node_add(x, y, **kwargs):
return node_op("ADD", x, y, **kwargs)
def node_abs(x, **kwargs):
return node_op("ABSOLUTE", x, **kwargs)
input_node = tree.nodes.new(type="CompositorNodeRLayers")
input_node.scene = bpy.context.scene
input_sockets = {}
for output in input_node.outputs:
input_sockets[output.name] = output
if capturing_material_alpha:
color_socket = input_sockets["Image"]
else:
raw_color_socket = input_sockets["Image"]
if basic_lighting:
# Compute diffuse lighting
normal_xyz = tree.nodes.new(type="CompositorNodeSeparateXYZ")
tree.links.new(input_sockets["Normal"], normal_xyz.inputs[0])
normal_x, normal_y, normal_z = [normal_xyz.outputs[i] for i in range(3)]
dot = node_add(
node_mul(UNIFORM_LIGHT_DIRECTION[0], normal_x),
node_add(
node_mul(UNIFORM_LIGHT_DIRECTION[1], normal_y),
node_mul(UNIFORM_LIGHT_DIRECTION[2], normal_z),
),
)
diffuse = node_abs(dot)
# Compute ambient + diffuse lighting
brightness = node_add(BASIC_AMBIENT_COLOR, node_mul(BASIC_DIFFUSE_COLOR, diffuse))
# Modulate the RGB channels using the total brightness.
rgba_node = tree.nodes.new(type="CompositorNodeSepRGBA")
tree.links.new(raw_color_socket, rgba_node.inputs[0])
combine_node = tree.nodes.new(type="CompositorNodeCombRGBA")
for i in range(3):
tree.links.new(node_mul(rgba_node.outputs[i], brightness), combine_node.inputs[i])
tree.links.new(rgba_node.outputs[3], combine_node.inputs[3])
raw_color_socket = combine_node.outputs[0]
# We apply sRGB here so that our fixed-point depth map and material
# alpha values are not sRGB, and so that we perform ambient+diffuse
# lighting in linear RGB space.
color_node = tree.nodes.new(type="CompositorNodeConvertColorSpace")
color_node.from_color_space = "Linear"
color_node.to_color_space = "sRGB"
tree.links.new(raw_color_socket, color_node.inputs[0])
color_socket = color_node.outputs[0]
split_node = tree.nodes.new(type="CompositorNodeSepRGBA")
tree.links.new(color_socket, split_node.inputs[0])
# Create separate file output nodes for every channel we care about.
# The process calling this script must decide how to recombine these
# channels, possibly into a single image.
for i, channel in enumerate("rgba") if not capturing_material_alpha else [(0, "MatAlpha")]:
output_node = tree.nodes.new(type="CompositorNodeOutputFile")
output_node.base_path = f"{output_path}_{channel}"
links.new(split_node.outputs[i], output_node.inputs[0])
if capturing_material_alpha:
# No need to re-write depth here.
return
depth_out = node_clamp(node_mul(input_sockets["Depth"], 1 / MAX_DEPTH))
output_node = tree.nodes.new(type="CompositorNodeOutputFile")
output_node.base_path = f"{output_path}_depth"
links.new(depth_out, output_node.inputs[0])
def render_scene(output_path, fast_mode: bool, extract_material: bool, basic_lighting: bool):
use_workbench = bpy.context.scene.render.engine == "BLENDER_WORKBENCH"
if use_workbench:
# We must use a different engine to compute depth maps.
bpy.context.scene.render.engine = "BLENDER_EEVEE"
bpy.context.scene.eevee.taa_render_samples = 1 # faster, since we discard image.
if fast_mode:
if bpy.context.scene.render.engine == "BLENDER_EEVEE":
bpy.context.scene.eevee.taa_render_samples = 1
elif bpy.context.scene.render.engine == "CYCLES":
bpy.context.scene.cycles.samples = 256
else:
if bpy.context.scene.render.engine == "CYCLES":
# We should still impose a per-frame time limit
# so that we don't timeout completely.
bpy.context.scene.cycles.time_limit = 40
bpy.context.view_layer.update()
bpy.context.scene.use_nodes = True
bpy.context.scene.view_layers["ViewLayer"].use_pass_z = True
if basic_lighting:
bpy.context.scene.view_layers["ViewLayer"].use_pass_normal = True
bpy.context.scene.view_settings.view_transform = "Raw" # sRGB done in graph nodes
bpy.context.scene.render.film_transparent = True
bpy.context.scene.render.resolution_x = 512
bpy.context.scene.render.resolution_y = 512
bpy.context.scene.render.image_settings.file_format = "PNG"
bpy.context.scene.render.image_settings.color_mode = "BW"
bpy.context.scene.render.image_settings.color_depth = "16"
bpy.context.scene.render.filepath = output_path
if extract_material:
for do_alpha in [False, True]:
undo_fn = setup_material_extraction_shaders(capturing_material_alpha=do_alpha)
setup_nodes(output_path, capturing_material_alpha=do_alpha)
bpy.ops.render.render(write_still=True)
undo_fn()
else:
setup_nodes(output_path, basic_lighting=basic_lighting)
bpy.ops.render.render(write_still=True)
# The output images must be moved from their own sub-directories, or
# discarded if we are using workbench for the color.
for channel_name in ["r", "g", "b", "a", "depth", *(["MatAlpha"] if extract_material else [])]:
sub_dir = f"{output_path}_{channel_name}"
image_path = os.path.join(sub_dir, os.listdir(sub_dir)[0])
name, ext = os.path.splitext(output_path)
if channel_name == "depth" or not use_workbench:
os.rename(image_path, f"{name}_{channel_name}{ext}")
else:
os.remove(image_path)
os.removedirs(sub_dir)
if use_workbench:
# Re-render RGBA using workbench with texture mode, since this seems
# to show the most reasonable colors when lighting is broken.
bpy.context.scene.use_nodes = False
bpy.context.scene.render.engine = "BLENDER_WORKBENCH"
bpy.context.scene.render.image_settings.color_mode = "RGBA"
bpy.context.scene.render.image_settings.color_depth = "8"
bpy.context.scene.display.shading.color_type = "TEXTURE"
bpy.context.scene.display.shading.light = "FLAT"
if fast_mode:
# Single pass anti-aliasing.
bpy.context.scene.display.render_aa = "FXAA"
os.remove(output_path)
bpy.ops.render.render(write_still=True)
bpy.context.scene.render.image_settings.color_mode = "BW"
bpy.context.scene.render.image_settings.color_depth = "16"
def scene_fov():
x_fov = bpy.context.scene.camera.data.angle_x
y_fov = bpy.context.scene.camera.data.angle_y
width = bpy.context.scene.render.resolution_x
height = bpy.context.scene.render.resolution_y
if bpy.context.scene.camera.data.angle == x_fov:
y_fov = 2 * math.atan(math.tan(x_fov / 2) * height / width)
else:
x_fov = 2 * math.atan(math.tan(y_fov / 2) * width / height)
return x_fov, y_fov
def write_camera_metadata(path):
x_fov, y_fov = scene_fov()
bbox_min, bbox_max = scene_bbox()
matrix = bpy.context.scene.camera.matrix_world
with open(path, "w") as f:
json.dump(
dict(
format_version=FORMAT_VERSION,
max_depth=MAX_DEPTH,
bbox=[list(bbox_min), list(bbox_max)],
origin=list(matrix.col[3])[:3],
x_fov=x_fov,
y_fov=y_fov,
x=list(matrix.col[0])[:3],
y=list(-matrix.col[1])[:3],
z=list(-matrix.col[2])[:3],
),
f,
)
def save_rendering_dataset(
input_path: str,
output_path: str,
num_images: int,
backend: str,
light_mode: str,
camera_pose: str,
camera_dist_min: float,
camera_dist_max: float,
fast_mode: bool,
extract_material: bool,
delete_material: bool,
):
assert light_mode in ["random", "uniform", "camera", "basic"]
assert camera_pose in ["random", "z-circular", "z-circular-elevated"]
basic_lighting = light_mode == "basic"
assert not (basic_lighting and extract_material), "cannot extract material with basic lighting"
assert not (delete_material and extract_material), "cannot extract material and delete it"
import_model(input_path)
bpy.context.scene.render.engine = backend
normalize_scene()
if light_mode == "random":
create_random_lights()
elif light_mode == "uniform":
create_uniform_light(backend)
create_camera()
create_vertex_color_shaders()
if delete_material:
delete_all_materials()
if extract_material or basic_lighting:
create_default_materials()
if basic_lighting:
# Make sure materials are uniformly lit, so that we can light
# them in the output shader.
setup_material_extraction_shaders(capturing_material_alpha=False)
for i in range(num_images):
t = i / max(num_images - 1, 1) # same as np.linspace(0, 1, num_images)
place_camera(
t,
camera_pose_mode=camera_pose,
camera_dist_min=camera_dist_min,
camera_dist_max=camera_dist_max,
)
if light_mode == "camera":
create_camera_light()
render_scene(
os.path.join(output_path, f"{i:05}.png"),
fast_mode=fast_mode,
extract_material=extract_material,
basic_lighting=basic_lighting,
)
write_camera_metadata(os.path.join(output_path, f"{i:05}.json"))
with open(os.path.join(output_path, "info.json"), "w") as f:
info = dict(
backend=backend,
light_mode=light_mode,
fast_mode=fast_mode,
extract_material=extract_material,
format_version=FORMAT_VERSION,
channels=["R", "G", "B", "A", "D", *(["MatAlpha"] if extract_material else [])],
scale=0.5, # The scene is bounded by [-scale, scale].
)
json.dump(info, f)
def main():
global UNIFORM_LIGHT_DIRECTION, BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR
try:
dash_index = sys.argv.index("--")
except ValueError as exc:
raise ValueError("arguments must be preceded by '--'") from exc
raw_args = sys.argv[dash_index + 1 :]
parser = argparse.ArgumentParser()
parser.add_argument("--input_path", required=True, type=str)
parser.add_argument("--output_path", required=True, type=str)
parser.add_argument("--num_images", required=True, type=int)
parser.add_argument("--backend", type=str, default="BLENDER_EEVEE")
parser.add_argument("--light_mode", type=str, default="random")
parser.add_argument("--camera_pose", type=str, default="random")
parser.add_argument("--camera_dist_min", type=float, default=2.0)
parser.add_argument("--camera_dist_max", type=float, default=2.0)
parser.add_argument("--fast_mode", action="store_true")
parser.add_argument("--extract_material", action="store_true")
parser.add_argument("--delete_material", action="store_true")
# Prevent constants from being repeated.
parser.add_argument("--uniform_light_direction", required=True, type=float, nargs="+")
parser.add_argument("--basic_ambient", required=True, type=float)
parser.add_argument("--basic_diffuse", required=True, type=float)
args = parser.parse_args(raw_args)
UNIFORM_LIGHT_DIRECTION = args.uniform_light_direction
BASIC_AMBIENT_COLOR = args.basic_ambient
BASIC_DIFFUSE_COLOR = args.basic_diffuse
save_rendering_dataset(
input_path=args.input_path,
output_path=args.output_path,
num_images=args.num_images,
backend=args.backend,
light_mode=args.light_mode,
camera_pose=args.camera_pose,
camera_dist_min=args.camera_dist_min,
camera_dist_max=args.camera_dist_max,
fast_mode=args.fast_mode,
extract_material=args.extract_material,
delete_material=args.delete_material,
)
main()
================================================
FILE: shap_e/rendering/blender/constants.py
================================================
UNIFORM_LIGHT_DIRECTION = [0.09387503, -0.63953443, -0.7630093]
BASIC_AMBIENT_COLOR = 0.3
BASIC_DIFFUSE_COLOR = 0.7
================================================
FILE: shap_e/rendering/blender/render.py
================================================
import os
import platform
import subprocess
import tempfile
import zipfile
import blobfile as bf
import numpy as np
from PIL import Image
from shap_e.rendering.mesh import TriMesh
from .constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR, UNIFORM_LIGHT_DIRECTION
SCRIPT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "blender_script.py")
def render_model(
model_path: str,
output_path: str,
num_images: int,
backend: str = "BLENDER_EEVEE",
light_mode: str = "random",
camera_pose: str = "random",
camera_dist_min: float = 2.0,
camera_dist_max: float = 2.0,
fast_mode: bool = False,
extract_material: bool = False,
delete_material: bool = False,
verbose: bool = False,
timeout: float = 15 * 60,
):
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_in = model_path
tmp_out = os.path.join(tmp_dir, "out")
zip_out = tmp_out + ".zip"
os.mkdir(tmp_out)
args = []
if platform.system() == "Linux":
# Needed to enable Eevee backend on headless linux.
args = ["xvfb-run", "-a"]
args.extend(
[
_blender_binary_path(),
"-b",
"-P",
SCRIPT_PATH,
"--",
"--input_path",
tmp_in,
"--output_path",
tmp_out,
"--num_images",
str(num_images),
"--backend",
backend,
"--light_mode",
light_mode,
"--camera_pose",
camera_pose,
"--camera_dist_min",
str(camera_dist_min),
"--camera_dist_max",
str(camera_dist_max),
"--uniform_light_direction",
*[str(x) for x in UNIFORM_LIGHT_DIRECTION],
"--basic_ambient",
str(BASIC_AMBIENT_COLOR),
"--basic_diffuse",
str(BASIC_DIFFUSE_COLOR),
]
)
if fast_mode:
args.append("--fast_mode")
if extract_material:
args.append("--extract_material")
if delete_material:
args.append("--delete_material")
if verbose:
subprocess.check_call(args)
else:
try:
output = subprocess.check_output(args, stderr=subprocess.STDOUT, timeout=timeout)
except subprocess.CalledProcessError as exc:
raise RuntimeError(f"{exc}: {exc.output}") from exc
if not os.path.exists(os.path.join(tmp_out, "info.json")):
if verbose:
# There is no output available, since it was
# logged directly to stdout/stderr.
raise RuntimeError(f"render failed: output file missing")
else:
raise RuntimeError(f"render failed: output file missing. Output: {output}")
_combine_rgba(tmp_out)
with zipfile.ZipFile(zip_out, mode="w") as zf:
for name in os.listdir(tmp_out):
zf.write(os.path.join(tmp_out, name), name)
bf.copy(zip_out, output_path, overwrite=True)
def render_mesh(
mesh: TriMesh,
output_path: str,
num_images: int,
backend: str = "BLENDER_EEVEE",
**kwargs,
):
if mesh.has_vertex_colors() and backend not in ["BLENDER_EEVEE", "CYCLES"]:
raise ValueError(f"backend does not support vertex colors: {backend}")
with tempfile.TemporaryDirectory() as tmp_dir:
ply_path = os.path.join(tmp_dir, "out.ply")
with open(ply_path, "wb") as f:
mesh.write_ply(f)
render_model(
ply_path, output_path=output_path, num_images=num_images, backend=backend, **kwargs
)
def _combine_rgba(out_dir: str):
i = 0
while True:
paths = [os.path.join(out_dir, f"{i:05}_{ch}.png") for ch in "rgba"]
if not os.path.exists(paths[0]):
break
joined = np.stack(
[(np.array(Image.open(path)) >> 8).astype(np.uint8) for path in paths], axis=-1
)
Image.fromarray(joined).save(os.path.join(out_dir, f"{i:05}.png"))
for path in paths:
os.remove(path)
i += 1
def _blender_binary_path() -> str:
path = os.getenv("BLENDER_PATH", None)
if path is not None:
return path
if os.path.exists("/Applications/Blender.app/Contents/MacOS/Blender"):
return "/Applications/Blender.app/Contents/MacOS/Blender"
raise EnvironmentError(
"To render 3D models, install Blender version 3.3.1 or higher and "
"set the environment variable `BLENDER_PATH` to the path of the Blender executable."
)
================================================
FILE: shap_e/rendering/blender/view_data.py
================================================
import itertools
import json
import zipfile
from typing import BinaryIO, List, Tuple
import numpy as np
from PIL import Image
from shap_e.rendering.view_data import Camera, ProjectiveCamera, ViewData
class BlenderViewData(ViewData):
"""
Interact with a dataset zipfile exported by view_data.py.
"""
def __init__(self, f_obj: BinaryIO):
self.zipfile = zipfile.ZipFile(f_obj, mode="r")
self.infos = []
with self.zipfile.open("info.json", "r") as f:
self.info = json.load(f)
self.channels = list(self.info.get("channels", "RGBAD"))
assert set("RGBA").issubset(
set(self.channels)
), "The blender output should at least have RGBA images."
names = set(x.filename for x in self.zipfile.infolist())
for i in itertools.count():
name = f"{i:05}.json"
if name not in names:
break
with self.zipfile.open(name, "r") as f:
self.infos.append(json.load(f))
@property
def num_views(self) -> int:
return len(self.infos)
@property
def channel_names(self) -> List[str]:
return list(self.channels)
def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]:
for ch in channels:
if ch not in self.channel_names:
raise ValueError(f"unsupported channel: {ch}")
# Gather (a superset of) the requested channels.
channel_map = {}
if any(x in channels for x in "RGBA"):
with self.zipfile.open(f"{index:05}.png", "r") as f:
rgba = np.array(Image.open(f)).astype(np.float32) / 255.0
channel_map.update(zip("RGBA", rgba.transpose([2, 0, 1])))
if "D" in channels:
with self.zipfile.open(f"{index:05}_depth.png", "r") as f:
# Decode a 16-bit fixed-point number.
fp = np.array(Image.open(f))
inf_dist = fp == 0xFFFF
channel_map["D"] = np.where(
inf_dist,
np.inf,
self.infos[index]["max_depth"] * (fp.astype(np.float32) / 65536),
)
if "MatAlpha" in channels:
with self.zipfile.open(f"{index:05}_MatAlpha.png", "r") as f:
channel_map["MatAlpha"] = np.array(Image.open(f)).astype(np.float32) / 65536
# The order of channels is user-specified.
combined = np.stack([channel_map[k] for k in channels], axis=-1)
h, w, _ = combined.shape
return self.camera(index, w, h), combined
def camera(self, index: int, width: int, height: int) -> ProjectiveCamera:
info = self.infos[index]
return ProjectiveCamera(
origin=np.array(info["origin"], dtype=np.float32),
x=np.array(info["x"], dtype=np.float32),
y=np.array(info["y"], dtype=np.float32),
z=np.array(info["z"], dtype=np.float32),
width=width,
height=height,
x_fov=info["x_fov"],
y_fov=info["y_fov"],
)
================================================
FILE: shap_e/rendering/mc.py
================================================
from dataclasses import dataclass
from functools import lru_cache
from typing import Tuple
import torch
from ._mc_table import MC_TABLE
from .torch_mesh import TorchMesh
def marching_cubes(
field: torch.Tensor,
min_point: torch.Tensor,
size: torch.Tensor,
) -> TorchMesh:
"""
For a signed distance field, produce a mesh using marching cubes.
:param field: a 3D tensor of field values, where negative values correspond
to the outside of the shape. The dimensions correspond to the
x, y, and z directions, respectively.
:param min_point: a tensor of shape [3] containing the point corresponding
to (0, 0, 0) in the field.
:param size: a tensor of shape [3] containing the per-axis distance from the
(0, 0, 0) field corner and the (-1, -1, -1) field corner.
"""
assert len(field.shape) == 3, "input must be a 3D scalar field"
dev = field.device
grid_size = field.shape
grid_size_tensor = torch.tensor(grid_size).to(size)
lut = _lookup_table(dev)
# Create bitmasks between 0 and 255 (inclusive) indicating the state
# of the eight corners of each cube.
bitmasks = (field > 0).to(torch.uint8)
bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1)
bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2)
bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4)
# Compute corner coordinates across the entire grid.
corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype)
corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(
grid_size[0], device=dev, dtype=field.dtype
)[:, None, None]
corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(
grid_size[1], device=dev, dtype=field.dtype
)[:, None]
corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(
grid_size[2], device=dev, dtype=field.dtype
)
# Compute all vertices across all edges in the grid, even though we will
# throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices.
# These are all midpoints, and don't account for interpolation (which is
# done later based on the used edge midpoints).
edge_midpoints = torch.cat(
[
((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3),
((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3),
((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3),
],
dim=0,
)
# Create a flat array of [X, Y, Z] indices for each cube.
cube_indices = torch.zeros(
grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long
)
cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[
:, None, None
]
cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[
:, None
]
cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev)
flat_cube_indices = cube_indices.reshape(-1, 3)
# Create a flat array mapping each cube to 12 global edge indices.
edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size)
# Apply the LUT to figure out the triangles.
flat_bitmasks = bitmasks.reshape(
-1
).long() # must cast to long for indexing to believe this not a mask
local_tris = lut.cases[flat_bitmasks]
local_masks = lut.masks[flat_bitmasks]
# Compute the global edge indices for the triangles.
global_tris = torch.gather(
edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)
).reshape(local_tris.shape)
# Select the used triangles for each cube.
selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)]
# Now we have a bunch of indices into the full list of possible vertices,
# but we want to reduce this list to only the used vertices.
used_vertex_indices = torch.unique(selected_tris.view(-1))
used_edge_midpoints = edge_midpoints[used_vertex_indices]
old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long)
old_index_to_new_index[used_vertex_indices] = torch.arange(
len(used_vertex_indices), device=dev, dtype=torch.long
)
# Rewrite the triangles to use the new indices
selected_tris = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(
selected_tris.shape
)
# Compute the actual interpolated coordinates corresponding to edge midpoints.
v1 = torch.floor(used_edge_midpoints).to(torch.long)
v2 = torch.ceil(used_edge_midpoints).to(torch.long)
s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]]
s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]]
p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point
p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point
# The signs of s1 and s2 should be different. We want to find
# t such that t*s2 + (1-t)*s1 = 0.
t = (s1 / (s1 - s2))[:, None]
verts = t * p2 + (1 - t) * p1
return TorchMesh(verts=verts, faces=selected_tris)
def _create_flat_edge_indices(
flat_cube_indices: torch.Tensor, grid_size: Tuple[int, int, int]
) -> torch.Tensor:
num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2]
y_offset = num_xs
num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2]
z_offset = num_xs + num_ys
return torch.stack(
[
# Edges spanning x-axis.
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2],
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
+ (flat_cube_indices[:, 1] + 1) * grid_size[2]
+ flat_cube_indices[:, 2],
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
+ 1,
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
+ (flat_cube_indices[:, 1] + 1) * grid_size[2]
+ flat_cube_indices[:, 2]
+ 1,
# Edges spanning y-axis.
(
y_offset
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
),
(
y_offset
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
),
(
y_offset
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
+ 1
),
(
y_offset
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
+ 1
),
# Edges spanning z-axis.
(
z_offset
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)
+ flat_cube_indices[:, 1] * (grid_size[2] - 1)
+ flat_cube_indices[:, 2]
),
(
z_offset
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)
+ flat_cube_indices[:, 1] * (grid_size[2] - 1)
+ flat_cube_indices[:, 2]
),
(
z_offset
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)
+ flat_cube_indices[:, 2]
),
(
z_offset
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)
+ flat_cube_indices[:, 2]
),
],
dim=-1,
)
@dataclass
class McLookupTable:
# Coordinates in triangles are represented as edge indices from 0-12
# Here is an MC cell with both corner and edge indices marked.
# 6 + ---------- 3 ----------+ 7
# /| /|
# 6 | 7 |
# / | / |
# 4 +--------- 2 ------------+ 5 |
# | 10 | |
# | | | 11
# | | | |
# 8 | 2 9 | 3
# | +--------- 1 --------|---+
# | / | /
# | 4 | 5
# |/ |/
# +---------- 0 -----------+
# 0 1
cases: torch.Tensor # [256 x 5 x 3] long tensor
masks: torch.Tensor # [256 x 5] bool tensor
@lru_cache(maxsize=9) # if there's more than 8 GPUs and a CPU, don't bother caching
def _lookup_table(device: torch.device) -> McLookupTable:
cases = torch.zeros(256, 5, 3, device=device, dtype=torch.long)
masks = torch.zeros(256, 5, device=device, dtype=torch.bool)
edge_to_index = {
(0, 1): 0,
(2, 3): 1,
(4, 5): 2,
(6, 7): 3,
(0, 2): 4,
(1, 3): 5,
(4, 6): 6,
(5, 7): 7,
(0, 4): 8,
(1, 5): 9,
(2, 6): 10,
(3, 7): 11,
}
for i, case in enumerate(MC_TABLE):
for j, tri in enumerate(case):
for k, (c1, c2) in enumerate(zip(tri[::2], tri[1::2])):
cases[i, j, k] = edge_to_index[(c1, c2) if c1 < c2 else (c2, c1)]
masks[i, j] = True
return McLookupTable(cases=cases, masks=masks)
================================================
FILE: shap_e/rendering/mesh.py
================================================
from dataclasses import dataclass, field
from typing import BinaryIO, Dict, Optional, Union
import blobfile as bf
import numpy as np
from .ply_util import write_ply
@dataclass
class TriMesh:
"""
A 3D triangle mesh with optional data at the vertices and faces.
"""
# [N x 3] array of vertex coordinates.
verts: np.ndarray
# [M x 3] array of triangles, pointing to indices in verts.
faces: np.ndarray
# [P x 3] array of normal vectors per face.
normals: Optional[np.ndarray] = None
# Extra data per vertex and face.
vertex_channels: Optional[Dict[str, np.ndarray]] = field(default_factory=dict)
face_channels: Optional[Dict[str, np.ndarray]] = field(default_factory=dict)
@classmethod
def load(cls, f: Union[str, BinaryIO]) -> "TriMesh":
"""
Load the mesh from a .npz file.
"""
if isinstance(f, str):
with bf.BlobFile(f, "rb") as reader:
return cls.load(reader)
else:
obj = np.load(f)
keys = list(obj.keys())
verts = obj["verts"]
faces = obj["faces"]
normals = obj["normals"] if "normals" in keys else None
vertex_channels = {}
face_channels = {}
for key in keys:
if key.startswith("v_"):
vertex_channels[key[2:]] = obj[key]
elif key.startswith("f_"):
face_channels[key[2:]] = obj[key]
return cls(
verts=verts,
faces=faces,
normals=normals,
vertex_channels=vertex_channels,
face_channels=face_channels,
)
def save(self, f: Union[str, BinaryIO]):
"""
Save the mesh to a .npz file.
"""
if isinstance(f, str):
with bf.BlobFile(f, "wb") as writer:
self.save(writer)
else:
obj_dict = dict(verts=self.verts, faces=self.faces)
if self.normals is not None:
obj_dict["normals"] = self.normals
for k, v in self.vertex_channels.items():
obj_dict[f"v_{k}"] = v
for k, v in self.face_channels.items():
obj_dict[f"f_{k}"] = v
np.savez(f, **obj_dict)
def has_vertex_colors(self) -> bool:
return self.vertex_channels is not None and all(x in self.vertex_channels for x in "RGB")
def write_ply(self, raw_f: BinaryIO):
write_ply(
raw_f,
coords=self.verts,
rgb=(
np.stack([self.vertex_channels[x] for x in "RGB"], axis=1)
if self.has_vertex_colors()
else None
),
faces=self.faces,
)
def write_obj(self, raw_f: BinaryIO):
if self.has_vertex_colors():
vertex_colors = np.stack([self.vertex_channels[x] for x in "RGB"], axis=1)
vertices = [
"{} {} {} {} {} {}".format(*coord, *color)
for coord, color in zip(self.verts.tolist(), vertex_colors.tolist())
]
else:
vertices = ["{} {} {}".format(*coord) for coord in self.verts.tolist()]
faces = [
"f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1))
for tri in self.faces.tolist()
]
combined_data = ["v " + vertex for vertex in vertices] + faces
raw_f.writelines("\n".join(combined_data))
================================================
FILE: shap_e/rendering/ply_util.py
================================================
import struct
from typing import BinaryIO, Optional
import numpy as np
from shap_e.util.io import buffered_writer
def write_ply(
raw_f: BinaryIO,
coords: np.ndarray,
rgb: Optional[np.ndarray] = None,
faces: Optional[np.ndarray] = None,
):
"""
Write a PLY file for a mesh or a point cloud.
:param coords: an [N x 3] array of floating point coordinates.
:param rgb: an [N x 3] array of vertex colors, in the range [0.0, 1.0].
:param faces: an [N x 3] array of triangles encoded as integer indices.
"""
with buffered_writer(raw_f) as f:
f.write(b"ply\n")
f.write(b"format binary_little_endian 1.0\n")
f.write(bytes(f"element vertex {len(coords)}\n", "ascii"))
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
if rgb is not None:
f.write(b"property uchar red\n")
f.write(b"property uchar green\n")
f.write(b"property uchar blue\n")
if faces is not None:
f.write(bytes(f"element face {len(faces)}\n", "ascii"))
f.write(b"property list uchar int vertex_index\n")
f.write(b"end_header\n")
if rgb is not None:
rgb = (rgb * 255.499).round().astype(int)
vertices = [
(*coord, *rgb)
for coord, rgb in zip(
coords.tolist(),
rgb.tolist(),
)
]
format = struct.Struct("<3f3B")
for item in vertices:
f.write(format.pack(*item))
else:
format = struct.Struct("<3f")
for vertex in coords.tolist():
f.write(format.pack(*vertex))
if faces is not None:
format = struct.Struct(" "PointCloud":
"""
Construct a point cloud from the given view data.
The data must have a depth channel. All other channels will be stored
in the `channels` attribute of the result.
Pixels in the rendered views are not converted into points in the cloud
if they have infinite depth or less than 1.0 alpha.
"""
channel_names = vd.channel_names
if "D" not in channel_names:
raise ValueError(f"view data must have depth channel")
depth_index = channel_names.index("D")
all_coords = []
all_channels = defaultdict(list)
if num_views is None:
num_views = vd.num_views
for i in range(num_views):
camera, channel_values = vd.load_view(i, channel_names)
flat_values = channel_values.reshape([-1, len(channel_names)])
# Create an array of integer (x, y) image coordinates for Camera methods.
image_coords = camera.image_coords()
# Select subset of pixels that have meaningful depth/color.
image_mask = np.isfinite(flat_values[:, depth_index])
if "A" in channel_names:
image_mask = image_mask & (flat_values[:, channel_names.index("A")] >= 1 - 1e-5)
image_coords = image_coords[image_mask]
flat_values = flat_values[image_mask]
# Use the depth and camera information to compute the coordinates
# corresponding to every visible pixel.
camera_rays = camera.camera_rays(image_coords)
camera_origins = camera_rays[:, 0]
camera_directions = camera_rays[:, 1]
depth_dirs = camera.depth_directions(image_coords)
ray_scales = flat_values[:, depth_index] / np.sum(
camera_directions * depth_dirs, axis=-1
)
coords = camera_origins + camera_directions * ray_scales[:, None]
all_coords.append(coords)
for j, name in enumerate(channel_names):
if name != "D":
all_channels[name].append(flat_values[:, j])
if len(all_coords) == 0:
return cls(coords=np.zeros([0, 3], dtype=np.float32), channels={})
return cls(
coords=np.concatenate(all_coords, axis=0),
channels={k: np.concatenate(v, axis=0) for k, v in all_channels.items()},
)
@classmethod
def load(cls, f: Union[str, BinaryIO]) -> "PointCloud":
"""
Load the point cloud from a .npz file.
"""
if isinstance(f, str):
with bf.BlobFile(f, "rb") as reader:
return cls.load(reader)
else:
obj = np.load(f)
keys = list(obj.keys())
return PointCloud(
coords=obj["coords"],
channels={k: obj[k] for k in keys if k != "coords"},
)
def save(self, f: Union[str, BinaryIO]):
"""
Save the point cloud to a .npz file.
"""
if isinstance(f, str):
with bf.BlobFile(f, "wb") as writer:
self.save(writer)
else:
np.savez(f, coords=self.coords, **self.channels)
def write_ply(self, raw_f: BinaryIO):
write_ply(
raw_f,
coords=self.coords,
rgb=(
np.stack([self.channels[x] for x in "RGB"], axis=1)
if all(x in self.channels for x in "RGB")
else None
),
)
def random_sample(self, num_points: int, **subsample_kwargs) -> "PointCloud":
"""
Sample a random subset of this PointCloud.
:param num_points: maximum number of points to sample.
:param subsample_kwargs: arguments to self.subsample().
:return: a reduced PointCloud, or self if num_points is not less than
the current number of points.
"""
if len(self.coords) <= num_points:
return self
indices = np.random.choice(len(self.coords), size=(num_points,), replace=False)
return self.subsample(indices, **subsample_kwargs)
def farthest_point_sample(
self, num_points: int, init_idx: Optional[int] = None, **subsample_kwargs
) -> "PointCloud":
"""
Sample a subset of the point cloud that is evenly distributed in space.
First, a random point is selected. Then each successive point is chosen
such that it is furthest from the currently selected points.
The time complexity of this operation is O(NM), where N is the original
number of points and M is the reduced number. Therefore, performance
can be improved by randomly subsampling points with random_sample()
before running farthest_point_sample().
:param num_points: maximum number of points to sample.
:param init_idx: if specified, the first point to sample.
:param subsample_kwargs: arguments to self.subsample().
:return: a reduced PointCloud, or self if num_points is not less than
the current number of points.
"""
if len(self.coords) <= num_points:
return self
init_idx = random.randrange(len(self.coords)) if init_idx is None else init_idx
indices = np.zeros([num_points], dtype=np.int64)
indices[0] = init_idx
sq_norms = np.sum(self.coords**2, axis=-1)
def compute_dists(idx: int):
# Utilize equality: ||A-B||^2 = ||A||^2 + ||B||^2 - 2*(A @ B).
return sq_norms + sq_norms[idx] - 2 * (self.coords @ self.coords[idx])
cur_dists = compute_dists(init_idx)
for i in range(1, num_points):
idx = np.argmax(cur_dists)
indices[i] = idx
# Without this line, we may duplicate an index more than once if
# there are duplicate points, due to rounding errors.
cur_dists[idx] = -1
cur_dists = np.minimum(cur_dists, compute_dists(idx))
return self.subsample(indices, **subsample_kwargs)
def subsample(self, indices: np.ndarray, average_neighbors: bool = False) -> "PointCloud":
if not average_neighbors:
return PointCloud(
coords=self.coords[indices],
channels={k: v[indices] for k, v in self.channels.items()},
)
new_coords = self.coords[indices]
neighbor_indices = PointCloud(coords=new_coords, channels={}).nearest_points(self.coords)
# Make sure every point points to itself, which might not
# be the case if points are duplicated or there is rounding
# error.
neighbor_indices[indices] = np.arange(len(indices))
new_channels = {}
for k, v in self.channels.items():
v_sum = np.zeros_like(v[: len(indices)])
v_count = np.zeros_like(v[: len(indices)])
np.add.at(v_sum, neighbor_indices, v)
np.add.at(v_count, neighbor_indices, 1)
new_channels[k] = v_sum / v_count
return PointCloud(coords=new_coords, channels=new_channels)
def select_channels(self, channel_names: List[str]) -> np.ndarray:
data = np.stack([preprocess(self.channels[name], name) for name in channel_names], axis=-1)
return data
def nearest_points(self, points: np.ndarray, batch_size: int = 16384) -> np.ndarray:
"""
For each point in another set of points, compute the point in this
pointcloud which is closest.
:param points: an [N x 3] array of points.
:param batch_size: the number of neighbor distances to compute at once.
Smaller values save memory, while larger values may
make the computation faster.
:return: an [N] array of indices into self.coords.
"""
norms = np.sum(self.coords**2, axis=-1)
all_indices = []
for i in range(0, len(points), batch_size):
batch = points[i : i + batch_size]
dists = norms + np.sum(batch**2, axis=-1)[:, None] - 2 * (batch @ self.coords.T)
all_indices.append(np.argmin(dists, axis=-1))
return np.concatenate(all_indices, axis=0)
def combine(self, other: "PointCloud") -> "PointCloud":
assert self.channels.keys() == other.channels.keys()
return PointCloud(
coords=np.concatenate([self.coords, other.coords], axis=0),
channels={
k: np.concatenate([v, other.channels[k]], axis=0) for k, v in self.channels.items()
},
)
================================================
FILE: shap_e/rendering/pytorch3d_util.py
================================================
import copy
import inspect
from typing import Any, Callable, List, Sequence, Tuple, Union
import numpy as np
import torch
from pytorch3d.renderer import (
BlendParams,
DirectionalLights,
FoVPerspectiveCameras,
MeshRasterizer,
MeshRenderer,
RasterizationSettings,
SoftPhongShader,
TexturesVertex,
)
from pytorch3d.renderer.utils import TensorProperties
from pytorch3d.structures import Meshes
from shap_e.models.nn.checkpoint import checkpoint
from .blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR, UNIFORM_LIGHT_DIRECTION
from .torch_mesh import TorchMesh
from .view_data import ProjectiveCamera
# Using a lower value like 1e-4 seems to result in weird issues
# for our high-poly meshes.
DEFAULT_RENDER_SIGMA = 1e-5
DEFAULT_RENDER_GAMMA = 1e-4
def render_images(
image_size: int,
meshes: Meshes,
cameras: Any,
lights: Any,
sigma: float = DEFAULT_RENDER_SIGMA,
gamma: float = DEFAULT_RENDER_GAMMA,
max_faces_per_bin=100000,
faces_per_pixel=50,
bin_size=None,
use_checkpoint: bool = False,
) -> torch.Tensor:
if use_checkpoint:
# Decompose all of our arguments into a bunch of tensor lists
# so that autograd can keep track of what the op depends on.
verts_list = meshes.verts_list()
faces_list = meshes.faces_list()
assert isinstance(meshes.textures, TexturesVertex)
assert isinstance(lights, BidirectionalLights)
textures = meshes.textures.verts_features_padded()
light_vecs, light_fn = _deconstruct_tensor_props(lights)
camera_vecs, camera_fn = _deconstruct_tensor_props(cameras)
def ckpt_fn(
*args: torch.Tensor,
num_verts=len(verts_list),
num_light_vecs=len(light_vecs),
num_camera_vecs=len(camera_vecs),
light_fn=light_fn,
camera_fn=camera_fn,
faces_list=faces_list
):
args = list(args)
verts_list = args[:num_verts]
del args[:num_verts]
light_vecs = args[:num_light_vecs]
del args[:num_light_vecs]
camera_vecs = args[:num_camera_vecs]
del args[:num_camera_vecs]
textures = args.pop(0)
meshes = Meshes(verts=verts_list, faces=faces_list, textures=TexturesVertex(textures))
lights = light_fn(light_vecs)
cameras = camera_fn(camera_vecs)
return render_images(
image_size=image_size,
meshes=meshes,
cameras=cameras,
lights=lights,
sigma=sigma,
gamma=gamma,
max_faces_per_bin=max_faces_per_bin,
faces_per_pixel=faces_per_pixel,
bin_size=bin_size,
use_checkpoint=False,
)
result = checkpoint(ckpt_fn, (*verts_list, *light_vecs, *camera_vecs, textures), (), True)
else:
raster_settings_soft = RasterizationSettings(
image_size=image_size,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma,
faces_per_pixel=faces_per_pixel,
max_faces_per_bin=max_faces_per_bin,
bin_size=bin_size,
perspective_correct=False,
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings_soft),
shader=SoftPhongShader(
device=meshes.device,
cameras=cameras,
lights=lights,
blend_params=BlendParams(sigma=sigma, gamma=gamma, background_color=(0, 0, 0)),
),
)
result = renderer(meshes)
return result
def _deconstruct_tensor_props(
props: TensorProperties,
) -> Tuple[List[torch.Tensor], Callable[[List[torch.Tensor]], TensorProperties]]:
vecs = []
names = []
other_props = {}
for k in dir(props):
if k.startswith("__"):
continue
v = getattr(props, k)
if inspect.ismethod(v):
continue
if torch.is_tensor(v):
vecs.append(v)
names.append(k)
else:
other_props[k] = v
def recreate_fn(vecs_arg):
other = type(props)(device=props.device)
for k, v in other_props.items():
setattr(other, k, copy.deepcopy(v))
for name, vec in zip(names, vecs_arg):
setattr(other, name, vec)
return other
return vecs, recreate_fn
def convert_meshes(raw_meshes: Sequence[TorchMesh], default_brightness=0.8) -> Meshes:
meshes = Meshes(
verts=[mesh.verts for mesh in raw_meshes], faces=[mesh.faces for mesh in raw_meshes]
)
rgbs = []
for mesh in raw_meshes:
if mesh.vertex_channels and all(k in mesh.vertex_channels for k in "RGB"):
rgbs.append(torch.stack([mesh.vertex_channels[k] for k in "RGB"], axis=-1))
else:
rgbs.append(
torch.ones(
len(mesh.verts) * default_brightness,
3,
device=mesh.verts.device,
dtype=mesh.verts.dtype,
)
)
meshes.textures = TexturesVertex(verts_features=rgbs)
return meshes
def convert_cameras(
cameras: Sequence[ProjectiveCamera], device: torch.device
) -> FoVPerspectiveCameras:
Rs = []
Ts = []
for camera in cameras:
assert (
camera.width == camera.height and camera.x_fov == camera.y_fov
), "viewports must be square"
assert camera.x_fov == cameras[0].x_fov, "all cameras must have same field-of-view"
R = np.stack([-camera.x, -camera.y, camera.z], axis=0).T
T = -R.T @ camera.origin
Rs.append(R)
Ts.append(T)
return FoVPerspectiveCameras(
R=np.stack(Rs, axis=0),
T=np.stack(Ts, axis=0),
fov=cameras[0].x_fov,
degrees=False,
device=device,
)
def convert_cameras_torch(
origins: torch.Tensor, xs: torch.Tensor, ys: torch.Tensor, zs: torch.Tensor, fov: float
) -> FoVPerspectiveCameras:
Rs = []
Ts = []
for origin, x, y, z in zip(origins, xs, ys, zs):
R = torch.stack([-x, -y, z], axis=0).T
T = -R.T @ origin
Rs.append(R)
Ts.append(T)
return FoVPerspectiveCameras(
R=torch.stack(Rs, dim=0),
T=torch.stack(Ts, dim=0),
fov=fov,
degrees=False,
device=origins.device,
)
def blender_uniform_lights(
batch_size: int,
device: torch.device,
ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR,
diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR,
specular_color: Union[float, Tuple[float]] = 0.0,
) -> "BidirectionalLights":
"""
Create a light that attempts to match the light used by the Blender
renderer when run with `--light_mode basic`.
"""
if isinstance(ambient_color, float):
ambient_color = (ambient_color,) * 3
if isinstance(diffuse_color, float):
diffuse_color = (diffuse_color,) * 3
if isinstance(specular_color, float):
specular_color = (specular_color,) * 3
return BidirectionalLights(
ambient_color=(ambient_color,) * batch_size,
diffuse_color=(diffuse_color,) * batch_size,
specular_color=(specular_color,) * batch_size,
direction=(UNIFORM_LIGHT_DIRECTION,) * batch_size,
device=device,
)
class BidirectionalLights(DirectionalLights):
"""
Adapted from here, but effectively shines the light in both positive and negative directions:
https://github.com/facebookresearch/pytorch3d/blob/efea540bbcab56fccde6f4bc729d640a403dac56/pytorch3d/renderer/lighting.py#L159
"""
def diffuse(self, normals, points=None) -> torch.Tensor:
return torch.maximum(
super().diffuse(normals, points=points), super().diffuse(-normals, points=points)
)
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
return torch.maximum(
super().specular(normals, points, camera_position, shininess),
super().specular(-normals, points, camera_position, shininess),
)
================================================
FILE: shap_e/rendering/raycast/__init__.py
================================================
================================================
FILE: shap_e/rendering/raycast/_utils.py
================================================
import torch
def normalize(v: torch.Tensor) -> torch.Tensor:
return v / torch.linalg.norm(v, dim=-1, keepdim=True)
def cross_product(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:
return torch.stack(
[
v1[..., 1] * v2[..., 2] - v2[..., 1] * v1[..., 2],
-(v1[..., 0] * v2[..., 2] - v2[..., 0] * v1[..., 2]),
v1[..., 0] * v2[..., 1] - v2[..., 0] * v1[..., 1],
],
dim=-1,
)
================================================
FILE: shap_e/rendering/raycast/cast.py
================================================
from typing import Iterator, Optional, Tuple
import numpy as np
import torch
from shap_e.rendering.view_data import ProjectiveCamera
from ._utils import cross_product
from .types import RayCollisions, Rays, TriMesh
def cast_camera(
camera: ProjectiveCamera,
mesh: TriMesh,
ray_batch_size: Optional[int] = None,
checkpoint: Optional[bool] = None,
) -> Iterator[RayCollisions]:
pixel_indices = np.arange(camera.width * camera.height)
image_coords = np.stack([pixel_indices % camera.width, pixel_indices // camera.width], axis=1)
rays = camera.camera_rays(image_coords)
batch_size = ray_batch_size or len(rays)
checkpoint = checkpoint if checkpoint is not None else batch_size < len(rays)
for i in range(0, len(rays), batch_size):
sub_rays = rays[i : i + batch_size]
origins = torch.from_numpy(sub_rays[:, 0]).to(mesh.vertices)
directions = torch.from_numpy(sub_rays[:, 1]).to(mesh.vertices)
yield cast_rays(Rays(origins=origins, directions=directions), mesh, checkpoint=checkpoint)
def cast_rays(rays: Rays, mesh: TriMesh, checkpoint: bool = False) -> RayCollisions:
"""
Cast a batch of rays onto a mesh.
"""
if checkpoint:
collides, ray_dists, tri_indices, barycentric, normals = RayCollisionFunction.apply(
rays.origins, rays.directions, mesh.faces, mesh.vertices
)
return RayCollisions(
collides=collides,
ray_dists=ray_dists,
tri_indices=tri_indices,
barycentric=barycentric,
normals=normals,
)
# https://github.com/unixpickle/vae-textures/blob/2968549ddd4a3487f9437d4db00793324453cd59/vae_textures/render.py#L98
normals = mesh.normals() # [N x 3]
directions = rays.directions # [M x 3]
collides = (directions @ normals.T).abs() > 1e-8 # [N x M]
tris = mesh.vertices[mesh.faces] # [N x 3 x 3]
v1 = tris[:, 1] - tris[:, 0]
v2 = tris[:, 2] - tris[:, 0]
cross1 = cross_product(directions[:, None], v2[None]) # [N x M x 3]
det = torch.sum(cross1 * v1[None], dim=-1) # [N x M]
collides = torch.logical_and(collides, det.abs() > 1e-8)
invDet = 1 / det # [N x M]
o = rays.origins[:, None] - tris[None, :, 0] # [N x M x 3]
bary1 = invDet * torch.sum(o * cross1, dim=-1) # [N x M]
collides = torch.logical_and(collides, torch.logical_and(bary1 >= 0, bary1 <= 1))
cross2 = cross_product(o, v1[None]) # [N x M x 3]
bary2 = invDet * torch.sum(directions[:, None] * cross2, dim=-1) # [N x M]
collides = torch.logical_and(collides, torch.logical_and(bary2 >= 0, bary2 <= 1))
bary0 = 1 - (bary1 + bary2)
# Make sure this is in the positive part of the ray.
scale = invDet * torch.sum(v2 * cross2, dim=-1)
collides = torch.logical_and(collides, scale > 0)
# Select the nearest collision
ray_dists, tri_indices = torch.min(
torch.where(collides, scale, torch.tensor(torch.inf).to(scale)), dim=-1
) # [N]
nearest_bary = torch.stack(
[
bary0[range(len(tri_indices)), tri_indices],
bary1[range(len(tri_indices)), tri_indices],
bary2[range(len(tri_indices)), tri_indices],
],
dim=-1,
)
return RayCollisions(
collides=torch.any(collides, dim=-1),
ray_dists=ray_dists,
tri_indices=tri_indices,
barycentric=nearest_bary,
normals=normals[tri_indices],
)
class RayCollisionFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx, origins, directions, faces, vertices
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
ctx.save_for_backward(origins, directions, faces, vertices)
with torch.no_grad():
res = cast_rays(
Rays(origins=origins, directions=directions),
TriMesh(faces=faces, vertices=vertices),
checkpoint=False,
)
return (res.collides, res.ray_dists, res.tri_indices, res.barycentric, res.normals)
@staticmethod
def backward(
ctx, _collides_grad, ray_dists_grad, _tri_indices_grad, barycentric_grad, normals_grad
):
origins, directions, faces, vertices = ctx.input_tensors
origins = origins.detach().requires_grad_(True)
directions = directions.detach().requires_grad_(True)
vertices = vertices.detach().requires_grad_(True)
with torch.enable_grad():
outputs = cast_rays(
Rays(origins=origins, directions=directions),
TriMesh(faces=faces, vertices=vertices),
checkpoint=False,
)
origins_grad, directions_grad, vertices_grad = torch.autograd.grad(
(outputs.ray_dists, outputs.barycentric, outputs.normals),
(origins, directions, vertices),
(ray_dists_grad, barycentric_grad, normals_grad),
)
return (origins_grad, directions_grad, None, vertices_grad)
================================================
FILE: shap_e/rendering/raycast/render.py
================================================
from typing import Optional, Sequence
import torch
from shap_e.rendering.blender.constants import (
BASIC_AMBIENT_COLOR,
BASIC_DIFFUSE_COLOR,
UNIFORM_LIGHT_DIRECTION,
)
from shap_e.rendering.view_data import ProjectiveCamera
from .cast import cast_camera
from .types import RayCollisions, TriMesh
def render_diffuse_mesh(
camera: ProjectiveCamera,
mesh: TriMesh,
light_direction: Sequence[float] = tuple(UNIFORM_LIGHT_DIRECTION),
diffuse: float = BASIC_DIFFUSE_COLOR,
ambient: float = BASIC_AMBIENT_COLOR,
ray_batch_size: Optional[int] = None,
checkpoint: Optional[bool] = None,
) -> torch.Tensor:
"""
Return an [H x W x 4] RGBA tensor of the rendered image.
The pixels are floating points, with alpha in the range [0, 1] and the
other colors matching the scale used by the mesh's vertex colors.
"""
light_direction = torch.tensor(
light_direction, device=mesh.vertices.device, dtype=mesh.vertices.dtype
)
all_collisions = RayCollisions.collect(
cast_camera(
camera=camera,
mesh=mesh,
ray_batch_size=ray_batch_size,
checkpoint=checkpoint,
)
)
num_rays = len(all_collisions.normals)
if mesh.vertex_colors is None:
vertex_colors = torch.tensor([[0.8, 0.8, 0.8]]).to(mesh.vertices).repeat(num_rays, 1)
else:
vertex_colors = mesh.vertex_colors
light_coeffs = ambient + (
diffuse * torch.sum(all_collisions.normals * light_direction, dim=-1).abs()
)
vertex_colors = mesh.vertex_colors[mesh.faces[all_collisions.tri_indices]]
bary_products = torch.sum(vertex_colors * all_collisions.barycentric[..., None], axis=-2)
out_colors = bary_products * light_coeffs[..., None]
res = torch.where(all_collisions.collides[:, None], out_colors, torch.zeros_like(out_colors))
return torch.cat([res, all_collisions.collides[:, None].float()], dim=-1).view(
camera.height, camera.width, 4
)
================================================
FILE: shap_e/rendering/raycast/types.py
================================================
from dataclasses import dataclass
from typing import Iterable, Optional
import numpy as np
import torch
import shap_e.rendering.mesh
from ._utils import cross_product, normalize
@dataclass
class Rays:
"""
A ray in ray casting.
"""
origins: torch.Tensor # [N x 3] float tensor
directions: torch.Tensor # [N x 3] float tensor
def normalized_directions(self) -> torch.Tensor:
return normalize(self.directions)
@dataclass
class RayCollisions:
"""
The result of casting N rays onto a mesh.
"""
collides: torch.Tensor # [N] boolean tensor
ray_dists: torch.Tensor # [N] float tensor
tri_indices: torch.Tensor # [N] long tensor
barycentric: torch.Tensor # [N x 3] float tensor
normals: torch.Tensor # [N x 3] float tensor
@classmethod
def collect(cls, it: Iterable["RayCollisions"]) -> "RayCollisions":
res = None
for x in it:
if res is None:
res = x
else:
res = cls(
collides=torch.cat([res.collides, x.collides]),
ray_dists=torch.cat([res.ray_dists, x.ray_dists]),
tri_indices=torch.cat([res.tri_indices, x.tri_indices]),
barycentric=torch.cat([res.barycentric, x.barycentric]),
normals=torch.cat([res.normals, x.normals]),
)
if res is None:
raise ValueError("cannot collect an empty iterable of RayCollisions")
return res
@dataclass
class TriMesh:
faces: torch.Tensor # [N x 3] long tensor
vertices: torch.Tensor # [N x 3] float tensor
vertex_colors: Optional[torch.Tensor] = None
def normals(self) -> torch.Tensor:
"""
Returns an [N x 3] batch of normal vectors per triangle assuming the
right-hand rule.
"""
tris = self.vertices[self.faces]
v1 = tris[:, 1] - tris[:, 0]
v2 = tris[:, 2] - tris[:, 0]
return normalize(cross_product(v1, v2))
@classmethod
def from_numpy(cls, x: shap_e.rendering.mesh.TriMesh) -> "TriMesh":
vertex_colors = None
if all(ch in x.vertex_channels for ch in "RGB"):
vertex_colors = torch.from_numpy(
np.stack([x.vertex_channels[ch] for ch in "RGB"], axis=-1)
)
return cls(
faces=torch.from_numpy(x.faces),
vertices=torch.from_numpy(x.verts),
vertex_colors=vertex_colors,
)
def to(self, *args, **kwargs) -> "TriMesh":
return TriMesh(
faces=self.faces.to(*args, **kwargs),
vertices=self.vertices.to(*args, **kwargs),
vertex_colors=None
if self.vertex_colors is None
else self.vertex_colors.to(*args, **kwargs),
)
================================================
FILE: shap_e/rendering/torch_mesh.py
================================================
from dataclasses import dataclass, field
from typing import Dict, Optional
import torch
from .mesh import TriMesh
@dataclass
class TorchMesh:
"""
A 3D triangle mesh with optional data at the vertices and faces.
"""
# [N x 3] array of vertex coordinates.
verts: torch.Tensor
# [M x 3] array of triangles, pointing to indices in verts.
faces: torch.Tensor
# Extra data per vertex and face.
vertex_channels: Optional[Dict[str, torch.Tensor]] = field(default_factory=dict)
face_channels: Optional[Dict[str, torch.Tensor]] = field(default_factory=dict)
def tri_mesh(self) -> TriMesh:
"""
Create a CPU version of the mesh.
"""
return TriMesh(
verts=self.verts.detach().cpu().numpy(),
faces=self.faces.cpu().numpy(),
vertex_channels=(
{k: v.detach().cpu().numpy() for k, v in self.vertex_channels.items()}
if self.vertex_channels is not None
else None
),
face_channels=(
{k: v.detach().cpu().numpy() for k, v in self.face_channels.items()}
if self.face_channels is not None
else None
),
)
================================================
FILE: shap_e/rendering/view_data.py
================================================
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
@dataclass
class Camera(ABC):
"""
An object describing how a camera corresponds to pixels in an image.
"""
@abstractmethod
def image_coords(self) -> np.ndarray:
"""
:return: ([self.height, self.width, 2]).reshape(self.height * self.width, 2) image coordinates
"""
@abstractmethod
def camera_rays(self, coords: np.ndarray) -> np.ndarray:
"""
For every (x, y) coordinate in a rendered image, compute the ray of the
corresponding pixel.
:param coords: an [N x 2] integer array of 2D image coordinates.
:return: an [N x 2 x 3] array of [2 x 3] (origin, direction) tuples.
The direction should always be unit length.
"""
def depth_directions(self, coords: np.ndarray) -> np.ndarray:
"""
For every (x, y) coordinate in a rendered image, get the direction that
corresponds to "depth" in an RGBD rendering.
This may raise an exception if there is no "D" channel in the
corresponding ViewData.
:param coords: an [N x 2] integer array of 2D image coordinates.
:return: an [N x 3] array of normalized depth directions.
"""
_ = coords
raise NotImplementedError
@abstractmethod
def center_crop(self) -> "Camera":
"""
Creates a new camera with the same intrinsics and direction as this one,
but with a center crop to a square of the smaller dimension.
"""
@abstractmethod
def resize_image(self, width: int, height: int) -> "Camera":
"""
Creates a new camera with the same intrinsics and direction as this one,
but with resized image dimensions.
"""
@abstractmethod
def scale_scene(self, factor: float) -> "Camera":
"""
Creates a new camera with the same intrinsics and direction as this one,
but with the scene rescaled by the given factor.
"""
@dataclass
class ProjectiveCamera(Camera):
"""
A Camera implementation for a standard pinhole camera.
The camera rays shoot away from the origin in the z direction, with the x
and y directions corresponding to the positive horizontal and vertical axes
in image space.
"""
origin: np.ndarray
x: np.ndarray
y: np.ndarray
z: np.ndarray
width: int
height: int
x_fov: float
y_fov: float
def image_coords(self) -> np.ndarray:
ind = np.arange(self.width * self.height)
coords = np.stack([ind % self.width, ind // self.width], axis=1).astype(np.float32)
return coords
def camera_rays(self, coords: np.ndarray) -> np.ndarray:
fracs = (coords / (np.array([self.width, self.height], dtype=np.float32) - 1)) * 2 - 1
fracs = fracs * np.tan(np.array([self.x_fov, self.y_fov]) / 2)
directions = self.z + self.x * fracs[:, :1] + self.y * fracs[:, 1:]
directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
return np.stack([np.broadcast_to(self.origin, directions.shape), directions], axis=1)
def depth_directions(self, coords: np.ndarray) -> np.ndarray:
return np.tile((self.z / np.linalg.norm(self.z))[None], [len(coords), 1])
def resize_image(self, width: int, height: int) -> "ProjectiveCamera":
"""
Creates a new camera for the resized view assuming the aspect ratio does not change.
"""
assert width * self.height == height * self.width, "The aspect ratio should not change."
return ProjectiveCamera(
origin=self.origin,
x=self.x,
y=self.y,
z=self.z,
width=width,
height=height,
x_fov=self.x_fov,
y_fov=self.y_fov,
)
def center_crop(self) -> "ProjectiveCamera":
"""
Creates a new camera for the center-cropped view
"""
size = min(self.width, self.height)
fov = min(self.x_fov, self.y_fov)
return ProjectiveCamera(
origin=self.origin,
x=self.x,
y=self.y,
z=self.z,
width=size,
height=size,
x_fov=fov,
y_fov=fov,
)
def scale_scene(self, factor: float) -> "ProjectiveCamera":
"""
Creates a new camera with the same intrinsics and direction as this one,
but with the camera frame rescaled by the given factor.
"""
return ProjectiveCamera(
origin=self.origin * factor,
x=self.x,
y=self.y,
z=self.z,
width=self.width,
height=self.height,
x_fov=self.x_fov,
y_fov=self.y_fov,
)
class ViewData(ABC):
"""
A collection of rendered camera views of a scene or object.
This is a generalization of a NeRF dataset, since NeRF datasets only encode
RGB or RGBA data, whereas this dataset supports arbitrary channels.
"""
@property
@abstractmethod
def num_views(self) -> int:
"""
The number of rendered views.
"""
@property
@abstractmethod
def channel_names(self) -> List[str]:
"""
Get all of the supported channels available for the views.
This can be arbitrary, but there are some standard names:
"R", "G", "B", "A" (alpha), and "D" (depth).
"""
@abstractmethod
def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]:
"""
Load the given channels from the view at the given index.
:return: a tuple (camera_view, data), where data is a float array of
shape [height x width x num_channels].
"""
class MemoryViewData(ViewData):
"""
A ViewData that is implemented in memory.
"""
def __init__(self, channels: Dict[str, np.ndarray], cameras: List[Camera]):
assert all(v.shape[0] == len(cameras) for v in channels.values())
self.channels = channels
self.cameras = cameras
@property
def num_views(self) -> int:
return len(self.cameras)
@property
def channel_names(self) -> List[str]:
return list(self.channels.keys())
def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]:
outputs = [self.channels[channel][index] for channel in channels]
return self.cameras[index], np.stack(outputs, axis=-1)
================================================
FILE: shap_e/util/__init__.py
================================================
================================================
FILE: shap_e/util/collections.py
================================================
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional
from typing import OrderedDict, Generic, TypeVar
K = TypeVar('K')
V = TypeVar('V')
class AttrDict(OrderedDict[K, V], Generic[K, V]):
"""
An attribute dictionary that automatically handles nested keys joined by "/".
Originally copied from: https://stackoverflow.com/questions/3031219/recursively-access-dict-via-attributes-as-well-as-index-access
"""
MARKER = object()
# pylint: disable=super-init-not-called
def __init__(self, *args, **kwargs):
if len(args) == 0:
for key, value in kwargs.items():
self.__setitem__(key, value)
else:
assert len(args) == 1
assert isinstance(args[0], (dict, AttrDict))
for key, value in args[0].items():
self.__setitem__(key, value)
def __contains__(self, key):
if "/" in key:
keys = key.split("/")
key, next_key = keys[0], "/".join(keys[1:])
return key in self and next_key in self[key]
return super(AttrDict, self).__contains__(key)
def __setitem__(self, key, value):
if "/" in key:
keys = key.split("/")
key, next_key = keys[0], "/".join(keys[1:])
if key not in self:
self[key] = AttrDict()
self[key].__setitem__(next_key, value)
return
if isinstance(value, dict) and not isinstance(value, AttrDict):
value = AttrDict(**value)
if isinstance(value, list):
value = [AttrDict(val) if isinstance(val, dict) else val for val in value]
super(AttrDict, self).__setitem__(key, value)
def __getitem__(self, key):
if "/" in key:
keys = key.split("/")
key, next_key = keys[0], "/".join(keys[1:])
val = self[key]
if not isinstance(val, AttrDict):
raise ValueError
return val.__getitem__(next_key)
return self.get(key, None)
def all_keys(
self,
leaves_only: bool = False,
parent: Optional[str] = None,
) -> List[str]:
keys = []
for key in self.keys():
cur = key if parent is None else f"{parent}/{key}"
if not leaves_only or not isinstance(self[key], dict):
keys.append(cur)
if isinstance(self[key], dict):
keys.extend(self[key].all_keys(leaves_only=leaves_only, parent=cur))
return keys
def dumpable(self, strip=True):
"""
Casts into OrderedDict and removes internal attributes
"""
def _dump(val):
if isinstance(val, AttrDict):
return val.dumpable()
elif isinstance(val, list):
return [_dump(v) for v in val]
return val
if strip:
return {k: _dump(v) for k, v in self.items() if not k.startswith("_")}
return {k: _dump(v if not k.startswith("_") else repr(v)) for k, v in self.items()}
def map(
self,
map_fn: Callable[[Any, Any], Any],
should_map: Optional[Callable[[Any, Any], bool]] = None,
) -> "AttrDict":
"""
Creates a copy of self where some or all values are transformed by
map_fn.
:param should_map: If provided, only those values that evaluate to true
are converted; otherwise, all values are mapped.
"""
def _apply(key, val):
if isinstance(val, AttrDict):
return val.map(map_fn, should_map)
elif should_map is None or should_map(key, val):
return map_fn(key, val)
return val
return AttrDict({k: _apply(k, v) for k, v in self.items()})
def __eq__(self, other):
return self.keys() == other.keys() and all(self[k] == other[k] for k in self.keys())
def combine(
self,
other: Dict[str, Any],
combine_fn: Callable[[Optional[Any], Optional[Any]], Any],
) -> "AttrDict":
"""
Some values may be missing, but the dictionary structures must be the
same.
:param combine_fn: a (possibly non-commutative) function to combine the
values
"""
def _apply(val, other_val):
if val is not None and isinstance(val, AttrDict):
assert isinstance(other_val, AttrDict)
return val.combine(other_val, combine_fn)
return combine_fn(val, other_val)
# TODO nit: this changes the ordering..
keys = self.keys() | other.keys()
return AttrDict({k: _apply(self[k], other[k]) for k in keys})
__setattr__, __getattr__ = __setitem__, __getitem__
================================================
FILE: shap_e/util/data_util.py
================================================
import tempfile
from contextlib import contextmanager
from typing import Iterator, Optional, Union
import blobfile as bf
import numpy as np
import torch
from PIL import Image
from shap_e.rendering.blender.render import render_mesh, render_model
from shap_e.rendering.blender.view_data import BlenderViewData
from shap_e.rendering.mesh import TriMesh
from shap_e.rendering.point_cloud import PointCloud
from shap_e.rendering.view_data import ViewData
from shap_e.util.collections import AttrDict
from shap_e.util.image_util import center_crop, get_alpha, remove_alpha, resize
def load_or_create_multimodal_batch(
device: torch.device,
*,
mesh_path: Optional[str] = None,
model_path: Optional[str] = None,
cache_dir: Optional[str] = None,
point_count: int = 2**14,
random_sample_count: int = 2**19,
pc_num_views: int = 40,
mv_light_mode: Optional[str] = None,
mv_num_views: int = 20,
mv_image_size: int = 512,
mv_alpha_removal: str = "black",
verbose: bool = False,
) -> AttrDict:
if verbose:
print("creating point cloud...")
pc = load_or_create_pc(
mesh_path=mesh_path,
model_path=model_path,
cache_dir=cache_dir,
random_sample_count=random_sample_count,
point_count=point_count,
num_views=pc_num_views,
verbose=verbose,
)
raw_pc = np.concatenate([pc.coords, pc.select_channels(["R", "G", "B"])], axis=-1)
encode_me = torch.from_numpy(raw_pc).float().to(device)
batch = AttrDict(points=encode_me.t()[None])
if mv_light_mode:
if verbose:
print("creating multiview...")
with load_or_create_multiview(
mesh_path=mesh_path,
model_path=model_path,
cache_dir=cache_dir,
num_views=mv_num_views,
extract_material=False,
light_mode=mv_light_mode,
verbose=verbose,
) as mv:
cameras, views, view_alphas, depths = [], [], [], []
for view_idx in range(mv.num_views):
camera, view = mv.load_view(
view_idx,
["R", "G", "B", "A"] if "A" in mv.channel_names else ["R", "G", "B"],
)
depth = None
if "D" in mv.channel_names:
_, depth = mv.load_view(view_idx, ["D"])
depth = process_depth(depth, mv_image_size)
view, alpha = process_image(
np.round(view * 255.0).astype(np.uint8), mv_alpha_removal, mv_image_size
)
camera = camera.center_crop().resize_image(mv_image_size, mv_image_size)
cameras.append(camera)
views.append(view)
view_alphas.append(alpha)
depths.append(depth)
batch.depths = [depths]
batch.views = [views]
batch.view_alphas = [view_alphas]
batch.cameras = [cameras]
return normalize_input_batch(batch, pc_scale=2.0, color_scale=1.0 / 255.0)
def load_or_create_pc(
*,
mesh_path: Optional[str],
model_path: Optional[str],
cache_dir: Optional[str],
random_sample_count: int,
point_count: int,
num_views: int,
verbose: bool = False,
) -> PointCloud:
assert (model_path is not None) ^ (
mesh_path is not None
), "must specify exactly one of model_path or mesh_path"
path = model_path if model_path is not None else mesh_path
if cache_dir is not None:
cache_path = bf.join(
cache_dir,
f"pc_{bf.basename(path)}_mat_{num_views}_{random_sample_count}_{point_count}.npz",
)
if bf.exists(cache_path):
return PointCloud.load(cache_path)
else:
cache_path = None
with load_or_create_multiview(
mesh_path=mesh_path,
model_path=model_path,
cache_dir=cache_dir,
num_views=num_views,
verbose=verbose,
) as mv:
if verbose:
print("extracting point cloud from multiview...")
pc = mv_to_pc(
multiview=mv, random_sample_count=random_sample_count, point_count=point_count
)
if cache_path is not None:
pc.save(cache_path)
return pc
@contextmanager
def load_or_create_multiview(
*,
mesh_path: Optional[str],
model_path: Optional[str],
cache_dir: Optional[str],
num_views: int = 20,
extract_material: bool = True,
light_mode: Optional[str] = None,
verbose: bool = False,
) -> Iterator[BlenderViewData]:
assert (model_path is not None) ^ (
mesh_path is not None
), "must specify exactly one of model_path or mesh_path"
path = model_path if model_path is not None else mesh_path
if extract_material:
assert light_mode is None, "light_mode is ignored when extract_material=True"
else:
assert light_mode is not None, "must specify light_mode when extract_material=False"
if cache_dir is not None:
if extract_material:
cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_mat_{num_views}.zip")
else:
cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_{light_mode}_{num_views}.zip")
if bf.exists(cache_path):
with bf.BlobFile(cache_path, "rb") as f:
yield BlenderViewData(f)
return
else:
cache_path = None
common_kwargs = dict(
fast_mode=True,
extract_material=extract_material,
camera_pose="random",
light_mode=light_mode or "uniform",
verbose=verbose,
)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = bf.join(tmp_dir, "out.zip")
if mesh_path is not None:
mesh = TriMesh.load(mesh_path)
render_mesh(
mesh=mesh,
output_path=tmp_path,
num_images=num_views,
backend="BLENDER_EEVEE",
**common_kwargs,
)
elif model_path is not None:
render_model(
model_path,
output_path=tmp_path,
num_images=num_views,
backend="BLENDER_EEVEE",
**common_kwargs,
)
if cache_path is not None:
bf.copy(tmp_path, cache_path)
with bf.BlobFile(tmp_path, "rb") as f:
yield BlenderViewData(f)
def mv_to_pc(multiview: ViewData, random_sample_count: int, point_count: int) -> PointCloud:
pc = PointCloud.from_rgbd(multiview)
# Handle empty samples.
if len(pc.coords) == 0:
pc = PointCloud(
coords=np.zeros([1, 3]),
channels=dict(zip("RGB", np.zeros([3, 1]))),
)
while len(pc.coords) < point_count:
pc = pc.combine(pc)
# Prevent duplicate points; some models may not like it.
pc.coords += np.random.normal(size=pc.coords.shape) * 1e-4
pc = pc.random_sample(random_sample_count)
pc = pc.farthest_point_sample(point_count, average_neighbors=True)
return pc
def normalize_input_batch(batch: AttrDict, *, pc_scale: float, color_scale: float) -> AttrDict:
res = batch.copy()
scale_vec = torch.tensor([*([pc_scale] * 3), *([color_scale] * 3)], device=batch.points.device)
res.points = res.points * scale_vec[:, None]
if "cameras" in res:
res.cameras = [[cam.scale_scene(pc_scale) for cam in cams] for cams in res.cameras]
if "depths" in res:
res.depths = [[depth * pc_scale for depth in depths] for depths in res.depths]
return res
def process_depth(depth_img: np.ndarray, image_size: int) -> np.ndarray:
depth_img = center_crop(depth_img)
depth_img = resize(depth_img, width=image_size, height=image_size)
return np.squeeze(depth_img)
def process_image(
img_or_img_arr: Union[Image.Image, np.ndarray], alpha_removal: str, image_size: int
):
if isinstance(img_or_img_arr, np.ndarray):
img = Image.fromarray(img_or_img_arr)
img_arr = img_or_img_arr
else:
img = img_or_img_arr
img_arr = np.array(img)
if len(img_arr.shape) == 2:
# Grayscale
rgb = Image.new("RGB", img.size)
rgb.paste(img)
img = rgb
img_arr = np.array(img)
img = center_crop(img)
alpha = get_alpha(img)
img = remove_alpha(img, mode=alpha_removal)
alpha = alpha.resize((image_size,) * 2, resample=Image.BILINEAR)
img = img.resize((image_size,) * 2, resample=Image.BILINEAR)
return img, alpha
================================================
FILE: shap_e/util/image_util.py
================================================
import random
from typing import Any, List, Optional, Union
import blobfile as bf
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
def center_crop(
img: Union[Image.Image, torch.Tensor, np.ndarray]
) -> Union[Image.Image, torch.Tensor, np.ndarray]:
"""
Center crops an image.
"""
if isinstance(img, (np.ndarray, torch.Tensor)):
height, width = img.shape[:2]
else:
width, height = img.size
size = min(width, height)
left, top = (width - size) // 2, (height - size) // 2
right, bottom = left + size, top + size
if isinstance(img, (np.ndarray, torch.Tensor)):
img = img[top:bottom, left:right]
else:
img = img.crop((left, top, right, bottom))
return img
def resize(
img: Union[Image.Image, torch.Tensor, np.ndarray],
*,
height: int,
width: int,
min_value: Optional[Any] = None,
max_value: Optional[Any] = None,
) -> Union[Image.Image, torch.Tensor, np.ndarray]:
"""
:param: img: image in HWC order
:return: currently written for downsampling
"""
orig, cls = img, type(img)
if isinstance(img, Image.Image):
img = np.array(img)
dtype = img.dtype
if isinstance(img, np.ndarray):
img = torch.from_numpy(img)
ndim = img.ndim
if img.ndim == 2:
img = img.unsqueeze(-1)
if min_value is None and max_value is None:
# .clamp throws an error when both are None
min_value = -np.inf
img = img.permute(2, 0, 1)
size = (height, width)
img = (
F.interpolate(img[None].float(), size=size, mode="area")[0]
.clamp(min_value, max_value)
.to(img.dtype)
.permute(1, 2, 0)
)
if ndim < img.ndim:
img = img.squeeze(-1)
if not isinstance(orig, torch.Tensor):
img = img.numpy()
img = img.astype(dtype)
if isinstance(orig, Image.Image):
img = Image.fromarray(img)
return img
def get_alpha(img: Image.Image) -> Image.Image:
"""
:return: the alpha channel separated out as a grayscale image
"""
img_arr = np.asarray(img)
if img_arr.shape[2] == 4:
alpha = img_arr[:, :, 3]
else:
alpha = np.full(img_arr.shape[:2], 255, dtype=np.uint8)
alpha = Image.fromarray(alpha)
return alpha
def remove_alpha(img: Image.Image, mode: str = "random") -> Image.Image:
"""
No op if the image doesn't have an alpha channel.
:param: mode: Defaults to "random" but has an option to use a "black" or
"white" background
:return: image with alpha removed
"""
img_arr = np.asarray(img)
if img_arr.shape[2] == 4:
# Add bg to get rid of alpha channel
if mode == "random":
height, width = img_arr.shape[:2]
bg = Image.fromarray(
random.choice([_black_bg, _gray_bg, _checker_bg, _noise_bg])(height, width)
)
bg.paste(img, mask=img)
img = bg
elif mode == "black" or mode == "white":
img_arr = img_arr.astype(float)
rgb, alpha = img_arr[:, :, :3], img_arr[:, :, -1:] / 255
background = np.zeros((1, 1, 3)) if mode == "black" else np.full((1, 1, 3), 255)
rgb = rgb * alpha + background * (1 - alpha)
img = Image.fromarray(np.round(rgb).astype(np.uint8))
return img
def _black_bg(h: int, w: int) -> np.ndarray:
return np.zeros([h, w, 3], dtype=np.uint8)
def _gray_bg(h: int, w: int) -> np.ndarray:
return (np.zeros([h, w, 3]) + np.random.randint(low=0, high=256)).astype(np.uint8)
def _checker_bg(h: int, w: int) -> np.ndarray:
checker_size = np.ceil(np.exp(np.random.uniform() * np.log(min(h, w))))
c1 = np.random.randint(low=0, high=256)
c2 = np.random.randint(low=0, high=256)
xs = np.arange(w)[None, :, None] + np.random.randint(low=0, high=checker_size + 1)
ys = np.arange(h)[:, None, None] + np.random.randint(low=0, high=checker_size + 1)
fields = np.logical_xor((xs // checker_size) % 2 == 0, (ys // checker_size) % 2 == 0)
return np.where(fields, np.array([c1] * 3), np.array([c2] * 3)).astype(np.uint8)
def _noise_bg(h: int, w: int) -> np.ndarray:
return np.random.randint(low=0, high=256, size=[h, w, 3]).astype(np.uint8)
def load_image(image_path: str) -> Image.Image:
with bf.BlobFile(image_path, "rb") as thefile:
img = Image.open(thefile)
img.load()
return img
def make_tile(images: List[Union[np.ndarray, Image.Image]], columns=8) -> Image.Image:
"""
to test, run
>>> display(make_tile([(np.zeros((128, 128, 3)) + c).astype(np.uint8) for c in np.linspace(0, 255, 15)]))
"""
images = list(map(np.array, images))
size = images[0].shape[0]
n = round_up(len(images), columns)
n_blanks = n - len(images)
images.extend([np.zeros((size, size, 3), dtype=np.uint8)] * n_blanks)
images = (
np.array(images)
.reshape(n // columns, columns, size, size, 3)
.transpose([0, 2, 1, 3, 4])
.reshape(n // columns * size, columns * size, 3)
)
return Image.fromarray(images)
def round_up(n: int, b: int) -> int:
return (n + b - 1) // b * b
================================================
FILE: shap_e/util/io.py
================================================
import io
from contextlib import contextmanager
from typing import Any, BinaryIO, Iterator, Union
import blobfile as bf
import yaml
from shap_e.util.collections import AttrDict
def read_config(path_or_file: Union[str, io.IOBase]) -> Any:
if isinstance(path_or_file, io.IOBase):
obj = yaml.load(path_or_file, Loader=yaml.SafeLoader)
else:
with bf.BlobFile(path_or_file, "rb") as f:
try:
obj = yaml.load(f, Loader=yaml.SafeLoader)
except Exception as exc:
with bf.BlobFile(path_or_file, "rb") as f:
print(f.read())
raise exc
if isinstance(obj, dict):
return AttrDict(obj)
return obj
@contextmanager
def buffered_writer(raw_f: BinaryIO) -> Iterator[io.BufferedIOBase]:
if isinstance(raw_f, io.BufferedIOBase):
yield raw_f
else:
f = io.BufferedWriter(raw_f)
yield f
f.flush()
================================================
FILE: shap_e/util/notebooks.py
================================================
import base64
import io
from typing import Union
import ipywidgets as widgets
import numpy as np
import torch
from PIL import Image
from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera
from shap_e.models.transmitter.base import Transmitter, VectorDecoder
from shap_e.rendering.torch_mesh import TorchMesh
from shap_e.util.collections import AttrDict
def create_pan_cameras(size: int, device: torch.device) -> DifferentiableCameraBatch:
origins = []
xs = []
ys = []
zs = []
for theta in np.linspace(0, 2 * np.pi, num=20):
z = np.array([np.sin(theta), np.cos(theta), -0.5])
z /= np.sqrt(np.sum(z**2))
origin = -z * 4
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
y = np.cross(z, x)
origins.append(origin)
xs.append(x)
ys.append(y)
zs.append(z)
return DifferentiableCameraBatch(
shape=(1, len(xs)),
flat_camera=DifferentiableProjectiveCamera(
origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device),
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device),
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device),
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device),
width=size,
height=size,
x_fov=0.7,
y_fov=0.7,
),
)
@torch.no_grad()
def decode_latent_images(
xm: Union[Transmitter, VectorDecoder],
latent: torch.Tensor,
cameras: DifferentiableCameraBatch,
rendering_mode: str = "stf",
):
decoded = xm.renderer.render_views(
AttrDict(cameras=cameras),
params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
latent[None]
),
options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False),
)
arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
return [Image.fromarray(x) for x in arr]
@torch.no_grad()
def decode_latent_mesh(
xm: Union[Transmitter, VectorDecoder],
latent: torch.Tensor,
) -> TorchMesh:
decoded = xm.renderer.render_views(
AttrDict(cameras=create_pan_cameras(2, latent.device)), # lowest resolution possible
params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
latent[None]
),
options=AttrDict(rendering_mode="stf", render_with_direction=False),
)
return decoded.raw_meshes[0]
def gif_widget(images):
writer = io.BytesIO()
images[0].save(
writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0
)
writer.seek(0)
data = base64.b64encode(writer.read()).decode("ascii")
return widgets.HTML(f'
')