Repository: openai/shap-e Branch: main Commit: 50131012ee11 Files: 79 Total size: 17.7 MB Directory structure: gitextract__yazl4ub/ ├── .gitignore ├── LICENSE ├── README.md ├── model-card.md ├── samples.md ├── setup.py └── shap_e/ ├── __init__.py ├── diffusion/ │ ├── __init__.py │ ├── gaussian_diffusion.py │ ├── k_diffusion.py │ └── sample.py ├── examples/ │ ├── encode_model.ipynb │ ├── example_data/ │ │ └── cactus/ │ │ ├── material.mtl │ │ └── object.obj │ ├── sample_image_to_3d.ipynb │ └── sample_text_to_3d.ipynb ├── models/ │ ├── __init__.py │ ├── configs.py │ ├── download.py │ ├── generation/ │ │ ├── __init__.py │ │ ├── latent_diffusion.py │ │ ├── perceiver.py │ │ ├── pooled_mlp.py │ │ ├── pretrained_clip.py │ │ ├── transformer.py │ │ └── util.py │ ├── nerf/ │ │ ├── __init__.py │ │ ├── model.py │ │ ├── ray.py │ │ └── renderer.py │ ├── nerstf/ │ │ ├── mlp.py │ │ └── renderer.py │ ├── nn/ │ │ ├── __init__.py │ │ ├── camera.py │ │ ├── checkpoint.py │ │ ├── encoding.py │ │ ├── meta.py │ │ ├── ops.py │ │ ├── pointnet2_utils.py │ │ └── utils.py │ ├── query.py │ ├── renderer.py │ ├── stf/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── mlp.py │ │ └── renderer.py │ ├── transmitter/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── bottleneck.py │ │ ├── channels_encoder.py │ │ ├── multiview_encoder.py │ │ ├── params_proj.py │ │ └── pc_encoder.py │ └── volume.py ├── rendering/ │ ├── __init__.py │ ├── _mc_table.py │ ├── blender/ │ │ ├── __init__.py │ │ ├── blender_script.py │ │ ├── constants.py │ │ ├── render.py │ │ └── view_data.py │ ├── mc.py │ ├── mesh.py │ ├── ply_util.py │ ├── point_cloud.py │ ├── pytorch3d_util.py │ ├── raycast/ │ │ ├── __init__.py │ │ ├── _utils.py │ │ ├── cast.py │ │ ├── render.py │ │ └── types.py │ ├── torch_mesh.py │ └── view_data.py └── util/ ├── __init__.py ├── collections.py ├── data_util.py ├── image_util.py ├── io.py └── notebooks.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__/ .DS_Store *.egg-info/ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 OpenAI Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Shap-E This is the official code and model release for [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463). * See [Usage](#usage) for guidance on how to use this repository. * See [Samples](#samples) for examples of what our text-conditional model can generate. # Samples Here are some highlighted samples from our text-conditional model. For random samples on selected prompts, see [samples.md](samples.md).
A chair that looks like an avocado An airplane that looks like a banana A spaceship
A chair that looks
like an avocado
An airplane that looks
like a banana
A spaceship
A birthday cupcake A chair that looks like a tree A green boot
A birthday cupcake A chair that looks
like a tree
A green boot
A penguin Ube ice cream cone A bowl of vegetables
A penguin Ube ice cream cone A bowl of vegetables
# Usage Install with `pip install -e .`. To get started with examples, see the following notebooks: * [sample_text_to_3d.ipynb](shap_e/examples/sample_text_to_3d.ipynb) - sample a 3D model, conditioned on a text prompt. * [sample_image_to_3d.ipynb](shap_e/examples/sample_image_to_3d.ipynb) - sample a 3D model, conditioned on a synthetic view image. To get the best result, you should remove background from the input image. * [encode_model.ipynb](shap_e/examples/encode_model.ipynb) - loads a 3D model or a trimesh, creates a batch of multiview renders and a point cloud, encodes them into a latent, and renders it back. For this to work, install Blender version 3.3.1 or higher, and set the environment variable `BLENDER_PATH` to the path of the Blender executable. ================================================ FILE: model-card.md ================================================ # Model Card: Shap-E This is the official codebase for running the latent diffusion models described in [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463). These models were trained and released by OpenAI. Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about how the models were trained and evaluated. # Model Details Shap-E includes two kinds of models: an encoder and a latent diffusion model. 1. **The encoder** converts 3D assets into the parameters of small neural networks which represent the 3D shape and texture as an implicit function. The resulting implicit function can be rendered from arbitrary viewpoints or imported into downstream applications as a mesh. 2. **The latent diffusion model** generates novel implicit functions conditioned on either images or text descriptions. As above, these samples can be rendered or exported as a mesh. Specifically, these models produce latents which must be linearly projected to get the final implicit function parameters. The final projection layer of the encoder is used for this purpose. Like [Point-E](https://github.com/openai/point-e/blob/main/model-card.md), Shap-E can often generate coherent 3D objects when conditioned on a rendering from a single viewpoint. When conditioned on text prompts directly, Shap-E is also often capable of producing recognizable objects, although it sometimes struggles to combine multiple objects or concepts. Samples from Shap-E are typically lower fidelity than professional 3D assets and often have rough edges, holes, or blurry surface textures. # Model Date April 2023 # Model Versions The following model checkpoints are available in this repository: * `transmitter` - the encoder and corresponding projection layers for converting encoder outputs into implicit neural representations. * `decoder` - just the final projection layer component of `transmitter`. This is a smaller checkpoint than `transmitter` since it does not include parameters for encoding 3D assets. This is the minimum required model to convert diffusion outputs into implicit neural representations. * `text300M` - the text-conditional latent diffusion model. * `image300M` - the image-conditional latent diffusion model. # Paper & Samples [Paper link](https://arxiv.org/abs/2305.02463) / [Samples](samples.md) # Training data The encoder and image-conditional diffusion models are trained on the [same dataset as Point-E](https://github.com/openai/point-e/blob/main/model-card.md#training-data). However, a few changes to the post-processing were made: * We rendered 60 views (instead of 20) of each model when computing point clouds, to avoid small cracks. * We produced 16K points in each point cloud instead of 4K. * We simplified the lighting and material setup to only include diffuse materials. For our text-conditional diffusion model, we expanded our dataset with roughly a million more 3D assets. Additionally, we collected 120K captions from human annotators for a high-quality subset of our 3D assets. # Evaluated Use We release these models with the intention of furthering progress in the field of generative modeling. However, we acknowledge that our models have certain constraints and biases, which is why we advise against employing them for commercial purposes at this time. We are aware that the utilization of our models could extend to areas beyond our expectations, and defining specific criteria for what is considered suitable for "research" purposes presents a challenge. Specifically, we advise caution when using these models in contexts that demand high accuracy, where minor imperfections in the generated 3D assets could have adverse consequences. Specifically, these models have been evaluated on the following tasks for research purposes: * Generating 3D renderings or meshes conditioned on single, synthetic images * Generating 3D renderings or meshes conditioned on text descriptions # Performance & Limitations Our image-conditional model has only been evaluated on a highly specific distribution of synthetic renderings. Even in these cases, the model still sometimes fails to infer the correct occluded parts of an object or produces geometry that is inconsistent with the given rendered images. These failure modes are similar to those of Point-E. The resulting 3D assets often have rough edges, holes, or blurry surface textures. Our text-conditional model can also produce a somewhat large and diverse vocabulary of objects. This model is often capable of producing objects with requested colors and textures, and sometimes even combining multiple objects. However, it often fails for more complex prompts that require placing multiple objects in a scene or binding attributes to objects. It also typically fails to produce a desired number of objects when a certain quantity is requested. We find that our text-conditional model can sometimes produce samples which reflect gender biases. For example, samples for "a nurse" typically have a different body shape than samples for "a doctor". When probing for potential misuses, we also found that our text-conditional model is capable of producing 3D assets related to violence, such as guns or tanks. However, the resulting quality of these samples is poor enough that they look unrealistic and toy like. As with Point-E, our dataset consists of many simple, cartoonish 3D assets, and our generative models are prone to imitating this style. We believe our models will have many potential use cases. For example, our text-conditional model could enable users to quickly produce many 3D assets, allowing for rapid prototyping for computer graphics applications or 3D printing. The use of 3D printing in concert with our models could potentially be harmful, for example if used to create dangerous objects or fabricate tools or parts that are deployed without external validation. Generative 3D models share many challenges and constraints with image generation models. This includes the tendency to generate content that may be biased or detrimental, as well as the potential for dual-use applications. As the capabilities of these models evolve, further investigation is required to gain a clearer understanding of how these risks manifest. ================================================ FILE: samples.md ================================================ # Samples Here is a collection of prompts and four random text-conditional samples for each prompt. Samples are rendered at 128x128 resolution with NeRF.
Prompt
a penguina penguina penguina penguina penguin
a campfirea campfirea campfirea campfirea campfire
an elephantan elephantan elephantan elephantan elephant
a donut with pink icinga donut with pink icinga donut with pink icinga donut with pink icinga donut with pink icing
a voxelized doga voxelized doga voxelized doga voxelized doga voxelized dog
ube ice cream coneube ice cream coneube ice cream coneube ice cream coneube ice cream cone
a birthday cupcakea birthday cupcakea birthday cupcakea birthday cupcakea birthday cupcake
shepherds pieshepherds pieshepherds pieshepherds pieshepherds pie
a bowl of vegetablesa bowl of vegetablesa bowl of vegetablesa bowl of vegetablesa bowl of vegetables
a cheeseburgera cheeseburgera cheeseburgera cheeseburgera cheeseburger
a plate of mushy green peasa plate of mushy green peasa plate of mushy green peasa plate of mushy green peasa plate of mushy green peas
a traffic conea traffic conea traffic conea traffic conea traffic cone
a car that looks like an avocadoa car that looks like an avocadoa car that looks like an avocadoa car that looks like an avocadoa car that looks like an avocado
an airplane that looks like a bananaan airplane that looks like a bananaan airplane that looks like a bananaan airplane that looks like a bananaan airplane that looks like a banana
a stop signa stop signa stop signa stop signa stop sign
a spaceshipa spaceshipa spaceshipa spaceshipa spaceship
a race cara race cara race cara race cara race car
a schoolbusa schoolbusa schoolbusa schoolbusa schoolbus
a firetrucka firetrucka firetrucka firetrucka firetruck
a rusty old cara rusty old cara rusty old cara rusty old cara rusty old car
a fast cara fast cara fast cara fast cara fast car
a chair that looks like an avocadoa chair that looks like an avocadoa chair that looks like an avocadoa chair that looks like an avocadoa chair that looks like an avocado
a chair that looks like fruita chair that looks like fruita chair that looks like fruita chair that looks like fruita chair that looks like fruit
a chair that looks like a treea chair that looks like a treea chair that looks like a treea chair that looks like a treea chair that looks like a tree
a chair that looks like a zebraa chair that looks like a zebraa chair that looks like a zebraa chair that looks like a zebraa chair that looks like a zebra
a chair that looks like a swimming poola chair that looks like a swimming poola chair that looks like a swimming poola chair that looks like a swimming poola chair that looks like a swimming pool
the person is runningthe person is runningthe person is runningthe person is runningthe person is running
the person is sittingthe person is sittingthe person is sittingthe person is sittingthe person is sitting
the person is lying downthe person is lying downthe person is lying downthe person is lying downthe person is lying down
a person that looks like a zebraa person that looks like a zebraa person that looks like a zebraa person that looks like a zebraa person that looks like a zebra
a person that looks like a leoparda person that looks like a leoparda person that looks like a leoparda person that looks like a leoparda person that looks like a leopard
a pair of shortsa pair of shortsa pair of shortsa pair of shortsa pair of shorts
a designer dressa designer dressa designer dressa designer dressa designer dress
banana shoesbanana shoesbanana shoesbanana shoesbanana shoes
a green boota green boota green boota green boota green boot
a pair of sunglassesa pair of sunglassesa pair of sunglassesa pair of sunglassesa pair of sunglasses
================================================ FILE: setup.py ================================================ from setuptools import setup setup( name="shap-e", packages=[ "shap_e", "shap_e.diffusion", "shap_e.models", "shap_e.models.generation", "shap_e.models.nerf", "shap_e.models.nerstf", "shap_e.models.nn", "shap_e.models.stf", "shap_e.models.transmitter", "shap_e.rendering", "shap_e.rendering.blender", "shap_e.rendering.raycast", "shap_e.util", ], install_requires=[ "filelock", "Pillow", "torch", "fire", "humanize", "requests", "tqdm", "matplotlib", "scikit-image", "scipy", "numpy", "blobfile", "clip @ git+https://github.com/openai/CLIP.git", ], author="OpenAI", ) ================================================ FILE: shap_e/__init__.py ================================================ ================================================ FILE: shap_e/diffusion/__init__.py ================================================ ================================================ FILE: shap_e/diffusion/gaussian_diffusion.py ================================================ """ Based on https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py """ import math from typing import Any, Dict, Iterable, Optional, Sequence, Union import blobfile as bf import numpy as np import torch as th import yaml def diffusion_from_config(config: Union[str, Dict[str, Any]]) -> "GaussianDiffusion": if isinstance(config, str): with bf.BlobFile(config, "rb") as f: obj = yaml.load(f, Loader=yaml.SafeLoader) return diffusion_from_config(obj) schedule = config["schedule"] steps = config["timesteps"] respace = config.get("respacing", None) mean_type = config.get("mean_type", "epsilon") betas = get_named_beta_schedule(schedule, steps, **config.get("schedule_args", {})) channel_scales = config.get("channel_scales", None) channel_biases = config.get("channel_biases", None) if channel_scales is not None: channel_scales = np.array(channel_scales) if channel_biases is not None: channel_biases = np.array(channel_biases) kwargs = dict( betas=betas, model_mean_type=mean_type, model_var_type="learned_range", loss_type="mse", channel_scales=channel_scales, channel_biases=channel_biases, ) if respace is None: return GaussianDiffusion(**kwargs) else: return SpacedDiffusion(use_timesteps=space_timesteps(steps, respace), **kwargs) def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): """ This is the deprecated API for creating beta schedules. See get_named_beta_schedule() for the new library of schedules. """ if beta_schedule == "linear": betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) else: raise NotImplementedError(beta_schedule) assert betas.shape == (num_diffusion_timesteps,) return betas def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, **extra_args: float): """ Get a pre-defined beta schedule for the given name. The beta schedule library consists of beta schedules which remain similar in the limit of num_diffusion_timesteps. Beta schedules may be added, but should not be removed or changed once they are committed to maintain backwards compatibility. """ if schedule_name == "linear": # Linear schedule from Ho et al, extended to work for any number of # diffusion steps. scale = 1000 / num_diffusion_timesteps return get_beta_schedule( "linear", beta_start=scale * 0.0001, beta_end=scale * 0.02, num_diffusion_timesteps=num_diffusion_timesteps, ) elif schedule_name == "cosine": return betas_for_alpha_bar( num_diffusion_timesteps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, ) elif schedule_name == "inv_parabola": exponent = extra_args.get("power", 2.0) return betas_for_alpha_bar( num_diffusion_timesteps, lambda t: 1 - t**exponent, ) elif schedule_name == "translated_parabola": exponent = extra_args.get("power", 2.0) return betas_for_alpha_bar( num_diffusion_timesteps, lambda t: (1 - t) ** exponent, ) elif schedule_name == "exp": coefficient = extra_args.get("coefficient", -12.0) return betas_for_alpha_bar(num_diffusion_timesteps, lambda t: math.exp(t * coefficient)) else: raise NotImplementedError(f"unknown beta schedule: {schedule_name}") def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t from 0 to 1 and produces the cumulative product of (1-beta) up to that part of the diffusion process. :param max_beta: the maximum beta to use; use values lower than 1 to prevent singularities. """ betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return np.array(betas) def space_timesteps(num_timesteps, section_counts): """ Create a list of timesteps to use from an original diffusion process, given the number of timesteps we want to take from equally-sized portions of the original process. For example, if there's 300 timesteps and the section counts are [10,15,20] then the first 100 timesteps are strided to be 10 timesteps, the second 100 are strided to be 15 timesteps, and the final 100 are strided to be 20. :param num_timesteps: the number of diffusion steps in the original process to divide up. :param section_counts: either a list of numbers, or a string containing comma-separated numbers, indicating the step count per section. As a special case, use "ddimN" where N is a number of steps to use the striding from the DDIM paper. :return: a set of diffusion steps from the original process to use. """ if isinstance(section_counts, str): if section_counts.startswith("ddim"): desired_count = int(section_counts[len("ddim") :]) for i in range(1, num_timesteps): if len(range(0, num_timesteps, i)) == desired_count: return set(range(0, num_timesteps, i)) raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") elif section_counts.startswith("exact"): res = set(int(x) for x in section_counts[len("exact") :].split(",")) for x in res: if x < 0 or x >= num_timesteps: raise ValueError(f"timestep out of bounds: {x}") return res section_counts = [int(x) for x in section_counts.split(",")] size_per = num_timesteps // len(section_counts) extra = num_timesteps % len(section_counts) start_idx = 0 all_steps = [] for i, section_count in enumerate(section_counts): size = size_per + (1 if i < extra else 0) if size < section_count: raise ValueError(f"cannot divide section of {size} steps into {section_count}") if section_count <= 1: frac_stride = 1 else: frac_stride = (size - 1) / (section_count - 1) cur_idx = 0.0 taken_steps = [] for _ in range(section_count): taken_steps.append(start_idx + round(cur_idx)) cur_idx += frac_stride all_steps += taken_steps start_idx += size return set(all_steps) class GaussianDiffusion: """ Utilities for training and sampling diffusion models. Ported directly from here: https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 :param betas: a 1-D array of betas for each diffusion timestep from T to 1. :param model_mean_type: a string determining what the model outputs. :param model_var_type: a string determining how variance is output. :param loss_type: a string determining the loss function to use. :param discretized_t0: if True, use discrete gaussian loss for t=0. Only makes sense for images. :param channel_scales: a multiplier to apply to x_start in training_losses and sampling functions. """ def __init__( self, *, betas: Sequence[float], model_mean_type: str, model_var_type: str, loss_type: str, discretized_t0: bool = False, channel_scales: Optional[np.ndarray] = None, channel_biases: Optional[np.ndarray] = None, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type self.loss_type = loss_type self.discretized_t0 = discretized_t0 self.channel_scales = channel_scales self.channel_biases = channel_biases # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(betas.shape) == 1, "betas must be 1-D" assert (betas > 0).all() and (betas <= 1).all() self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) def get_sigmas(self, t): return _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, t.shape) def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): """ Diffuse the data for a given number of diffusion steps. In other words, sample from q(x_t | x_0). :param x_start: the initial data batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def q_posterior_mean_variance(self, x_start, x_t, t): """ Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance( self, model, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ if model_kwargs is None: model_kwargs = {} B, C = x.shape[:2] assert t.shape == (B,) model_output = model(x, t, **model_kwargs) if isinstance(model_output, tuple): model_output, extra = model_output else: extra = None if self.model_var_type in ["learned", "learned_range"]: assert model_output.shape == (B, C * 2, *x.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) if self.model_var_type == "learned": model_log_variance = model_var_values model_variance = th.exp(model_log_variance) else: min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) # The model_var_values is [-1, 1] for [min_var, max_var]. frac = (model_var_values + 1) / 2 model_log_variance = frac * max_log + (1 - frac) * min_log model_variance = th.exp(model_log_variance) else: model_variance, model_log_variance = { # for fixedlarge, we set the initial (log-)variance like so # to get a better decoder log likelihood. "fixed_large": ( np.append(self.posterior_variance[1], self.betas[1:]), np.log(np.append(self.posterior_variance[1], self.betas[1:])), ), "fixed_small": ( self.posterior_variance, self.posterior_log_variance_clipped, ), }[self.model_var_type] model_variance = _extract_into_tensor(model_variance, t, x.shape) model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) def process_xstart(x): if denoised_fn is not None: x = denoised_fn(x) if clip_denoised: return x.clamp(-1, 1) return x if self.model_mean_type == "x_prev": pred_xstart = process_xstart( self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) ) model_mean = model_output elif self.model_mean_type in ["x_start", "epsilon"]: if self.model_mean_type == "x_start": pred_xstart = process_xstart(model_output) else: pred_xstart = process_xstart( self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) ) model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) else: raise NotImplementedError(self.model_mean_type) assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape return { "mean": model_mean, "variance": model_variance, "log_variance": model_log_variance, "pred_xstart": pred_xstart, "extra": extra, } def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor( self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape ) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, t, **(model_kwargs or {})) new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **(model_kwargs or {})) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) return out def p_sample( self, model, x, t, clip_denoised=False, denoised_fn=None, cond_fn=None, model_kwargs=None, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = th.randn_like(x) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} def p_sample_loop( self, model, shape, noise=None, clip_denoised=False, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, temp=1.0, ): """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :return: a non-differentiable batch of samples. """ final = None for sample in self.p_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, temp=temp, ): final = sample return final["sample"] def p_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=False, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, temp=1.0, ): """ Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) * temp indices = list(range(self.num_timesteps))[::-1] if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) with th.no_grad(): out = self.p_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, ) yield self.unscale_out_dict(out) img = out["sample"] def ddim_sample( self, model, x, t, clip_denoised=False, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} def ddim_reverse_sample( self, model, x, t, clip_denoised=False, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, "Reverse ODE only for deterministic path" out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=False, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, temp=1.0, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ final = None for sample in self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, temp=temp, ): final = sample return final["sample"] def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=False, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, temp=1.0, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) * temp indices = list(range(self.num_timesteps))[::-1] if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) with th.no_grad(): out = self.ddim_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, eta=eta, ) yield self.unscale_out_dict(out) img = out["sample"] def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=False, model_kwargs=None): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). This allows for comparison to other papers. :return: a dict with the following keys: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t ) out = self.p_mean_variance( model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) if not self.discretized_t0: decoder_nll = th.zeros_like(decoder_nll) assert decoder_nll.shape == x_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = th.where((t == 0), decoder_nll, kl) return { "output": output, "pred_xstart": out["pred_xstart"], "extra": out["extra"], } def training_losses( self, model, x_start, t, model_kwargs=None, noise=None ) -> Dict[str, th.Tensor]: """ Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ x_start = self.scale_channels(x_start) if model_kwargs is None: model_kwargs = {} if noise is None: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) terms = {} if self.loss_type == "kl" or self.loss_type == "rescaled_kl": vb_terms = self._vb_terms_bpd( model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, ) terms["loss"] = vb_terms["output"] if self.loss_type == "rescaled_kl": terms["loss"] *= self.num_timesteps extra = vb_terms["extra"] elif self.loss_type == "mse" or self.loss_type == "rescaled_mse": model_output = model(x_t, t, **model_kwargs) if isinstance(model_output, tuple): model_output, extra = model_output else: extra = {} if self.model_var_type in [ "learned", "learned_range", ]: B, C = x_t.shape[:2] assert model_output.shape == ( B, C * 2, *x_t.shape[2:], ), f"{model_output.shape} != {(B, C * 2, *x_t.shape[2:])}" model_output, model_var_values = th.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, )["output"] if self.loss_type == "rescaled_mse": # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. terms["vb"] *= self.num_timesteps / 1000.0 target = { "x_prev": self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], "x_start": x_start, "epsilon": noise, }[self.model_mean_type] assert model_output.shape == target.shape == x_start.shape terms["mse"] = mean_flat((target - model_output) ** 2) if "vb" in terms: terms["loss"] = terms["mse"] + terms["vb"] else: terms["loss"] = terms["mse"] else: raise NotImplementedError(self.loss_type) if "losses" in extra: terms.update({k: loss for k, (loss, _scale) in extra["losses"].items()}) for loss, scale in extra["losses"].values(): terms["loss"] = terms["loss"] + loss * scale return terms def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=False, model_kwargs=None): """ Compute the entire variational lower-bound, measured in bits-per-dim, as well as other related quantities. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param clip_denoised: if True, clip denoised samples. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - total_bpd: the total variational lower-bound, per batch element. - prior_bpd: the prior term in the lower-bound. - vb: an [N x T] tensor of terms in the lower-bound. - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - mse: an [N x T] tensor of epsilon MSEs for each timestep. """ device = x_start.device batch_size = x_start.shape[0] vb = [] xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: t_batch = th.tensor([t] * batch_size, device=device) noise = th.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with th.no_grad(): out = self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, ) vb.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) mse.append(mean_flat((eps - noise) ** 2)) vb = th.stack(vb, dim=1) xstart_mse = th.stack(xstart_mse, dim=1) mse = th.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd return { "total_bpd": total_bpd, "prior_bpd": prior_bpd, "vb": vb, "xstart_mse": xstart_mse, "mse": mse, } def scale_channels(self, x: th.Tensor) -> th.Tensor: if self.channel_scales is not None: x = x * th.from_numpy(self.channel_scales).to(x).reshape( [1, -1, *([1] * (len(x.shape) - 2))] ) if self.channel_biases is not None: x = x + th.from_numpy(self.channel_biases).to(x).reshape( [1, -1, *([1] * (len(x.shape) - 2))] ) return x def unscale_channels(self, x: th.Tensor) -> th.Tensor: if self.channel_biases is not None: x = x - th.from_numpy(self.channel_biases).to(x).reshape( [1, -1, *([1] * (len(x.shape) - 2))] ) if self.channel_scales is not None: x = x / th.from_numpy(self.channel_scales).to(x).reshape( [1, -1, *([1] * (len(x.shape) - 2))] ) return x def unscale_out_dict( self, out: Dict[str, Union[th.Tensor, Any]] ) -> Dict[str, Union[th.Tensor, Any]]: return { k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v) for k, v in out.items() } class SpacedDiffusion(GaussianDiffusion): """ A diffusion process which can skip steps in a base diffusion process. :param use_timesteps: (unordered) timesteps from the original diffusion process to retain. :param kwargs: the kwargs to create the base diffusion process. """ def __init__(self, use_timesteps: Iterable[int], **kwargs): self.use_timesteps = set(use_timesteps) self.timestep_map = [] self.original_num_steps = len(kwargs["betas"]) base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa last_alpha_cumprod = 1.0 new_betas = [] for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): if i in self.use_timesteps: new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) last_alpha_cumprod = alpha_cumprod self.timestep_map.append(i) kwargs["betas"] = np.array(new_betas) super().__init__(**kwargs) def p_mean_variance(self, model, *args, **kwargs): return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) def training_losses(self, model, *args, **kwargs): return super().training_losses(self._wrap_model(model), *args, **kwargs) def condition_mean(self, cond_fn, *args, **kwargs): return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) def condition_score(self, cond_fn, *args, **kwargs): return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) def _wrap_model(self, model): if isinstance(model, _WrappedModel): return model return _WrappedModel(model, self.timestep_map, self.original_num_steps) class _WrappedModel: def __init__(self, model, timestep_map, original_num_steps): self.model = model self.timestep_map = timestep_map self.original_num_steps = original_num_steps def __call__(self, x, ts, **kwargs): map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) new_ts = map_tensor[ts] return self.model(x, new_ts, **kwargs) def _extract_into_tensor(arr, timesteps, broadcast_shape): """ Extract values from a 1-D numpy array for a batch of indices. :param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param broadcast_shape: a larger shape of K dimensions with the batch dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() while len(res.shape) < len(broadcast_shape): res = res[..., None] return res + th.zeros(broadcast_shape, device=timesteps.device) def normal_kl(mean1, logvar1, mean2, logvar2): """ Compute the KL divergence between two gaussians. Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): if isinstance(obj, th.Tensor): tensor = obj break assert tensor is not None, "at least one argument must be a Tensor" # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for th.exp(). logvar1, logvar2 = [ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2) ] return 0.5 * ( -1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2) ) def approx_standard_normal_cdf(x): """ A fast approximation of the cumulative distribution function of the standard normal. """ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) def discretized_gaussian_log_likelihood(x, *, means, log_scales): """ Compute the log-likelihood of a Gaussian distribution discretizing to a given image. :param x: the target images. It is assumed that this was uint8 values, rescaled to the range [-1, 1]. :param means: the Gaussian mean Tensor. :param log_scales: the Gaussian log stddev Tensor. :return: a tensor like x of log probabilities (in nats). """ assert x.shape == means.shape == log_scales.shape centered_x = x - means inv_stdv = th.exp(-log_scales) plus_in = inv_stdv * (centered_x + 1.0 / 255.0) cdf_plus = approx_standard_normal_cdf(plus_in) min_in = inv_stdv * (centered_x - 1.0 / 255.0) cdf_min = approx_standard_normal_cdf(min_in) log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) cdf_delta = cdf_plus - cdf_min log_probs = th.where( x < -0.999, log_cdf_plus, th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), ) assert log_probs.shape == x.shape return log_probs def mean_flat(tensor): """ Take the mean over all non-batch dimensions. """ return tensor.flatten(1).mean(1) ================================================ FILE: shap_e/diffusion/k_diffusion.py ================================================ """ Based on: https://github.com/crowsonkb/k-diffusion Copyright (c) 2022 Katherine Crowson Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import numpy as np import torch as th from .gaussian_diffusion import GaussianDiffusion, mean_flat class KarrasDenoiser: def __init__(self, sigma_data: float = 0.5): self.sigma_data = sigma_data def get_snr(self, sigmas): return sigmas**-2 def get_sigmas(self, sigmas): return sigmas def get_scalings(self, sigma): c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 return c_skip, c_out, c_in def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None): if model_kwargs is None: model_kwargs = {} if noise is None: noise = th.randn_like(x_start) terms = {} dims = x_start.ndim x_t = x_start + noise * append_dims(sigmas, dims) c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)] model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs) target = (x_start - c_skip * x_t) / c_out terms["mse"] = mean_flat((model_output - target) ** 2) terms["xs_mse"] = mean_flat((denoised - x_start) ** 2) if "vb" in terms: terms["loss"] = terms["mse"] + terms["vb"] else: terms["loss"] = terms["mse"] return terms def denoise(self, model, x_t, sigmas, **model_kwargs): c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)] rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44) model_output = model(c_in * x_t, rescaled_t, **model_kwargs) denoised = c_out * model_output + c_skip * x_t return model_output, denoised class GaussianToKarrasDenoiser: def __init__(self, model, diffusion): from scipy import interpolate self.model = model self.diffusion = diffusion self.alpha_cumprod_to_t = interpolate.interp1d( diffusion.alphas_cumprod, np.arange(0, diffusion.num_timesteps) ) def sigma_to_t(self, sigma): alpha_cumprod = 1.0 / (sigma**2 + 1) if alpha_cumprod > self.diffusion.alphas_cumprod[0]: return 0 elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]: return self.diffusion.num_timesteps - 1 else: return float(self.alpha_cumprod_to_t(alpha_cumprod)) def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None): t = th.tensor( [self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()], dtype=th.long, device=sigmas.device, ) c_in = append_dims(1.0 / (sigmas**2 + 1) ** 0.5, x_t.ndim) out = self.diffusion.p_mean_variance( self.model, x_t * c_in, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) return None, out["pred_xstart"] def karras_sample(*args, **kwargs): last = None for x in karras_sample_progressive(*args, **kwargs): last = x["x"] return last def karras_sample_progressive( diffusion, model, shape, steps, clip_denoised=True, progress=False, model_kwargs=None, device=None, sigma_min=0.002, sigma_max=80, # higher for highres? rho=7.0, sampler="heun", s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, guidance_scale=0.0, ): sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device) x_T = th.randn(*shape, device=device) * sigma_max sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[ sampler ] if sampler != "ancestral": sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise) else: sampler_args = {} if isinstance(diffusion, KarrasDenoiser): def denoiser(x_t, sigma): _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs) if clip_denoised: denoised = denoised.clamp(-1, 1) return denoised elif isinstance(diffusion, GaussianDiffusion): model = GaussianToKarrasDenoiser(model, diffusion) def denoiser(x_t, sigma): _, denoised = model.denoise( x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) return denoised else: raise NotImplementedError if guidance_scale != 0 and guidance_scale != 1: def guided_denoiser(x_t, sigma): x_t = th.cat([x_t, x_t], dim=0) sigma = th.cat([sigma, sigma], dim=0) x_0 = denoiser(x_t, sigma) cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0) x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0) return x_0 else: guided_denoiser = denoiser for obj in sample_fn( guided_denoiser, x_T, sigmas, progress=progress, **sampler_args, ): if isinstance(diffusion, GaussianDiffusion): yield diffusion.unscale_out_dict(obj) else: yield obj def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"): """Constructs the noise schedule of Karras et al. (2022).""" ramp = th.linspace(0, 1, n) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return append_zero(sigmas).to(device) def to_d(x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / append_dims(sigma, x.ndim) def get_ancestral_step(sigma_from, sigma_to): """Calculates the noise level (sigma_down) to step down to and the amount of noise to add (sigma_up) when doing an ancestral sampling step.""" sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 return sigma_down, sigma_up @th.no_grad() def sample_euler_ancestral(model, x, sigmas, progress=False): """Ancestral sampling with Euler method steps.""" s_in = x.new_ones([x.shape[0]]) indices = range(len(sigmas) - 1) if progress: from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: denoised = model(x, sigmas[i] * s_in) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "pred_xstart": denoised} d = to_d(x, sigmas[i], denoised) # Euler method dt = sigma_down - sigmas[i] x = x + d * dt x = x + th.randn_like(x) * sigma_up yield {"x": x, "pred_xstart": x} @th.no_grad() def sample_heun( denoiser, x, sigmas, progress=False, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, ): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" s_in = x.new_ones([x.shape[0]]) indices = range(len(sigmas) - 1) if progress: from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: gamma = ( min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 ) eps = th.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = denoiser(x, sigma_hat * s_in) d = to_d(x, sigma_hat, denoised) yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "pred_xstart": denoised} dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == 0: # Euler method x = x + d * dt else: # Heun's method x_2 = x + d * dt denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in) d_2 = to_d(x_2, sigmas[i + 1], denoised_2) d_prime = (d + d_2) / 2 x = x + d_prime * dt yield {"x": x, "pred_xstart": denoised} @th.no_grad() def sample_dpm( denoiser, x, sigmas, progress=False, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, ): """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" s_in = x.new_ones([x.shape[0]]) indices = range(len(sigmas) - 1) if progress: from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: gamma = ( min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 ) eps = th.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = denoiser(x, sigma_hat * s_in) d = to_d(x, sigma_hat, denoised) yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised} # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 dt_1 = sigma_mid - sigma_hat dt_2 = sigmas[i + 1] - sigma_hat x_2 = x + d * dt_1 denoised_2 = denoiser(x_2, sigma_mid * s_in) d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 yield {"x": x, "pred_xstart": denoised} def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") return x[(...,) + (None,) * dims_to_append] def append_zero(x): return th.cat([x, x.new_zeros([1])]) ================================================ FILE: shap_e/diffusion/sample.py ================================================ from typing import Any, Callable, Dict, Optional import torch import torch.nn as nn from .gaussian_diffusion import GaussianDiffusion from .k_diffusion import karras_sample DEFAULT_KARRAS_STEPS = 64 DEFAULT_KARRAS_SIGMA_MIN = 1e-3 DEFAULT_KARRAS_SIGMA_MAX = 160 DEFAULT_KARRAS_S_CHURN = 0.0 def uncond_guide_model( model: Callable[..., torch.Tensor], scale: float ) -> Callable[..., torch.Tensor]: def model_fn(x_t, ts, **kwargs): half = x_t[: len(x_t) // 2] combined = torch.cat([half, half], dim=0) model_out = model(combined, ts, **kwargs) eps, rest = model_out[:, :3], model_out[:, 3:] cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0) half_eps = uncond_eps + scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) return torch.cat([eps, rest], dim=1) return model_fn def sample_latents( *, batch_size: int, model: nn.Module, diffusion: GaussianDiffusion, model_kwargs: Dict[str, Any], guidance_scale: float, clip_denoised: bool, use_fp16: bool, use_karras: bool, karras_steps: int, sigma_min: float, sigma_max: float, s_churn: float, device: Optional[torch.device] = None, progress: bool = False, ) -> torch.Tensor: sample_shape = (batch_size, model.d_latent) if device is None: device = next(model.parameters()).device if hasattr(model, "cached_model_kwargs"): model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs) if guidance_scale != 1.0 and guidance_scale != 0.0: for k, v in model_kwargs.copy().items(): model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0) sample_shape = (batch_size, model.d_latent) with torch.autocast(device_type=device.type, enabled=use_fp16): if use_karras: samples = karras_sample( diffusion=diffusion, model=model, shape=sample_shape, steps=karras_steps, clip_denoised=clip_denoised, model_kwargs=model_kwargs, device=device, sigma_min=sigma_min, sigma_max=sigma_max, s_churn=s_churn, guidance_scale=guidance_scale, progress=progress, ) else: internal_batch_size = batch_size if guidance_scale != 1.0: model = uncond_guide_model(model, guidance_scale) internal_batch_size *= 2 samples = diffusion.p_sample_loop( model, shape=(internal_batch_size, *sample_shape[1:]), model_kwargs=model_kwargs, device=device, clip_denoised=clip_denoised, progress=progress, ) return samples ================================================ FILE: shap_e/examples/encode_model.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from shap_e.models.download import load_model\n", "from shap_e.util.data_util import load_or_create_multimodal_batch\n", "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xm = load_model('transmitter', device=device)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model_path = \"example_data/cactus/object.obj\"\n", "\n", "# This may take a few minutes, since it requires rendering the model twice\n", "# in two different modes.\n", "batch = load_or_create_multimodal_batch(\n", " device,\n", " model_path=model_path,\n", " mv_light_mode=\"basic\",\n", " mv_image_size=256,\n", " cache_dir=\"example_data/cactus/cached\",\n", " verbose=True, # this will show Blender output during renders\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " latent = xm.encoder.encode_to_bottleneck(batch)\n", "\n", " render_mode = 'stf' # you can change this to 'nerf'\n", " size = 128 # recommended that you lower resolution when using nerf\n", "\n", " cameras = create_pan_cameras(size, device)\n", " images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n", " display(gif_widget(images))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.9" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: shap_e/examples/example_data/cactus/material.mtl ================================================ newmtl mat0 Ka 0.0000 0.7000 0.0000 Kd 0.0000 0.7000 0.0000 Ks 0.0000 0.0000 0.0000 newmtl mat1 Ka 0.6600 0.4400 0.2000 Kd 0.6600 0.4400 0.2000 Ks 0.0000 0.0000 0.0000 newmtl mat2 Ka 0.3000 0.3000 0.3000 Kd 0.3000 0.3000 0.3000 Ks 0.0000 0.0000 0.0000 newmtl mat3 Ka 0.0000 0.5000 0.0000 Kd 0.0000 0.5000 0.0000 Ks 0.0000 0.0000 0.0000 newmtl mat4 Ka 0.0000 0.5667 0.0000 Kd 0.0000 0.5667 0.0000 Ks 0.0000 0.0000 0.0000 newmtl mat5 Ka 0.5400 0.3933 0.2333 Kd 0.5400 0.3933 0.2333 Ks 0.0000 0.0000 0.0000 newmtl mat6 Ka 0.0000 0.6333 0.0000 Kd 0.0000 0.6333 0.0000 Ks 0.0000 0.0000 0.0000 newmtl mat7 Ka 0.2000 0.3667 0.2000 Kd 0.2000 0.3667 0.2000 Ks 0.0000 0.0000 0.0000 newmtl mat8 Ka 0.4200 0.3467 0.2667 Kd 0.4200 0.3467 0.2667 Ks 0.0000 0.0000 0.0000 newmtl mat9 Ka 0.1000 0.4333 0.1000 Kd 0.1000 0.4333 0.1000 Ks 0.0000 0.0000 0.0000 newmtl mat10 Ka 0.1000 0.5667 0.1000 Kd 0.1000 0.5667 0.1000 Ks 0.0000 0.0000 0.0000 newmtl mat11 Ka 0.2000 0.4333 0.2000 Kd 0.2000 0.4333 0.2000 Ks 0.0000 0.0000 0.0000 newmtl mat12 Ka 0.1000 0.5000 0.1000 Kd 0.1000 0.5000 0.1000 Ks 0.0000 0.0000 0.0000 ================================================ FILE: shap_e/examples/example_data/cactus/object.obj ================================================ [File too large to display: 17.2 MB] ================================================ FILE: shap_e/examples/sample_image_to_3d.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "964ccced", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from shap_e.diffusion.sample import sample_latents\n", "from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n", "from shap_e.models.download import load_model, load_config\n", "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget\n", "from shap_e.util.image_util import load_image" ] }, { "cell_type": "code", "execution_count": null, "id": "8eed3a76", "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": null, "id": "2d922637", "metadata": {}, "outputs": [], "source": [ "xm = load_model('transmitter', device=device)\n", "model = load_model('image300M', device=device)\n", "diffusion = diffusion_from_config(load_config('diffusion'))" ] }, { "cell_type": "code", "execution_count": null, "id": "53d329d0", "metadata": {}, "outputs": [], "source": [ "batch_size = 4\n", "guidance_scale = 3.0\n", "\n", "# To get the best result, you should remove the background and show only the object of interest to the model.\n", "image = load_image(\"example_data/corgi.png\")\n", "\n", "latents = sample_latents(\n", " batch_size=batch_size,\n", " model=model,\n", " diffusion=diffusion,\n", " guidance_scale=guidance_scale,\n", " model_kwargs=dict(images=[image] * batch_size),\n", " progress=True,\n", " clip_denoised=True,\n", " use_fp16=True,\n", " use_karras=True,\n", " karras_steps=64,\n", " sigma_min=1e-3,\n", " sigma_max=160,\n", " s_churn=0,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "633da2ec", "metadata": {}, "outputs": [], "source": [ "render_mode = 'nerf' # you can change this to 'stf' for mesh rendering\n", "size = 64 # this is the size of the renders; higher values take longer to render.\n", "\n", "cameras = create_pan_cameras(size, device)\n", "for i, latent in enumerate(latents):\n", " images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n", " display(gif_widget(images))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.9" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: shap_e/examples/sample_text_to_3d.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "964ccced", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from shap_e.diffusion.sample import sample_latents\n", "from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n", "from shap_e.models.download import load_model, load_config\n", "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget" ] }, { "cell_type": "code", "execution_count": null, "id": "8eed3a76", "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": null, "id": "2d922637", "metadata": {}, "outputs": [], "source": [ "xm = load_model('transmitter', device=device)\n", "model = load_model('text300M', device=device)\n", "diffusion = diffusion_from_config(load_config('diffusion'))" ] }, { "cell_type": "code", "execution_count": null, "id": "53d329d0", "metadata": {}, "outputs": [], "source": [ "batch_size = 4\n", "guidance_scale = 15.0\n", "prompt = \"a shark\"\n", "\n", "latents = sample_latents(\n", " batch_size=batch_size,\n", " model=model,\n", " diffusion=diffusion,\n", " guidance_scale=guidance_scale,\n", " model_kwargs=dict(texts=[prompt] * batch_size),\n", " progress=True,\n", " clip_denoised=True,\n", " use_fp16=True,\n", " use_karras=True,\n", " karras_steps=64,\n", " sigma_min=1e-3,\n", " sigma_max=160,\n", " s_churn=0,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "633da2ec", "metadata": {}, "outputs": [], "source": [ "render_mode = 'nerf' # you can change this to 'stf'\n", "size = 64 # this is the size of the renders; higher values take longer to render.\n", "\n", "cameras = create_pan_cameras(size, device)\n", "for i, latent in enumerate(latents):\n", " images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n", " display(gif_widget(images))" ] }, { "cell_type": "code", "execution_count": null, "id": "85a4dce4", "metadata": {}, "outputs": [], "source": [ "# Example of saving the latents as meshes.\n", "from shap_e.util.notebooks import decode_latent_mesh\n", "\n", "for i, latent in enumerate(latents):\n", " t = decode_latent_mesh(xm, latent).tri_mesh()\n", " with open(f'example_mesh_{i}.ply', 'wb') as f:\n", " t.write_ply(f)\n", " with open(f'example_mesh_{i}.obj', 'w') as f:\n", " t.write_obj(f)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: shap_e/models/__init__.py ================================================ ================================================ FILE: shap_e/models/configs.py ================================================ from typing import Any, Dict, Union import blobfile as bf import torch import torch.nn as nn import yaml from shap_e.models.generation.latent_diffusion import SplitVectorDiffusion from shap_e.models.generation.perceiver import PointDiffusionPerceiver from shap_e.models.generation.pooled_mlp import PooledMLP from shap_e.models.generation.transformer import ( CLIPImageGridPointDiffusionTransformer, CLIPImageGridUpsamplePointDiffusionTransformer, CLIPImagePointDiffusionTransformer, PointDiffusionTransformer, UpsamplePointDiffusionTransformer, ) from shap_e.models.nerf.model import MLPNeRFModel, VoidNeRFModel from shap_e.models.nerf.renderer import OneStepNeRFRenderer, TwoStepNeRFRenderer from shap_e.models.nerstf.mlp import MLPDensitySDFModel, MLPNeRSTFModel from shap_e.models.nerstf.renderer import NeRSTFRenderer from shap_e.models.nn.meta import batch_meta_state_dict from shap_e.models.stf.mlp import MLPSDFModel, MLPTextureFieldModel from shap_e.models.stf.renderer import STFRenderer from shap_e.models.transmitter.base import ChannelsDecoder, Transmitter, VectorDecoder from shap_e.models.transmitter.channels_encoder import ( PointCloudPerceiverChannelsEncoder, PointCloudTransformerChannelsEncoder, ) from shap_e.models.transmitter.multiview_encoder import MultiviewTransformerEncoder from shap_e.models.transmitter.pc_encoder import ( PointCloudPerceiverEncoder, PointCloudTransformerEncoder, ) from shap_e.models.volume import BoundingBoxVolume, SphericalVolume, UnboundedVolume def model_from_config(config: Union[str, Dict[str, Any]], device: torch.device) -> nn.Module: if isinstance(config, str): with bf.BlobFile(config, "rb") as f: obj = yaml.load(f, Loader=yaml.SafeLoader) return model_from_config(obj, device=device) config = config.copy() name = config.pop("name") if name == "PointCloudTransformerEncoder": return PointCloudTransformerEncoder(device=device, dtype=torch.float32, **config) elif name == "PointCloudPerceiverEncoder": return PointCloudPerceiverEncoder(device=device, dtype=torch.float32, **config) elif name == "PointCloudTransformerChannelsEncoder": return PointCloudTransformerChannelsEncoder(device=device, dtype=torch.float32, **config) elif name == "PointCloudPerceiverChannelsEncoder": return PointCloudPerceiverChannelsEncoder(device=device, dtype=torch.float32, **config) elif name == "MultiviewTransformerEncoder": return MultiviewTransformerEncoder(device=device, dtype=torch.float32, **config) elif name == "Transmitter": renderer = model_from_config(config.pop("renderer"), device=device) param_shapes = { k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items() } encoder_config = config.pop("encoder").copy() encoder_config["param_shapes"] = param_shapes encoder = model_from_config(encoder_config, device=device) return Transmitter(encoder=encoder, renderer=renderer, **config) elif name == "VectorDecoder": renderer = model_from_config(config.pop("renderer"), device=device) param_shapes = { k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items() } return VectorDecoder(param_shapes=param_shapes, renderer=renderer, device=device, **config) elif name == "ChannelsDecoder": renderer = model_from_config(config.pop("renderer"), device=device) param_shapes = { k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items() } return ChannelsDecoder( param_shapes=param_shapes, renderer=renderer, device=device, **config ) elif name == "OneStepNeRFRenderer": config = config.copy() for field in [ # Required "void_model", "foreground_model", "volume", # Optional to use NeRF++ "background_model", "outer_volume", ]: if field in config: config[field] = model_from_config(config.pop(field).copy(), device) return OneStepNeRFRenderer(device=device, **config) elif name == "TwoStepNeRFRenderer": config = config.copy() for field in [ # Required "void_model", "coarse_model", "fine_model", "volume", # Optional to use NeRF++ "coarse_background_model", "fine_background_model", "outer_volume", ]: if field in config: config[field] = model_from_config(config.pop(field).copy(), device) return TwoStepNeRFRenderer(device=device, **config) elif name == "PooledMLP": return PooledMLP(device, **config) elif name == "PointDiffusionTransformer": return PointDiffusionTransformer(device=device, dtype=torch.float32, **config) elif name == "PointDiffusionPerceiver": return PointDiffusionPerceiver(device=device, dtype=torch.float32, **config) elif name == "CLIPImagePointDiffusionTransformer": return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config) elif name == "CLIPImageGridPointDiffusionTransformer": return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config) elif name == "UpsamplePointDiffusionTransformer": return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config) elif name == "CLIPImageGridUpsamplePointDiffusionTransformer": return CLIPImageGridUpsamplePointDiffusionTransformer( device=device, dtype=torch.float32, **config ) elif name == "SplitVectorDiffusion": inner_config = config.pop("inner") d_latent = config.pop("d_latent") latent_ctx = config.pop("latent_ctx", 1) inner_config["input_channels"] = d_latent // latent_ctx inner_config["n_ctx"] = latent_ctx inner_config["output_channels"] = d_latent // latent_ctx * 2 inner_model = model_from_config(inner_config, device) return SplitVectorDiffusion( device=device, wrapped=inner_model, n_ctx=latent_ctx, d_latent=d_latent ) elif name == "STFRenderer": config = config.copy() for field in ["sdf", "tf", "volume"]: config[field] = model_from_config(config.pop(field), device) return STFRenderer(device=device, **config) elif name == "NeRSTFRenderer": config = config.copy() for field in ["sdf", "tf", "nerstf", "void", "volume"]: if field not in config: continue config[field] = model_from_config(config.pop(field), device) config.setdefault("sdf", None) config.setdefault("tf", None) config.setdefault("nerstf", None) return NeRSTFRenderer(device=device, **config) model_cls = { "MLPSDFModel": MLPSDFModel, "MLPTextureFieldModel": MLPTextureFieldModel, "MLPNeRFModel": MLPNeRFModel, "MLPDensitySDFModel": MLPDensitySDFModel, "MLPNeRSTFModel": MLPNeRSTFModel, "VoidNeRFModel": VoidNeRFModel, "BoundingBoxVolume": BoundingBoxVolume, "SphericalVolume": SphericalVolume, "UnboundedVolume": UnboundedVolume, }[name] return model_cls(device=device, **config) ================================================ FILE: shap_e/models/download.py ================================================ """ Adapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py """ import hashlib import os from functools import lru_cache from typing import Dict, Optional import requests import torch import yaml from filelock import FileLock from tqdm.auto import tqdm MODEL_PATHS = { "transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt", "decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt", "text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt", "image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt", } CONFIG_PATHS = { "transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml", "decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml", "text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml", "image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml", "diffusion": "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml", } URL_HASHES = { "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt": "af02a0b85a8abdfb3919584b63c540ba175f6ad4790f574a7fef4617e5acdc3b", "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt": "d7e7ebbfe3780499ae89b2da5e7c1354012dba5a6abfe295bed42f25c3be1b98", "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt": "e6b4fa599a7b3c3b16c222d5f5fe56f9db9289ff0b6575fbe5c11bc97106aad4", "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt": "cb8072c64bbbcf6910488814d212227de5db291780d4ea99c6152f9346cf12aa", "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml": "ffe1bcb405104a37d9408391182ab118a4ef313c391e07689684f1f62071605e", "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml": "e6d373649f8e24d85925f4674b9ac41c57aba5f60e42cde6d10f87381326365c", "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml": "f290beeea3d3e9ff15db01bde5382b6e549e463060c0744f89c049505be246c1", "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml": "4e0745605a533c543c72add803a78d233e2a6401e0abfa0cad58afb4d74ad0b0", "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml": "efcb2cd7ee545b2d27223979d41857802448143990572a42645cd09c2942ed57", } @lru_cache() def default_cache_dir() -> str: return os.path.join(os.path.abspath(os.getcwd()), "shap_e_model_cache") def fetch_file_cached( url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096 ) -> str: """ Download the file at the given URL into a local file and return the path. If cache_dir is specified, it will be used to download the files. Otherwise, default_cache_dir() is used. """ expected_hash = URL_HASHES[url] if cache_dir is None: cache_dir = default_cache_dir() os.makedirs(cache_dir, exist_ok=True) local_path = os.path.join(cache_dir, url.split("/")[-1]) if os.path.exists(local_path): check_hash(local_path, expected_hash) return local_path response = requests.get(url, stream=True) size = int(response.headers.get("content-length", "0")) with FileLock(local_path + ".lock"): if progress: pbar = tqdm(total=size, unit="iB", unit_scale=True) tmp_path = local_path + ".tmp" with open(tmp_path, "wb") as f: for chunk in response.iter_content(chunk_size): if progress: pbar.update(len(chunk)) f.write(chunk) os.rename(tmp_path, local_path) if progress: pbar.close() check_hash(local_path, expected_hash) return local_path def check_hash(path: str, expected_hash: str): actual_hash = hash_file(path) if actual_hash != expected_hash: raise RuntimeError( f"The file {path} should have hash {expected_hash} but has {actual_hash}. " "Try deleting it and running this call again." ) def hash_file(path: str) -> str: sha256_hash = hashlib.sha256() with open(path, "rb") as file: while True: data = file.read(4096) if not len(data): break sha256_hash.update(data) return sha256_hash.hexdigest() def load_config( config_name: str, progress: bool = False, cache_dir: Optional[str] = None, chunk_size: int = 4096, ): if config_name not in CONFIG_PATHS: raise ValueError( f"Unknown config name {config_name}. Known names are: {CONFIG_PATHS.keys()}." ) path = fetch_file_cached( CONFIG_PATHS[config_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size ) with open(path, "r") as f: return yaml.safe_load(f) def load_checkpoint( checkpoint_name: str, device: torch.device, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096, ) -> Dict[str, torch.Tensor]: if checkpoint_name not in MODEL_PATHS: raise ValueError( f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}." ) path = fetch_file_cached( MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size ) return torch.load(path, map_location=device) def load_model( model_name: str, device: torch.device, **kwargs, ) -> Dict[str, torch.Tensor]: from .configs import model_from_config model = model_from_config(load_config(model_name, **kwargs), device=device) model.load_state_dict(load_checkpoint(model_name, device=device, **kwargs)) model.eval() return model ================================================ FILE: shap_e/models/generation/__init__.py ================================================ ================================================ FILE: shap_e/models/generation/latent_diffusion.py ================================================ from typing import Any, Dict import torch import torch.nn as nn class SplitVectorDiffusion(nn.Module): def __init__(self, *, device: torch.device, wrapped: nn.Module, n_ctx: int, d_latent: int): super().__init__() self.device = device self.n_ctx = n_ctx self.d_latent = d_latent self.wrapped = wrapped if hasattr(self.wrapped, "cached_model_kwargs"): self.cached_model_kwargs = self.wrapped.cached_model_kwargs def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs): h = x.reshape(x.shape[0], self.n_ctx, -1).permute(0, 2, 1) pre_channels = h.shape[1] h = self.wrapped(h, t, **kwargs) assert ( h.shape[1] == pre_channels * 2 ), "expected twice as many outputs for variance prediction" eps, var = torch.chunk(h, 2, dim=1) return torch.cat( [ eps.permute(0, 2, 1).flatten(1), var.permute(0, 2, 1).flatten(1), ], dim=1, ) ================================================ FILE: shap_e/models/generation/perceiver.py ================================================ import math from typing import Optional import torch import torch.nn as nn from shap_e.models.nn.checkpoint import checkpoint from .transformer import MLP, Transformer, init_linear from .util import timestep_embedding class MultiheadCrossAttention(nn.Module): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int, n_data: int, width: int, heads: int, init_scale: float, data_width: Optional[int] = None, ): super().__init__() self.n_ctx = n_ctx self.n_data = n_data self.width = width self.heads = heads self.data_width = width if data_width is None else data_width self.c_q = nn.Linear(width, width, device=device, dtype=dtype) self.c_kv = nn.Linear(self.data_width, width * 2, device=device, dtype=dtype) self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) self.attention = QKVMultiheadCrossAttention( device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, n_data=n_data ) init_linear(self.c_q, init_scale) init_linear(self.c_kv, init_scale) init_linear(self.c_proj, init_scale) def forward(self, x, data): x = self.c_q(x) data = self.c_kv(data) x = checkpoint(self.attention, (x, data), (), True) x = self.c_proj(x) return x class QKVMultiheadCrossAttention(nn.Module): def __init__( self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, n_data: int ): super().__init__() self.device = device self.dtype = dtype self.heads = heads self.n_ctx = n_ctx self.n_data = n_data def forward(self, q, kv): _, n_ctx, _ = q.shape bs, n_data, width = kv.shape attn_ch = width // self.heads // 2 scale = 1 / math.sqrt(math.sqrt(attn_ch)) q = q.view(bs, n_ctx, self.heads, -1) kv = kv.view(bs, n_data, self.heads, -1) k, v = torch.split(kv, attn_ch, dim=-1) weight = torch.einsum( "bthc,bshc->bhts", q * scale, k * scale ) # More stable with f16 than dividing afterwards wdtype = weight.dtype weight = torch.softmax(weight.float(), dim=-1).type(wdtype) return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) class ResidualCrossAttentionBlock(nn.Module): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int, n_data: int, width: int, heads: int, data_width: Optional[int] = None, init_scale: float = 1.0, ): super().__init__() if data_width is None: data_width = width self.attn = MultiheadCrossAttention( device=device, dtype=dtype, n_ctx=n_ctx, n_data=n_data, width=width, heads=heads, data_width=data_width, init_scale=init_scale, ) self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) def forward(self, x: torch.Tensor, data: torch.Tensor): x = x + self.attn(self.ln_1(x), self.ln_2(data)) x = x + self.mlp(self.ln_3(x)) return x class SimplePerceiver(nn.Module): """ Only does cross attention """ def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int, n_data: int, width: int, layers: int, heads: int, init_scale: float = 0.25, data_width: Optional[int] = None, ): super().__init__() self.n_ctx = n_ctx self.width = width self.layers = layers init_scale = init_scale * math.sqrt(1.0 / width) self.resblocks = nn.ModuleList( [ ResidualCrossAttentionBlock( device=device, dtype=dtype, n_ctx=n_ctx, n_data=n_data, width=width, heads=heads, init_scale=init_scale, data_width=data_width, ) for _ in range(layers) ] ) def forward(self, x: torch.Tensor, data: torch.Tensor): for block in self.resblocks: x = block(x, data) return x class PointDiffusionPerceiver(nn.Module): def __init__( self, *, device: torch.device, dtype: torch.dtype, input_channels: int = 3, output_channels: int = 3, n_ctx: int = 1024, n_latent: int = 128, width: int = 512, encoder_layers: int = 12, latent_layers: int = 12, decoder_layers: int = 12, heads: int = 8, init_scale: float = 0.25, ): super().__init__() self.time_embed = MLP( device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width) ) self.latent_embed = MLP( device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width) ) self.n_latent = n_latent self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) self.encoder = SimplePerceiver( device=device, dtype=dtype, n_ctx=n_latent, n_data=n_ctx, width=width, layers=encoder_layers, heads=heads, init_scale=init_scale, ) self.processor = Transformer( device=device, dtype=dtype, n_ctx=n_latent, width=width, layers=latent_layers, heads=heads, init_scale=init_scale, ) self.decoder = SimplePerceiver( device=device, dtype=dtype, n_ctx=n_ctx, n_data=n_latent, width=width, layers=decoder_layers, heads=heads, init_scale=init_scale, ) self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) with torch.no_grad(): self.output_proj.weight.zero_() self.output_proj.bias.zero_() def forward(self, x: torch.Tensor, t: torch.Tensor): """ :param x: an [N x C x T] tensor. :param t: an [N] tensor. :return: an [N x C' x T] tensor. """ assert x.shape[-1] == self.decoder.n_ctx t_embed = self.time_embed(timestep_embedding(t, self.encoder.width)) data = self.input_proj(x.permute(0, 2, 1)) + t_embed[:, None] data = self.ln_pre(data) l = torch.arange(self.n_latent).to(x.device) h = self.latent_embed(timestep_embedding(l, self.decoder.width)) h = h.unsqueeze(0).repeat(x.shape[0], 1, 1) h = self.encoder(h, data) h = self.processor(h) h = self.decoder(data, h) h = self.ln_post(h) h = self.output_proj(h) return h.permute(0, 2, 1) ================================================ FILE: shap_e/models/generation/pooled_mlp.py ================================================ import torch import torch.nn as nn from .util import timestep_embedding class PooledMLP(nn.Module): def __init__( self, device: torch.device, *, input_channels: int = 3, output_channels: int = 6, hidden_size: int = 256, resblocks: int = 4, pool_op: str = "max", ): super().__init__() self.input_embed = nn.Conv1d(input_channels, hidden_size, kernel_size=1, device=device) self.time_embed = nn.Linear(hidden_size, hidden_size, device=device) blocks = [] for _ in range(resblocks): blocks.append(ResBlock(hidden_size, pool_op, device=device)) self.sequence = nn.Sequential(*blocks) self.out = nn.Conv1d(hidden_size, output_channels, kernel_size=1, device=device) with torch.no_grad(): self.out.bias.zero_() self.out.weight.zero_() def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: in_embed = self.input_embed(x) t_embed = self.time_embed(timestep_embedding(t, in_embed.shape[1])) h = in_embed + t_embed[..., None] h = self.sequence(h) h = self.out(h) return h class ResBlock(nn.Module): def __init__(self, hidden_size: int, pool_op: str, device: torch.device): super().__init__() assert pool_op in ["mean", "max"] self.pool_op = pool_op self.body = nn.Sequential( nn.SiLU(), nn.LayerNorm((hidden_size,), device=device), nn.Linear(hidden_size, hidden_size, device=device), nn.SiLU(), nn.LayerNorm((hidden_size,), device=device), nn.Linear(hidden_size, hidden_size, device=device), ) self.gate = nn.Sequential( nn.Linear(hidden_size, hidden_size, device=device), nn.Tanh(), ) def forward(self, x: torch.Tensor): N, C, T = x.shape out = self.body(x.permute(0, 2, 1).reshape(N * T, C)).reshape([N, T, C]).permute(0, 2, 1) pooled = pool(self.pool_op, x) gate = self.gate(pooled) return x + out * gate[..., None] def pool(op_name: str, x: torch.Tensor) -> torch.Tensor: if op_name == "max": pooled, _ = torch.max(x, dim=-1) elif op_name == "mean": pooled, _ = torch.mean(x, dim=-1) else: raise ValueError(f"unknown pool op: {op_name}") return pooled ================================================ FILE: shap_e/models/generation/pretrained_clip.py ================================================ from typing import Iterable, List, Optional, Union import numpy as np import torch import torch.nn as nn from PIL import Image from shap_e.models.download import default_cache_dir ImageType = Union[np.ndarray, torch.Tensor, Image.Image] class ImageCLIP(nn.Module): """ A wrapper around a pre-trained CLIP model that automatically handles batches of texts, images, and embeddings. """ def __init__( self, device: torch.device, dtype: Optional[torch.dtype] = torch.float32, ensure_used_params: bool = True, clip_name: str = "ViT-L/14", cache_dir: Optional[str] = None, ): super().__init__() assert clip_name in ["ViT-L/14", "ViT-B/32"] self.device = device self.ensure_used_params = ensure_used_params # Lazy import because of torchvision. import clip self.clip_model, self.preprocess = clip.load( clip_name, device=device, download_root=cache_dir or default_cache_dir() ) self.clip_name = clip_name if dtype is not None: self.clip_model.to(dtype) self._tokenize = clip.tokenize @property def feature_dim(self) -> int: if self.clip_name == "ViT-L/14": return 768 else: return 512 @property def grid_size(self) -> int: if self.clip_name == "ViT-L/14": return 16 else: return 7 @property def grid_feature_dim(self) -> int: if self.clip_name == "ViT-L/14": return 1024 else: return 768 def forward( self, batch_size: int, images: Optional[Iterable[Optional[ImageType]]] = None, texts: Optional[Iterable[Optional[str]]] = None, embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, ) -> torch.Tensor: """ Generate a batch of embeddings from a mixture of images, texts, precomputed embeddings, and possibly empty values. For each batch element, at most one of images, texts, and embeddings should have a non-None value. Embeddings from multiple modalities cannot be mixed for a single batch element. If no modality is provided, a zero embedding will be used for the batch element. """ image_seq = [None] * batch_size if images is None else list(images) text_seq = [None] * batch_size if texts is None else list(texts) embedding_seq = [None] * batch_size if embeddings is None else list(embeddings) assert len(image_seq) == batch_size, "number of images should match batch size" assert len(text_seq) == batch_size, "number of texts should match batch size" assert len(embedding_seq) == batch_size, "number of embeddings should match batch size" if self.ensure_used_params: return self._static_multimodal_embed( images=image_seq, texts=text_seq, embeddings=embedding_seq ) result = torch.zeros((batch_size, self.feature_dim), device=self.device) index_images = [] index_texts = [] for i, (image, text, emb) in enumerate(zip(image_seq, text_seq, embedding_seq)): assert ( sum([int(image is not None), int(text is not None), int(emb is not None)]) < 2 ), "only one modality may be non-None per batch element" if image is not None: index_images.append((i, image)) elif text is not None: index_texts.append((i, text)) elif emb is not None: result[i] = emb.to(result) if len(index_images): embs = self.embed_images((img for _, img in index_images)) for (i, _), emb in zip(index_images, embs): result[i] = emb.to(result) if len(index_texts): embs = self.embed_text((text for _, text in index_texts)) for (i, _), emb in zip(index_texts, embs): result[i] = emb.to(result) return result def _static_multimodal_embed( self, images: List[Optional[ImageType]] = None, texts: List[Optional[str]] = None, embeddings: List[Optional[torch.Tensor]] = None, ) -> torch.Tensor: """ Like forward(), but always runs all encoders to ensure that the forward graph looks the same on every rank. """ image_emb = self.embed_images(images) text_emb = self.embed_text(t if t else "" for t in texts) joined_embs = torch.stack( [ emb.to(device=self.device, dtype=torch.float32) if emb is not None else torch.zeros(self.feature_dim, device=self.device) for emb in embeddings ], dim=0, ) image_flag = torch.tensor([x is not None for x in images], device=self.device)[ :, None ].expand_as(image_emb) text_flag = torch.tensor([x is not None for x in texts], device=self.device)[ :, None ].expand_as(image_emb) emb_flag = torch.tensor([x is not None for x in embeddings], device=self.device)[ :, None ].expand_as(image_emb) return ( image_flag.float() * image_emb + text_flag.float() * text_emb + emb_flag.float() * joined_embs + self.clip_model.logit_scale * 0 # avoid unused parameters ) def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: """ :param xs: N images, stored as numpy arrays, tensors, or PIL images. :return: an [N x D] tensor of features. """ clip_inputs = self.images_to_tensor(xs) results = self.clip_model.encode_image(clip_inputs).float() return results / torch.linalg.norm(results, dim=-1, keepdim=True) def embed_text(self, prompts: Iterable[str]) -> torch.Tensor: """ Embed text prompts as an [N x D] tensor. """ enc = self.clip_model.encode_text( self._tokenize(list(prompts), truncate=True).to(self.device) ).float() return enc / torch.linalg.norm(enc, dim=-1, keepdim=True) def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: """ Embed images into latent grids. :param xs: an iterable of images to embed. :return: a tensor of shape [N x C x L], where L = self.grid_size**2. """ if self.ensure_used_params: extra_value = 0.0 for p in self.parameters(): extra_value = extra_value + p.mean() * 0.0 else: extra_value = 0.0 x = self.images_to_tensor(xs).to(self.clip_model.dtype) # https://github.com/openai/CLIP/blob/4d120f3ec35b30bd0f992f5d8af2d793aad98d2a/clip/model.py#L225 vt = self.clip_model.visual x = vt.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat( [ vt.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x, ], dim=1, ) # shape = [*, grid ** 2 + 1, width] x = x + vt.positional_embedding.to(x.dtype) x = vt.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = vt.transformer(x) x = x.permute(1, 2, 0) # LND -> NDL return x[..., 1:].contiguous().float() + extra_value def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0).to(self.device) class FrozenImageCLIP: def __init__(self, device: torch.device, **kwargs): self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **kwargs) for parameter in self.model.parameters(): parameter.requires_grad_(False) @property def feature_dim(self) -> int: return self.model.feature_dim @property def grid_size(self) -> int: return self.model.grid_size @property def grid_feature_dim(self) -> int: return self.model.grid_feature_dim def __call__( self, batch_size: int, images: Optional[Iterable[Optional[ImageType]]] = None, texts: Optional[Iterable[Optional[str]]] = None, embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, ) -> torch.Tensor: # We don't do a no_grad() here so that gradients could still # flow to the input embeddings argument. # This behavior is currently not used, but it could be. return self.model(batch_size=batch_size, images=images, texts=texts, embeddings=embeddings) def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: with torch.no_grad(): return self.model.embed_images(xs) def embed_text(self, prompts: Iterable[str]) -> torch.Tensor: with torch.no_grad(): return self.model.embed_text(prompts) def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: with torch.no_grad(): return self.model.embed_images_grid(xs) def _image_to_pil(obj: Optional[ImageType]) -> Image.Image: if obj is None: return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8)) if isinstance(obj, np.ndarray): return Image.fromarray(obj.astype(np.uint8)) elif isinstance(obj, torch.Tensor): return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8)) else: return obj ================================================ FILE: shap_e/models/generation/transformer.py ================================================ import math from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple import torch import torch.nn as nn from shap_e.models.nn.checkpoint import checkpoint from .pretrained_clip import FrozenImageCLIP, ImageCLIP, ImageType from .util import timestep_embedding def init_linear(l, stddev): nn.init.normal_(l.weight, std=stddev) if l.bias is not None: nn.init.constant_(l.bias, 0.0) class MultiheadAttention(nn.Module): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int, width: int, heads: int, init_scale: float, ): super().__init__() self.n_ctx = n_ctx self.width = width self.heads = heads self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype) self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx) init_linear(self.c_qkv, init_scale) init_linear(self.c_proj, init_scale) def forward(self, x): x = self.c_qkv(x) x = checkpoint(self.attention, (x,), (), True) x = self.c_proj(x) return x class MLP(nn.Module): def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float): super().__init__() self.width = width self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) self.gelu = nn.GELU() init_linear(self.c_fc, init_scale) init_linear(self.c_proj, init_scale) def forward(self, x): return self.c_proj(self.gelu(self.c_fc(x))) class QKVMultiheadAttention(nn.Module): def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int): super().__init__() self.device = device self.dtype = dtype self.heads = heads self.n_ctx = n_ctx def forward(self, qkv): bs, n_ctx, width = qkv.shape attn_ch = width // self.heads // 3 scale = 1 / math.sqrt(math.sqrt(attn_ch)) qkv = qkv.view(bs, n_ctx, self.heads, -1) q, k, v = torch.split(qkv, attn_ch, dim=-1) weight = torch.einsum( "bthc,bshc->bhts", q * scale, k * scale ) # More stable with f16 than dividing afterwards wdtype = weight.dtype weight = torch.softmax(weight.float(), dim=-1).type(wdtype) return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) class ResidualAttentionBlock(nn.Module): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int, width: int, heads: int, init_scale: float = 1.0, ): super().__init__() self.attn = MultiheadAttention( device=device, dtype=dtype, n_ctx=n_ctx, width=width, heads=heads, init_scale=init_scale, ) self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) def forward(self, x: torch.Tensor): x = x + self.attn(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class Transformer(nn.Module): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int, width: int, layers: int, heads: int, init_scale: float = 0.25, ): super().__init__() self.n_ctx = n_ctx self.width = width self.layers = layers init_scale = init_scale * math.sqrt(1.0 / width) self.resblocks = nn.ModuleList( [ ResidualAttentionBlock( device=device, dtype=dtype, n_ctx=n_ctx, width=width, heads=heads, init_scale=init_scale, ) for _ in range(layers) ] ) def forward(self, x: torch.Tensor): for block in self.resblocks: x = block(x) return x class PointDiffusionTransformer(nn.Module): def __init__( self, *, device: torch.device, dtype: torch.dtype, input_channels: int = 3, output_channels: int = 3, n_ctx: int = 1024, width: int = 512, layers: int = 12, heads: int = 8, init_scale: float = 0.25, time_token_cond: bool = False, use_pos_emb: bool = False, pos_emb_init_scale: float = 1.0, pos_emb_n_ctx: Optional[int] = None, ): super().__init__() self.input_channels = input_channels self.output_channels = output_channels self.n_ctx = n_ctx self.time_token_cond = time_token_cond self.use_pos_emb = use_pos_emb self.time_embed = MLP( device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width) ) self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) self.backbone = Transformer( device=device, dtype=dtype, n_ctx=n_ctx + int(time_token_cond), width=width, layers=layers, heads=heads, init_scale=init_scale, ) self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) with torch.no_grad(): self.output_proj.weight.zero_() self.output_proj.bias.zero_() if self.use_pos_emb: self.register_parameter( "pos_emb", nn.Parameter( pos_emb_init_scale * torch.randn(pos_emb_n_ctx or self.n_ctx, width, device=device, dtype=dtype) ), ) def forward(self, x: torch.Tensor, t: torch.Tensor): """ :param x: an [N x C x T] tensor. :param t: an [N] tensor. :return: an [N x C' x T] tensor. """ assert x.shape[-1] == self.n_ctx t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) return self._forward_with_cond(x, [(t_embed, self.time_token_cond)]) def _forward_with_cond( self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]] ) -> torch.Tensor: h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC for emb, as_token in cond_as_token: if not as_token: h = h + emb[:, None] if self.use_pos_emb: h = h + self.pos_emb extra_tokens = [ (emb[:, None] if len(emb.shape) == 2 else emb) for emb, as_token in cond_as_token if as_token ] if len(extra_tokens): h = torch.cat(extra_tokens + [h], dim=1) h = self.ln_pre(h) h = self.backbone(h) h = self.ln_post(h) if len(extra_tokens): h = h[:, sum(h.shape[1] for h in extra_tokens) :] h = self.output_proj(h) return h.permute(0, 2, 1) class CLIPImagePointDiffusionTransformer(PointDiffusionTransformer): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int = 1024, token_cond: bool = False, cond_drop_prob: float = 0.0, frozen_clip: bool = True, **kwargs, ): super().__init__( device=device, dtype=dtype, n_ctx=n_ctx + int(token_cond), pos_emb_n_ctx=n_ctx, **kwargs ) self.n_ctx = n_ctx self.token_cond = token_cond self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device) self.clip_embed = nn.Linear( self.clip.feature_dim, self.backbone.width, device=device, dtype=dtype ) self.cond_drop_prob = cond_drop_prob def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: with torch.no_grad(): return dict(embeddings=self.clip(batch_size, **model_kwargs)) def forward( self, x: torch.Tensor, t: torch.Tensor, images: Optional[Iterable[Optional[ImageType]]] = None, texts: Optional[Iterable[Optional[str]]] = None, embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, ): """ :param x: an [N x C x T] tensor. :param t: an [N] tensor. :param images: a batch of images to condition on. :param texts: a batch of texts to condition on. :param embeddings: a batch of CLIP embeddings to condition on. :return: an [N x C' x T] tensor. """ assert x.shape[-1] == self.n_ctx t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) clip_out = self.clip(batch_size=len(x), images=images, texts=texts, embeddings=embeddings) assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0] if self.training: mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob clip_out = clip_out * mask[:, None].to(clip_out) # Rescale the features to have unit variance clip_out = math.sqrt(clip_out.shape[1]) * clip_out clip_embed = self.clip_embed(clip_out) cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)] return self._forward_with_cond(x, cond) class CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int = 1024, cond_drop_prob: float = 0.0, frozen_clip: bool = True, **kwargs, ): clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device) super().__init__( device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, pos_emb_n_ctx=n_ctx, **kwargs, ) self.n_ctx = n_ctx self.clip = clip self.clip_embed = nn.Sequential( nn.LayerNorm( normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype ), nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype), ) self.cond_drop_prob = cond_drop_prob def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: _ = batch_size with torch.no_grad(): return dict(embeddings=self.clip.embed_images_grid(model_kwargs["images"])) def forward( self, x: torch.Tensor, t: torch.Tensor, images: Optional[Iterable[ImageType]] = None, embeddings: Optional[Iterable[torch.Tensor]] = None, ): """ :param x: an [N x C x T] tensor. :param t: an [N] tensor. :param images: a batch of images to condition on. :param embeddings: a batch of CLIP latent grids to condition on. :return: an [N x C' x T] tensor. """ assert images is not None or embeddings is not None, "must specify images or embeddings" assert images is None or embeddings is None, "cannot specify both images and embeddings" assert x.shape[-1] == self.n_ctx t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) if images is not None: clip_out = self.clip.embed_images_grid(images) else: clip_out = embeddings if self.training: mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob clip_out = clip_out * mask[:, None, None].to(clip_out) clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC clip_embed = self.clip_embed(clip_out) cond = [(t_embed, self.time_token_cond), (clip_embed, True)] return self._forward_with_cond(x, cond) class UpsamplePointDiffusionTransformer(PointDiffusionTransformer): def __init__( self, *, device: torch.device, dtype: torch.dtype, cond_input_channels: Optional[int] = None, cond_ctx: int = 1024, n_ctx: int = 4096 - 1024, channel_scales: Optional[Sequence[float]] = None, channel_biases: Optional[Sequence[float]] = None, **kwargs, ): super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **kwargs) self.n_ctx = n_ctx self.cond_input_channels = cond_input_channels or self.input_channels self.cond_point_proj = nn.Linear( self.cond_input_channels, self.backbone.width, device=device, dtype=dtype ) self.register_buffer( "channel_scales", torch.tensor(channel_scales, dtype=dtype, device=device) if channel_scales is not None else None, ) self.register_buffer( "channel_biases", torch.tensor(channel_biases, dtype=dtype, device=device) if channel_biases is not None else None, ) def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor): """ :param x: an [N x C1 x T] tensor. :param t: an [N] tensor. :param low_res: an [N x C2 x T'] tensor of conditioning points. :return: an [N x C3 x T] tensor. """ assert x.shape[-1] == self.n_ctx t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) low_res_embed = self._embed_low_res(low_res) cond = [(t_embed, self.time_token_cond), (low_res_embed, True)] return self._forward_with_cond(x, cond) def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor: if self.channel_scales is not None: x = x * self.channel_scales[None, :, None] if self.channel_biases is not None: x = x + self.channel_biases[None, :, None] return self.cond_point_proj(x.permute(0, 2, 1)) class CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffusionTransformer): def __init__( self, *, device: torch.device, dtype: torch.dtype, n_ctx: int = 4096 - 1024, cond_drop_prob: float = 0.0, frozen_clip: bool = True, **kwargs, ): clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device) super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs) self.n_ctx = n_ctx self.clip = clip self.clip_embed = nn.Sequential( nn.LayerNorm( normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype ), nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype), ) self.cond_drop_prob = cond_drop_prob def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: _ = batch_size with torch.no_grad(): return dict( embeddings=self.clip.embed_images_grid(model_kwargs["images"]), low_res=model_kwargs["low_res"], ) def forward( self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor, images: Optional[Iterable[ImageType]] = None, embeddings: Optional[Iterable[torch.Tensor]] = None, ): """ :param x: an [N x C1 x T] tensor. :param t: an [N] tensor. :param low_res: an [N x C2 x T'] tensor of conditioning points. :param images: a batch of images to condition on. :param embeddings: a batch of CLIP latent grids to condition on. :return: an [N x C3 x T] tensor. """ assert x.shape[-1] == self.n_ctx t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) low_res_embed = self._embed_low_res(low_res) if images is not None: clip_out = self.clip.embed_images_grid(images) else: clip_out = embeddings if self.training: mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob clip_out = clip_out * mask[:, None, None].to(clip_out) clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC clip_embed = self.clip_embed(clip_out) cond = [(t_embed, self.time_token_cond), (clip_embed, True), (low_res_embed, True)] return self._forward_with_cond(x, cond) ================================================ FILE: shap_e/models/generation/util.py ================================================ import math import torch def timestep_embedding(timesteps, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].to(timesteps.dtype) * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding ================================================ FILE: shap_e/models/nerf/__init__.py ================================================ ================================================ FILE: shap_e/models/nerf/model.py ================================================ from abc import ABC, abstractmethod from functools import partial from typing import Any, Dict, Optional, Tuple import numpy as np import torch import torch.nn as nn from shap_e.models.nn.checkpoint import checkpoint from shap_e.models.nn.encoding import encode_position, spherical_harmonics_basis from shap_e.models.nn.meta import MetaModule, subdict from shap_e.models.nn.ops import MLP, MetaMLP, get_act, mlp_init, zero_init from shap_e.models.nn.utils import ArrayType from shap_e.models.query import Query from shap_e.util.collections import AttrDict class NeRFModel(ABC): """ Parametric scene representation whose outputs are integrated by NeRFRenderer """ @abstractmethod def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict: """ :param query: the points in the field to query. :param params: Meta parameters :param options: Optional hyperparameters :return: An AttrDict containing at least - density: [batch_size x ... x 1] - channels: [batch_size x ... x n_channels] - aux_losses: [batch_size x ... x 1] """ class VoidNeRFModel(MetaModule, NeRFModel): """ Implements the default empty space model where all queries are rendered as background. """ def __init__( self, background: ArrayType, trainable: bool = False, channel_scale: float = 255.0, device: torch.device = torch.device("cuda"), ): super().__init__() background = nn.Parameter( torch.from_numpy(np.array(background)).to(dtype=torch.float32, device=device) / channel_scale ) if trainable: self.register_parameter("background", background) else: self.register_buffer("background", background) def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict: _ = params default_bg = self.background[None] background = options.get("background", default_bg) if options is not None else default_bg shape = query.position.shape[:-1] ones = [1] * (len(shape) - 1) n_channels = background.shape[-1] background = torch.broadcast_to( background.view(background.shape[0], *ones, n_channels), [*shape, n_channels] ) return background class MLPNeRFModel(MetaModule, NeRFModel): def __init__( self, # Positional encoding parameters n_levels: int = 10, # MLP parameters d_hidden: int = 256, n_density_layers: int = 4, n_channel_layers: int = 1, n_channels: int = 3, sh_degree: int = 4, activation: str = "relu", density_activation: str = "exp", init: Optional[str] = None, init_scale: float = 1.0, output_activation: str = "sigmoid", meta_parameters: bool = False, trainable_meta: bool = False, zero_out: bool = True, register_freqs: bool = True, posenc_version: str = "v1", device: torch.device = torch.device("cuda"), ): super().__init__() # Positional encoding if register_freqs: # not used anymore self.register_buffer( "freqs", 2.0 ** torch.arange(n_levels, device=device, dtype=torch.float).view(1, n_levels), ) self.posenc_version = posenc_version dummy = torch.eye(1, 3) d_input = encode_position(posenc_version, position=dummy).shape[-1] self.n_levels = n_levels self.sh_degree = sh_degree d_sh_coeffs = sh_degree**2 self.meta_parameters = meta_parameters mlp_cls = ( partial( MetaMLP, meta_scale=False, meta_shift=False, meta_proj=True, meta_bias=True, trainable_meta=trainable_meta, ) if meta_parameters else MLP ) self.density_mlp = mlp_cls( d_input=d_input, d_hidden=[d_hidden] * (n_density_layers - 1), d_output=d_hidden, act_name=activation, init_scale=init_scale, ) self.channel_mlp = mlp_cls( d_input=d_hidden + d_sh_coeffs, d_hidden=[d_hidden] * n_channel_layers, d_output=n_channels, act_name=activation, init_scale=init_scale, ) self.act = get_act(output_activation) self.density_act = get_act(density_activation) mlp_init( list(self.density_mlp.affines) + list(self.channel_mlp.affines), init=init, init_scale=init_scale, ) if zero_out: zero_init(self.channel_mlp.affines[-1]) self.to(device) def encode_position(self, query: Query): h = encode_position(self.posenc_version, position=query.position) return h def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict: params = self.update(params) options = AttrDict() if options is None else AttrDict(options) query = query.copy() h_position = self.encode_position(query) if self.meta_parameters: density_params = subdict(params, "density_mlp") density_mlp = partial( self.density_mlp, params=density_params, options=options, log_prefix="density_" ) density_mlp_parameters = list(density_params.values()) else: density_mlp = partial(self.density_mlp, options=options, log_prefix="density_") density_mlp_parameters = self.density_mlp.parameters() h_density = checkpoint( density_mlp, (h_position,), density_mlp_parameters, options.checkpoint_nerf_mlp, ) h_direction = maybe_get_spherical_harmonics_basis( sh_degree=self.sh_degree, coords_shape=query.position.shape, coords=query.direction, device=query.position.device, ) if self.meta_parameters: channel_params = subdict(params, "channel_mlp") channel_mlp = partial( self.channel_mlp, params=channel_params, options=options, log_prefix="channel_" ) channel_mlp_parameters = list(channel_params.values()) else: channel_mlp = partial(self.channel_mlp, options=options, log_prefix="channel_") channel_mlp_parameters = self.channel_mlp.parameters() h_channel = checkpoint( channel_mlp, (torch.cat([h_density, h_direction], dim=-1),), channel_mlp_parameters, options.checkpoint_nerf_mlp, ) density_logit = h_density[..., :1] res = AttrDict( density_logit=density_logit, density=self.density_act(density_logit), channels=self.act(h_channel), aux_losses=AttrDict(), no_weight_grad_aux_losses=AttrDict(), ) if options.return_h_density: res.h_density = h_density return res def maybe_get_spherical_harmonics_basis( sh_degree: int, coords_shape: Tuple[int], coords: Optional[torch.Tensor] = None, device: torch.device = torch.device("cuda"), ) -> torch.Tensor: """ :param sh_degree: Spherical harmonics degree :param coords_shape: [*shape, 3] :param coords: optional coordinate tensor of coords_shape """ if coords is None: return torch.zeros(*coords_shape[:-1], sh_degree**2).to(device) return spherical_harmonics_basis(coords, sh_degree) ================================================ FILE: shap_e/models/nerf/ray.py ================================================ from abc import ABC, abstractmethod from dataclasses import dataclass from functools import partial from typing import Any, Dict, List, Optional, Tuple import torch from shap_e.models.nn.utils import sample_pmf from shap_e.models.volume import Volume, VolumeRange from shap_e.util.collections import AttrDict from .model import NeRFModel, Query def render_rays( rays: torch.Tensor, parts: List["RayVolumeIntegral"], void_model: NeRFModel, shared: bool = False, prev_raw_outputs: Optional[List[AttrDict]] = None, render_with_direction: bool = True, importance_sampling_options: Optional[Dict[str, Any]] = None, ) -> Tuple["RayVolumeIntegralResults", List["RaySampler"], List[AttrDict]]: """ Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below with some abuse of notations) C(r) := sum( transmittance(t[i]) * integrate( lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]], ) for i in range(len(parts)) ) + transmittance(t[-1]) * void_model(t[-1]).channels where 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). :param rays: [batch_size x ... x 2 x 3] origin and direction. :param parts: disjoint volume integrals. :param void_model: use this model to integrate over the empty space :param shared: All RayVolumeIntegrals are calculated with the same model. :param prev_raw_outputs: Raw outputs from the previous rendering step :return: A tuple of - AttrDict containing the rendered `channels`, `distances`, and the `aux_losses` - A list of importance samplers for additional fine-grained rendering - A list of raw output for each interval """ if importance_sampling_options is None: importance_sampling_options = {} origin, direc = rays[..., 0, :], rays[..., 1, :] if prev_raw_outputs is None: prev_raw_outputs = [None] * len(parts) samplers = [] raw_outputs = [] t0 = None results = None for part_i, prev_raw_i in zip(parts, prev_raw_outputs): # Integrate over [t[i], t[i + 1]] results_i = part_i.render_rays( origin, direc, t0=t0, prev_raw=prev_raw_i, shared=shared, render_with_direction=render_with_direction, ) # Create an importance sampler for (optional) fine rendering samplers.append( ImportanceRaySampler( results_i.volume_range, results_i.raw, **importance_sampling_options ) ) raw_outputs.append(results_i.raw) # Pass t[i + 1] as the start of integration for the next interval. t0 = results_i.volume_range.next_t0() # Combine the results from [t[0], t[i]] and [t[i], t[i+1]] results = results_i if results is None else results.combine(results_i) # While integrating out [t[-1], math.inf] is the correct thing to do, this # erases a lot of useful information. Also, void_model is meant to predict # the channels at t=math.inf. # # Add the void background over [t[-1], math.inf] to complete integration. # results = results.combine( # RayVolumeIntegralResults( # output=AttrDict( # channels=void_model(origin, direc), # distances=torch.zeros_like(t0), # aux_losses=AttrDict(), # ), # volume_range=VolumeRange( # t0=t0, # t1=torch.full_like(t0, math.inf), # intersected=torch.full_like(results.volume_range.intersected, True), # ), # # Void space extends to infinity. It is assumed that no light # # passes beyond the void. # transmittance=torch.zeros_like(results_i.transmittance), # ) # ) results.output.channels = results.output.channels + results.transmittance * void_model( Query(origin, direc) ) return results, samplers, raw_outputs @dataclass class RayVolumeIntegralResults: """ Stores the relevant state and results of integrate( lambda t: density(t) * channels(t) * transmittance(t), [t0, t1], ) """ # Rendered output and auxiliary losses # output.channels has shape [batch_size, *inner_shape, n_channels] output: AttrDict """ Optional values """ # Raw values contain the sampled `ts`, `density`, `channels`, etc. raw: Optional[AttrDict] = None # Integration volume_range: Optional[VolumeRange] = None # If a ray intersects, the transmittance from t0 to t1 (e.g. the # probability that the ray passes through this volume). # has shape [batch_size, *inner_shape, 1] transmittance: Optional[torch.Tensor] = None def combine(self, cur: "RayVolumeIntegralResults") -> "RayVolumeIntegralResults": """ Combines the integration results of `self` over [t0, t1] and `cur` over [t1, t2] to produce a new set of results over [t0, t2] by using a similar equation to (4) in NeRF++: integrate( lambda t: density(t) * channels(t) * transmittance(t), [t0, t2] ) = integrate( lambda t: density(t) * channels(t) * transmittance(t), [t0, t1] ) + transmittance(t1) * integrate( lambda t: density(t) * channels(t) * transmittance(t), [t1, t2] ) """ assert torch.allclose(self.volume_range.next_t0(), cur.volume_range.t0) def _combine_fn( prev_val: Optional[torch.Tensor], cur_val: Optional[torch.Tensor], *, prev_transmittance: torch.Tensor, ): assert prev_val is not None if cur_val is None: # cur_output.aux_losses are empty for the void_model. return prev_val return prev_val + prev_transmittance * cur_val output = self.output.combine( cur.output, combine_fn=partial(_combine_fn, prev_transmittance=self.transmittance) ) combined = RayVolumeIntegralResults( output=output, volume_range=self.volume_range.extend(cur.volume_range), transmittance=self.transmittance * cur.transmittance, ) return combined @dataclass class RayVolumeIntegral: model: NeRFModel volume: Volume sampler: "RaySampler" n_samples: int def render_rays( self, origin: torch.Tensor, direction: torch.Tensor, t0: Optional[torch.Tensor] = None, prev_raw: Optional[AttrDict] = None, shared: bool = False, render_with_direction: bool = True, ) -> "RayVolumeIntegralResults": """ Perform volumetric rendering over the given volume. :param position: [batch_size, *shape, 3] :param direction: [batch_size, *shape, 3] :param t0: Optional [batch_size, *shape, 1] :param prev_raw: the raw outputs when using multiple levels with this model. :param shared: means the same model is used for all RayVolumeIntegral's :param render_with_direction: use the incoming ray direction when querying the model. :return: RayVolumeIntegralResults """ # 1. Intersect the rays with the current volume and sample ts to # integrate along. vrange = self.volume.intersect(origin, direction, t0_lower=t0) ts = self.sampler.sample(vrange.t0, vrange.t1, self.n_samples) if prev_raw is not None and not shared: # Append the previous ts now before fprop because previous # rendering used a different model and we can't reuse the output. ts = torch.sort(torch.cat([ts, prev_raw.ts], dim=-2), dim=-2).values # Shape sanity checks batch_size, *_shape, _t0_dim = vrange.t0.shape _, *ts_shape, _ts_dim = ts.shape # 2. Get the points along the ray and query the model directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) positions = origin.unsqueeze(-2) + ts * directions optional_directions = directions if render_with_direction else None mids = (ts[..., 1:, :] + ts[..., :-1, :]) / 2 raw = self.model( Query( position=positions, direction=optional_directions, t_min=torch.cat([vrange.t0[..., None, :], mids], dim=-2), t_max=torch.cat([mids, vrange.t1[..., None, :]], dim=-2), ) ) raw.ts = ts if prev_raw is not None and shared: # We can append the additional queries to previous raw outputs # before integration copy = prev_raw.copy() result = torch.sort(torch.cat([raw.pop("ts"), copy.pop("ts")], dim=-2), dim=-2) merge_results = partial(self._merge_results, dim=-2, indices=result.indices) raw = raw.combine(copy, merge_results) raw.ts = result.values # 3. Integrate the raw results output, transmittance = self.integrate_samples(vrange, raw) # 4. Clean up results that do not intersect with the volume. transmittance = torch.where( vrange.intersected, transmittance, torch.ones_like(transmittance) ) def _mask_fn(_key: str, tensor: torch.Tensor): return torch.where(vrange.intersected, tensor, torch.zeros_like(tensor)) def _is_tensor(_key: str, value: Any): return isinstance(value, torch.Tensor) output = output.map(map_fn=_mask_fn, should_map=_is_tensor) return RayVolumeIntegralResults( output=output, raw=raw, volume_range=vrange, transmittance=transmittance, ) def integrate_samples( self, volume_range: VolumeRange, raw: AttrDict, ) -> Tuple[AttrDict, torch.Tensor]: """ Integrate the raw.channels along with other aux_losses and values to produce the final output dictionary containing rendered `channels`, estimated `distances` and `aux_losses`. :param volume_range: Specifies the integral range [t0, t1] :param raw: Contains a dict of function evaluations at ts. Should have density: torch.Tensor [batch_size, *shape, n_samples, 1] channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] aux_losses: {key: torch.Tensor [batch_size, *shape, n_samples, 1] for each key} no_weight_grad_aux_losses: an optional set of losses for which the weights should be detached before integration. after the call, integrate_samples populates some intermediate calculations for later use like weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density * transmittance)[i] weight for each rgb output at [..., i, :]. :returns: a tuple of ( a dictionary of rendered outputs and aux_losses, transmittance of this volume, ) """ # 1. Calculate the weights _, _, dt = volume_range.partition(raw.ts) ddensity = raw.density * dt mass = torch.cumsum(ddensity, dim=-2) transmittance = torch.exp(-mass[..., -1, :]) alphas = 1.0 - torch.exp(-ddensity) Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) # This is the probability of light hitting and reflecting off of # something at depth [..., i, :]. weights = alphas * Ts # 2. Integrate all results def _integrate(key: str, samples: torch.Tensor, weights: torch.Tensor): if key == "density": # Omit integrating the density, because we don't need it return None return torch.sum(samples * weights, dim=-2) def _is_tensor(_key: str, value: Any): return isinstance(value, torch.Tensor) if raw.no_weight_grad_aux_losses: extra_aux_losses = raw.no_weight_grad_aux_losses.map( partial(_integrate, weights=weights.detach()), should_map=_is_tensor ) else: extra_aux_losses = {} output = raw.map(partial(_integrate, weights=weights), should_map=_is_tensor) if "no_weight_grad_aux_losses" in output: del output["no_weight_grad_aux_losses"] output.aux_losses.update(extra_aux_losses) # Integrating the ts yields the distance away from the origin; rename the variable. output.distances = output.ts del output["ts"] del output["density"] assert output.distances.shape == (*output.channels.shape[:-1], 1) assert output.channels.shape[:-1] == raw.channels.shape[:-2] assert output.channels.shape[-1] == raw.channels.shape[-1] # 3. Reduce loss def _reduce_loss(_key: str, loss: torch.Tensor): return loss.view(loss.shape[0], -1).sum(dim=-1) # 4. Store other useful calculations raw.weights = weights output.aux_losses = output.aux_losses.map(_reduce_loss) return output, transmittance def _merge_results( self, a: Optional[torch.Tensor], b: torch.Tensor, dim: int, indices: torch.Tensor ): """ :param a: [..., n_a, ...]. The other dictionary containing the b's may contain extra tensors from earlier calculations, so a can be None. :param b: [..., n_b, ...] :param dim: dimension to merge :param indices: how the merged results should be sorted at the end :return: a concatted and sorted tensor of size [..., n_a + n_b, ...] """ if a is None: return None merged = torch.cat([a, b], dim=dim) return torch.gather(merged, dim=dim, index=torch.broadcast_to(indices, merged.shape)) class RaySampler(ABC): @abstractmethod def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: """ :param t0: start time has shape [batch_size, *shape, 1] :param t1: finish time has shape [batch_size, *shape, 1] :param n_samples: number of ts to sample :return: sampled ts of shape [batch_size, *shape, n_samples, 1] """ class StratifiedRaySampler(RaySampler): """ Instead of fixed intervals, a sample is drawn uniformly at random from each interval. """ def __init__(self, depth_mode: str = "linear"): """ :param depth_mode: linear samples ts linearly in depth. harmonic ensures closer points are sampled more densely. """ self.depth_mode = depth_mode assert self.depth_mode in ("linear", "geometric", "harmonic") def sample( self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int, epsilon: float = 1e-3, ) -> torch.Tensor: """ :param t0: start time has shape [batch_size, *shape, 1] :param t1: finish time has shape [batch_size, *shape, 1] :param n_samples: number of ts to sample :return: sampled ts of shape [batch_size, *shape, n_samples, 1] """ ones = [1] * (len(t0.shape) - 1) ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) if self.depth_mode == "linear": ts = t0 * (1.0 - ts) + t1 * ts elif self.depth_mode == "geometric": ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() elif self.depth_mode == "harmonic": # The original NeRF recommends this interpolation scheme for # spherical scenes, but there could be some weird edge cases when # the observer crosses from the inner to outer volume. ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) upper = torch.cat([mids, t1], dim=-1) lower = torch.cat([t0, mids], dim=-1) t_rand = torch.rand_like(ts) ts = lower + (upper - lower) * t_rand return ts.unsqueeze(-1) class ImportanceRaySampler(RaySampler): """ Given the initial estimate of densities, this samples more from regions/bins expected to have objects. """ def __init__( self, volume_range: VolumeRange, raw: AttrDict, blur_pool: bool = False, alpha: float = 1e-5 ): """ :param volume_range: the range in which a ray intersects the given volume. :param raw: dictionary of raw outputs from the NeRF models of shape [batch_size, *shape, n_coarse_samples, 1]. Should at least contain :param ts: earlier samples from the coarse rendering step :param weights: discretized version of density * transmittance :param blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. :param alpha: small value to add to weights. """ self.volume_range = volume_range self.ts = raw.ts.clone().detach() self.weights = raw.weights.clone().detach() self.blur_pool = blur_pool self.alpha = alpha @torch.no_grad() def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: """ :param t0: start time has shape [batch_size, *shape, 1] :param t1: finish time has shape [batch_size, *shape, 1] :param n_samples: number of ts to sample :return: sampled ts of shape [batch_size, *shape, n_samples, 1] """ lower, upper, _ = self.volume_range.partition(self.ts) batch_size, *shape, n_coarse_samples, _ = self.ts.shape weights = self.weights if self.blur_pool: padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) weights = weights + self.alpha pmf = weights / weights.sum(dim=-2, keepdim=True) inds = sample_pmf(pmf, n_samples) assert inds.shape == (batch_size, *shape, n_samples, 1) assert (inds >= 0).all() and (inds < n_coarse_samples).all() t_rand = torch.rand(inds.shape, device=inds.device) lower_ = torch.gather(lower, -2, inds) upper_ = torch.gather(upper, -2, inds) ts = lower_ + (upper_ - lower_) * t_rand ts = torch.sort(ts, dim=-2).values return ts ================================================ FILE: shap_e/models/nerf/renderer.py ================================================ from functools import partial from typing import Any, Dict, Optional import torch from shap_e.models.nn.meta import subdict from shap_e.models.renderer import RayRenderer from shap_e.models.volume import Volume from shap_e.util.collections import AttrDict from .model import NeRFModel from .ray import RayVolumeIntegral, StratifiedRaySampler, render_rays class TwoStepNeRFRenderer(RayRenderer): """ Coarse and fine-grained rendering as proposed by NeRF. This class additionally supports background rendering like NeRF++. """ def __init__( self, n_coarse_samples: int, n_fine_samples: int, void_model: NeRFModel, fine_model: NeRFModel, volume: Volume, coarse_model: Optional[NeRFModel] = None, coarse_background_model: Optional[NeRFModel] = None, fine_background_model: Optional[NeRFModel] = None, outer_volume: Optional[Volume] = None, foreground_stratified_depth_sampling_mode: str = "linear", background_stratified_depth_sampling_mode: str = "linear", importance_sampling_options: Optional[Dict[str, Any]] = None, channel_scale: float = 255, device: torch.device = torch.device("cuda"), **kwargs, ): """ :param outer_volume: is where distant objects are encoded. """ super().__init__(**kwargs) if coarse_model is None: assert ( fine_background_model is None or coarse_background_model is None ), "models should be shared for both fg and bg" self.n_coarse_samples = n_coarse_samples self.n_fine_samples = n_fine_samples self.void_model = void_model self.coarse_model = coarse_model self.fine_model = fine_model self.volume = volume self.coarse_background_model = coarse_background_model self.fine_background_model = fine_background_model self.outer_volume = outer_volume self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode self.importance_sampling_options = AttrDict(importance_sampling_options or {}) self.channel_scale = channel_scale self.device = device self.to(device) if self.coarse_background_model is not None: assert self.fine_background_model is not None assert self.outer_volume is not None def render_rays( self, batch: Dict, params: Optional[Dict] = None, options: Optional[Dict] = None, ) -> AttrDict: params = self.update(params) batch = AttrDict(batch) if options is None: options = AttrDict() options.setdefault("render_background", True) options.setdefault("render_with_direction", True) options.setdefault("n_coarse_samples", self.n_coarse_samples) options.setdefault("n_fine_samples", self.n_fine_samples) options.setdefault( "foreground_stratified_depth_sampling_mode", self.foreground_stratified_depth_sampling_mode, ) options.setdefault( "background_stratified_depth_sampling_mode", self.background_stratified_depth_sampling_mode, ) shared = self.coarse_model is None # First, render rays using the coarse models with stratified ray samples. coarse_model, coarse_key = ( (self.fine_model, "fine_model") if shared else (self.coarse_model, "coarse_model") ) coarse_model = partial( coarse_model, params=subdict(params, coarse_key), options=options, ) parts = [ RayVolumeIntegral( model=coarse_model, volume=self.volume, sampler=StratifiedRaySampler( depth_mode=options.foreground_stratified_depth_sampling_mode, ), n_samples=options.n_coarse_samples, ), ] if options.render_background and self.outer_volume is not None: coarse_background_model, coarse_background_key = ( (self.fine_background_model, "fine_background_model") if shared else (self.coarse_background_model, "coarse_background_model") ) coarse_background_model = partial( coarse_background_model, params=subdict(params, coarse_background_key), options=options, ) parts.append( RayVolumeIntegral( model=coarse_background_model, volume=self.outer_volume, sampler=StratifiedRaySampler( depth_mode=options.background_stratified_depth_sampling_mode, ), n_samples=options.n_coarse_samples, ) ) coarse_results, samplers, coarse_raw_outputs = render_rays( batch.rays, parts, partial(self.void_model, options=options), shared=shared, render_with_direction=options.render_with_direction, importance_sampling_options=AttrDict(self.importance_sampling_options), ) # Then, render rays using the fine models with importance-weighted ray samples. fine_model = partial( self.fine_model, params=subdict(params, "fine_model"), options=options, ) parts = [ RayVolumeIntegral( model=fine_model, volume=self.volume, sampler=samplers[0], n_samples=options.n_fine_samples, ), ] if options.render_background and self.outer_volume is not None: fine_background_model = partial( self.fine_background_model, params=subdict(params, "fine_background_model"), options=options, ) parts.append( RayVolumeIntegral( model=fine_background_model, volume=self.outer_volume, sampler=samplers[1], n_samples=options.n_fine_samples, ) ) fine_results, *_ = render_rays( batch.rays, parts, partial(self.void_model, options=options), shared=shared, prev_raw_outputs=coarse_raw_outputs, render_with_direction=options.render_with_direction, ) # Combine results aux_losses = fine_results.output.aux_losses.copy() for key, val in coarse_results.output.aux_losses.items(): aux_losses[key + "_coarse"] = val return AttrDict( channels=fine_results.output.channels * self.channel_scale, channels_coarse=coarse_results.output.channels * self.channel_scale, distances=fine_results.output.distances, transmittance=fine_results.transmittance, transmittance_coarse=coarse_results.transmittance, t0=fine_results.volume_range.t0, t1=fine_results.volume_range.t1, intersected=fine_results.volume_range.intersected, aux_losses=aux_losses, ) class OneStepNeRFRenderer(RayRenderer): """ Renders rays using stratified sampling only unlike vanilla NeRF. The same setup as NeRF++. """ def __init__( self, n_samples: int, void_model: NeRFModel, foreground_model: NeRFModel, volume: Volume, background_model: Optional[NeRFModel] = None, outer_volume: Optional[Volume] = None, foreground_stratified_depth_sampling_mode: str = "linear", background_stratified_depth_sampling_mode: str = "linear", channel_scale: float = 255, device: torch.device = torch.device("cuda"), **kwargs, ): super().__init__(**kwargs) self.n_samples = n_samples self.void_model = void_model self.foreground_model = foreground_model self.volume = volume self.background_model = background_model self.outer_volume = outer_volume self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode self.channel_scale = channel_scale self.device = device self.to(device) def render_rays( self, batch: Dict, params: Optional[Dict] = None, options: Optional[Dict] = None, ) -> AttrDict: params = self.update(params) batch = AttrDict(batch) if options is None: options = AttrDict() options.setdefault("render_background", True) options.setdefault("render_with_direction", True) options.setdefault("n_samples", self.n_samples) options.setdefault( "foreground_stratified_depth_sampling_mode", self.foreground_stratified_depth_sampling_mode, ) options.setdefault( "background_stratified_depth_sampling_mode", self.background_stratified_depth_sampling_mode, ) foreground_model = partial( self.foreground_model, params=subdict(params, "foreground_model"), options=options, ) parts = [ RayVolumeIntegral( model=foreground_model, volume=self.volume, sampler=StratifiedRaySampler( depth_mode=options.foreground_stratified_depth_sampling_mode ), n_samples=options.n_samples, ), ] if options.render_background and self.outer_volume is not None: background_model = partial( self.background_model, params=subdict(params, "background_model"), options=options, ) parts.append( RayVolumeIntegral( model=background_model, volume=self.outer_volume, sampler=StratifiedRaySampler( depth_mode=options.background_stratified_depth_sampling_mode ), n_samples=options.n_samples, ) ) results, *_ = render_rays( batch.rays, parts, self.void_model, render_with_direction=options.render_with_direction, ) return AttrDict( channels=results.output.channels * self.channel_scale, distances=results.output.distances, transmittance=results.transmittance, t0=results.volume_range.t0, t1=results.volume_range.t1, intersected=results.volume_range.intersected, aux_losses=results.output.aux_losses, ) ================================================ FILE: shap_e/models/nerstf/mlp.py ================================================ from typing import Any, Dict, Optional, Tuple import torch from shap_e.models.nn.ops import get_act from shap_e.models.query import Query from shap_e.models.stf.mlp import MLPModel from shap_e.util.collections import AttrDict class MLPDensitySDFModel(MLPModel): def __init__( self, initial_bias: float = -0.1, sdf_activation="tanh", density_activation="exp", **kwargs, ): super().__init__( n_output=2, output_activation="identity", **kwargs, ) self.mlp[-1].bias[0].data.fill_(initial_bias) self.sdf_activation = get_act(sdf_activation) self.density_activation = get_act(density_activation) def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict[str, Any]: # query.direction is None typically for SDF models and training h, _h_directionless = self._mlp( query.position, query.direction, params=params, options=options ) h_sdf, h_density = h.split(1, dim=-1) return AttrDict( density=self.density_activation(h_density), signed_distance=self.sdf_activation(h_sdf), ) class MLPNeRSTFModel(MLPModel): def __init__( self, sdf_activation="tanh", density_activation="exp", channel_activation="sigmoid", direction_dependent_shape: bool = True, # To be able to load old models. Set this to be False in future models. separate_nerf_channels: bool = False, separate_coarse_channels: bool = False, initial_density_bias: float = 0.0, initial_sdf_bias: float = -0.1, **kwargs, ): h_map, h_directionless_map = indices_for_output_mode( direction_dependent_shape=direction_dependent_shape, separate_nerf_channels=separate_nerf_channels, separate_coarse_channels=separate_coarse_channels, ) n_output = index_mapping_max(h_map) super().__init__( n_output=n_output, output_activation="identity", **kwargs, ) self.direction_dependent_shape = direction_dependent_shape self.separate_nerf_channels = separate_nerf_channels self.separate_coarse_channels = separate_coarse_channels self.sdf_activation = get_act(sdf_activation) self.density_activation = get_act(density_activation) self.channel_activation = get_act(channel_activation) self.h_map = h_map self.h_directionless_map = h_directionless_map self.mlp[-1].bias.data.zero_() layer = -1 if self.direction_dependent_shape else self.insert_direction_at self.mlp[layer].bias[0].data.fill_(initial_sdf_bias) self.mlp[layer].bias[1].data.fill_(initial_density_bias) def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict[str, Any]: options = AttrDict() if options is None else AttrDict(options) h, h_directionless = self._mlp( query.position, query.direction, params=params, options=options ) activations = map_indices_to_keys(self.h_map, h) activations.update(map_indices_to_keys(self.h_directionless_map, h_directionless)) if options.nerf_level == "coarse": h_density = activations.density_coarse else: h_density = activations.density_fine if options.get("rendering_mode", "stf") == "nerf": if options.nerf_level == "coarse": h_channels = activations.nerf_coarse else: h_channels = activations.nerf_fine else: h_channels = activations.stf return AttrDict( density=self.density_activation(h_density), signed_distance=self.sdf_activation(activations.sdf), channels=self.channel_activation(h_channels), ) IndexMapping = AttrDict[str, Tuple[int, int]] def indices_for_output_mode( direction_dependent_shape: bool, separate_nerf_channels: bool, separate_coarse_channels: bool, ) -> Tuple[IndexMapping, IndexMapping]: """ Get output mappings for (h, h_directionless). """ h_map = AttrDict() h_directionless_map = AttrDict() if direction_dependent_shape: h_map.sdf = (0, 1) if separate_coarse_channels: assert separate_nerf_channels h_map.density_coarse = (1, 2) h_map.density_fine = (2, 3) h_map.stf = (3, 6) h_map.nerf_coarse = (6, 9) h_map.nerf_fine = (9, 12) else: h_map.density_coarse = (1, 2) h_map.density_fine = (1, 2) if separate_nerf_channels: h_map.stf = (2, 5) h_map.nerf_coarse = (5, 8) h_map.nerf_fine = (5, 8) else: h_map.stf = (2, 5) h_map.nerf_coarse = (2, 5) h_map.nerf_fine = (2, 5) else: h_directionless_map.sdf = (0, 1) h_directionless_map.density_coarse = (1, 2) if separate_coarse_channels: h_directionless_map.density_fine = (2, 3) else: h_directionless_map.density_fine = h_directionless_map.density_coarse h_map.stf = (0, 3) if separate_coarse_channels: assert separate_nerf_channels h_map.nerf_coarse = (3, 6) h_map.nerf_fine = (6, 9) else: if separate_nerf_channels: h_map.nerf_coarse = (3, 6) else: h_map.nerf_coarse = (0, 3) h_map.nerf_fine = h_map.nerf_coarse return h_map, h_directionless_map def map_indices_to_keys(mapping: IndexMapping, data: torch.Tensor) -> AttrDict[str, torch.Tensor]: return AttrDict({k: data[..., start:end] for k, (start, end) in mapping.items()}) def index_mapping_max(mapping: IndexMapping) -> int: return max(end for _, (_, end) in mapping.items()) ================================================ FILE: shap_e/models/nerstf/renderer.py ================================================ from functools import partial from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch from shap_e.models.nerf.model import NeRFModel from shap_e.models.nerf.ray import RayVolumeIntegral, StratifiedRaySampler, render_rays from shap_e.models.nn.meta import subdict from shap_e.models.nn.utils import to_torch from shap_e.models.query import Query from shap_e.models.renderer import RayRenderer, render_views_from_rays from shap_e.models.stf.base import Model from shap_e.models.stf.renderer import STFRendererBase, render_views_from_stf from shap_e.models.volume import BoundingBoxVolume, Volume from shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR from shap_e.util.collections import AttrDict class NeRSTFRenderer(RayRenderer, STFRendererBase): def __init__( self, sdf: Optional[Model], tf: Optional[Model], nerstf: Optional[Model], void: NeRFModel, volume: Volume, grid_size: int, n_coarse_samples: int, n_fine_samples: int, importance_sampling_options: Optional[Dict[str, Any]] = None, separate_shared_samples: bool = False, texture_channels: Sequence[str] = ("R", "G", "B"), channel_scale: Sequence[float] = (255.0, 255.0, 255.0), ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR, diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR, specular_color: Union[float, Tuple[float]] = 0.0, output_srgb: bool = True, device: torch.device = torch.device("cuda"), **kwargs, ): super().__init__(**kwargs) assert isinstance(volume, BoundingBoxVolume), "cannot sample points in unknown volume" assert (nerstf is not None) ^ (sdf is not None and tf is not None) self.sdf = sdf self.tf = tf self.nerstf = nerstf self.void = void self.volume = volume self.grid_size = grid_size self.n_coarse_samples = n_coarse_samples self.n_fine_samples = n_fine_samples self.importance_sampling_options = AttrDict(importance_sampling_options or {}) self.separate_shared_samples = separate_shared_samples self.texture_channels = texture_channels self.channel_scale = to_torch(channel_scale).to(device) self.ambient_color = ambient_color self.diffuse_color = diffuse_color self.specular_color = specular_color self.output_srgb = output_srgb self.device = device self.to(device) def _query( self, query: Query, params: AttrDict[str, torch.Tensor], options: AttrDict[str, Any], ) -> AttrDict: no_dir_query = query.copy() no_dir_query.direction = None if options.get("rendering_mode", "stf") == "stf": assert query.direction is None if self.nerstf is not None: sdf = tf = self.nerstf( query, params=subdict(params, "nerstf"), options=options, ) else: sdf = self.sdf(no_dir_query, params=subdict(params, "sdf"), options=options) tf = self.tf(query, params=subdict(params, "tf"), options=options) return AttrDict( density=sdf.density, signed_distance=sdf.signed_distance, channels=tf.channels, aux_losses=dict(), ) def render_rays( self, batch: AttrDict, params: Optional[Dict] = None, options: Optional[AttrDict] = None, ) -> AttrDict: """ :param batch: has - rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray. :param options: Optional[Dict] """ params = self.update(params) options = AttrDict() if options is None else AttrDict(options) # Necessary to tell the TF to use specific NeRF channels. options.rendering_mode = "nerf" model = partial(self._query, params=params, options=options) # First, render rays with coarse, stratified samples. options.nerf_level = "coarse" parts = [ RayVolumeIntegral( model=model, volume=self.volume, sampler=StratifiedRaySampler(), n_samples=self.n_coarse_samples, ), ] coarse_results, samplers, coarse_raw_outputs = render_rays( batch.rays, parts, self.void, shared=not self.separate_shared_samples, render_with_direction=options.render_with_direction, importance_sampling_options=self.importance_sampling_options, ) # Then, render with additional importance-weighted ray samples. options.nerf_level = "fine" parts = [ RayVolumeIntegral( model=model, volume=self.volume, sampler=samplers[0], n_samples=self.n_fine_samples, ), ] fine_results, _, raw_outputs = render_rays( batch.rays, parts, self.void, shared=not self.separate_shared_samples, prev_raw_outputs=coarse_raw_outputs, render_with_direction=options.render_with_direction, ) raw = raw_outputs[0] aux_losses = fine_results.output.aux_losses.copy() if self.separate_shared_samples: for key, val in coarse_results.output.aux_losses.items(): aux_losses[key + "_coarse"] = val channels = fine_results.output.channels shape = [1] * (channels.ndim - 1) + [len(self.texture_channels)] channels = channels * self.channel_scale.view(*shape) res = AttrDict( channels=channels, transmittance=fine_results.transmittance, raw_signed_distance=raw.signed_distance, raw_density=raw.density, distances=fine_results.output.distances, t0=fine_results.volume_range.t0, t1=fine_results.volume_range.t1, intersected=fine_results.volume_range.intersected, aux_losses=aux_losses, ) if self.separate_shared_samples: res.update( dict( channels_coarse=( coarse_results.output.channels * self.channel_scale.view(*shape) ), distances_coarse=coarse_results.output.distances, transmittance_coarse=coarse_results.transmittance, ) ) return res def render_views( self, batch: AttrDict, params: Optional[Dict] = None, options: Optional[AttrDict] = None, ) -> AttrDict: """ Returns a backproppable rendering of a view :param batch: contains either ["poses", "camera"], or ["cameras"]. Can optionally contain any of ["height", "width", "query_batch_size"] :param params: Meta parameters contains rendering_mode in ["stf", "nerf"] :param options: controls checkpointing, caching, and rendering. Can provide a `rendering_mode` in ["stf", "nerf"] """ params = self.update(params) options = AttrDict() if options is None else AttrDict(options) if options.cache is None: created_cache = True options.cache = AttrDict() else: created_cache = False rendering_mode = options.get("rendering_mode", "stf") if rendering_mode == "nerf": output = render_views_from_rays( self.render_rays, batch, params=params, options=options, device=self.device, ) elif rendering_mode == "stf": sdf_fn = tf_fn = nerstf_fn = None if self.nerstf is not None: nerstf_fn = partial( self.nerstf.forward_batched, params=subdict(params, "nerstf"), options=options, ) else: sdf_fn = partial( self.sdf.forward_batched, params=subdict(params, "sdf"), options=options, ) tf_fn = partial( self.tf.forward_batched, params=subdict(params, "tf"), options=options, ) output = render_views_from_stf( batch, options, sdf_fn=sdf_fn, tf_fn=tf_fn, nerstf_fn=nerstf_fn, volume=self.volume, grid_size=self.grid_size, channel_scale=self.channel_scale, texture_channels=self.texture_channels, ambient_color=self.ambient_color, diffuse_color=self.diffuse_color, specular_color=self.specular_color, output_srgb=self.output_srgb, device=self.device, ) else: raise NotImplementedError if created_cache: del options["cache"] return output def get_signed_distance( self, query: Query, params: Dict[str, torch.Tensor], options: AttrDict[str, Any], ) -> torch.Tensor: if self.sdf is not None: return self.sdf(query, params=subdict(params, "sdf"), options=options).signed_distance assert self.nerstf is not None return self.nerstf(query, params=subdict(params, "nerstf"), options=options).signed_distance def get_texture( self, query: Query, params: Dict[str, torch.Tensor], options: AttrDict[str, Any], ) -> torch.Tensor: if self.tf is not None: return self.tf(query, params=subdict(params, "tf"), options=options).channels assert self.nerstf is not None return self.nerstf(query, params=subdict(params, "nerstf"), options=options).channels ================================================ FILE: shap_e/models/nn/__init__.py ================================================ from .meta import * from .ops import * ================================================ FILE: shap_e/models/nn/camera.py ================================================ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from shap_e.rendering.view_data import ProjectiveCamera @dataclass class DifferentiableCamera(ABC): """ An object describing how a camera corresponds to pixels in an image. """ @abstractmethod def camera_rays(self, coords: torch.Tensor) -> torch.Tensor: """ For every (x, y) coordinate in a rendered image, compute the ray of the corresponding pixel. :param coords: an [N x ... x 2] integer array of 2D image coordinates. :return: an [N x ... x 2 x 3] array of [2 x 3] (origin, direction) tuples. The direction should always be unit length. """ @abstractmethod def resize_image(self, width: int, height: int) -> "DifferentiableCamera": """ Creates a new camera with the same intrinsics and direction as this one, but with resized image dimensions. """ @dataclass class DifferentiableProjectiveCamera(DifferentiableCamera): """ Implements a batch, differentiable, standard pinhole camera """ origin: torch.Tensor # [batch_size x 3] x: torch.Tensor # [batch_size x 3] y: torch.Tensor # [batch_size x 3] z: torch.Tensor # [batch_size x 3] width: int height: int x_fov: float y_fov: float def __post_init__(self): assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0] assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3 assert ( len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2 ) def resolution(self): return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32)) def fov(self): return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32)) def image_coords(self) -> torch.Tensor: """ :return: coords of shape (width * height, 2) """ pixel_indices = torch.arange(self.height * self.width) coords = torch.stack( [ pixel_indices % self.width, torch.div(pixel_indices, self.width, rounding_mode="trunc"), ], axis=1, ) return coords def camera_rays(self, coords: torch.Tensor) -> torch.Tensor: batch_size, *shape, n_coords = coords.shape assert n_coords == 2 assert batch_size == self.origin.shape[0] flat = coords.view(batch_size, -1, 2) res = self.resolution().to(flat.device) fov = self.fov().to(flat.device) fracs = (flat.float() / (res - 1)) * 2 - 1 fracs = fracs * torch.tan(fov / 2) fracs = fracs.view(batch_size, -1, 2) directions = ( self.z.view(batch_size, 1, 3) + self.x.view(batch_size, 1, 3) * fracs[:, :, :1] + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] ) directions = directions / directions.norm(dim=-1, keepdim=True) rays = torch.stack( [ torch.broadcast_to( self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3] ), directions, ], dim=2, ) return rays.view(batch_size, *shape, 2, 3) def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera": """ Creates a new camera for the resized view assuming the aspect ratio does not change. """ assert width * self.height == height * self.width, "The aspect ratio should not change." return DifferentiableProjectiveCamera( origin=self.origin, x=self.x, y=self.y, z=self.z, width=width, height=height, x_fov=self.x_fov, y_fov=self.y_fov, ) @dataclass class DifferentiableCameraBatch(ABC): """ Annotate a differentiable camera with a multi-dimensional batch shape. """ shape: Tuple[int] flat_camera: DifferentiableCamera def normalize(vec: torch.Tensor) -> torch.Tensor: return vec / vec.norm(dim=-1, keepdim=True) def project_out(vec1: torch.Tensor, vec2: torch.Tensor) -> torch.Tensor: """ Removes the vec2 component from vec1 """ vec2 = normalize(vec2) proj = (vec1 * vec2).sum(dim=-1, keepdim=True) return vec1 - proj * vec2 def camera_orientation(toward: torch.Tensor, up: Optional[torch.Tensor] = None) -> torch.Tensor: """ :param toward: [batch_size x 3] unit vector from camera position to the object :param up: Optional [batch_size x 3] specifying the physical up direction in the world frame. :return: [batch_size x 3 x 3] """ if up is None: up = torch.zeros_like(toward) up[:, 2] = 1 assert len(toward.shape) == 2 assert toward.shape[1] == 3 assert len(up.shape) == 2 assert up.shape[1] == 3 z = toward / toward.norm(dim=-1, keepdim=True) y = -normalize(project_out(up, toward)) x = torch.cross(y, z, dim=1) return torch.stack([x, y, z], dim=1) def projective_camera_frame( origin: torch.Tensor, toward: torch.Tensor, camera_params: Union[ProjectiveCamera, DifferentiableProjectiveCamera], ) -> DifferentiableProjectiveCamera: """ Given the origin and the direction of a view, return a differentiable projective camera with the given parameters. TODO: We need to support the rotation of the camera frame about the `toward` vector to fully implement 6 degrees of freedom. """ rot = camera_orientation(toward) camera = DifferentiableProjectiveCamera( origin=origin, x=rot[:, 0], y=rot[:, 1], z=rot[:, 2], width=camera_params.width, height=camera_params.height, x_fov=camera_params.x_fov, y_fov=camera_params.y_fov, ) return camera @torch.no_grad() def get_image_coords(width, height) -> torch.Tensor: pixel_indices = torch.arange(height * width) # torch throws warnings for pixel_indices // width pixel_indices_div = torch.div(pixel_indices, width, rounding_mode="trunc") coords = torch.stack([pixel_indices % width, pixel_indices_div], dim=1) return coords ================================================ FILE: shap_e/models/nn/checkpoint.py ================================================ from typing import Callable, Iterable, Sequence, Union import torch from torch.cuda.amp import custom_bwd, custom_fwd def checkpoint( func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], inputs: Sequence[torch.Tensor], params: Iterable[torch.Tensor], flag: bool, ): """ Evaluate a function without caching intermediate activations, allowing for reduced memory at the expense of extra compute in the backward pass. :param func: the function to evaluate. :param inputs: the argument sequence to pass to `func`. :param params: a sequence of parameters `func` depends on but does not explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ if flag: args = tuple(inputs) + tuple(params) return CheckpointFunction.apply(func, len(inputs), *args) else: return func(*inputs) class CheckpointFunction(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.length = length input_tensors = list(args[:length]) input_params = list(args[length:]) ctx.save_for_backward(*input_tensors, *input_params) with torch.no_grad(): output_tensors = ctx.run_function(*input_tensors) return output_tensors @staticmethod @custom_bwd def backward(ctx, *output_grads): inputs = ctx.saved_tensors input_tensors = inputs[: ctx.length] input_params = inputs[ctx.length :] res = CheckpointFunctionGradFunction.apply( ctx.run_function, len(input_tensors), len(input_params), *input_tensors, *input_params, *output_grads ) return (None, None) + res class CheckpointFunctionGradFunction(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, run_function, length_1, length_2, *args): ctx.run_function = run_function ctx.length_1 = length_1 ctx.length_2 = length_2 input_tensors = [x.detach().requires_grad_(True) for x in args[:length_1]] input_params = list(args[length_1 : length_1 + length_2]) output_grads = list(args[length_1 + length_2 :]) ctx.save_for_backward(*input_tensors, *input_params, *output_grads) with torch.enable_grad(): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. shallow_copies = [x.view_as(x) for x in input_tensors] output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, input_tensors + input_params, output_grads, allow_unused=True, ) return input_grads @staticmethod @custom_bwd def backward(ctx, *all_output_grads): args = ctx.saved_tensors input_tensors = [x.detach().requires_grad_(True) for x in args[: ctx.length_1]] input_params = list(args[ctx.length_1 : ctx.length_1 + ctx.length_2]) output_grads = [ x.detach().requires_grad_(True) for x in args[ctx.length_1 + ctx.length_2 :] ] with torch.enable_grad(): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. shallow_copies = [x.view_as(x) for x in input_tensors] output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, input_tensors + input_params, output_grads, allow_unused=True, create_graph=True, retain_graph=True, ) input_grads_grads = torch.autograd.grad( input_grads, input_tensors + input_params + output_grads, all_output_grads, allow_unused=True, ) del input_grads return (None, None, None) + input_grads_grads ================================================ FILE: shap_e/models/nn/encoding.py ================================================ import math from functools import lru_cache from typing import Optional import torch import torch.nn as nn def encode_position(version: str, *, position: torch.Tensor): if version == "v1": freqs = get_scales(0, 10, position.dtype, position.device).view(1, -1) freqs = position.reshape(-1, 1) * freqs return torch.cat([freqs.cos(), freqs.sin()], dim=1).reshape(*position.shape[:-1], -1) elif version == "nerf": return posenc_nerf(position, min_deg=0, max_deg=15) else: raise ValueError(version) def encode_channels(version: str, *, channels: torch.Tensor): if version == "v1": freqs = get_scales(0, 10, channels.dtype, channels.device).view(1, -1) freqs = channels.reshape(-1, 1) * freqs return torch.cat([freqs.cos(), freqs.sin()], dim=1).reshape(*channels.shape[:-1], -1) elif version == "nerf": return posenc_nerf(channels, min_deg=0, max_deg=15) else: raise ValueError(version) def position_encoding_channels(version: Optional[str] = None) -> int: if version is None: return 1 return encode_position(version, position=torch.zeros(1, 1)).shape[-1] def channel_encoding_channels(version: Optional[str] = None) -> int: if version is None: return 1 return encode_channels(version, channels=torch.zeros(1, 1)).shape[-1] class PosEmbLinear(nn.Linear): def __init__( self, posemb_version: Optional[str], in_features: int, out_features: int, **kwargs ): super().__init__( in_features * position_encoding_channels(posemb_version), out_features, **kwargs, ) self.posemb_version = posemb_version def forward(self, x: torch.Tensor): if self.posemb_version is not None: x = encode_position(self.posemb_version, position=x) return super().forward(x) class MultiviewPoseEmbedding(nn.Conv2d): def __init__( self, posemb_version: Optional[str], n_channels: int, out_features: int, stride: int = 1, **kwargs, ): in_features = ( n_channels * channel_encoding_channels(version=posemb_version) + 3 * position_encoding_channels(version=posemb_version) + 3 * position_encoding_channels(version=posemb_version) ) super().__init__( in_features, out_features, kernel_size=3, stride=stride, padding=1, **kwargs, ) self.posemb_version = posemb_version def forward( self, channels: torch.Tensor, position: torch.Tensor, direction: torch.Tensor ) -> torch.Tensor: """ :param channels: [batch_shape, inner_batch_shape, n_channels, height, width] :param position: [batch_shape, inner_batch_shape, 3, height, width] :param direction: [batch_shape, inner_batch_shape, 3, height, width] :return: [*batch_shape, out_features, height, width] """ if self.posemb_version is not None: channels = channels.permute(0, 1, 3, 4, 2) position = position.permute(0, 1, 3, 4, 2) direction = direction.permute(0, 1, 3, 4, 2) channels = encode_channels(self.posemb_version, channels=channels).permute( 0, 1, 4, 2, 3 ) direction = maybe_encode_direction( self.posemb_version, position=position, direction=direction ).permute(0, 1, 4, 2, 3) position = encode_position(self.posemb_version, position=position).permute( 0, 1, 4, 2, 3 ) x = torch.cat([channels, position, direction], dim=-3) *batch_shape, in_features, height, width = x.shape return ( super() .forward(x.view(-1, in_features, height, width)) .view(*batch_shape, -1, height, width) ) class MultiviewPointCloudEmbedding(nn.Conv2d): def __init__( self, posemb_version: Optional[str], n_channels: int, out_features: int, stride: int = 1, **kwargs, ): in_features = ( n_channels * channel_encoding_channels(version=posemb_version) + 3 * position_encoding_channels(version=posemb_version) + 3 * position_encoding_channels(version=posemb_version) ) super().__init__( in_features, out_features, kernel_size=3, stride=stride, padding=1, **kwargs, ) self.posemb_version = posemb_version self.register_parameter( "unk_token", nn.Parameter(torch.randn(in_features, **kwargs) * 0.01) ) self.unk_token: torch.Tensor def forward( self, channels: torch.Tensor, origin: torch.Tensor, position: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """ :param channels: [batch_shape, inner_batch_shape, n_channels, height, width] :param origin: [batch_shape, inner_batch_shape, 3, height, width] :param position: [batch_shape, inner_batch_shape, 3, height, width] :return: [*batch_shape, out_features, height, width] """ if self.posemb_version is not None: channels = channels.permute(0, 1, 3, 4, 2) origin = origin.permute(0, 1, 3, 4, 2) position = position.permute(0, 1, 3, 4, 2) channels = encode_channels(self.posemb_version, channels=channels).permute( 0, 1, 4, 2, 3 ) origin = encode_position(self.posemb_version, position=origin).permute(0, 1, 4, 2, 3) position = encode_position(self.posemb_version, position=position).permute( 0, 1, 4, 2, 3 ) x = torch.cat([channels, origin, position], dim=-3) unk_token = torch.broadcast_to(self.unk_token.view(1, 1, -1, 1, 1), x.shape) x = torch.where(mask, x, unk_token) *batch_shape, in_features, height, width = x.shape return ( super() .forward(x.view(-1, in_features, height, width)) .view(*batch_shape, -1, height, width) ) def maybe_encode_direction( version: str, *, position: torch.Tensor, direction: Optional[torch.Tensor] = None, ): if version == "v1": sh_degree = 4 if direction is None: return torch.zeros(*position.shape[:-1], sh_degree**2).to(position) return spherical_harmonics_basis(direction, sh_degree=sh_degree) elif version == "nerf": if direction is None: return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) return posenc_nerf(direction, min_deg=0, max_deg=8) else: raise ValueError(version) def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: """ Concatenate x and its positional encodings, following NeRF. Reference: https://arxiv.org/pdf/2210.04628.pdf """ if min_deg == max_deg: return x scales = get_scales(min_deg, max_deg, x.dtype, x.device) *shape, dim = x.shape xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) assert xb.shape[-1] == dim * (max_deg - min_deg) emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() return torch.cat([x, emb], dim=-1) @lru_cache def get_scales( min_deg: int, max_deg: int, dtype: torch.dtype, device: torch.device, ) -> torch.Tensor: return 2.0 ** torch.arange(min_deg, max_deg, device=device, dtype=dtype) def spherical_harmonics_basis( coords: torch.Tensor, sh_degree: int, ) -> torch.Tensor: """ Calculate the spherical harmonics basis :param coords: [batch_size, *shape, 3] of unit norm :param sh_degree: Spherical harmonics degree :return: [batch_size, *shape, sh_degree**2] """ if sh_degree > 8: raise NotImplementedError batch_size, *shape, _ = coords.shape x, y, z = coords.reshape(-1, 3).split(1, dim=-1) x = x.squeeze(dim=-1) y = y.squeeze(dim=-1) z = z.squeeze(dim=-1) xy, xz, yz = x * y, x * z, y * z x2, y2, z2 = x * x, y * y, z * z x4, y4, z4 = x2 * x2, y2 * y2, z2 * z2 x6, y6, z6 = x4 * x2, y4 * y2, z4 * z2 xyz = xy * z # https://github.com/NVlabs/tiny-cuda-nn/blob/8575542682cb67cddfc748cc3d3cfc12593799aa/include/tiny-cuda-nn/encodings/spherical_harmonics.h#L76 out = torch.zeros(x.shape[0], sh_degree**2, dtype=x.dtype, device=x.device) def _sh(): out[:, 0] = 0.28209479177387814 # 1/(2*sqrt(pi)) if sh_degree <= 1: return out[:, 1] = -0.48860251190291987 * y # -sqrt(3)*y/(2*sqrt(pi)) out[:, 2] = 0.48860251190291987 * z # sqrt(3)*z/(2*sqrt(pi)) out[:, 3] = -0.48860251190291987 * x # -sqrt(3)*x/(2*sqrt(pi)) if sh_degree <= 2: return out[:, 4] = 1.0925484305920792 * xy # sqrt(15)*xy/(2*sqrt(pi)) out[:, 5] = -1.0925484305920792 * yz # -sqrt(15)*yz/(2*sqrt(pi)) out[:, 6] = ( 0.94617469575755997 * z2 - 0.31539156525251999 ) # sqrt(5)*(3*z2 - 1)/(4*sqrt(pi)) out[:, 7] = -1.0925484305920792 * xz # -sqrt(15)*xz/(2*sqrt(pi)) out[:, 8] = ( 0.54627421529603959 * x2 - 0.54627421529603959 * y2 ) # sqrt(15)*(x2 - y2)/(4*sqrt(pi)) if sh_degree <= 3: return out[:, 9] = ( 0.59004358992664352 * y * (-3.0 * x2 + y2) ) # sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) out[:, 10] = 2.8906114426405538 * xy * z # sqrt(105)*xy*z/(2*sqrt(pi)) out[:, 11] = ( 0.45704579946446572 * y * (1.0 - 5.0 * z2) ) # sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi)) out[:, 12] = 0.3731763325901154 * z * (5.0 * z2 - 3.0) # sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi)) out[:, 13] = ( 0.45704579946446572 * x * (1.0 - 5.0 * z2) ) # sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi)) out[:, 14] = 1.4453057213202769 * z * (x2 - y2) # sqrt(105)*z*(x2 - y2)/(4*sqrt(pi)) out[:, 15] = ( 0.59004358992664352 * x * (-x2 + 3.0 * y2) ) # sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) if sh_degree <= 4: return out[:, 16] = 2.5033429417967046 * xy * (x2 - y2) # 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi)) out[:, 17] = ( 1.7701307697799304 * yz * (-3.0 * x2 + y2) ) # 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi)) out[:, 18] = ( 0.94617469575756008 * xy * (7.0 * z2 - 1.0) ) # 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi)) out[:, 19] = ( 0.66904654355728921 * yz * (3.0 - 7.0 * z2) ) # 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi)) out[:, 20] = ( -3.1735664074561294 * z2 + 3.7024941420321507 * z4 + 0.31735664074561293 ) # 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi)) out[:, 21] = ( 0.66904654355728921 * xz * (3.0 - 7.0 * z2) ) # 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi)) out[:, 22] = ( 0.47308734787878004 * (x2 - y2) * (7.0 * z2 - 1.0) ) # 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi)) out[:, 23] = ( 1.7701307697799304 * xz * (-x2 + 3.0 * y2) ) # 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi)) out[:, 24] = ( -3.7550144126950569 * x2 * y2 + 0.62583573544917614 * x4 + 0.62583573544917614 * y4 ) # 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) if sh_degree <= 5: return out[:, 25] = ( 0.65638205684017015 * y * (10.0 * x2 * y2 - 5.0 * x4 - y4) ) # 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) out[:, 26] = ( 8.3026492595241645 * xy * z * (x2 - y2) ) # 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi)) out[:, 27] = ( -0.48923829943525038 * y * (3.0 * x2 - y2) * (9.0 * z2 - 1.0) ) # -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) out[:, 28] = ( 4.7935367849733241 * xy * z * (3.0 * z2 - 1.0) ) # sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi)) out[:, 29] = ( 0.45294665119569694 * y * (14.0 * z2 - 21.0 * z4 - 1.0) ) # sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) out[:, 30] = ( 0.1169503224534236 * z * (-70.0 * z2 + 63.0 * z4 + 15.0) ) # sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi)) out[:, 31] = ( 0.45294665119569694 * x * (14.0 * z2 - 21.0 * z4 - 1.0) ) # sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) out[:, 32] = ( 2.3967683924866621 * z * (x2 - y2) * (3.0 * z2 - 1.0) ) # sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi)) out[:, 33] = ( -0.48923829943525038 * x * (x2 - 3.0 * y2) * (9.0 * z2 - 1.0) ) # -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi)) out[:, 34] = ( 2.0756623148810411 * z * (-6.0 * x2 * y2 + x4 + y4) ) # 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) out[:, 35] = ( 0.65638205684017015 * x * (10.0 * x2 * y2 - x4 - 5.0 * y4) ) # 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) if sh_degree <= 6: return out[:, 36] = ( 1.3663682103838286 * xy * (-10.0 * x2 * y2 + 3.0 * x4 + 3.0 * y4) ) # sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) out[:, 37] = ( 2.3666191622317521 * yz * (10.0 * x2 * y2 - 5.0 * x4 - y4) ) # 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) out[:, 38] = ( 2.0182596029148963 * xy * (x2 - y2) * (11.0 * z2 - 1.0) ) # 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) out[:, 39] = ( -0.92120525951492349 * yz * (3.0 * x2 - y2) * (11.0 * z2 - 3.0) ) # -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) out[:, 40] = ( 0.92120525951492349 * xy * (-18.0 * z2 + 33.0 * z4 + 1.0) ) # sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) out[:, 41] = ( 0.58262136251873131 * yz * (30.0 * z2 - 33.0 * z4 - 5.0) ) # sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) out[:, 42] = ( 6.6747662381009842 * z2 - 20.024298714302954 * z4 + 14.684485723822165 * z6 - 0.31784601133814211 ) # sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi)) out[:, 43] = ( 0.58262136251873131 * xz * (30.0 * z2 - 33.0 * z4 - 5.0) ) # sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) out[:, 44] = ( 0.46060262975746175 * (x2 - y2) * (11.0 * z2 * (3.0 * z2 - 1.0) - 7.0 * z2 + 1.0) ) # sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi)) out[:, 45] = ( -0.92120525951492349 * xz * (x2 - 3.0 * y2) * (11.0 * z2 - 3.0) ) # -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi)) out[:, 46] = ( 0.50456490072872406 * (11.0 * z2 - 1.0) * (-6.0 * x2 * y2 + x4 + y4) ) # 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) out[:, 47] = ( 2.3666191622317521 * xz * (10.0 * x2 * y2 - x4 - 5.0 * y4) ) # 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) out[:, 48] = ( 10.247761577878714 * x2 * y4 - 10.247761577878714 * x4 * y2 + 0.6831841051919143 * x6 - 0.6831841051919143 * y6 ) # sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) if sh_degree <= 7: return out[:, 49] = ( 0.70716273252459627 * y * (-21.0 * x2 * y4 + 35.0 * x4 * y2 - 7.0 * x6 + y6) ) # 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi)) out[:, 50] = ( 5.2919213236038001 * xy * z * (-10.0 * x2 * y2 + 3.0 * x4 + 3.0 * y4) ) # 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) out[:, 51] = ( -0.51891557872026028 * y * (13.0 * z2 - 1.0) * (-10.0 * x2 * y2 + 5.0 * x4 + y4) ) # -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi)) out[:, 52] = ( 4.1513246297620823 * xy * z * (x2 - y2) * (13.0 * z2 - 3.0) ) # 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) out[:, 53] = ( -0.15645893386229404 * y * (3.0 * x2 - y2) * (13.0 * z2 * (11.0 * z2 - 3.0) - 27.0 * z2 + 3.0) ) # -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) out[:, 54] = ( 0.44253269244498261 * xy * z * (-110.0 * z2 + 143.0 * z4 + 15.0) ) # 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) out[:, 55] = ( 0.090331607582517306 * y * (-135.0 * z2 + 495.0 * z4 - 429.0 * z6 + 5.0) ) # sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) out[:, 56] = ( 0.068284276912004949 * z * (315.0 * z2 - 693.0 * z4 + 429.0 * z6 - 35.0) ) # sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi)) out[:, 57] = ( 0.090331607582517306 * x * (-135.0 * z2 + 495.0 * z4 - 429.0 * z6 + 5.0) ) # sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) out[:, 58] = ( 0.07375544874083044 * z * (x2 - y2) * (143.0 * z2 * (3.0 * z2 - 1.0) - 187.0 * z2 + 45.0) ) # sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi)) out[:, 59] = ( -0.15645893386229404 * x * (x2 - 3.0 * y2) * (13.0 * z2 * (11.0 * z2 - 3.0) - 27.0 * z2 + 3.0) ) # -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) out[:, 60] = ( 1.0378311574405206 * z * (13.0 * z2 - 3.0) * (-6.0 * x2 * y2 + x4 + y4) ) # 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) out[:, 61] = ( -0.51891557872026028 * x * (13.0 * z2 - 1.0) * (-10.0 * x2 * y2 + x4 + 5.0 * y4) ) # -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi)) out[:, 62] = ( 2.6459606618019 * z * (15.0 * x2 * y4 - 15.0 * x4 * y2 + x6 - y6) ) # 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) out[:, 63] = ( 0.70716273252459627 * x * (-35.0 * x2 * y4 + 21.0 * x4 * y2 - x6 + 7.0 * y6) ) # 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi)) _sh() return out.view(batch_size, *shape, sh_degree**2) ================================================ FILE: shap_e/models/nn/meta.py ================================================ """ Meta-learning modules based on: https://github.com/tristandeleu/pytorch-meta MIT License Copyright (c) 2019-2020 Tristan Deleu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import itertools import re from collections import OrderedDict import torch.nn as nn from shap_e.util.collections import AttrDict __all__ = [ "MetaModule", "subdict", "superdict", "leveldict", "leveliter", "batch_meta_parameters", "batch_meta_state_dict", ] def subdict(dictionary, key=None): if dictionary is None: return None if (key is None) or (key == ""): return dictionary key_re = re.compile(r"^{0}\.(.+)".format(re.escape(key))) return AttrDict( OrderedDict( (key_re.sub(r"\1", k), value) for (k, value) in dictionary.items() if key_re.match(k) is not None ) ) def superdict(dictionary, key=None): if dictionary is None: return None if (key is None) or (key == ""): return dictionary return AttrDict(OrderedDict((key + "." + k, value) for (k, value) in dictionary.items())) def leveldict(dictionary, depth=0): return AttrDict(leveliter(dictionary, depth=depth)) def leveliter(dictionary, depth=0): """ depth == 0 is root """ for key, value in dictionary.items(): if key.count(".") == depth: yield key, value class MetaModule(nn.Module): """ Base class for PyTorch meta-learning modules. These modules accept an additional argument `params` in their `forward` method. Notes ----- Objects inherited from `MetaModule` are fully compatible with PyTorch modules from `torch.nn.Module`. The argument `params` is a dictionary of tensors, with full support of the computation graph (for differentiation). Based on SIREN's torchmeta with some additional features/changes. All meta weights must not have the batch dimension, as they are later tiled to the given batch size after unsqueezing the first dimension (e.g. a weight of dimension [d_out x d_in] is tiled to have the dimension [batch x d_out x d_in]). Requiring all meta weights to have a batch dimension of 1 (e.g. [1 x d_out x d_in] from the earlier example) could be a more natural choice, but this results in silent failures. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._meta_state_dict = set() self._meta_params = set() def register_meta_buffer(self, name: str, param: nn.Parameter): """ Registers a trainable or nontrainable parameter as a meta buffer. This can be later retrieved by meta_state_dict """ self.register_buffer(name, param) self._meta_state_dict.add(name) def register_meta_parameter(self, name: str, parameter: nn.Parameter): """ Registers a meta parameter so it is included in named_meta_parameters and meta_state_dict. """ self.register_parameter(name, parameter) self._meta_params.add(name) self._meta_state_dict.add(name) def register_meta(self, name: str, parameter: nn.Parameter, trainable: bool = True): if trainable: self.register_meta_parameter(name, parameter) else: self.register_meta_buffer(name, parameter) def register(self, name: str, parameter: nn.Parameter, meta: bool, trainable: bool = True): if meta: if trainable: self.register_meta_parameter(name, parameter) else: self.register_meta_buffer(name, parameter) else: if trainable: self.register_parameter(name, parameter) else: self.register_buffer(name, parameter) def named_meta_parameters(self, prefix="", recurse=True): """ Returns an iterator over all the names and meta parameters """ def meta_iterator(module): meta = module._meta_params if isinstance(module, MetaModule) else set() for name, param in module._parameters.items(): if name in meta: yield name, param gen = self._named_members( meta_iterator, prefix=prefix, recurse=recurse, ) for name, param in gen: yield name, param def named_nonmeta_parameters(self, prefix="", recurse=True): def _iterator(module): meta = module._meta_params if isinstance(module, MetaModule) else set() for name, param in module._parameters.items(): if name not in meta: yield name, param gen = self._named_members( _iterator, prefix=prefix, recurse=recurse, ) for name, param in gen: yield name, param def nonmeta_parameters(self, prefix="", recurse=True): for _, param in self.named_nonmeta_parameters(prefix=prefix, recurse=recurse): yield param def meta_state_dict(self, prefix="", recurse=True): """ Returns an iterator over all the names and meta parameters/buffers. One difference between module.state_dict() is that this preserves requires_grad, because we may want to compute the gradient w.r.t. meta buffers, but don't necessarily update them automatically. """ def meta_iterator(module): meta = module._meta_state_dict if isinstance(module, MetaModule) else set() for name, param in itertools.chain(module._buffers.items(), module._parameters.items()): if name in meta: yield name, param gen = self._named_members( meta_iterator, prefix=prefix, recurse=recurse, ) return dict(gen) def update(self, params=None): """ Updates the parameter list before the forward prop so that if `params` is None or doesn't have a certain key, the module uses the default parameter/buffer registered in the module. """ if params is None: params = AttrDict() params = AttrDict(params) named_params = set([name for name, _ in self.named_parameters()]) for name, param in self.named_parameters(): params.setdefault(name, param) for name, param in self.state_dict().items(): if name not in named_params: params.setdefault(name, param) return params def batch_meta_parameters(net, batch_size): params = AttrDict() for name, param in net.named_meta_parameters(): params[name] = param.clone().unsqueeze(0).repeat(batch_size, *[1] * len(param.shape)) return params def batch_meta_state_dict(net, batch_size): state_dict = AttrDict() meta_parameters = set([name for name, _ in net.named_meta_parameters()]) for name, param in net.meta_state_dict().items(): state_dict[name] = param.clone().unsqueeze(0).repeat(batch_size, *[1] * len(param.shape)) return state_dict ================================================ FILE: shap_e/models/nn/ops.py ================================================ import math from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from shap_e.util.collections import AttrDict from .meta import MetaModule, subdict from .pointnet2_utils import sample_and_group, sample_and_group_all def gelu(x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) def swish(x): return x * torch.sigmoid(x) def quick_gelu(x): return x * torch.sigmoid(1.702 * x) def torch_gelu(x): return torch.nn.functional.gelu(x) def geglu(x): v, gates = x.chunk(2, dim=-1) return v * gelu(gates) class SirenSin: def __init__(self, w0=30.0): self.w0 = w0 def __call__(self, x): return torch.sin(self.w0 * x) def get_act(name): return { "relu": torch.nn.functional.relu, "leaky_relu": torch.nn.functional.leaky_relu, "swish": swish, "tanh": torch.tanh, "gelu": gelu, "quick_gelu": quick_gelu, "torch_gelu": torch_gelu, "gelu2": quick_gelu, "geglu": geglu, "sigmoid": torch.sigmoid, "sin": torch.sin, "sin30": SirenSin(w0=30.0), "softplus": F.softplus, "exp": torch.exp, "identity": lambda x: x, }[name] def zero_init(affine): nn.init.constant_(affine.weight, 0.0) if affine.bias is not None: nn.init.constant_(affine.bias, 0.0) def siren_init_first_layer(affine, init_scale: float = 1.0): n_input = affine.weight.shape[1] u = init_scale / n_input nn.init.uniform_(affine.weight, -u, u) if affine.bias is not None: nn.init.constant_(affine.bias, 0.0) def siren_init(affine, coeff=1.0, init_scale: float = 1.0): n_input = affine.weight.shape[1] u = init_scale * np.sqrt(6.0 / n_input) / coeff nn.init.uniform_(affine.weight, -u, u) if affine.bias is not None: nn.init.constant_(affine.bias, 0.0) def siren_init_30(affine, init_scale: float = 1.0): siren_init(affine, coeff=30.0, init_scale=init_scale) def std_init(affine, init_scale: float = 1.0): n_in = affine.weight.shape[1] stddev = init_scale / math.sqrt(n_in) nn.init.normal_(affine.weight, std=stddev) if affine.bias is not None: nn.init.constant_(affine.bias, 0.0) def mlp_init(affines, init: Optional[str] = None, init_scale: float = 1.0): if init == "siren30": for idx, affine in enumerate(affines): init = siren_init_first_layer if idx == 0 else siren_init_30 init(affine, init_scale=init_scale) elif init == "siren": for idx, affine in enumerate(affines): init = siren_init_first_layer if idx == 0 else siren_init init(affine, init_scale=init_scale) elif init is None: for affine in affines: std_init(affine, init_scale=init_scale) else: raise NotImplementedError(init) class MetaLinear(MetaModule): def __init__( self, n_in, n_out, bias: bool = True, meta_scale: bool = True, meta_shift: bool = True, meta_proj: bool = False, meta_bias: bool = False, trainable_meta: bool = False, **kwargs, ): super().__init__() # n_in, n_out, bias=bias) register_meta_fn = ( self.register_meta_parameter if trainable_meta else self.register_meta_buffer ) if meta_scale: register_meta_fn("scale", nn.Parameter(torch.ones(n_out, **kwargs))) if meta_shift: register_meta_fn("shift", nn.Parameter(torch.zeros(n_out, **kwargs))) register_proj_fn = self.register_parameter if not meta_proj else register_meta_fn register_proj_fn("weight", nn.Parameter(torch.empty((n_out, n_in), **kwargs))) if not bias: self.register_parameter("bias", None) else: register_bias_fn = self.register_parameter if not meta_bias else register_meta_fn register_bias_fn("bias", nn.Parameter(torch.empty(n_out, **kwargs))) self.reset_parameters() def reset_parameters(self) -> None: # from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see # https://github.com/pytorch/pytorch/issues/57109 nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound) def _bcast(self, op, left, right): if right.ndim == 2: # Has dimension [batch x d_output] right = right.unsqueeze(1) return op(left, right) def forward(self, x, params=None): params = self.update(params) batch_size, *shape, d_in = x.shape x = x.view(batch_size, -1, d_in) if params.weight.ndim == 2: h = torch.einsum("bni,oi->bno", x, params.weight) elif params.weight.ndim == 3: h = torch.einsum("bni,boi->bno", x, params.weight) if params.bias is not None: h = self._bcast(torch.add, h, params.bias) if params.scale is not None: h = self._bcast(torch.mul, h, params.scale) if params.shift is not None: h = self._bcast(torch.add, h, params.shift) h = h.view(batch_size, *shape, -1) return h def Conv(n_dim, d_in, d_out, kernel, stride=1, padding=0, dilation=1, **kwargs): cls = { 1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d, }[n_dim] return cls(d_in, d_out, kernel, stride=stride, padding=padding, dilation=dilation, **kwargs) def flatten(x): batch_size, *shape, n_channels = x.shape n_ctx = np.prod(shape) return x.view(batch_size, n_ctx, n_channels), AttrDict( shape=shape, n_ctx=n_ctx, n_channels=n_channels ) def unflatten(x, info): batch_size = x.shape[0] return x.view(batch_size, *info.shape, info.n_channels) def torchify(x): extent = list(range(1, x.ndim - 1)) return x.permute([0, x.ndim - 1, *extent]) def untorchify(x): extent = list(range(2, x.ndim)) return x.permute([0, *extent, 1]) class MLP(nn.Module): def __init__( self, d_input: int, d_hidden: List[int], d_output: int, act_name: str = "quick_gelu", bias: bool = True, init: Optional[str] = None, init_scale: float = 1.0, zero_out: bool = False, ): """ Required: d_input, d_hidden, d_output Optional: act_name, bias """ super().__init__() ds = [d_input] + d_hidden + [d_output] affines = [nn.Linear(d_in, d_out, bias=bias) for d_in, d_out in zip(ds[:-1], ds[1:])] self.d = ds self.affines = nn.ModuleList(affines) self.act = get_act(act_name) mlp_init(self.affines, init=init, init_scale=init_scale) if zero_out: zero_init(affines[-1]) def forward(self, h, options: Optional[AttrDict] = None, log_prefix: str = ""): options = AttrDict() if options is None else AttrDict(options) *hid, out = self.affines for i, f in enumerate(hid): h = self.act(f(h)) h = out(h) return h class MetaMLP(MetaModule): def __init__( self, d_input: int, d_hidden: List[int], d_output: int, act_name: str = "quick_gelu", bias: bool = True, meta_scale: bool = True, meta_shift: bool = True, meta_proj: bool = False, meta_bias: bool = False, trainable_meta: bool = False, init: Optional[str] = None, init_scale: float = 1.0, zero_out: bool = False, ): super().__init__() ds = [d_input] + d_hidden + [d_output] affines = [ MetaLinear( d_in, d_out, bias=bias, meta_scale=meta_scale, meta_shift=meta_shift, meta_proj=meta_proj, meta_bias=meta_bias, trainable_meta=trainable_meta, ) for d_in, d_out in zip(ds[:-1], ds[1:]) ] self.d = ds self.affines = nn.ModuleList(affines) self.act = get_act(act_name) mlp_init(affines, init=init, init_scale=init_scale) if zero_out: zero_init(affines[-1]) def forward(self, h, params=None, options: Optional[AttrDict] = None, log_prefix: str = ""): options = AttrDict() if options is None else AttrDict(options) params = self.update(params) *hid, out = self.affines for i, layer in enumerate(hid): h = self.act(layer(h, params=subdict(params, f"{log_prefix}affines.{i}"))) last = len(self.affines) - 1 h = out(h, params=subdict(params, f"{log_prefix}affines.{last}")) return h class LayerNorm(nn.LayerNorm): def __init__( self, norm_shape: Union[int, Tuple[int]], eps: float = 1e-5, elementwise_affine: bool = True ): super().__init__(norm_shape, eps=eps, elementwise_affine=elementwise_affine) self.width = np.prod(norm_shape) self.max_numel = 65535 * self.width def forward(self, input): if input.numel() > self.max_numel: return F.layer_norm( input.float(), self.normalized_shape, self.weight, self.bias, self.eps ).type_as(input) else: return super(LayerNorm, self).forward(input.float()).type_as(input) class PointSetEmbedding(nn.Module): def __init__( self, *, radius: float, n_point: int, n_sample: int, d_input: int, d_hidden: List[int], patch_size: int = 1, stride: int = 1, activation: str = "swish", group_all: bool = False, padding_mode: str = "zeros", fps_method: str = "fps", **kwargs, ): super().__init__() self.n_point = n_point self.radius = radius self.n_sample = n_sample self.mlp_convs = nn.ModuleList() self.act = get_act(activation) self.patch_size = patch_size self.stride = stride last_channel = d_input + 3 for out_channel in d_hidden: self.mlp_convs.append( nn.Conv2d( last_channel, out_channel, kernel_size=(patch_size, 1), stride=(stride, 1), padding=(patch_size // 2, 0), padding_mode=padding_mode, **kwargs, ) ) last_channel = out_channel self.group_all = group_all self.fps_method = fps_method def forward(self, xyz, points): """ Input: xyz: input points position data, [B, C, N] points: input points data, [B, D, N] Return: new_points: sample points feature data, [B, d_hidden[-1], n_point] """ xyz = xyz.permute(0, 2, 1) if points is not None: points = points.permute(0, 2, 1) if self.group_all: new_xyz, new_points = sample_and_group_all(xyz, points) else: new_xyz, new_points = sample_and_group( self.n_point, self.radius, self.n_sample, xyz, points, deterministic=not self.training, fps_method=self.fps_method, ) # new_xyz: sampled points position data, [B, n_point, C] # new_points: sampled points data, [B, n_point, n_sample, C+D] new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, n_sample, n_point] for i, conv in enumerate(self.mlp_convs): new_points = self.act(self.apply_conv(new_points, conv)) new_points = new_points.mean(dim=2) return new_points def apply_conv(self, points: torch.Tensor, conv: nn.Module): batch, channels, n_samples, _ = points.shape # Shuffle the representations if self.patch_size > 1: # TODO shuffle deterministically when not self.training _, indices = torch.rand(batch, channels, n_samples, 1, device=points.device).sort(dim=2) points = torch.gather(points, 2, torch.broadcast_to(indices, points.shape)) return conv(points) ================================================ FILE: shap_e/models/nn/pointnet2_utils.py ================================================ """ Based on https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet2_utils.py MIT License Copyright (c) 2019 benny Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from time import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def timeit(tag, t): print("{}: {}s".format(tag, time() - t)) return time() def pc_normalize(pc): l = pc.shape[0] centroid = np.mean(pc, axis=0) pc = pc - centroid m = np.max(np.sqrt(np.sum(pc**2, axis=1))) pc = pc / m return pc def square_distance(src, dst): """ Calculate Euclid distance between each two points. src^T * dst = xn * xm + yn * ym + zn * zm; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst Input: src: source points, [B, N, C] dst: target points, [B, M, C] Output: dist: per-point square distance, [B, N, M] """ B, N, _ = src.shape _, M, _ = dst.shape dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist += torch.sum(src**2, -1).view(B, N, 1) dist += torch.sum(dst**2, -1).view(B, 1, M) return dist def index_points(points, idx): """ Input: points: input points data, [B, N, C] idx: sample index data, [B, S] Return: new_points:, indexed points data, [B, S, C] """ device = points.device B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 batch_indices = ( torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) ) new_points = points[batch_indices, idx, :] return new_points def farthest_point_sample(xyz, npoint, deterministic=False): """ Input: xyz: pointcloud data, [B, N, 3] npoint: number of samples Return: centroids: sampled pointcloud index, [B, npoint] """ device = xyz.device B, N, C = xyz.shape centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) distance = torch.ones(B, N).to(device) * 1e10 if deterministic: farthest = torch.arange(0, B, dtype=torch.long).to(device) else: farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) batch_indices = torch.arange(B, dtype=torch.long).to(device) for i in range(npoint): centroids[:, i] = farthest centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) dist = torch.sum((xyz - centroid) ** 2, -1) mask = dist < distance distance[mask] = dist[mask] farthest = torch.max(distance, -1)[1] return centroids def query_ball_point(radius, nsample, xyz, new_xyz): """ Input: radius: local region radius nsample: max sample number in local region xyz: all points, [B, N, 3] new_xyz: query points, [B, S, 3] Return: group_idx: grouped points index, [B, S, nsample] """ device = xyz.device B, N, C = xyz.shape _, S, _ = new_xyz.shape group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) sqrdists = square_distance(new_xyz, xyz) group_idx[sqrdists > radius**2] = N group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) mask = group_idx == N group_idx[mask] = group_first[mask] return group_idx def sample_and_group( npoint, radius, nsample, xyz, points, returnfps=False, deterministic=False, fps_method: str = "fps", ): """ Input: npoint: radius: nsample: xyz: input points position data, [B, N, 3] points: input points data, [B, N, D] Return: new_xyz: sampled points position data, [B, npoint, nsample, 3] new_points: sampled points data, [B, npoint, nsample, 3+D] """ B, N, C = xyz.shape S = npoint if fps_method == "fps": fps_idx = farthest_point_sample(xyz, npoint, deterministic=deterministic) # [B, npoint, C] elif fps_method == "first": fps_idx = torch.arange(npoint)[None].repeat(B, 1) else: raise ValueError(f"Unknown FPS method: {fps_method}") new_xyz = index_points(xyz, fps_idx) idx = query_ball_point(radius, nsample, xyz, new_xyz) grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) if points is not None: grouped_points = index_points(points, idx) new_points = torch.cat( [grouped_xyz_norm, grouped_points], dim=-1 ) # [B, npoint, nsample, C+D] else: new_points = grouped_xyz_norm if returnfps: return new_xyz, new_points, grouped_xyz, fps_idx else: return new_xyz, new_points def sample_and_group_all(xyz, points): """ Input: xyz: input points position data, [B, N, 3] points: input points data, [B, N, D] Return: new_xyz: sampled points position data, [B, 1, 3] new_points: sampled points data, [B, 1, N, 3+D] """ device = xyz.device B, N, C = xyz.shape new_xyz = torch.zeros(B, 1, C).to(device) grouped_xyz = xyz.view(B, 1, N, C) if points is not None: new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) else: new_points = grouped_xyz return new_xyz, new_points class PointNetSetAbstraction(nn.Module): def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): super(PointNetSetAbstraction, self).__init__() self.npoint = npoint self.radius = radius self.nsample = nsample self.mlp_convs = nn.ModuleList() self.mlp_bns = nn.ModuleList() last_channel = in_channel for out_channel in mlp: self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) self.mlp_bns.append(nn.BatchNorm2d(out_channel)) last_channel = out_channel self.group_all = group_all def forward(self, xyz, points): """ Input: xyz: input points position data, [B, C, N] points: input points data, [B, D, N] Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ xyz = xyz.permute(0, 2, 1) if points is not None: points = points.permute(0, 2, 1) if self.group_all: new_xyz, new_points = sample_and_group_all(xyz, points) else: new_xyz, new_points = sample_and_group( self.npoint, self.radius, self.nsample, xyz, points, deterministic=not self.training ) # new_xyz: sampled points position data, [B, npoint, C] # new_points: sampled points data, [B, npoint, nsample, C+D] new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] new_points = F.relu(bn(conv(new_points))) new_points = torch.max(new_points, 2)[0] new_xyz = new_xyz.permute(0, 2, 1) return new_xyz, new_points class PointNetSetAbstractionMsg(nn.Module): def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): super(PointNetSetAbstractionMsg, self).__init__() self.npoint = npoint self.radius_list = radius_list self.nsample_list = nsample_list self.conv_blocks = nn.ModuleList() self.bn_blocks = nn.ModuleList() for i in range(len(mlp_list)): convs = nn.ModuleList() bns = nn.ModuleList() last_channel = in_channel + 3 for out_channel in mlp_list[i]: convs.append(nn.Conv2d(last_channel, out_channel, 1)) bns.append(nn.BatchNorm2d(out_channel)) last_channel = out_channel self.conv_blocks.append(convs) self.bn_blocks.append(bns) def forward(self, xyz, points): """ Input: xyz: input points position data, [B, C, N] points: input points data, [B, D, N] Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ xyz = xyz.permute(0, 2, 1) if points is not None: points = points.permute(0, 2, 1) B, N, C = xyz.shape S = self.npoint new_xyz = index_points(xyz, farthest_point_sample(xyz, S, deterministic=not self.training)) new_points_list = [] for i, radius in enumerate(self.radius_list): K = self.nsample_list[i] group_idx = query_ball_point(radius, K, xyz, new_xyz) grouped_xyz = index_points(xyz, group_idx) grouped_xyz -= new_xyz.view(B, S, 1, C) if points is not None: grouped_points = index_points(points, group_idx) grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) else: grouped_points = grouped_xyz grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] for j in range(len(self.conv_blocks[i])): conv = self.conv_blocks[i][j] bn = self.bn_blocks[i][j] grouped_points = F.relu(bn(conv(grouped_points))) new_points = torch.max(grouped_points, 2)[0] # [B, D', S] new_points_list.append(new_points) new_xyz = new_xyz.permute(0, 2, 1) new_points_concat = torch.cat(new_points_list, dim=1) return new_xyz, new_points_concat class PointNetFeaturePropagation(nn.Module): def __init__(self, in_channel, mlp): super(PointNetFeaturePropagation, self).__init__() self.mlp_convs = nn.ModuleList() self.mlp_bns = nn.ModuleList() last_channel = in_channel for out_channel in mlp: self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) self.mlp_bns.append(nn.BatchNorm1d(out_channel)) last_channel = out_channel def forward(self, xyz1, xyz2, points1, points2): """ Input: xyz1: input points position data, [B, C, N] xyz2: sampled input points position data, [B, C, S] points1: input points data, [B, D, N] points2: input points data, [B, D, S] Return: new_points: upsampled points data, [B, D', N] """ xyz1 = xyz1.permute(0, 2, 1) xyz2 = xyz2.permute(0, 2, 1) points2 = points2.permute(0, 2, 1) B, N, C = xyz1.shape _, S, _ = xyz2.shape if S == 1: interpolated_points = points2.repeat(1, N, 1) else: dists = square_distance(xyz1, xyz2) dists, idx = dists.sort(dim=-1) dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] dist_recip = 1.0 / (dists + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_points = torch.sum( index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2 ) if points1 is not None: points1 = points1.permute(0, 2, 1) new_points = torch.cat([points1, interpolated_points], dim=-1) else: new_points = interpolated_points new_points = new_points.permute(0, 2, 1) for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] new_points = F.relu(bn(conv(new_points))) return new_points ================================================ FILE: shap_e/models/nn/utils.py ================================================ from typing import Iterable, Union import numpy as np import torch ArrayType = Union[np.ndarray, Iterable[int], torch.Tensor] def to_torch(arr: ArrayType, dtype=torch.float): if isinstance(arr, torch.Tensor): return arr return torch.from_numpy(np.array(arr)).to(dtype) def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor: """ Sample from the given discrete probability distribution with replacement. The i-th bin is assumed to have mass pmf[i]. :param pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all() :param n_samples: number of samples :return: indices sampled with replacement """ *shape, support_size, last_dim = pmf.shape assert last_dim == 1 cdf = torch.cumsum(pmf.view(-1, support_size), dim=1) inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device)) return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1) def safe_divide(a, b, epsilon=1e-6): return a / torch.where(b < 0, b - epsilon, b + epsilon) ================================================ FILE: shap_e/models/query.py ================================================ from dataclasses import dataclass from typing import Callable, Optional import torch @dataclass class Query: # Both of these are of shape [batch_size x ... x 3] position: torch.Tensor direction: Optional[torch.Tensor] = None t_min: Optional[torch.Tensor] = None t_max: Optional[torch.Tensor] = None def copy(self) -> "Query": return Query( position=self.position, direction=self.direction, t_min=self.t_min, t_max=self.t_max, ) def map_tensors(self, f: Callable[[torch.Tensor], torch.Tensor]) -> "Query": return Query( position=f(self.position), direction=f(self.direction) if self.direction is not None else None, t_min=f(self.t_min) if self.t_min is not None else None, t_max=f(self.t_max) if self.t_max is not None else None, ) ================================================ FILE: shap_e/models/renderer.py ================================================ from abc import abstractmethod from typing import Callable, Dict, List, Optional, Tuple import numpy as np import torch from shap_e.models.nn.camera import ( DifferentiableCamera, DifferentiableProjectiveCamera, get_image_coords, projective_camera_frame, ) from shap_e.models.nn.meta import MetaModule from shap_e.util.collections import AttrDict class Renderer(MetaModule): """ A rendering abstraction that can render rays and views by calling the appropriate models. The models are instantiated outside but registered in this module. """ @abstractmethod def render_views( self, batch: AttrDict, params: Optional[Dict] = None, options: Optional[Dict] = None, ) -> AttrDict: """ Returns a backproppable rendering of a view :param batch: contains - height: Optional[int] - width: Optional[int] - inner_batch_size or ray_batch_size: Optional[int] defaults to 4096 rays And additionally, to specify poses with a default up direction: - poses: [batch_size x *shape x 2 x 3] where poses[:, ..., 0, :] are the camera positions, and poses[:, ..., 1, :] are the z-axis (toward the object) of the camera frame. - camera: DifferentiableCamera. Assumes the same camera position across batch for simplicity. Could eventually support batched cameras. or to specify a batch of arbitrary poses: - cameras: DifferentiableCameraBatch of shape [batch_size x *shape]. :param params: Meta parameters :param options: Optional[Dict] """ class RayRenderer(Renderer): """ A rendering abstraction that can render rays and views by calling the appropriate models. The models are instantiated outside but registered in this module. """ @abstractmethod def render_rays( self, batch: AttrDict, params: Optional[Dict] = None, options: Optional[Dict] = None, ) -> AttrDict: """ :param batch: has - rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray. - radii (optional): [batch_size x ... x 1] the "thickness" of each ray. :param options: Optional[Dict] """ def render_views( self, batch: AttrDict, params: Optional[Dict] = None, options: Optional[Dict] = None, device: torch.device = torch.device("cuda"), ) -> AttrDict: output = render_views_from_rays( self.render_rays, batch, params=params, options=options, device=self.device, ) return output def forward( self, batch: AttrDict, params: Optional[Dict] = None, options: Optional[Dict] = None, ) -> AttrDict: """ :param batch: must contain either - rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray. or - poses: [batch_size x 2 x 3] where poses[:, 0] are the camera positions, and poses[:, 1] are the z-axis (toward the object) of the camera frame. - camera: an instance of Camera that implements camera_rays or - cameras: DifferentiableCameraBatch of shape [batch_size x *shape]. For both of the above two options, these may be specified. - height: Optional[int] - width: Optional[int] - ray_batch_size or inner_batch_size: Optional[int] defaults to 4096 rays :param params: a dictionary of optional meta parameters. :param options: A Dict of other hyperparameters that could be related to rendering or debugging :return: a dictionary containing - channels: [batch_size, *shape, n_channels] - distances: [batch_size, *shape, 1] - transmittance: [batch_size, *shape, 1] - aux_losses: Dict[str, torch.Tensor] """ if "rays" in batch: for key in ["poses", "camera", "height", "width"]: assert key not in batch return self.render_rays(batch, params=params, options=options) elif "poses" in batch or "cameras" in batch: assert "rays" not in batch if "poses" in batch: assert "camera" in batch else: assert "camera" not in batch return self.render_views(batch, params=params, options=options) raise NotImplementedError def get_camera_from_batch(batch: AttrDict) -> Tuple[DifferentiableCamera, int, Tuple[int]]: if "poses" in batch: assert not "cameras" in batch batch_size, *inner_shape, n_vecs, spatial_dim = batch.poses.shape assert n_vecs == 2 and spatial_dim == 3 inner_batch_size = int(np.prod(inner_shape)) poses = batch.poses.view(batch_size * inner_batch_size, 2, 3) position, direction = poses[:, 0], poses[:, 1] camera = projective_camera_frame(position, direction, batch.camera) elif "cameras" in batch: assert not "camera" in batch batch_size, *inner_shape = batch.cameras.shape camera = batch.cameras.flat_camera else: raise ValueError(f'neither "poses" nor "cameras" found in keys: {batch.keys()}') if "height" in batch and "width" in batch: camera = camera.resize_image(batch.width, batch.height) return camera, batch_size, inner_shape def append_tensor(val_list: Optional[List[torch.Tensor]], output: Optional[torch.Tensor]): if val_list is None: return [output] return val_list + [output] def render_views_from_rays( render_rays: Callable[[AttrDict, AttrDict, AttrDict], AttrDict], batch: AttrDict, params: Optional[Dict] = None, options: Optional[Dict] = None, device: torch.device = torch.device("cuda"), ) -> AttrDict: camera, batch_size, inner_shape = get_camera_from_batch(batch) inner_batch_size = int(np.prod(inner_shape)) coords = get_image_coords(camera.width, camera.height).to(device) coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape]) rays = camera.camera_rays(coords) # mip-NeRF radii calculation from: https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/datasets.py#L193-L200 directions = rays.view(batch_size, inner_batch_size, camera.height, camera.width, 2, 3)[ ..., 1, : ] neighbor_dists = torch.linalg.norm(directions[:, :, :, 1:] - directions[:, :, :, :-1], dim=-1) neighbor_dists = torch.cat([neighbor_dists, neighbor_dists[:, :, :, -2:-1]], dim=3) radii = (neighbor_dists * 2 / np.sqrt(12)).view(batch_size, -1, 1) rays = rays.view(batch_size, inner_batch_size * camera.height * camera.width, 2, 3) if isinstance(camera, DifferentiableProjectiveCamera): # Compute the camera z direction corresponding to every ray's pixel. # Used for depth computations below. z_directions = ( (camera.z / torch.linalg.norm(camera.z, dim=-1, keepdim=True)) .reshape([batch_size, inner_batch_size, 1, 3]) .repeat(1, 1, camera.width * camera.height, 1) .reshape(1, inner_batch_size * camera.height * camera.width, 3) ) ray_batch_size = batch.get("ray_batch_size", batch.get("inner_batch_size", 4096)) assert rays.shape[1] % ray_batch_size == 0 n_batches = rays.shape[1] // ray_batch_size output_list = AttrDict(aux_losses=dict()) for idx in range(n_batches): rays_batch = AttrDict( rays=rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size], radii=radii[:, idx * ray_batch_size : (idx + 1) * ray_batch_size], ) output = render_rays(rays_batch, params=params, options=options) if isinstance(camera, DifferentiableProjectiveCamera): z_batch = z_directions[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] ray_directions = rays_batch.rays[:, :, 1] z_dots = (ray_directions * z_batch).sum(-1, keepdim=True) output.depth = output.distances * z_dots output_list = output_list.combine(output, append_tensor) def _resize(val_list: List[torch.Tensor]): val = torch.cat(val_list, dim=1) assert val.shape[1] == inner_batch_size * camera.height * camera.width return val.view(batch_size, *inner_shape, camera.height, camera.width, -1) def _avg(_key: str, loss_list: List[torch.Tensor]): return sum(loss_list) / n_batches output = AttrDict( {name: _resize(val_list) for name, val_list in output_list.items() if name != "aux_losses"} ) output.aux_losses = output_list.aux_losses.map(_avg) return output ================================================ FILE: shap_e/models/stf/__init__.py ================================================ ================================================ FILE: shap_e/models/stf/base.py ================================================ from abc import ABC, abstractmethod from typing import Any, Dict, Optional import torch from shap_e.models.query import Query from shap_e.models.renderer import append_tensor from shap_e.util.collections import AttrDict class Model(ABC): @abstractmethod def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict[str, Any]: """ Predict an attribute given position """ def forward_batched( self, query: Query, query_batch_size: int = 4096, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict[str, Any]: if not query.position.numel(): # Avoid torch.cat() of zero tensors. return self(query, params=params, options=options) if options.cache is None: created_cache = True options.cache = AttrDict() else: created_cache = False results_list = AttrDict() for i in range(0, query.position.shape[1], query_batch_size): out = self( query=query.map_tensors(lambda x, i=i: x[:, i : i + query_batch_size]), params=params, options=options, ) results_list = results_list.combine(out, append_tensor) if created_cache: del options["cache"] return results_list.map(lambda key, tensor_list: torch.cat(tensor_list, dim=1)) ================================================ FILE: shap_e/models/stf/mlp.py ================================================ from functools import partial from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn from shap_e.models.nn.checkpoint import checkpoint from shap_e.models.nn.encoding import encode_position, maybe_encode_direction from shap_e.models.nn.meta import MetaModule, subdict from shap_e.models.nn.ops import MetaLinear, get_act, mlp_init from shap_e.models.query import Query from shap_e.util.collections import AttrDict from .base import Model class MLPModel(MetaModule, Model): def __init__( self, n_output: int, output_activation: str, # Positional encoding parameters posenc_version: str = "v1", # Direction related channel prediction insert_direction_at: Optional[int] = None, # MLP parameters d_hidden: int = 256, n_hidden_layers: int = 4, activation: str = "relu", init: Optional[str] = None, init_scale: float = 1.0, meta_parameters: bool = False, trainable_meta: bool = False, meta_proj: bool = True, meta_bias: bool = True, meta_start: int = 0, meta_stop: Optional[int] = None, n_meta_layers: Optional[int] = None, register_freqs: bool = False, device: torch.device = torch.device("cuda"), ): super().__init__() if register_freqs: self.register_buffer("freqs", 2.0 ** torch.arange(10, device=device).view(1, 10)) # Positional encoding self.posenc_version = posenc_version dummy = torch.eye(1, 3) d_posenc_pos = encode_position(posenc_version, position=dummy).shape[-1] d_posenc_dir = maybe_encode_direction(posenc_version, position=dummy).shape[-1] # Instantiate the MLP mlp_widths = [d_hidden] * n_hidden_layers input_widths = [d_posenc_pos, *mlp_widths] output_widths = mlp_widths + [n_output] self.meta_parameters = meta_parameters # When this model is used jointly to express NeRF, it may have to # process directions as well in which case we simply concatenate # the direction representation at the specified layer. self.insert_direction_at = insert_direction_at if insert_direction_at is not None: input_widths[self.insert_direction_at] += d_posenc_dir linear_cls = lambda meta: ( partial( MetaLinear, meta_scale=False, meta_shift=False, meta_proj=meta_proj, meta_bias=meta_bias, trainable_meta=trainable_meta, ) if meta else nn.Linear ) if meta_stop is None: if n_meta_layers is not None: assert n_meta_layers > 0 meta_stop = meta_start + n_meta_layers - 1 else: meta_stop = n_hidden_layers if meta_parameters: metas = [meta_start <= layer <= meta_stop for layer in range(n_hidden_layers + 1)] else: metas = [False] * (n_hidden_layers + 1) self.mlp = nn.ModuleList( [ linear_cls(meta)(d_in, d_out, device=device) for meta, d_in, d_out in zip(metas, input_widths, output_widths) ] ) mlp_init(self.mlp, init=init, init_scale=init_scale) self.activation = get_act(activation) self.output_activation = get_act(output_activation) self.device = device self.to(device) def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict: """ :param position: [batch_size x ... x 3] :param params: Meta parameters :param options: Optional hyperparameters """ # query.direction is None typically for SDF models and training h_final, _h_directionless = self._mlp( query.position, query.direction, params=params, options=options ) return self.output_activation(h_final) def _run_mlp( self, position: torch.Tensor, direction: torch.Tensor, params: AttrDict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """ :return: the final and directionless activations at the given query """ h_preact = h = encode_position(self.posenc_version, position=position) h_directionless = None for i, layer in enumerate(self.mlp): if i == self.insert_direction_at: h_directionless = h_preact h_direction = maybe_encode_direction( self.posenc_version, position=position, direction=direction ) h = torch.cat([h, h_direction], dim=-1) if isinstance(layer, MetaLinear): h = layer(h, params=subdict(params, f"mlp.{i}")) else: h = layer(h) h_preact = h if i < len(self.mlp) - 1: h = self.activation(h) h_final = h if h_directionless is None: h_directionless = h_preact return h_final, h_directionless def _mlp( self, position: torch.Tensor, direction: Optional[torch.Tensor] = None, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ :param position: [batch_size x ... x 3] :param params: Meta parameters :param options: Optional hyperparameters :return: the final and directionless activations at the given query """ params = self.update(params) options = AttrDict() if options is None else AttrDict(options) mlp = partial(self._run_mlp, direction=direction, params=params) parameters = [] for i, layer in enumerate(self.mlp): if isinstance(layer, MetaLinear): parameters.extend(list(subdict(params, f"mlp.{i}").values())) else: parameters.extend(layer.parameters()) h_final, h_directionless = checkpoint( mlp, (position,), parameters, options.checkpoint_stf_model ) return h_final, h_directionless class MLPSDFModel(MLPModel): def __init__(self, initial_bias: float = -0.1, **kwargs): super().__init__(n_output=1, output_activation="identity", **kwargs) self.mlp[-1].bias.data.fill_(initial_bias) def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict[str, Any]: signed_distance = super().forward(query=query, params=params, options=options) return AttrDict(signed_distance=signed_distance) class MLPTextureFieldModel(MLPModel): def __init__( self, n_channels: int = 3, **kwargs, ): super().__init__(n_output=n_channels, output_activation="sigmoid", **kwargs) def forward( self, query: Query, params: Optional[Dict[str, torch.Tensor]] = None, options: Optional[Dict[str, Any]] = None, ) -> AttrDict[str, Any]: channels = super().forward(query=query, params=params, options=options) return AttrDict(channels=channels) ================================================ FILE: shap_e/models/stf/renderer.py ================================================ import warnings from abc import ABC, abstractmethod from functools import partial from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn.functional as F from shap_e.models.nn.camera import DifferentiableCamera, DifferentiableProjectiveCamera from shap_e.models.nn.meta import subdict from shap_e.models.nn.utils import to_torch from shap_e.models.query import Query from shap_e.models.renderer import Renderer, get_camera_from_batch from shap_e.models.volume import BoundingBoxVolume, Volume from shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR from shap_e.rendering.mc import marching_cubes from shap_e.rendering.torch_mesh import TorchMesh from shap_e.rendering.view_data import ProjectiveCamera from shap_e.util.collections import AttrDict from .base import Model class STFRendererBase(ABC): @abstractmethod def get_signed_distance( self, position: torch.Tensor, params: Dict[str, torch.Tensor], options: AttrDict[str, Any], ) -> torch.Tensor: pass @abstractmethod def get_texture( self, position: torch.Tensor, params: Dict[str, torch.Tensor], options: AttrDict[str, Any], ) -> torch.Tensor: pass class STFRenderer(Renderer, STFRendererBase): def __init__( self, sdf: Model, tf: Model, volume: Volume, grid_size: int, texture_channels: Sequence[str] = ("R", "G", "B"), channel_scale: Sequence[float] = (255.0, 255.0, 255.0), ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR, diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR, specular_color: Union[float, Tuple[float]] = 0.0, output_srgb: bool = True, device: torch.device = torch.device("cuda"), **kwargs, ): super().__init__(**kwargs) assert isinstance(volume, BoundingBoxVolume), "cannot sample points in unknown volume" self.sdf = sdf self.tf = tf self.volume = volume self.grid_size = grid_size self.texture_channels = texture_channels self.channel_scale = to_torch(channel_scale).to(device) self.ambient_color = ambient_color self.diffuse_color = diffuse_color self.specular_color = specular_color self.output_srgb = output_srgb self.device = device self.to(device) def render_views( self, batch: Dict, params: Optional[Dict] = None, options: Optional[Dict] = None, ) -> AttrDict: params = self.update(params) options = AttrDict() if not options else AttrDict(options) sdf_fn = partial(self.sdf.forward_batched, params=subdict(params, "sdf")) tf_fn = partial(self.tf.forward_batched, params=subdict(params, "tf")) nerstf_fn = None return render_views_from_stf( batch, options, sdf_fn=sdf_fn, tf_fn=tf_fn, nerstf_fn=nerstf_fn, volume=self.volume, grid_size=self.grid_size, channel_scale=self.channel_scale, texture_channels=self.texture_channels, ambient_color=self.ambient_color, diffuse_color=self.diffuse_color, specular_color=self.specular_color, output_srgb=self.output_srgb, device=self.device, ) def get_signed_distance( self, query: Query, params: Dict[str, torch.Tensor], options: AttrDict[str, Any], ) -> torch.Tensor: return self.sdf( query, params=subdict(params, "sdf"), options=options, ).signed_distance def get_texture( self, query: Query, params: Dict[str, torch.Tensor], options: AttrDict[str, Any], ) -> torch.Tensor: return self.tf( query, params=subdict(params, "tf"), options=options, ).channels def render_views_from_stf( batch: Dict, options: AttrDict[str, Any], *, sdf_fn: Optional[Callable], tf_fn: Optional[Callable], nerstf_fn: Optional[Callable], volume: BoundingBoxVolume, grid_size: int, channel_scale: torch.Tensor, texture_channels: Sequence[str] = ("R", "G", "B"), ambient_color: Union[float, Tuple[float]] = 0.0, diffuse_color: Union[float, Tuple[float]] = 1.0, specular_color: Union[float, Tuple[float]] = 0.2, output_srgb: bool = False, device: torch.device = torch.device("cuda"), ) -> AttrDict: """ :param batch: contains either ["poses", "camera"], or ["cameras"]. Can optionally contain any of ["height", "width", "query_batch_size"] :param options: controls checkpointing, caching, and rendering :param sdf_fn: returns [batch_size, query_batch_size, n_output] where n_output >= 1. :param tf_fn: returns [batch_size, query_batch_size, n_channels] :param volume: AABB volume :param grid_size: SDF sampling resolution :param texture_channels: what texture to predict :param channel_scale: how each channel is scaled :return: at least channels: [batch_size, len(cameras), height, width, 3] transmittance: [batch_size, len(cameras), height, width, 1] aux_losses: AttrDict[str, torch.Tensor] """ camera, batch_size, inner_shape = get_camera_from_batch(batch) inner_batch_size = int(np.prod(inner_shape)) assert camera.width == camera.height, "only square views are supported" assert camera.x_fov == camera.y_fov, "only square views are supported" assert isinstance(camera, DifferentiableProjectiveCamera) device = camera.origin.device device_type = device.type TO_CACHE = ["fields", "raw_meshes", "raw_signed_distance", "raw_density", "mesh_mask", "meshes"] if options.cache is not None and all(key in options.cache for key in TO_CACHE): fields = options.cache.fields raw_meshes = options.cache.raw_meshes raw_signed_distance = options.cache.raw_signed_distance raw_density = options.cache.raw_density mesh_mask = options.cache.mesh_mask else: query_batch_size = batch.get("query_batch_size", batch.get("ray_batch_size", 4096)) query_points = volume_query_points(volume, grid_size) fn = nerstf_fn if sdf_fn is None else sdf_fn sdf_out = fn( query=Query(position=query_points[None].repeat(batch_size, 1, 1)), query_batch_size=query_batch_size, options=options, ) raw_signed_distance = sdf_out.signed_distance raw_density = None if "density" in sdf_out: raw_density = sdf_out.density with torch.autocast(device_type, enabled=False): fields = sdf_out.signed_distance.float() raw_signed_distance = sdf_out.signed_distance assert ( len(fields.shape) == 3 and fields.shape[-1] == 1 ), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" fields = fields.reshape(batch_size, *([grid_size] * 3)) # Force a negative border around the SDFs to close off all the models. full_grid = torch.zeros( batch_size, grid_size + 2, grid_size + 2, grid_size + 2, device=fields.device, dtype=fields.dtype, ) full_grid.fill_(-1.0) full_grid[:, 1:-1, 1:-1, 1:-1] = fields fields = full_grid raw_meshes = [] mesh_mask = [] for field in fields: raw_mesh = marching_cubes(field, volume.bbox_min, volume.bbox_max - volume.bbox_min) if len(raw_mesh.faces) == 0: # DDP deadlocks when there are unused parameters on some ranks # and not others, so we make sure the field is a dependency in # the graph regardless of empty meshes. vertex_dependency = field.mean() raw_mesh = TorchMesh( verts=torch.zeros(3, 3, device=device) + vertex_dependency, faces=torch.tensor([[0, 1, 2]], dtype=torch.long, device=device), ) # Make sure we only feed back zero gradients to the field # by masking out the final renderings of this mesh. mesh_mask.append(False) else: mesh_mask.append(True) raw_meshes.append(raw_mesh) mesh_mask = torch.tensor(mesh_mask, device=device) max_vertices = max(len(m.verts) for m in raw_meshes) fn = nerstf_fn if tf_fn is None else tf_fn tf_out = fn( query=Query( position=torch.stack( [m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes], dim=0, ) ), query_batch_size=query_batch_size, options=options, ) if "cache" in options: options.cache.fields = fields options.cache.raw_meshes = raw_meshes options.cache.raw_signed_distance = raw_signed_distance options.cache.raw_density = raw_density options.cache.mesh_mask = mesh_mask if output_srgb: tf_out.channels = _convert_srgb_to_linear(tf_out.channels) # Make sure the raw meshes have colors. with torch.autocast(device_type, enabled=False): textures = tf_out.channels.float() assert len(textures.shape) == 3 and textures.shape[-1] == len( texture_channels ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" for m, texture in zip(raw_meshes, textures): texture = texture[: len(m.verts)] m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))} args = dict( options=options, texture_channels=texture_channels, ambient_color=ambient_color, diffuse_color=diffuse_color, specular_color=specular_color, camera=camera, batch_size=batch_size, inner_batch_size=inner_batch_size, inner_shape=inner_shape, raw_meshes=raw_meshes, tf_out=tf_out, ) try: out = _render_with_pytorch3d(**args) except ModuleNotFoundError as exc: warnings.warn(f"exception rendering with PyTorch3D: {exc}") warnings.warn( "falling back on native PyTorch renderer, which does not support full gradients" ) out = _render_with_raycast(**args) # Apply mask to prevent gradients for empty meshes. reshaped_mask = mesh_mask.view([-1] + [1] * (len(out.channels.shape) - 1)) out.channels = torch.where(reshaped_mask, out.channels, torch.zeros_like(out.channels)) out.transmittance = torch.where( reshaped_mask, out.transmittance, torch.ones_like(out.transmittance) ) if output_srgb: out.channels = _convert_linear_to_srgb(out.channels) out.channels = out.channels * (1 - out.transmittance) * channel_scale.view(-1) # This might be useful information to have downstream out.raw_meshes = raw_meshes out.fields = fields out.mesh_mask = mesh_mask out.raw_signed_distance = raw_signed_distance out.aux_losses = AttrDict(cross_entropy=cross_entropy_sdf_loss(fields)) if raw_density is not None: out.raw_density = raw_density return out def _render_with_pytorch3d( options: AttrDict, texture_channels: Sequence[str], ambient_color: Union[float, Tuple[float]], diffuse_color: Union[float, Tuple[float]], specular_color: Union[float, Tuple[float]], camera: DifferentiableCamera, batch_size: int, inner_shape: Sequence[int], inner_batch_size: int, raw_meshes: List[TorchMesh], tf_out: AttrDict, ): _ = tf_out # Lazy import because pytorch3d is installed lazily. from shap_e.rendering.pytorch3d_util import ( blender_uniform_lights, convert_cameras_torch, convert_meshes, render_images, ) n_channels = len(texture_channels) device = camera.origin.device device_type = device.type with torch.autocast(device_type, enabled=False): meshes = convert_meshes(raw_meshes) lights = blender_uniform_lights( batch_size, device, ambient_color=ambient_color, diffuse_color=diffuse_color, specular_color=specular_color, ) # Separate camera intrinsics for each view, so that we can # create a new camera for each batch of views. cam_shape = [batch_size, inner_batch_size, -1] position = camera.origin.reshape(cam_shape) x = camera.x.reshape(cam_shape) y = camera.y.reshape(cam_shape) z = camera.z.reshape(cam_shape) results = [] for i in range(inner_batch_size): sub_cams = convert_cameras_torch( position[:, i], x[:, i], y[:, i], z[:, i], fov=camera.x_fov ) imgs = render_images( camera.width, meshes, sub_cams, lights, use_checkpoint=options.checkpoint_render, **options.get("render_options", {}), ) results.append(imgs) views = torch.stack(results, dim=1) views = views.view(batch_size, *inner_shape, camera.height, camera.width, n_channels + 1) out = AttrDict( channels=views[..., :-1], # [batch_size, *inner_shape, height, width, n_channels] transmittance=1 - views[..., -1:], # [batch_size, *inner_shape, height, width, 1] meshes=meshes, ) return out def _render_with_raycast( options: AttrDict, texture_channels: Sequence[str], ambient_color: Union[float, Tuple[float]], diffuse_color: Union[float, Tuple[float]], specular_color: Union[float, Tuple[float]], camera: DifferentiableCamera, batch_size: int, inner_shape: Sequence[int], inner_batch_size: int, raw_meshes: List[TorchMesh], tf_out: AttrDict, ): assert np.mean(np.array(specular_color)) == 0 from shap_e.rendering.raycast.render import render_diffuse_mesh from shap_e.rendering.raycast.types import TriMesh as TorchTriMesh device = camera.origin.device device_type = device.type cam_shape = [batch_size, inner_batch_size, -1] origin = camera.origin.reshape(cam_shape) x = camera.x.reshape(cam_shape) y = camera.y.reshape(cam_shape) z = camera.z.reshape(cam_shape) with torch.autocast(device_type, enabled=False): all_meshes = [] for i, mesh in enumerate(raw_meshes): all_meshes.append( TorchTriMesh( faces=mesh.faces.long(), vertices=mesh.verts.float(), vertex_colors=tf_out.channels[i, : len(mesh.verts)].float(), ) ) all_images = [] for i, mesh in enumerate(all_meshes): for j in range(inner_batch_size): all_images.append( render_diffuse_mesh( camera=ProjectiveCamera( origin=origin[i, j].detach().cpu().numpy(), x=x[i, j].detach().cpu().numpy(), y=y[i, j].detach().cpu().numpy(), z=z[i, j].detach().cpu().numpy(), width=camera.width, height=camera.height, x_fov=camera.x_fov, y_fov=camera.y_fov, ), mesh=mesh, diffuse=float(np.array(diffuse_color).mean()), ambient=float(np.array(ambient_color).mean()), ray_batch_size=16, # low memory usage checkpoint=options.checkpoint_render, ) ) n_channels = len(texture_channels) views = torch.stack(all_images).view( batch_size, *inner_shape, camera.height, camera.width, n_channels + 1 ) return AttrDict( channels=views[..., :-1], # [batch_size, *inner_shape, height, width, n_channels] transmittance=1 - views[..., -1:], # [batch_size, *inner_shape, height, width, 1] meshes=all_meshes, ) def _convert_srgb_to_linear(u: torch.Tensor) -> torch.Tensor: return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4) def _convert_linear_to_srgb(u: torch.Tensor) -> torch.Tensor: return torch.where(u <= 0.0031308, 12.92 * u, 1.055 * (u ** (1 / 2.4)) - 0.055) def cross_entropy_sdf_loss(fields: torch.Tensor): logits = F.logsigmoid(fields) signs = (fields > 0).float() losses = [] for dim in range(1, 4): n = logits.shape[dim] for (t_start, t_end, p_start, p_end) in [(0, -1, 1, n), (1, n, 0, -1)]: targets = slice_fields(signs, dim, t_start, t_end) preds = slice_fields(logits, dim, p_start, p_end) losses.append( F.binary_cross_entropy_with_logits(preds, targets, reduction="none") .flatten(1) .mean() ) return torch.stack(losses, dim=-1).sum() def slice_fields(fields: torch.Tensor, dim: int, start: int, end: int): if dim == 1: return fields[:, start:end] elif dim == 2: return fields[:, :, start:end] elif dim == 3: return fields[:, :, :, start:end] else: raise ValueError(f"cannot slice dimension {dim}") def volume_query_points( volume: Volume, grid_size: int, ): assert isinstance(volume, BoundingBoxVolume) indices = torch.arange(grid_size**3, device=volume.bbox_min.device) zs = indices % grid_size ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size combined = torch.stack([xs, ys, zs], dim=1) return (combined.float() / (grid_size - 1)) * ( volume.bbox_max - volume.bbox_min ) + volume.bbox_min ================================================ FILE: shap_e/models/transmitter/__init__.py ================================================ ================================================ FILE: shap_e/models/transmitter/base.py ================================================ from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Tuple import torch.nn as nn from torch import torch from shap_e.models.renderer import Renderer from shap_e.util.collections import AttrDict from .bottleneck import latent_bottleneck_from_config, latent_warp_from_config from .params_proj import flatten_param_shapes, params_proj_from_config class Encoder(nn.Module, ABC): def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]]): """ Instantiate the encoder with information about the renderer's input parameters. This information can be used to create output layers to generate the necessary latents. """ super().__init__() self.param_shapes = param_shapes self.device = device @abstractmethod def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict: """ Encode a batch of data into a batch of latent information. """ class VectorEncoder(Encoder): def __init__( self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]], params_proj: Dict[str, Any], d_latent: int, latent_bottleneck: Optional[Dict[str, Any]] = None, latent_warp: Optional[Dict[str, Any]] = None, ): super().__init__(device=device, param_shapes=param_shapes) if latent_bottleneck is None: latent_bottleneck = dict(name="identity") if latent_warp is None: latent_warp = dict(name="identity") self.d_latent = d_latent self.params_proj = params_proj_from_config( params_proj, device=device, param_shapes=param_shapes, d_latent=d_latent ) self.latent_bottleneck = latent_bottleneck_from_config( latent_bottleneck, device=device, d_latent=d_latent ) self.latent_warp = latent_warp_from_config(latent_warp, device=device) def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict: h = self.encode_to_bottleneck(batch, options=options) return self.bottleneck_to_params(h, options=options) def encode_to_bottleneck( self, batch: AttrDict, options: Optional[AttrDict] = None ) -> torch.Tensor: return self.latent_warp.warp( self.latent_bottleneck(self.encode_to_vector(batch, options=options), options=options), options=options, ) @abstractmethod def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: """ Encode the batch into a single latent vector. """ def bottleneck_to_params( self, vector: torch.Tensor, options: Optional[AttrDict] = None ) -> AttrDict: _ = options return self.params_proj(self.latent_warp.unwarp(vector, options=options), options=options) class ChannelsEncoder(VectorEncoder): def __init__( self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]], params_proj: Dict[str, Any], d_latent: int, latent_bottleneck: Optional[Dict[str, Any]] = None, latent_warp: Optional[Dict[str, Any]] = None, ): super().__init__( device=device, param_shapes=param_shapes, params_proj=params_proj, d_latent=d_latent, latent_bottleneck=latent_bottleneck, latent_warp=latent_warp, ) self.flat_shapes = flatten_param_shapes(param_shapes) self.latent_ctx = sum(flat[0] for flat in self.flat_shapes.values()) @abstractmethod def encode_to_channels( self, batch: AttrDict, options: Optional[AttrDict] = None ) -> torch.Tensor: """ Encode the batch into a per-data-point set of latents. :return: [batch_size, latent_ctx, latent_width] """ def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: return self.encode_to_channels(batch, options=options).flatten(1) def bottleneck_to_channels( self, vector: torch.Tensor, options: Optional[AttrDict] = None ) -> torch.Tensor: _ = options return vector.view(vector.shape[0], self.latent_ctx, -1) def bottleneck_to_params( self, vector: torch.Tensor, options: Optional[AttrDict] = None ) -> AttrDict: _ = options return self.params_proj( self.bottleneck_to_channels(self.latent_warp.unwarp(vector)), options=options ) class Transmitter(nn.Module): def __init__(self, encoder: Encoder, renderer: Renderer): super().__init__() self.encoder = encoder self.renderer = renderer def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict: """ Transmit the batch through the encoder and then the renderer. """ params = self.encoder(batch, options=options) return self.renderer(batch, params=params, options=options) class VectorDecoder(nn.Module): def __init__( self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]], params_proj: Dict[str, Any], d_latent: int, latent_warp: Optional[Dict[str, Any]] = None, renderer: Renderer, ): super().__init__() self.device = device self.param_shapes = param_shapes if latent_warp is None: latent_warp = dict(name="identity") self.d_latent = d_latent self.params_proj = params_proj_from_config( params_proj, device=device, param_shapes=param_shapes, d_latent=d_latent ) self.latent_warp = latent_warp_from_config(latent_warp, device=device) self.renderer = renderer def bottleneck_to_params( self, vector: torch.Tensor, options: Optional[AttrDict] = None ) -> AttrDict: _ = options return self.params_proj(self.latent_warp.unwarp(vector, options=options), options=options) class ChannelsDecoder(VectorDecoder): def __init__( self, *, latent_ctx: int, **kwargs, ): super().__init__(**kwargs) self.latent_ctx = latent_ctx def bottleneck_to_channels( self, vector: torch.Tensor, options: Optional[AttrDict] = None ) -> torch.Tensor: _ = options return vector.view(vector.shape[0], self.latent_ctx, -1) def bottleneck_to_params( self, vector: torch.Tensor, options: Optional[AttrDict] = None ) -> AttrDict: _ = options return self.params_proj( self.bottleneck_to_channels(self.latent_warp.unwarp(vector)), options=options ) ================================================ FILE: shap_e/models/transmitter/bottleneck.py ================================================ from abc import ABC, abstractmethod from typing import Any, Dict, Optional import numpy as np import torch.nn as nn from torch import torch from shap_e.diffusion.gaussian_diffusion import diffusion_from_config from shap_e.util.collections import AttrDict class LatentBottleneck(nn.Module, ABC): def __init__(self, *, device: torch.device, d_latent: int): super().__init__() self.device = device self.d_latent = d_latent @abstractmethod def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: pass class LatentWarp(nn.Module, ABC): def __init__(self, *, device: torch.device): super().__init__() self.device = device @abstractmethod def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: pass @abstractmethod def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: pass class IdentityLatentWarp(LatentWarp): def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: _ = options return x def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: _ = options return x class Tan2LatentWarp(LatentWarp): def __init__(self, *, coeff1: float = 1.0, device: torch.device): super().__init__(device=device) self.coeff1 = coeff1 self.scale = np.tan(np.tan(1.0) * coeff1) def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: _ = options return ((x.float().tan() * self.coeff1).tan() / self.scale).to(x.dtype) def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: _ = options return ((x.float() * self.scale).arctan() / self.coeff1).arctan().to(x.dtype) class IdentityLatentBottleneck(LatentBottleneck): def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: _ = options return x class ClampNoiseBottleneck(LatentBottleneck): def __init__(self, *, device: torch.device, d_latent: int, noise_scale: float): super().__init__(device=device, d_latent=d_latent) self.noise_scale = noise_scale def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: _ = options x = x.tanh() if not self.training: return x return x + torch.randn_like(x) * self.noise_scale class ClampDiffusionNoiseBottleneck(LatentBottleneck): def __init__( self, *, device: torch.device, d_latent: int, diffusion: Dict[str, Any], diffusion_prob: float = 1.0, ): super().__init__(device=device, d_latent=d_latent) self.diffusion = diffusion_from_config(diffusion) self.diffusion_prob = diffusion_prob def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: _ = options x = x.tanh() if not self.training: return x t = torch.randint(low=0, high=self.diffusion.num_timesteps, size=(len(x),), device=x.device) t = torch.where( torch.rand(len(x), device=x.device) < self.diffusion_prob, t, torch.zeros_like(t) ) return self.diffusion.q_sample(x, t) def latent_bottleneck_from_config(config: Dict[str, Any], device: torch.device, d_latent: int): name = config.pop("name") if name == "clamp_noise": return ClampNoiseBottleneck(**config, device=device, d_latent=d_latent) elif name == "identity": return IdentityLatentBottleneck(**config, device=device, d_latent=d_latent) elif name == "clamp_diffusion_noise": return ClampDiffusionNoiseBottleneck(**config, device=device, d_latent=d_latent) else: raise ValueError(f"unknown latent bottleneck: {name}") def latent_warp_from_config(config: Dict[str, Any], device: torch.device): name = config.pop("name") if name == "identity": return IdentityLatentWarp(**config, device=device) elif name == "tan2": return Tan2LatentWarp(**config, device=device) else: raise ValueError(f"unknown latent warping function: {name}") ================================================ FILE: shap_e/models/transmitter/channels_encoder.py ================================================ from abc import ABC, abstractmethod from dataclasses import dataclass from functools import partial from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from PIL import Image from torch import torch from shap_e.models.generation.perceiver import SimplePerceiver from shap_e.models.generation.transformer import Transformer from shap_e.models.nn.camera import DifferentiableProjectiveCamera from shap_e.models.nn.encoding import ( MultiviewPointCloudEmbedding, MultiviewPoseEmbedding, PosEmbLinear, ) from shap_e.models.nn.ops import PointSetEmbedding from shap_e.rendering.point_cloud import PointCloud from shap_e.rendering.view_data import ProjectiveCamera from shap_e.util.collections import AttrDict from .base import ChannelsEncoder class TransformerChannelsEncoder(ChannelsEncoder, ABC): """ Encode point clouds using a transformer model with an extra output token used to extract a latent vector. """ def __init__( self, *, device: torch.device, dtype: torch.dtype, param_shapes: Dict[str, Tuple[int]], params_proj: Dict[str, Any], d_latent: int = 512, latent_bottleneck: Optional[Dict[str, Any]] = None, latent_warp: Optional[Dict[str, Any]] = None, n_ctx: int = 1024, width: int = 512, layers: int = 12, heads: int = 8, init_scale: float = 0.25, latent_scale: float = 1.0, ): super().__init__( device=device, param_shapes=param_shapes, params_proj=params_proj, d_latent=d_latent, latent_bottleneck=latent_bottleneck, latent_warp=latent_warp, ) self.width = width self.device = device self.dtype = dtype self.n_ctx = n_ctx self.backbone = Transformer( device=device, dtype=dtype, n_ctx=n_ctx + self.latent_ctx, width=width, layers=layers, heads=heads, init_scale=init_scale, ) self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) self.register_parameter( "output_tokens", nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)), ) self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype) self.latent_scale = latent_scale @abstractmethod def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: pass def encode_to_channels( self, batch: AttrDict, options: Optional[AttrDict] = None ) -> torch.Tensor: h = self.encode_input(batch, options=options) h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1) h = self.ln_pre(h) h = self.backbone(h) h = h[:, -self.latent_ctx :] h = self.ln_post(h) h = self.output_proj(h) return h class PerceiverChannelsEncoder(ChannelsEncoder, ABC): """ Encode point clouds using a perceiver model with an extra output token used to extract a latent vector. """ def __init__( self, *, device: torch.device, dtype: torch.dtype, param_shapes: Dict[str, Tuple[int]], params_proj: Dict[str, Any], min_unrolls: int, max_unrolls: int, d_latent: int = 512, latent_bottleneck: Optional[Dict[str, Any]] = None, latent_warp: Optional[Dict[str, Any]] = None, width: int = 512, layers: int = 12, xattn_layers: int = 1, heads: int = 8, init_scale: float = 0.25, # Training hparams inner_batch_size: Union[int, List[int]] = 1, data_ctx: int = 1, ): super().__init__( device=device, param_shapes=param_shapes, params_proj=params_proj, d_latent=d_latent, latent_bottleneck=latent_bottleneck, latent_warp=latent_warp, ) self.width = width self.device = device self.dtype = dtype if isinstance(inner_batch_size, int): inner_batch_size = [inner_batch_size] self.inner_batch_size = inner_batch_size self.data_ctx = data_ctx self.min_unrolls = min_unrolls self.max_unrolls = max_unrolls encoder_fn = lambda inner_batch_size: SimplePerceiver( device=device, dtype=dtype, n_ctx=self.data_ctx + self.latent_ctx, n_data=inner_batch_size, width=width, layers=xattn_layers, heads=heads, init_scale=init_scale, ) self.encoder = ( encoder_fn(self.inner_batch_size[0]) if len(self.inner_batch_size) == 1 else nn.ModuleList([encoder_fn(inner_bsz) for inner_bsz in self.inner_batch_size]) ) self.processor = Transformer( device=device, dtype=dtype, n_ctx=self.data_ctx + self.latent_ctx, layers=layers - xattn_layers, width=width, heads=heads, init_scale=init_scale, ) self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) self.register_parameter( "output_tokens", nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)), ) self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype) @abstractmethod def get_h_and_iterator( self, batch: AttrDict, options: Optional[AttrDict] = None ) -> Tuple[torch.Tensor, Iterable[Union[torch.Tensor, Tuple]]]: """ :return: a tuple of ( the initial output tokens of size [batch_size, data_ctx + latent_ctx, width], an iterator over the given data ) """ def encode_to_channels( self, batch: AttrDict, options: Optional[AttrDict] = None ) -> torch.Tensor: h, it = self.get_h_and_iterator(batch, options=options) n_unrolls = self.get_n_unrolls() for _ in range(n_unrolls): data = next(it) if isinstance(data, tuple): for data_i, encoder_i in zip(data, self.encoder): h = encoder_i(h, data_i) else: h = self.encoder(h, data) h = self.processor(h) h = self.output_proj(self.ln_post(h[:, -self.latent_ctx :])) return h def get_n_unrolls(self): if self.training: n_unrolls = torch.randint( self.min_unrolls, self.max_unrolls + 1, size=(), device=self.device ) dist.broadcast(n_unrolls, 0) n_unrolls = n_unrolls.item() else: n_unrolls = self.max_unrolls return n_unrolls @dataclass class DatasetIterator: embs: torch.Tensor # [batch_size, dataset_size, *shape] batch_size: int def __iter__(self): self._reset() return self def __next__(self): _outer_batch_size, dataset_size, *_shape = self.embs.shape while True: start = self.idx self.idx += self.batch_size end = self.idx if end <= dataset_size: break self._reset() return self.embs[:, start:end] def _reset(self): self._shuffle() self.idx = 0 # pylint: disable=attribute-defined-outside-init def _shuffle(self): outer_batch_size, dataset_size, *shape = self.embs.shape idx = torch.stack( [ torch.randperm(dataset_size, device=self.embs.device) for _ in range(outer_batch_size) ], dim=0, ) idx = idx.view(outer_batch_size, dataset_size, *([1] * len(shape))) idx = torch.broadcast_to(idx, self.embs.shape) self.embs = torch.gather(self.embs, 1, idx) class PointCloudTransformerChannelsEncoder(TransformerChannelsEncoder): """ Encode point clouds using a transformer model with an extra output token used to extract a latent vector. """ def __init__( self, *, input_channels: int = 6, **kwargs, ): super().__init__(**kwargs) self.input_channels = input_channels self.input_proj = nn.Linear( input_channels, self.width, device=self.device, dtype=self.dtype ) def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: _ = options points = batch.points h = self.input_proj(points.permute(0, 2, 1)) # NCL -> NLC return h class PointCloudPerceiverChannelsEncoder(PerceiverChannelsEncoder): """ Encode point clouds using a transformer model with an extra output token used to extract a latent vector. """ def __init__( self, *, cross_attention_dataset: str = "pcl", fps_method: str = "fps", # point cloud hyperparameters input_channels: int = 6, pos_emb: Optional[str] = None, # multiview hyperparameters image_size: int = 256, patch_size: int = 32, pose_dropout: float = 0.0, use_depth: bool = False, max_depth: float = 5.0, # point conv hyperparameters pointconv_radius: float = 0.5, pointconv_samples: int = 32, pointconv_hidden: Optional[List[int]] = None, pointconv_patch_size: int = 1, pointconv_stride: int = 1, pointconv_padding_mode: str = "zeros", use_pointconv: bool = False, # other hyperparameters **kwargs, ): super().__init__(**kwargs) assert cross_attention_dataset in ( "pcl", "multiview", "dense_pose_multiview", "multiview_pcl", "pcl_and_multiview_pcl", "incorrect_multiview_pcl", "pcl_and_incorrect_multiview_pcl", ) assert fps_method in ("fps", "first") self.cross_attention_dataset = cross_attention_dataset self.fps_method = fps_method self.input_channels = input_channels self.input_proj = PosEmbLinear( pos_emb, input_channels, self.width, device=self.device, dtype=self.dtype, ) self.use_pointconv = use_pointconv if use_pointconv: if pointconv_hidden is None: pointconv_hidden = [self.width] self.point_conv = PointSetEmbedding( n_point=self.data_ctx, radius=pointconv_radius, n_sample=pointconv_samples, d_input=self.input_proj.weight.shape[0], d_hidden=pointconv_hidden, patch_size=pointconv_patch_size, stride=pointconv_stride, padding_mode=pointconv_padding_mode, fps_method=fps_method, device=self.device, dtype=self.dtype, ) if self.cross_attention_dataset == "multiview": self.image_size = image_size self.patch_size = patch_size self.pose_dropout = pose_dropout self.use_depth = use_depth self.max_depth = max_depth pos_ctx = (image_size // patch_size) ** 2 self.register_parameter( "pos_emb", nn.Parameter( torch.randn( pos_ctx * self.inner_batch_size, self.width, device=self.device, dtype=self.dtype, ) ), ) self.patch_emb = nn.Conv2d( in_channels=3 if not use_depth else 4, out_channels=self.width, kernel_size=patch_size, stride=patch_size, device=self.device, dtype=self.dtype, ) self.camera_emb = nn.Sequential( nn.Linear( 3 * 4 + 1, self.width, device=self.device, dtype=self.dtype ), # input size is for origin+x+y+z+fov nn.GELU(), nn.Linear(self.width, 2 * self.width, device=self.device, dtype=self.dtype), ) elif self.cross_attention_dataset == "dense_pose_multiview": # The number of output features is halved, because a patch_size of # 32 ends up with a large patch_emb weight. self.view_pose_width = self.width // 2 self.image_size = image_size self.patch_size = patch_size self.use_depth = use_depth self.max_depth = max_depth self.mv_pose_embed = MultiviewPoseEmbedding( posemb_version="nerf", n_channels=4 if self.use_depth else 3, out_features=self.view_pose_width, device=self.device, dtype=self.dtype, ) pos_ctx = (image_size // patch_size) ** 2 # Positional embedding is unnecessary because pose information is baked into each pixel self.patch_emb = nn.Conv2d( in_channels=self.view_pose_width, out_channels=self.width, kernel_size=patch_size, stride=patch_size, device=self.device, dtype=self.dtype, ) elif ( self.cross_attention_dataset == "multiview_pcl" or self.cross_attention_dataset == "incorrect_multiview_pcl" ): self.view_pose_width = self.width // 2 self.image_size = image_size self.patch_size = patch_size self.max_depth = max_depth assert use_depth self.mv_pcl_embed = MultiviewPointCloudEmbedding( posemb_version="nerf", n_channels=3, out_features=self.view_pose_width, device=self.device, dtype=self.dtype, ) self.patch_emb = nn.Conv2d( in_channels=self.view_pose_width, out_channels=self.width, kernel_size=patch_size, stride=patch_size, device=self.device, dtype=self.dtype, ) elif ( self.cross_attention_dataset == "pcl_and_multiview_pcl" or self.cross_attention_dataset == "pcl_and_incorrect_multiview_pcl" ): self.view_pose_width = self.width // 2 self.image_size = image_size self.patch_size = patch_size self.max_depth = max_depth assert use_depth self.mv_pcl_embed = MultiviewPointCloudEmbedding( posemb_version="nerf", n_channels=3, out_features=self.view_pose_width, device=self.device, dtype=self.dtype, ) self.patch_emb = nn.Conv2d( in_channels=self.view_pose_width, out_channels=self.width, kernel_size=patch_size, stride=patch_size, device=self.device, dtype=self.dtype, ) def get_h_and_iterator( self, batch: AttrDict, options: Optional[AttrDict] = None ) -> Tuple[torch.Tensor, Iterable]: """ :return: a tuple of ( the initial output tokens of size [batch_size, data_ctx + latent_ctx, width], an iterator over the given data ) """ options = AttrDict() if options is None else options # Build the initial query embeddings points = batch.points.permute(0, 2, 1) # NCL -> NLC if self.use_pointconv: points = self.input_proj(points).permute(0, 2, 1) # NLC -> NCL xyz = batch.points[:, :3] data_tokens = self.point_conv(xyz, points).permute(0, 2, 1) # NCL -> NLC else: fps_samples = self.sample_pcl_fps(points) data_tokens = self.input_proj(fps_samples) batch_size = points.shape[0] latent_tokens = self.output_tokens.unsqueeze(0).repeat(batch_size, 1, 1) h = self.ln_pre(torch.cat([data_tokens, latent_tokens], dim=1)) assert h.shape == (batch_size, self.data_ctx + self.latent_ctx, self.width) # Build the dataset embedding iterator dataset_fn = { "pcl": self.get_pcl_dataset, "multiview": self.get_multiview_dataset, "dense_pose_multiview": self.get_dense_pose_multiview_dataset, "pcl_and_multiview_pcl": self.get_pcl_and_multiview_pcl_dataset, "multiview_pcl": self.get_multiview_pcl_dataset, }[self.cross_attention_dataset] it = dataset_fn(batch, options=options) return h, it def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor: return sample_pcl_fps(points, data_ctx=self.data_ctx, method=self.fps_method) def get_pcl_dataset( self, batch: AttrDict, options: Optional[AttrDict[str, Any]] = None, inner_batch_size: Optional[int] = None, ) -> Iterable: _ = options if inner_batch_size is None: inner_batch_size = self.inner_batch_size[0] points = batch.points.permute(0, 2, 1) # NCL -> NLC dataset_emb = self.input_proj(points) assert dataset_emb.shape[1] >= inner_batch_size return iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size)) def get_multiview_dataset( self, batch: AttrDict, options: Optional[AttrDict] = None, inner_batch_size: Optional[int] = None, ) -> Iterable: _ = options if inner_batch_size is None: inner_batch_size = self.inner_batch_size[0] dataset_emb = self.encode_views(batch) batch_size, num_views, n_patches, width = dataset_emb.shape assert num_views >= inner_batch_size it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size)) def gen(): while True: examples = next(it) assert examples.shape == (batch_size, self.inner_batch_size, n_patches, self.width) views = examples.reshape(batch_size, -1, width) + self.pos_emb yield views return gen() def get_dense_pose_multiview_dataset( self, batch: AttrDict, options: Optional[AttrDict] = None, inner_batch_size: Optional[int] = None, ) -> Iterable: _ = options if inner_batch_size is None: inner_batch_size = self.inner_batch_size[0] dataset_emb = self.encode_dense_pose_views(batch) batch_size, num_views, n_patches, width = dataset_emb.shape assert num_views >= inner_batch_size it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size)) def gen(): while True: examples = next(it) assert examples.shape == (batch_size, inner_batch_size, n_patches, self.width) views = examples.reshape(batch_size, -1, width) yield views return gen() def get_pcl_and_multiview_pcl_dataset( self, batch: AttrDict, options: Optional[AttrDict] = None, use_distance: bool = True, ) -> Iterable: _ = options pcl_it = self.get_pcl_dataset( batch, options=options, inner_batch_size=self.inner_batch_size[0] ) multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance) batch_size, num_views, n_patches, width = multiview_pcl_emb.shape assert num_views >= self.inner_batch_size[1] multiview_pcl_it = iter( DatasetIterator(multiview_pcl_emb, batch_size=self.inner_batch_size[1]) ) def gen(): while True: pcl = next(pcl_it) multiview_pcl = next(multiview_pcl_it) assert multiview_pcl.shape == ( batch_size, self.inner_batch_size[1], n_patches, self.width, ) yield pcl, multiview_pcl.reshape(batch_size, -1, width) return gen() def get_multiview_pcl_dataset( self, batch: AttrDict, options: Optional[AttrDict] = None, inner_batch_size: Optional[int] = None, use_distance: bool = True, ) -> Iterable: _ = options if inner_batch_size is None: inner_batch_size = self.inner_batch_size[0] multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance) batch_size, num_views, n_patches, width = multiview_pcl_emb.shape assert num_views >= inner_batch_size multiview_pcl_it = iter(DatasetIterator(multiview_pcl_emb, batch_size=inner_batch_size)) def gen(): while True: multiview_pcl = next(multiview_pcl_it) assert multiview_pcl.shape == ( batch_size, inner_batch_size, n_patches, self.width, ) yield multiview_pcl.reshape(batch_size, -1, width) return gen() def encode_views(self, batch: AttrDict) -> torch.Tensor: """ :return: [batch_size, num_views, n_patches, width] """ all_views = self.views_to_tensor(batch.views).to(self.device) if self.use_depth: all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2) all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device) batch_size, num_views, _, _, _ = all_views.shape views_proj = self.patch_emb( all_views.reshape([batch_size * num_views, *all_views.shape[2:]]) ) views_proj = ( views_proj.reshape([batch_size, num_views, self.width, -1]) .permute(0, 1, 3, 2) .contiguous() ) # [batch_size x num_views x n_patches x width] # [batch_size, num_views, 1, 2 * width] camera_proj = self.camera_emb(all_cameras).reshape( [batch_size, num_views, 1, self.width * 2] ) pose_dropout = self.pose_dropout if self.training else 0.0 mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj)) scale, shift = camera_proj.chunk(2, dim=3) views_proj = views_proj * (scale + 1.0) + shift return views_proj def encode_dense_pose_views(self, batch: AttrDict) -> torch.Tensor: """ :return: [batch_size, num_views, n_patches, width] """ all_views = self.views_to_tensor(batch.views).to(self.device) if self.use_depth: depths = self.depths_to_tensor(batch.depths) all_views = torch.cat([all_views, depths], dim=2) dense_poses, _ = self.dense_pose_cameras_to_tensor(batch.cameras) dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3) position, direction = dense_poses[:, :, 0], dense_poses[:, :, 1] all_view_poses = self.mv_pose_embed(all_views, position, direction) batch_size, num_views, _, _, _ = all_view_poses.shape views_proj = self.patch_emb( all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]]) ) views_proj = ( views_proj.reshape([batch_size, num_views, self.width, -1]) .permute(0, 1, 3, 2) .contiguous() ) # [batch_size x num_views x n_patches x width] return views_proj def encode_multiview_pcl(self, batch: AttrDict, use_distance: bool = True) -> torch.Tensor: """ :return: [batch_size, num_views, n_patches, width] """ all_views = self.views_to_tensor(batch.views).to(self.device) depths = self.raw_depths_to_tensor(batch.depths) all_view_alphas = self.view_alphas_to_tensor(batch.view_alphas).to(self.device) mask = all_view_alphas >= 0.999 dense_poses, camera_z = self.dense_pose_cameras_to_tensor(batch.cameras) dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3) origin, direction = dense_poses[:, :, 0], dense_poses[:, :, 1] if use_distance: ray_depth_factor = torch.sum(direction * camera_z[..., None, None], dim=2, keepdim=True) depths = depths / ray_depth_factor position = origin + depths * direction all_view_poses = self.mv_pcl_embed(all_views, origin, position, mask) batch_size, num_views, _, _, _ = all_view_poses.shape views_proj = self.patch_emb( all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]]) ) views_proj = ( views_proj.reshape([batch_size, num_views, self.width, -1]) .permute(0, 1, 3, 2) .contiguous() ) # [batch_size x num_views x n_patches x width] return views_proj def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor: """ Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1]. """ if isinstance(views, torch.Tensor): return views tensor_batch = [] num_views = len(views[0]) for inner_list in views: assert len(inner_list) == num_views inner_batch = [] for img in inner_list: img = img.resize((self.image_size,) * 2).convert("RGB") inner_batch.append( torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32) / 127.5 - 1 ) tensor_batch.append(torch.stack(inner_batch, dim=0)) return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3) def depths_to_tensor( self, depths: Union[torch.Tensor, List[List[Image.Image]]] ) -> torch.Tensor: """ Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1]. """ if isinstance(depths, torch.Tensor): return depths tensor_batch = [] num_views = len(depths[0]) for inner_list in depths: assert len(inner_list) == num_views inner_batch = [] for arr in inner_list: tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth tensor = tensor * 2 - 1 tensor = F.interpolate( tensor[None, None], (self.image_size,) * 2, mode="nearest", ) inner_batch.append(tensor.to(device=self.device, dtype=torch.float32)) tensor_batch.append(torch.cat(inner_batch, dim=0)) return torch.stack(tensor_batch, dim=0) def view_alphas_to_tensor( self, view_alphas: Union[torch.Tensor, List[List[Image.Image]]] ) -> torch.Tensor: """ Returns a [batch x num_views x 1 x size x size] tensor in the range [0, 1]. """ if isinstance(view_alphas, torch.Tensor): return view_alphas tensor_batch = [] num_views = len(view_alphas[0]) for inner_list in view_alphas: assert len(inner_list) == num_views inner_batch = [] for img in inner_list: tensor = ( torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32) / 255.0 ) tensor = F.interpolate( tensor[None, None], (self.image_size,) * 2, mode="nearest", ) inner_batch.append(tensor) tensor_batch.append(torch.cat(inner_batch, dim=0)) return torch.stack(tensor_batch, dim=0) def raw_depths_to_tensor( self, depths: Union[torch.Tensor, List[List[Image.Image]]] ) -> torch.Tensor: """ Returns a [batch x num_views x 1 x size x size] tensor """ if isinstance(depths, torch.Tensor): return depths tensor_batch = [] num_views = len(depths[0]) for inner_list in depths: assert len(inner_list) == num_views inner_batch = [] for arr in inner_list: tensor = torch.from_numpy(arr).clamp(max=self.max_depth) tensor = F.interpolate( tensor[None, None], (self.image_size,) * 2, mode="nearest", ) inner_batch.append(tensor.to(device=self.device, dtype=torch.float32)) tensor_batch.append(torch.cat(inner_batch, dim=0)) return torch.stack(tensor_batch, dim=0) def cameras_to_tensor( self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]] ) -> torch.Tensor: """ Returns a [batch x num_views x 3*4+1] tensor of camera information. """ if isinstance(cameras, torch.Tensor): return cameras outer_batch = [] for inner_list in cameras: inner_batch = [] for camera in inner_list: inner_batch.append( np.array( [ *camera.x, *camera.y, *camera.z, *camera.origin, camera.x_fov, ] ) ) outer_batch.append(np.stack(inner_batch, axis=0)) return torch.from_numpy(np.stack(outer_batch, axis=0)).float() def dense_pose_cameras_to_tensor( self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns a tuple of (rays, z_directions) where - rays: [batch, num_views, height, width, 2, 3] tensor of camera information. - z_directions: [batch, num_views, 3] tensor of camera z directions. """ if isinstance(cameras, torch.Tensor): raise NotImplementedError for inner_list in cameras: assert len(inner_list) == len(cameras[0]) camera = cameras[0][0] flat_camera = DifferentiableProjectiveCamera( origin=torch.from_numpy( np.stack( [cam.origin for inner_list in cameras for cam in inner_list], axis=0, ) ).to(self.device), x=torch.from_numpy( np.stack( [cam.x for inner_list in cameras for cam in inner_list], axis=0, ) ).to(self.device), y=torch.from_numpy( np.stack( [cam.y for inner_list in cameras for cam in inner_list], axis=0, ) ).to(self.device), z=torch.from_numpy( np.stack( [cam.z for inner_list in cameras for cam in inner_list], axis=0, ) ).to(self.device), width=camera.width, height=camera.height, x_fov=camera.x_fov, y_fov=camera.y_fov, ) batch_size = len(cameras) * len(cameras[0]) coords = ( flat_camera.image_coords() .to(flat_camera.origin.device) .unsqueeze(0) .repeat(batch_size, 1, 1) ) rays = flat_camera.camera_rays(coords) return ( rays.view(len(cameras), len(cameras[0]), camera.height, camera.width, 2, 3).to( self.device ), flat_camera.z.view(len(cameras), len(cameras[0]), 3).to(self.device), ) def sample_pcl_fps(points: torch.Tensor, data_ctx: int, method: str = "fps") -> torch.Tensor: """ Run farthest-point sampling on a batch of point clouds. :param points: batch of shape [N x num_points]. :param data_ctx: subsample count. :param method: either 'fps' or 'first'. Using 'first' assumes that the points are already sorted according to FPS sampling. :return: batch of shape [N x min(num_points, data_ctx)]. """ n_points = points.shape[1] if n_points == data_ctx: return points if method == "first": return points[:, :data_ctx] elif method == "fps": batch = points.cpu().split(1, dim=0) fps = [sample_fps(x, n_samples=data_ctx) for x in batch] return torch.cat(fps, dim=0).to(points.device) else: raise ValueError(f"unsupported farthest-point sampling method: {method}") def sample_fps(example: torch.Tensor, n_samples: int) -> torch.Tensor: """ :param example: [1, n_points, 3 + n_channels] :return: [1, n_samples, 3 + n_channels] """ points = example.cpu().squeeze(0).numpy() coords, raw_channels = points[:, :3], points[:, 3:] n_points, n_channels = raw_channels.shape assert n_samples <= n_points channels = {str(idx): raw_channels[:, idx] for idx in range(n_channels)} max_points = min(32768, n_points) fps_pcl = ( PointCloud(coords=coords, channels=channels) .random_sample(max_points) .farthest_point_sample(n_samples) ) fps_channels = np.stack([fps_pcl.channels[str(idx)] for idx in range(n_channels)], axis=1) fps = np.concatenate([fps_pcl.coords, fps_channels], axis=1) fps = torch.from_numpy(fps).unsqueeze(0) assert fps.shape == (1, n_samples, 3 + n_channels) return fps ================================================ FILE: shap_e/models/transmitter/multiview_encoder.py ================================================ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from shap_e.models.generation.transformer import Transformer from shap_e.rendering.view_data import ProjectiveCamera from shap_e.util.collections import AttrDict from .base import VectorEncoder class MultiviewTransformerEncoder(VectorEncoder): """ Encode cameras and views using a transformer model with extra output token(s) used to extract a latent vector. """ def __init__( self, *, device: torch.device, dtype: torch.dtype, param_shapes: Dict[str, Tuple[int]], params_proj: Dict[str, Any], latent_bottleneck: Optional[Dict[str, Any]] = None, d_latent: int = 512, latent_ctx: int = 1, num_views: int = 20, image_size: int = 256, patch_size: int = 32, use_depth: bool = False, max_depth: float = 5.0, width: int = 512, layers: int = 12, heads: int = 8, init_scale: float = 0.25, pos_emb_init_scale: float = 1.0, ): super().__init__( device=device, param_shapes=param_shapes, params_proj=params_proj, latent_bottleneck=latent_bottleneck, d_latent=d_latent, ) self.num_views = num_views self.image_size = image_size self.patch_size = patch_size self.use_depth = use_depth self.max_depth = max_depth self.n_ctx = num_views * (1 + (image_size // patch_size) ** 2) self.latent_ctx = latent_ctx self.width = width assert d_latent % latent_ctx == 0 self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) self.backbone = Transformer( device=device, dtype=dtype, n_ctx=self.n_ctx + latent_ctx, width=width, layers=layers, heads=heads, init_scale=init_scale, ) self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) self.register_parameter( "output_tokens", nn.Parameter(torch.randn(latent_ctx, width, device=device, dtype=dtype)), ) self.register_parameter( "pos_emb", nn.Parameter( pos_emb_init_scale * torch.randn(self.n_ctx, width, device=device, dtype=dtype) ), ) self.patch_emb = nn.Conv2d( in_channels=3 if not use_depth else 4, out_channels=width, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype, ) self.camera_emb = nn.Sequential( nn.Linear( 3 * 4 + 1, width, device=device, dtype=dtype ), # input size is for origin+x+y+z+fov nn.GELU(), nn.Linear(width, width, device=device, dtype=dtype), ) self.output_proj = nn.Linear(width, d_latent // latent_ctx, device=device, dtype=dtype) def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: _ = options all_views = self.views_to_tensor(batch.views).to(self.device) if self.use_depth: all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2) all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device) batch_size, num_views, _, _, _ = all_views.shape views_proj = self.patch_emb( all_views.reshape([batch_size * num_views, *all_views.shape[2:]]) ) views_proj = ( views_proj.reshape([batch_size, num_views, self.width, -1]) .permute(0, 1, 3, 2) .contiguous() ) # [batch_size x num_views x n_patches x width] cameras_proj = self.camera_emb(all_cameras).reshape([batch_size, num_views, 1, self.width]) h = torch.cat([views_proj, cameras_proj], dim=2).reshape([batch_size, -1, self.width]) h = h + self.pos_emb h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1) h = self.ln_pre(h) h = self.backbone(h) h = self.ln_post(h) h = h[:, self.n_ctx :] h = self.output_proj(h).flatten(1) return h def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor: """ Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1]. """ if isinstance(views, torch.Tensor): return views tensor_batch = [] for inner_list in views: assert len(inner_list) == self.num_views inner_batch = [] for img in inner_list: img = img.resize((self.image_size,) * 2).convert("RGB") inner_batch.append( torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32) / 127.5 - 1 ) tensor_batch.append(torch.stack(inner_batch, dim=0)) return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3) def depths_to_tensor( self, depths: Union[torch.Tensor, List[List[Image.Image]]] ) -> torch.Tensor: """ Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1]. """ if isinstance(depths, torch.Tensor): return depths tensor_batch = [] for inner_list in depths: assert len(inner_list) == self.num_views inner_batch = [] for arr in inner_list: tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth tensor = tensor * 2 - 1 tensor = F.interpolate( tensor[None, None], (self.image_size,) * 2, mode="nearest", ) inner_batch.append(tensor.to(device=self.device, dtype=torch.float32)) tensor_batch.append(torch.cat(inner_batch, dim=0)) return torch.stack(tensor_batch, dim=0) def cameras_to_tensor( self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]] ) -> torch.Tensor: """ Returns a [batch x num_views x 3*4+1] tensor of camera information. """ if isinstance(cameras, torch.Tensor): return cameras outer_batch = [] for inner_list in cameras: inner_batch = [] for camera in inner_list: inner_batch.append( np.array( [ *camera.x, *camera.y, *camera.z, *camera.origin, camera.x_fov, ] ) ) outer_batch.append(np.stack(inner_batch, axis=0)) return torch.from_numpy(np.stack(outer_batch, axis=0)).float() ================================================ FILE: shap_e/models/transmitter/params_proj.py ================================================ import math from abc import ABC, abstractmethod from collections import OrderedDict from typing import Any, Dict, Optional, Tuple import numpy as np import torch.nn as nn from torch import torch from shap_e.util.collections import AttrDict def flatten_param_shapes(param_shapes: Dict[str, Tuple[int]]): flat_shapes = OrderedDict( (name, (int(np.prod(shape)) // shape[-1], shape[-1])) for name, shape in param_shapes.items() ) return flat_shapes class ParamsProj(nn.Module, ABC): def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int): super().__init__() self.device = device self.param_shapes = param_shapes self.d_latent = d_latent @abstractmethod def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: pass class LinearParamsProj(ParamsProj): def __init__( self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int, init_scale: Optional[float] = None, ): super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent) self.param_shapes = param_shapes self.projections = nn.ModuleDict({}) for k, v in param_shapes.items(): self.projections[_sanitize_name(k)] = nn.Linear( d_latent, int(np.prod(v)), device=device ) if init_scale is not None: scale = init_scale / math.sqrt(d_latent) mod = self.projections[_sanitize_name(k)] nn.init.normal_(mod.weight, std=scale) nn.init.zeros_(mod.bias) def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: out = AttrDict() for k in self.param_shapes.keys(): proj = self.projections[_sanitize_name(k)] out[k] = proj(x).reshape([len(x), *self.param_shapes[k]]) return out class MLPParamsProj(ParamsProj): def __init__( self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int, hidden_size: Optional[int] = None, ): super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent) if hidden_size is None: hidden_size = d_latent self.param_shapes = param_shapes self.projections = nn.ModuleDict({}) for k, v in param_shapes.items(): self.projections[_sanitize_name(k)] = nn.Sequential( nn.Linear(d_latent, hidden_size, device=device), nn.GELU(), nn.Linear(hidden_size, int(np.prod(v)), device=device), ) def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: out = AttrDict() for k in self.param_shapes.keys(): proj = self.projections[_sanitize_name(k)] out[k] = proj(x).reshape([len(x), *self.param_shapes[k]]) return out class ChannelsProj(nn.Module): def __init__( self, *, device: torch.device, vectors: int, channels: int, d_latent: int, init_scale: float = 1.0, learned_scale: Optional[float] = None, use_ln: bool = False, ): super().__init__() self.proj = nn.Linear(d_latent, vectors * channels, device=device) self.use_ln = use_ln self.learned_scale = learned_scale if use_ln: self.norm = nn.LayerNorm(normalized_shape=(channels,), device=device) if learned_scale is not None: self.norm.weight.data.fill_(learned_scale) scale = init_scale / math.sqrt(d_latent) elif learned_scale is not None: gain = torch.ones((channels,), device=device) * learned_scale self.register_parameter("gain", nn.Parameter(gain)) scale = init_scale / math.sqrt(d_latent) else: scale = init_scale / math.sqrt(d_latent * channels) nn.init.normal_(self.proj.weight, std=scale) nn.init.zeros_(self.proj.bias) self.d_latent = d_latent self.vectors = vectors self.channels = channels def forward(self, x: torch.Tensor) -> torch.Tensor: x_bvd = x w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) b_vc = self.proj.bias.view(1, self.vectors, self.channels) h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) if self.use_ln: h = self.norm(h) elif self.learned_scale is not None: h = h * self.gain.view(1, 1, -1) h = h + b_vc return h class ChannelsParamsProj(ParamsProj): def __init__( self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int, init_scale: float = 1.0, learned_scale: Optional[float] = None, use_ln: bool = False, ): super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent) self.param_shapes = param_shapes self.projections = nn.ModuleDict({}) self.flat_shapes = flatten_param_shapes(param_shapes) self.learned_scale = learned_scale self.use_ln = use_ln for k, (vectors, channels) in self.flat_shapes.items(): self.projections[_sanitize_name(k)] = ChannelsProj( device=device, vectors=vectors, channels=channels, d_latent=d_latent, init_scale=init_scale, learned_scale=learned_scale, use_ln=use_ln, ) def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: out = AttrDict() start = 0 for k, shape in self.param_shapes.items(): vectors, _ = self.flat_shapes[k] end = start + vectors x_bvd = x[:, start:end] out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) start = end return out def params_proj_from_config( config: Dict[str, Any], device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int ): name = config.pop("name") if name == "linear": return LinearParamsProj( **config, device=device, param_shapes=param_shapes, d_latent=d_latent ) elif name == "mlp": return MLPParamsProj(**config, device=device, param_shapes=param_shapes, d_latent=d_latent) elif name == "channels": return ChannelsParamsProj( **config, device=device, param_shapes=param_shapes, d_latent=d_latent ) else: raise ValueError(f"unknown params proj: {name}") def _sanitize_name(x: str) -> str: return x.replace(".", "__") ================================================ FILE: shap_e/models/transmitter/pc_encoder.py ================================================ from abc import abstractmethod from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from PIL import Image from torch import torch from shap_e.models.generation.perceiver import SimplePerceiver from shap_e.models.generation.transformer import Transformer from shap_e.models.nn.encoding import PosEmbLinear from shap_e.rendering.view_data import ProjectiveCamera from shap_e.util.collections import AttrDict from .base import VectorEncoder from .channels_encoder import DatasetIterator, sample_pcl_fps class PointCloudTransformerEncoder(VectorEncoder): """ Encode point clouds using a transformer model with an extra output token used to extract a latent vector. """ def __init__( self, *, device: torch.device, dtype: torch.dtype, param_shapes: Dict[str, Tuple[int]], params_proj: Dict[str, Any], latent_bottleneck: Optional[Dict[str, Any]] = None, d_latent: int = 512, latent_ctx: int = 1, input_channels: int = 6, n_ctx: int = 1024, width: int = 512, layers: int = 12, heads: int = 8, init_scale: float = 0.25, pos_emb: Optional[str] = None, ): super().__init__( device=device, param_shapes=param_shapes, params_proj=params_proj, latent_bottleneck=latent_bottleneck, d_latent=d_latent, ) self.input_channels = input_channels self.n_ctx = n_ctx self.latent_ctx = latent_ctx assert d_latent % latent_ctx == 0 self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) self.backbone = Transformer( device=device, dtype=dtype, n_ctx=n_ctx + latent_ctx, width=width, layers=layers, heads=heads, init_scale=init_scale, ) self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) self.register_parameter( "output_tokens", nn.Parameter(torch.randn(latent_ctx, width, device=device, dtype=dtype)), ) self.input_proj = PosEmbLinear(pos_emb, input_channels, width, device=device, dtype=dtype) self.output_proj = nn.Linear(width, d_latent // latent_ctx, device=device, dtype=dtype) def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: _ = options points = batch.points.permute(0, 2, 1) # NCL -> NLC h = self.input_proj(points) h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1) h = self.ln_pre(h) h = self.backbone(h) h = self.ln_post(h) h = h[:, self.n_ctx :] h = self.output_proj(h).flatten(1) return h class PerceiverEncoder(VectorEncoder): """ Encode point clouds using a perceiver model with an extra output token used to extract a latent vector. """ def __init__( self, *, device: torch.device, dtype: torch.dtype, param_shapes: Dict[str, Tuple[int]], params_proj: Dict[str, Any], latent_bottleneck: Optional[Dict[str, Any]] = None, d_latent: int = 512, latent_ctx: int = 1, width: int = 512, layers: int = 12, xattn_layers: int = 1, heads: int = 8, init_scale: float = 0.25, # Training hparams inner_batch_size: int = 1, data_ctx: int = 1, min_unrolls: int, max_unrolls: int, ): super().__init__( device=device, param_shapes=param_shapes, params_proj=params_proj, latent_bottleneck=latent_bottleneck, d_latent=d_latent, ) self.width = width self.device = device self.dtype = dtype self.latent_ctx = latent_ctx self.inner_batch_size = inner_batch_size self.data_ctx = data_ctx self.min_unrolls = min_unrolls self.max_unrolls = max_unrolls self.encoder = SimplePerceiver( device=device, dtype=dtype, n_ctx=self.data_ctx + self.latent_ctx, n_data=self.inner_batch_size, width=width, layers=xattn_layers, heads=heads, init_scale=init_scale, ) self.processor = Transformer( device=device, dtype=dtype, n_ctx=self.data_ctx + self.latent_ctx, layers=layers - xattn_layers, width=width, heads=heads, init_scale=init_scale, ) self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) self.register_parameter( "output_tokens", nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)), ) self.output_proj = nn.Linear(width, d_latent // self.latent_ctx, device=device, dtype=dtype) @abstractmethod def get_h_and_iterator( self, batch: AttrDict, options: Optional[AttrDict] = None ) -> Tuple[torch.Tensor, Iterable]: """ :return: a tuple of ( the initial output tokens of size [batch_size, data_ctx + latent_ctx, width], an iterator over the given data ) """ def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: h, it = self.get_h_and_iterator(batch, options=options) n_unrolls = self.get_n_unrolls() for _ in range(n_unrolls): data = next(it) h = self.encoder(h, data) h = self.processor(h) h = self.output_proj(self.ln_post(h[:, -self.latent_ctx :])) return h.flatten(1) def get_n_unrolls(self): if self.training: n_unrolls = torch.randint( self.min_unrolls, self.max_unrolls + 1, size=(), device=self.device ) dist.broadcast(n_unrolls, 0) n_unrolls = n_unrolls.item() else: n_unrolls = self.max_unrolls return n_unrolls class PointCloudPerceiverEncoder(PerceiverEncoder): """ Encode point clouds using a transformer model with an extra output token used to extract a latent vector. """ def __init__( self, *, cross_attention_dataset: str = "pcl", fps_method: str = "fps", # point cloud hyperparameters input_channels: int = 6, pos_emb: Optional[str] = None, # multiview hyperparameters image_size: int = 256, patch_size: int = 32, pose_dropout: float = 0.0, use_depth: bool = False, max_depth: float = 5.0, # other hyperparameters **kwargs, ): super().__init__(**kwargs) assert cross_attention_dataset in ("pcl", "multiview") assert fps_method in ("fps", "first") self.cross_attention_dataset = cross_attention_dataset self.fps_method = fps_method self.input_channels = input_channels self.input_proj = PosEmbLinear( pos_emb, input_channels, self.width, device=self.device, dtype=self.dtype ) if self.cross_attention_dataset == "multiview": self.image_size = image_size self.patch_size = patch_size self.pose_dropout = pose_dropout self.use_depth = use_depth self.max_depth = max_depth pos_ctx = (image_size // patch_size) ** 2 self.register_parameter( "pos_emb", nn.Parameter( torch.randn( pos_ctx * self.inner_batch_size, self.width, device=self.device, dtype=self.dtype, ) ), ) self.patch_emb = nn.Conv2d( in_channels=3 if not use_depth else 4, out_channels=self.width, kernel_size=patch_size, stride=patch_size, device=self.device, dtype=self.dtype, ) self.camera_emb = nn.Sequential( nn.Linear( 3 * 4 + 1, self.width, device=self.device, dtype=self.dtype ), # input size is for origin+x+y+z+fov nn.GELU(), nn.Linear(self.width, 2 * self.width, device=self.device, dtype=self.dtype), ) def get_h_and_iterator( self, batch: AttrDict, options: Optional[AttrDict] = None ) -> Tuple[torch.Tensor, Iterable]: """ :return: a tuple of ( the initial output tokens of size [batch_size, data_ctx + latent_ctx, width], an iterator over the given data ) """ options = AttrDict() if options is None else options # Build the initial query embeddings points = batch.points.permute(0, 2, 1) # NCL -> NLC fps_samples = self.sample_pcl_fps(points) batch_size = points.shape[0] data_tokens = self.input_proj(fps_samples) latent_tokens = self.output_tokens.unsqueeze(0).repeat(batch_size, 1, 1) h = self.ln_pre(torch.cat([data_tokens, latent_tokens], dim=1)) assert h.shape == (batch_size, self.data_ctx + self.latent_ctx, self.width) # Build the dataset embedding iterator dataset_fn = { "pcl": self.get_pcl_dataset, "multiview": self.get_multiview_dataset, }[self.cross_attention_dataset] it = dataset_fn(batch, options=options) return h, it def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor: return sample_pcl_fps(points, data_ctx=self.data_ctx, method=self.fps_method) def get_pcl_dataset( self, batch: AttrDict, options: Optional[AttrDict[str, Any]] = None ) -> Iterable: _ = options dataset_emb = self.input_proj(batch.points.permute(0, 2, 1)) # NCL -> NLC assert dataset_emb.shape[1] >= self.inner_batch_size return iter(DatasetIterator(dataset_emb, batch_size=self.inner_batch_size)) def get_multiview_dataset( self, batch: AttrDict, options: Optional[AttrDict] = None ) -> Iterable: _ = options dataset_emb = self.encode_views(batch) batch_size, num_views, n_patches, width = dataset_emb.shape assert num_views >= self.inner_batch_size it = iter(DatasetIterator(dataset_emb, batch_size=self.inner_batch_size)) def gen(): while True: examples = next(it) assert examples.shape == (batch_size, self.inner_batch_size, n_patches, self.width) views = examples.reshape(batch_size, -1, width) + self.pos_emb yield views return gen() def encode_views(self, batch: AttrDict) -> torch.Tensor: """ :return: [batch_size, num_views, n_patches, width] """ all_views = self.views_to_tensor(batch.views).to(self.device) if self.use_depth: all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2) all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device) batch_size, num_views, _, _, _ = all_views.shape views_proj = self.patch_emb( all_views.reshape([batch_size * num_views, *all_views.shape[2:]]) ) views_proj = ( views_proj.reshape([batch_size, num_views, self.width, -1]) .permute(0, 1, 3, 2) .contiguous() ) # [batch_size x num_views x n_patches x width] # [batch_size, num_views, 1, 2 * width] camera_proj = self.camera_emb(all_cameras).reshape( [batch_size, num_views, 1, self.width * 2] ) pose_dropout = self.pose_dropout if self.training else 0.0 mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj)) scale, shift = camera_proj.chunk(2, dim=3) views_proj = views_proj * (scale + 1.0) + shift return views_proj def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor: """ Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1]. """ if isinstance(views, torch.Tensor): return views tensor_batch = [] num_views = len(views[0]) for inner_list in views: assert len(inner_list) == num_views inner_batch = [] for img in inner_list: img = img.resize((self.image_size,) * 2).convert("RGB") inner_batch.append( torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32) / 127.5 - 1 ) tensor_batch.append(torch.stack(inner_batch, dim=0)) return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3) def depths_to_tensor( self, depths: Union[torch.Tensor, List[List[Image.Image]]] ) -> torch.Tensor: """ Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1]. """ if isinstance(depths, torch.Tensor): return depths tensor_batch = [] num_views = len(depths[0]) for inner_list in depths: assert len(inner_list) == num_views inner_batch = [] for arr in inner_list: tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth tensor = tensor * 2 - 1 tensor = F.interpolate( tensor[None, None], (self.image_size,) * 2, mode="nearest", ) inner_batch.append(tensor.to(device=self.device, dtype=torch.float32)) tensor_batch.append(torch.cat(inner_batch, dim=0)) return torch.stack(tensor_batch, dim=0) def cameras_to_tensor( self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]] ) -> torch.Tensor: """ Returns a [batch x num_views x 3*4+1] tensor of camera information. """ if isinstance(cameras, torch.Tensor): return cameras outer_batch = [] for inner_list in cameras: inner_batch = [] for camera in inner_list: inner_batch.append( np.array( [ *camera.x, *camera.y, *camera.z, *camera.origin, camera.x_fov, ] ) ) outer_batch.append(np.stack(inner_batch, axis=0)) return torch.from_numpy(np.stack(outer_batch, axis=0)).float() ================================================ FILE: shap_e/models/volume.py ================================================ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Dict, Optional, Tuple import torch from shap_e.models.nn.meta import MetaModule from shap_e.models.nn.utils import ArrayType, safe_divide, to_torch @dataclass class VolumeRange: t0: torch.Tensor t1: torch.Tensor intersected: torch.Tensor def __post_init__(self): assert self.t0.shape == self.t1.shape == self.intersected.shape def next_t0(self): """ Given convex volume1 and volume2, where volume1 is contained in volume2, this function returns the t0 at which rays leave volume1 and intersect with volume2 \\ volume1. """ return self.t1 * self.intersected.float() def extend(self, another: "VolumeRange") -> "VolumeRange": """ The ranges at which rays intersect with either one, or both, or none of the self and another are merged together. """ return VolumeRange( t0=torch.where(self.intersected, self.t0, another.t0), t1=torch.where(another.intersected, another.t1, self.t1), intersected=torch.logical_or(self.intersected, another.intersected), ) def partition(self, ts) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Partitions t0 and t1 into n_samples intervals. :param ts: [batch_size, *shape, n_samples, 1] :return: a tuple of ( lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size, *shape, n_samples, 1] ) where ts \\in [lower, upper] deltas = upper - lower """ mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 lower = torch.cat([self.t0[..., None, :], mids], dim=-2) upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) delta = upper - lower assert lower.shape == upper.shape == delta.shape == ts.shape return lower, upper, delta class Volume(ABC): """ An abstraction of rendering volume. """ @abstractmethod def intersect( self, origin: torch.Tensor, direction: torch.Tensor, t0_lower: Optional[torch.Tensor] = None, params: Optional[Dict] = None, epsilon: float = 1e-6, ) -> VolumeRange: """ :param origin: [batch_size, *shape, 3] :param direction: [batch_size, *shape, 3] :param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. :param params: Optional meta parameters in case Volume is parametric :param epsilon: to stabilize calculations :return: A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to be on the boundary of the volume. """ class BoundingBoxVolume(MetaModule, Volume): """ Axis-aligned bounding box defined by the two opposite corners. """ def __init__( self, *, bbox_min: ArrayType, bbox_max: ArrayType, min_dist: float = 0.0, min_t_range: float = 1e-3, device: torch.device = torch.device("cuda"), ): """ :param bbox_min: the left/bottommost corner of the bounding box :param bbox_max: the other corner of the bounding box :param min_dist: all rays should start at least this distance away from the origin. """ super().__init__() self.bbox_min = to_torch(bbox_min).to(device) self.bbox_max = to_torch(bbox_max).to(device) self.min_dist = min_dist self.min_t_range = min_t_range self.bbox = torch.stack([self.bbox_min, self.bbox_max]) assert self.bbox.shape == (2, 3) assert self.min_dist >= 0.0 assert self.min_t_range > 0.0 self.device = device def intersect( self, origin: torch.Tensor, direction: torch.Tensor, t0_lower: Optional[torch.Tensor] = None, params: Optional[Dict] = None, epsilon=1e-6, ) -> VolumeRange: """ :param origin: [batch_size, *shape, 3] :param direction: [batch_size, *shape, 3] :param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. :param params: Optional meta parameters in case Volume is parametric :param epsilon: to stabilize calculations :return: A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to be on the boundary of the volume. """ batch_size, *shape, _ = origin.shape ones = [1] * len(shape) bbox = self.bbox.view(1, *ones, 2, 3) ts = safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) # Cases to think about: # # 1. t1 <= t0: the ray does not pass through the AABB. # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. # 3. t0 <= 0 <= t1: the ray starts from inside the BB # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. # # 1 and 4 are clearly handled from t0 < t1 below. # Making t0 at least min_dist (>= 0) takes care of 2 and 3. t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values assert t0.shape == t1.shape == (batch_size, *shape, 1) if t0_lower is not None: assert t0.shape == t0_lower.shape t0 = torch.maximum(t0, t0_lower) intersected = t0 + self.min_t_range < t1 t0 = torch.where(intersected, t0, torch.zeros_like(t0)) t1 = torch.where(intersected, t1, torch.ones_like(t1)) return VolumeRange(t0=t0, t1=t1, intersected=intersected) class UnboundedVolume(MetaModule, Volume): """ Originally used in NeRF. Unbounded volume but with a limited visibility when rendering (e.g. objects that are farther away than the max_dist from the ray origin are not considered) """ def __init__( self, *, max_dist: float, min_dist: float = 0.0, min_t_range: float = 1e-3, device: torch.device = torch.device("cuda"), ): super().__init__() self.max_dist = max_dist self.min_dist = min_dist self.min_t_range = min_t_range assert self.min_dist >= 0.0 assert self.min_t_range > 0.0 self.device = device def intersect( self, origin: torch.Tensor, direction: torch.Tensor, t0_lower: Optional[torch.Tensor] = None, params: Optional[Dict] = None, ) -> VolumeRange: """ :param origin: [batch_size, *shape, 3] :param direction: [batch_size, *shape, 3] :param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. :param params: Optional meta parameters in case Volume is parametric :param epsilon: to stabilize calculations :return: A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to be on the boundary of the volume. """ batch_size, *shape, _ = origin.shape t0 = torch.zeros(batch_size, *shape, 1, dtype=origin.dtype, device=origin.device) if t0_lower is not None: t0 = torch.maximum(t0, t0_lower) t1 = t0 + self.max_dist t0 = t0.clamp(self.min_dist) return VolumeRange(t0=t0, t1=t1, intersected=t0 + self.min_t_range < t1) class SphericalVolume(MetaModule, Volume): """ Used in NeRF++ but will not be used probably unless we want to reproduce their results. """ def __init__( self, *, radius: float, center: ArrayType = (0.0, 0.0, 0.0), min_dist: float = 0.0, min_t_range: float = 1e-3, device: torch.device = torch.device("cuda"), ): super().__init__() self.radius = radius self.center = to_torch(center).to(device) self.min_dist = min_dist self.min_t_range = min_t_range assert self.min_dist >= 0.0 assert self.min_t_range > 0.0 self.device = device def intersect( self, origin: torch.Tensor, direction: torch.Tensor, t0_lower: Optional[torch.Tensor] = None, params: Optional[Dict] = None, epsilon=1e-6, ) -> VolumeRange: raise NotImplementedError ================================================ FILE: shap_e/rendering/__init__.py ================================================ ================================================ FILE: shap_e/rendering/_mc_table.py ================================================ # Treat a cube as a bitmap, and create the index into this array in order of # ZYX (note Z is the most significant digit). # The resulting object is an array of triangles, where each triangle is 6 # indices. Each consecutive pair of indices within this triangle represents an # edge spanning two corners (identified by the indices). # # The corners of a cube are indexed as follows # # (0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0), # (0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1) # # Here is a visualization of the cube indices: # # 6 + -----------------------+ 7 # /| /| # / | / | # / | / | # 4 +------------------------+ 5 | # | | | | # | | | | # | | | | # | | 2 | | 3 # | +--------------------|---+ # | / | / # | / | / # |/ |/ # +------------------------+ # 0 1 # # Derived using model3d, in particular this function: # https://github.com/unixpickle/model3d/blob/7a3adb982c154c80c1a22032b5a0695160a7f96d/model3d/mc.go#L434 # MC_TABLE = [ [], [[0, 1, 0, 2, 0, 4]], [[1, 0, 1, 5, 1, 3]], [[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2]], [[2, 0, 2, 3, 2, 6]], [[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4]], [[1, 0, 1, 5, 1, 3], [2, 6, 0, 2, 3, 2]], [[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4]], [[3, 1, 3, 7, 3, 2]], [[0, 2, 0, 4, 0, 1], [3, 7, 2, 3, 1, 3]], [[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0]], [[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5]], [[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6]], [[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6]], [[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7]], [[0, 4, 1, 5, 3, 7], [0, 4, 3, 7, 2, 6]], [[4, 0, 4, 6, 4, 5]], [[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1]], [[1, 5, 1, 3, 1, 0], [4, 6, 5, 4, 0, 4]], [[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2]], [[2, 0, 2, 3, 2, 6], [4, 5, 0, 4, 6, 4]], [[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1]], [[2, 6, 2, 0, 3, 2], [1, 0, 1, 5, 3, 1], [6, 4, 5, 4, 0, 4]], [[1, 3, 5, 4, 1, 5], [1, 3, 4, 6, 5, 4], [1, 3, 3, 2, 4, 6], [3, 2, 2, 6, 4, 6]], [[3, 1, 3, 7, 3, 2], [6, 4, 5, 4, 0, 4]], [[4, 5, 0, 1, 4, 6], [0, 1, 0, 2, 4, 6], [7, 3, 2, 3, 1, 3]], [[3, 2, 1, 0, 3, 7], [1, 0, 1, 5, 3, 7], [6, 4, 5, 4, 0, 4]], [[3, 7, 3, 2, 1, 5], [3, 2, 6, 4, 1, 5], [1, 5, 6, 4, 5, 4], [3, 2, 2, 0, 6, 4]], [[3, 7, 2, 6, 3, 1], [2, 6, 2, 0, 3, 1], [5, 4, 0, 4, 6, 4]], [[1, 0, 1, 3, 5, 4], [1, 3, 2, 6, 5, 4], [1, 3, 3, 7, 2, 6], [5, 4, 2, 6, 4, 6]], [[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7], [4, 5, 0, 4, 4, 6]], [[6, 2, 4, 6, 4, 5], [4, 5, 5, 1, 6, 2], [6, 2, 5, 1, 7, 3]], [[5, 1, 5, 4, 5, 7]], [[0, 1, 0, 2, 0, 4], [5, 7, 1, 5, 4, 5]], [[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3]], [[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3]], [[2, 0, 2, 3, 2, 6], [7, 5, 1, 5, 4, 5]], [[2, 6, 0, 4, 2, 3], [0, 4, 0, 1, 2, 3], [7, 5, 1, 5, 4, 5]], [[5, 7, 1, 3, 5, 4], [1, 3, 1, 0, 5, 4], [6, 2, 0, 2, 3, 2]], [[3, 1, 3, 2, 7, 5], [3, 2, 0, 4, 7, 5], [3, 2, 2, 6, 0, 4], [7, 5, 0, 4, 5, 4]], [[3, 7, 3, 2, 3, 1], [5, 4, 7, 5, 1, 5]], [[0, 4, 0, 1, 2, 0], [3, 1, 3, 7, 2, 3], [4, 5, 7, 5, 1, 5]], [[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0]], [[0, 4, 2, 3, 0, 2], [0, 4, 3, 7, 2, 3], [0, 4, 4, 5, 3, 7], [4, 5, 5, 7, 3, 7]], [[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6], [4, 5, 7, 5, 1, 5]], [[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6], [5, 7, 1, 5, 5, 4]], [[2, 6, 2, 0, 3, 7], [2, 0, 4, 5, 3, 7], [3, 7, 4, 5, 7, 5], [2, 0, 0, 1, 4, 5]], [[4, 0, 5, 4, 5, 7], [5, 7, 7, 3, 4, 0], [4, 0, 7, 3, 6, 2]], [[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0]], [[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6]], [[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7]], [[0, 2, 4, 6, 5, 7], [0, 2, 5, 7, 1, 3]], [[5, 1, 4, 0, 5, 7], [4, 0, 4, 6, 5, 7], [3, 2, 6, 2, 0, 2]], [[2, 3, 2, 6, 0, 1], [2, 6, 7, 5, 0, 1], [0, 1, 7, 5, 1, 5], [2, 6, 6, 4, 7, 5]], [[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7], [2, 6, 0, 2, 2, 3]], [[3, 1, 2, 3, 2, 6], [2, 6, 6, 4, 3, 1], [3, 1, 6, 4, 7, 5]], [[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0], [2, 3, 1, 3, 7, 3]], [[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6], [3, 2, 1, 3, 3, 7]], [[0, 1, 0, 4, 2, 3], [0, 4, 5, 7, 2, 3], [0, 4, 4, 6, 5, 7], [2, 3, 5, 7, 3, 7]], [[7, 5, 3, 7, 3, 2], [3, 2, 2, 0, 7, 5], [7, 5, 2, 0, 6, 4]], [[0, 4, 4, 6, 5, 7], [0, 4, 5, 7, 1, 5], [0, 2, 1, 3, 3, 7], [3, 7, 2, 6, 0, 2]], [ [3, 1, 7, 3, 6, 2], [6, 2, 0, 1, 3, 1], [6, 4, 0, 1, 6, 2], [6, 4, 5, 1, 0, 1], [6, 4, 7, 5, 5, 1], ], [ [4, 0, 6, 4, 7, 5], [7, 5, 1, 0, 4, 0], [7, 3, 1, 0, 7, 5], [7, 3, 2, 0, 1, 0], [7, 3, 6, 2, 2, 0], ], [[7, 3, 6, 2, 6, 4], [7, 5, 7, 3, 6, 4]], [[6, 2, 6, 7, 6, 4]], [[0, 4, 0, 1, 0, 2], [6, 7, 4, 6, 2, 6]], [[1, 0, 1, 5, 1, 3], [7, 6, 4, 6, 2, 6]], [[1, 3, 0, 2, 1, 5], [0, 2, 0, 4, 1, 5], [7, 6, 4, 6, 2, 6]], [[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0]], [[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3]], [[6, 4, 2, 0, 6, 7], [2, 0, 2, 3, 6, 7], [5, 1, 3, 1, 0, 1]], [[1, 5, 1, 3, 0, 4], [1, 3, 7, 6, 0, 4], [0, 4, 7, 6, 4, 6], [1, 3, 3, 2, 7, 6]], [[3, 2, 3, 1, 3, 7], [6, 4, 2, 6, 7, 6]], [[3, 7, 3, 2, 1, 3], [0, 2, 0, 4, 1, 0], [7, 6, 4, 6, 2, 6]], [[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0], [4, 6, 2, 6, 7, 6]], [[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5], [6, 4, 2, 6, 6, 7]], [[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0]], [[0, 1, 4, 6, 0, 4], [0, 1, 6, 7, 4, 6], [0, 1, 1, 3, 6, 7], [1, 3, 3, 7, 6, 7]], [[0, 2, 0, 1, 4, 6], [0, 1, 3, 7, 4, 6], [0, 1, 1, 5, 3, 7], [4, 6, 3, 7, 6, 7]], [[7, 3, 6, 7, 6, 4], [6, 4, 4, 0, 7, 3], [7, 3, 4, 0, 5, 1]], [[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5]], [[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5]], [[6, 7, 4, 5, 6, 2], [4, 5, 4, 0, 6, 2], [3, 1, 0, 1, 5, 1]], [[2, 0, 2, 6, 3, 1], [2, 6, 4, 5, 3, 1], [2, 6, 6, 7, 4, 5], [3, 1, 4, 5, 1, 5]], [[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7]], [[0, 1, 2, 3, 6, 7], [0, 1, 6, 7, 4, 5]], [[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7], [1, 3, 0, 1, 1, 5]], [[5, 4, 1, 5, 1, 3], [1, 3, 3, 2, 5, 4], [5, 4, 3, 2, 7, 6]], [[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5], [1, 3, 7, 3, 2, 3]], [[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5], [3, 7, 2, 3, 3, 1]], [[0, 1, 1, 5, 3, 7], [0, 1, 3, 7, 2, 3], [0, 4, 2, 6, 6, 7], [6, 7, 4, 5, 0, 4]], [ [6, 2, 7, 6, 5, 4], [5, 4, 0, 2, 6, 2], [5, 1, 0, 2, 5, 4], [5, 1, 3, 2, 0, 2], [5, 1, 7, 3, 3, 2], ], [[3, 1, 3, 7, 2, 0], [3, 7, 5, 4, 2, 0], [2, 0, 5, 4, 0, 4], [3, 7, 7, 6, 5, 4]], [[1, 0, 3, 1, 3, 7], [3, 7, 7, 6, 1, 0], [1, 0, 7, 6, 5, 4]], [ [1, 0, 5, 1, 7, 3], [7, 3, 2, 0, 1, 0], [7, 6, 2, 0, 7, 3], [7, 6, 4, 0, 2, 0], [7, 6, 5, 4, 4, 0], ], [[7, 6, 5, 4, 5, 1], [7, 3, 7, 6, 5, 1]], [[5, 7, 5, 1, 5, 4], [6, 2, 7, 6, 4, 6]], [[0, 2, 0, 4, 1, 0], [5, 4, 5, 7, 1, 5], [2, 6, 7, 6, 4, 6]], [[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3], [2, 6, 7, 6, 4, 6]], [[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3], [6, 7, 4, 6, 6, 2]], [[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0], [1, 5, 4, 5, 7, 5]], [[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3], [5, 1, 4, 5, 5, 7]], [[0, 2, 2, 3, 6, 7], [0, 2, 6, 7, 4, 6], [0, 1, 4, 5, 5, 7], [5, 7, 1, 3, 0, 1]], [ [5, 4, 7, 5, 3, 1], [3, 1, 0, 4, 5, 4], [3, 2, 0, 4, 3, 1], [3, 2, 6, 4, 0, 4], [3, 2, 7, 6, 6, 4], ], [[5, 4, 5, 7, 1, 5], [3, 7, 3, 2, 1, 3], [4, 6, 2, 6, 7, 6]], [[1, 0, 0, 2, 0, 4], [1, 5, 5, 4, 5, 7], [3, 2, 1, 3, 3, 7], [2, 6, 7, 6, 4, 6]], [[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0], [6, 2, 7, 6, 6, 4]], [ [0, 4, 2, 3, 0, 2], [0, 4, 3, 7, 2, 3], [0, 4, 4, 5, 3, 7], [4, 5, 5, 7, 3, 7], [6, 7, 4, 6, 2, 6], ], [[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0], [5, 4, 7, 5, 5, 1]], [ [0, 1, 4, 6, 0, 4], [0, 1, 6, 7, 4, 6], [0, 1, 1, 3, 6, 7], [1, 3, 3, 7, 6, 7], [5, 7, 1, 5, 4, 5], ], [ [6, 7, 4, 6, 0, 2], [0, 2, 3, 7, 6, 7], [0, 1, 3, 7, 0, 2], [0, 1, 5, 7, 3, 7], [0, 1, 4, 5, 5, 7], ], [[4, 0, 6, 7, 4, 6], [4, 0, 7, 3, 6, 7], [4, 0, 5, 7, 7, 3], [4, 5, 5, 7, 4, 0]], [[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0]], [[0, 2, 1, 5, 0, 1], [0, 2, 5, 7, 1, 5], [0, 2, 2, 6, 5, 7], [2, 6, 6, 7, 5, 7]], [[1, 3, 1, 0, 5, 7], [1, 0, 2, 6, 5, 7], [5, 7, 2, 6, 7, 6], [1, 0, 0, 4, 2, 6]], [[2, 0, 6, 2, 6, 7], [6, 7, 7, 5, 2, 0], [2, 0, 7, 5, 3, 1]], [[0, 4, 0, 2, 1, 5], [0, 2, 6, 7, 1, 5], [0, 2, 2, 3, 6, 7], [1, 5, 6, 7, 5, 7]], [[7, 6, 5, 7, 5, 1], [5, 1, 1, 0, 7, 6], [7, 6, 1, 0, 3, 2]], [ [2, 0, 3, 2, 7, 6], [7, 6, 4, 0, 2, 0], [7, 5, 4, 0, 7, 6], [7, 5, 1, 0, 4, 0], [7, 5, 3, 1, 1, 0], ], [[7, 5, 3, 1, 3, 2], [7, 6, 7, 5, 3, 2]], [[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0], [3, 1, 7, 3, 3, 2]], [ [0, 2, 1, 5, 0, 1], [0, 2, 5, 7, 1, 5], [0, 2, 2, 6, 5, 7], [2, 6, 6, 7, 5, 7], [3, 7, 2, 3, 1, 3], ], [ [3, 7, 2, 3, 0, 1], [0, 1, 5, 7, 3, 7], [0, 4, 5, 7, 0, 1], [0, 4, 6, 7, 5, 7], [0, 4, 2, 6, 6, 7], ], [[2, 0, 3, 7, 2, 3], [2, 0, 7, 5, 3, 7], [2, 0, 6, 7, 7, 5], [2, 6, 6, 7, 2, 0]], [ [5, 7, 1, 5, 0, 4], [0, 4, 6, 7, 5, 7], [0, 2, 6, 7, 0, 4], [0, 2, 3, 7, 6, 7], [0, 2, 1, 3, 3, 7], ], [[1, 0, 5, 7, 1, 5], [1, 0, 7, 6, 5, 7], [1, 0, 3, 7, 7, 6], [1, 3, 3, 7, 1, 0]], [[0, 2, 0, 1, 0, 4], [3, 7, 6, 7, 5, 7]], [[7, 5, 7, 3, 7, 6]], [[7, 3, 7, 5, 7, 6]], [[0, 1, 0, 2, 0, 4], [6, 7, 3, 7, 5, 7]], [[1, 3, 1, 0, 1, 5], [7, 6, 3, 7, 5, 7]], [[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2], [6, 7, 3, 7, 5, 7]], [[2, 6, 2, 0, 2, 3], [7, 5, 6, 7, 3, 7]], [[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4], [5, 7, 6, 7, 3, 7]], [[1, 5, 1, 3, 0, 1], [2, 3, 2, 6, 0, 2], [5, 7, 6, 7, 3, 7]], [[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4], [7, 6, 3, 7, 7, 5]], [[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2]], [[7, 6, 3, 2, 7, 5], [3, 2, 3, 1, 7, 5], [4, 0, 1, 0, 2, 0]], [[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2]], [[2, 3, 2, 0, 6, 7], [2, 0, 1, 5, 6, 7], [2, 0, 0, 4, 1, 5], [6, 7, 1, 5, 7, 5]], [[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1]], [[0, 4, 0, 1, 2, 6], [0, 1, 5, 7, 2, 6], [2, 6, 5, 7, 6, 7], [0, 1, 1, 3, 5, 7]], [[1, 5, 0, 2, 1, 0], [1, 5, 2, 6, 0, 2], [1, 5, 5, 7, 2, 6], [5, 7, 7, 6, 2, 6]], [[5, 1, 7, 5, 7, 6], [7, 6, 6, 2, 5, 1], [5, 1, 6, 2, 4, 0]], [[4, 5, 4, 0, 4, 6], [7, 3, 5, 7, 6, 7]], [[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1], [3, 7, 5, 7, 6, 7]], [[4, 6, 4, 5, 0, 4], [1, 5, 1, 3, 0, 1], [6, 7, 3, 7, 5, 7]], [[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2], [7, 3, 5, 7, 7, 6]], [[2, 3, 2, 6, 0, 2], [4, 6, 4, 5, 0, 4], [3, 7, 5, 7, 6, 7]], [[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1], [7, 5, 6, 7, 7, 3]], [[0, 1, 1, 5, 1, 3], [0, 2, 2, 3, 2, 6], [4, 5, 0, 4, 4, 6], [5, 7, 6, 7, 3, 7]], [ [1, 3, 5, 4, 1, 5], [1, 3, 4, 6, 5, 4], [1, 3, 3, 2, 4, 6], [3, 2, 2, 6, 4, 6], [7, 6, 3, 7, 5, 7], ], [[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2], [0, 4, 6, 4, 5, 4]], [[1, 0, 0, 2, 4, 6], [1, 0, 4, 6, 5, 4], [1, 3, 5, 7, 7, 6], [7, 6, 3, 2, 1, 3]], [[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2], [4, 6, 5, 4, 4, 0]], [ [7, 5, 6, 7, 2, 3], [2, 3, 1, 5, 7, 5], [2, 0, 1, 5, 2, 3], [2, 0, 4, 5, 1, 5], [2, 0, 6, 4, 4, 5], ], [[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1], [4, 0, 6, 4, 4, 5]], [ [4, 6, 5, 4, 1, 0], [1, 0, 2, 6, 4, 6], [1, 3, 2, 6, 1, 0], [1, 3, 7, 6, 2, 6], [1, 3, 5, 7, 7, 6], ], [ [1, 5, 0, 2, 1, 0], [1, 5, 2, 6, 0, 2], [1, 5, 5, 7, 2, 6], [5, 7, 7, 6, 2, 6], [4, 6, 5, 4, 0, 4], ], [[5, 1, 4, 6, 5, 4], [5, 1, 6, 2, 4, 6], [5, 1, 7, 6, 6, 2], [5, 7, 7, 6, 5, 1]], [[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1]], [[7, 3, 5, 1, 7, 6], [5, 1, 5, 4, 7, 6], [2, 0, 4, 0, 1, 0]], [[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4]], [[0, 2, 0, 4, 1, 3], [0, 4, 6, 7, 1, 3], [1, 3, 6, 7, 3, 7], [0, 4, 4, 5, 6, 7]], [[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1], [0, 2, 3, 2, 6, 2]], [[1, 5, 5, 4, 7, 6], [1, 5, 7, 6, 3, 7], [1, 0, 3, 2, 2, 6], [2, 6, 0, 4, 1, 0]], [[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4], [2, 0, 3, 2, 2, 6]], [ [2, 3, 6, 2, 4, 0], [4, 0, 1, 3, 2, 3], [4, 5, 1, 3, 4, 0], [4, 5, 7, 3, 1, 3], [4, 5, 6, 7, 7, 3], ], [[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6]], [[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6], [0, 4, 1, 0, 0, 2]], [[1, 0, 5, 4, 7, 6], [1, 0, 7, 6, 3, 2]], [[2, 3, 0, 2, 0, 4], [0, 4, 4, 5, 2, 3], [2, 3, 4, 5, 6, 7]], [[1, 3, 1, 5, 0, 2], [1, 5, 7, 6, 0, 2], [1, 5, 5, 4, 7, 6], [0, 2, 7, 6, 2, 6]], [ [5, 1, 4, 5, 6, 7], [6, 7, 3, 1, 5, 1], [6, 2, 3, 1, 6, 7], [6, 2, 0, 1, 3, 1], [6, 2, 4, 0, 0, 1], ], [[6, 7, 2, 6, 2, 0], [2, 0, 0, 1, 6, 7], [6, 7, 0, 1, 4, 5]], [[6, 2, 4, 0, 4, 5], [6, 7, 6, 2, 4, 5]], [[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1]], [[1, 5, 1, 0, 3, 7], [1, 0, 4, 6, 3, 7], [1, 0, 0, 2, 4, 6], [3, 7, 4, 6, 7, 6]], [[1, 0, 3, 7, 1, 3], [1, 0, 7, 6, 3, 7], [1, 0, 0, 4, 7, 6], [0, 4, 4, 6, 7, 6]], [[6, 4, 7, 6, 7, 3], [7, 3, 3, 1, 6, 4], [6, 4, 3, 1, 2, 0]], [[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1], [2, 3, 6, 2, 2, 0]], [ [7, 6, 3, 7, 1, 5], [1, 5, 4, 6, 7, 6], [1, 0, 4, 6, 1, 5], [1, 0, 2, 6, 4, 6], [1, 0, 3, 2, 2, 6], ], [ [1, 0, 3, 7, 1, 3], [1, 0, 7, 6, 3, 7], [1, 0, 0, 4, 7, 6], [0, 4, 4, 6, 7, 6], [2, 6, 0, 2, 3, 2], ], [[3, 1, 7, 6, 3, 7], [3, 1, 6, 4, 7, 6], [3, 1, 2, 6, 6, 4], [3, 2, 2, 6, 3, 1]], [[3, 2, 3, 1, 7, 6], [3, 1, 0, 4, 7, 6], [7, 6, 0, 4, 6, 4], [3, 1, 1, 5, 0, 4]], [ [0, 1, 2, 0, 6, 4], [6, 4, 5, 1, 0, 1], [6, 7, 5, 1, 6, 4], [6, 7, 3, 1, 5, 1], [6, 7, 2, 3, 3, 1], ], [[0, 1, 4, 0, 4, 6], [4, 6, 6, 7, 0, 1], [0, 1, 6, 7, 2, 3]], [[6, 7, 2, 3, 2, 0], [6, 4, 6, 7, 2, 0]], [ [2, 6, 0, 2, 1, 3], [1, 3, 7, 6, 2, 6], [1, 5, 7, 6, 1, 3], [1, 5, 4, 6, 7, 6], [1, 5, 0, 4, 4, 6], ], [[1, 5, 1, 0, 1, 3], [4, 6, 7, 6, 2, 6]], [[0, 1, 2, 6, 0, 2], [0, 1, 6, 7, 2, 6], [0, 1, 4, 6, 6, 7], [0, 4, 4, 6, 0, 1]], [[6, 7, 6, 2, 6, 4]], [[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4]], [[7, 5, 6, 4, 7, 3], [6, 4, 6, 2, 7, 3], [1, 0, 2, 0, 4, 0]], [[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4], [0, 1, 5, 1, 3, 1]], [[2, 0, 0, 4, 1, 5], [2, 0, 1, 5, 3, 1], [2, 6, 3, 7, 7, 5], [7, 5, 6, 4, 2, 6]], [[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4]], [[3, 2, 3, 7, 1, 0], [3, 7, 6, 4, 1, 0], [3, 7, 7, 5, 6, 4], [1, 0, 6, 4, 0, 4]], [[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4], [1, 5, 3, 1, 1, 0]], [ [7, 3, 5, 7, 4, 6], [4, 6, 2, 3, 7, 3], [4, 0, 2, 3, 4, 6], [4, 0, 1, 3, 2, 3], [4, 0, 5, 1, 1, 3], ], [[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5]], [[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5], [0, 1, 2, 0, 0, 4]], [[1, 0, 1, 5, 3, 2], [1, 5, 4, 6, 3, 2], [3, 2, 4, 6, 2, 6], [1, 5, 5, 7, 4, 6]], [ [0, 2, 4, 0, 5, 1], [5, 1, 3, 2, 0, 2], [5, 7, 3, 2, 5, 1], [5, 7, 6, 2, 3, 2], [5, 7, 4, 6, 6, 2], ], [[2, 0, 3, 1, 7, 5], [2, 0, 7, 5, 6, 4]], [[4, 6, 0, 4, 0, 1], [0, 1, 1, 3, 4, 6], [4, 6, 1, 3, 5, 7]], [[0, 2, 1, 0, 1, 5], [1, 5, 5, 7, 0, 2], [0, 2, 5, 7, 4, 6]], [[5, 7, 4, 6, 4, 0], [5, 1, 5, 7, 4, 0]], [[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2]], [[0, 1, 0, 2, 4, 5], [0, 2, 3, 7, 4, 5], [4, 5, 3, 7, 5, 7], [0, 2, 2, 6, 3, 7]], [[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2], [1, 0, 5, 1, 1, 3]], [ [1, 5, 3, 1, 2, 0], [2, 0, 4, 5, 1, 5], [2, 6, 4, 5, 2, 0], [2, 6, 7, 5, 4, 5], [2, 6, 3, 7, 7, 5], ], [[2, 3, 0, 4, 2, 0], [2, 3, 4, 5, 0, 4], [2, 3, 3, 7, 4, 5], [3, 7, 7, 5, 4, 5]], [[3, 2, 7, 3, 7, 5], [7, 5, 5, 4, 3, 2], [3, 2, 5, 4, 1, 0]], [ [2, 3, 0, 4, 2, 0], [2, 3, 4, 5, 0, 4], [2, 3, 3, 7, 4, 5], [3, 7, 7, 5, 4, 5], [1, 5, 3, 1, 0, 1], ], [[3, 2, 1, 5, 3, 1], [3, 2, 5, 4, 1, 5], [3, 2, 7, 5, 5, 4], [3, 7, 7, 5, 3, 2]], [[2, 6, 2, 3, 0, 4], [2, 3, 7, 5, 0, 4], [2, 3, 3, 1, 7, 5], [0, 4, 7, 5, 4, 5]], [ [3, 2, 1, 3, 5, 7], [5, 7, 6, 2, 3, 2], [5, 4, 6, 2, 5, 7], [5, 4, 0, 2, 6, 2], [5, 4, 1, 0, 0, 2], ], [ [4, 5, 0, 4, 2, 6], [2, 6, 7, 5, 4, 5], [2, 3, 7, 5, 2, 6], [2, 3, 1, 5, 7, 5], [2, 3, 0, 1, 1, 5], ], [[2, 3, 2, 0, 2, 6], [1, 5, 7, 5, 4, 5]], [[5, 7, 4, 5, 4, 0], [4, 0, 0, 2, 5, 7], [5, 7, 0, 2, 1, 3]], [[5, 4, 1, 0, 1, 3], [5, 7, 5, 4, 1, 3]], [[0, 2, 4, 5, 0, 4], [0, 2, 5, 7, 4, 5], [0, 2, 1, 5, 5, 7], [0, 1, 1, 5, 0, 2]], [[5, 4, 5, 1, 5, 7]], [[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3]], [[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3], [0, 2, 4, 0, 0, 1]], [[3, 7, 3, 1, 2, 6], [3, 1, 5, 4, 2, 6], [3, 1, 1, 0, 5, 4], [2, 6, 5, 4, 6, 4]], [ [6, 4, 2, 6, 3, 7], [3, 7, 5, 4, 6, 4], [3, 1, 5, 4, 3, 7], [3, 1, 0, 4, 5, 4], [3, 1, 2, 0, 0, 4], ], [[2, 0, 2, 3, 6, 4], [2, 3, 1, 5, 6, 4], [6, 4, 1, 5, 4, 5], [2, 3, 3, 7, 1, 5]], [ [0, 4, 1, 0, 3, 2], [3, 2, 6, 4, 0, 4], [3, 7, 6, 4, 3, 2], [3, 7, 5, 4, 6, 4], [3, 7, 1, 5, 5, 4], ], [ [1, 3, 0, 1, 4, 5], [4, 5, 7, 3, 1, 3], [4, 6, 7, 3, 4, 5], [4, 6, 2, 3, 7, 3], [4, 6, 0, 2, 2, 3], ], [[3, 7, 3, 1, 3, 2], [5, 4, 6, 4, 0, 4]], [[3, 1, 2, 6, 3, 2], [3, 1, 6, 4, 2, 6], [3, 1, 1, 5, 6, 4], [1, 5, 5, 4, 6, 4]], [ [3, 1, 2, 6, 3, 2], [3, 1, 6, 4, 2, 6], [3, 1, 1, 5, 6, 4], [1, 5, 5, 4, 6, 4], [0, 4, 1, 0, 2, 0], ], [[4, 5, 6, 4, 6, 2], [6, 2, 2, 3, 4, 5], [4, 5, 2, 3, 0, 1]], [[2, 3, 6, 4, 2, 6], [2, 3, 4, 5, 6, 4], [2, 3, 0, 4, 4, 5], [2, 0, 0, 4, 2, 3]], [[1, 3, 5, 1, 5, 4], [5, 4, 4, 6, 1, 3], [1, 3, 4, 6, 0, 2]], [[1, 3, 0, 4, 1, 0], [1, 3, 4, 6, 0, 4], [1, 3, 5, 4, 4, 6], [1, 5, 5, 4, 1, 3]], [[4, 6, 0, 2, 0, 1], [4, 5, 4, 6, 0, 1]], [[4, 6, 4, 0, 4, 5]], [[4, 0, 6, 2, 7, 3], [4, 0, 7, 3, 5, 1]], [[1, 5, 0, 1, 0, 2], [0, 2, 2, 6, 1, 5], [1, 5, 2, 6, 3, 7]], [[3, 7, 1, 3, 1, 0], [1, 0, 0, 4, 3, 7], [3, 7, 0, 4, 2, 6]], [[3, 1, 2, 0, 2, 6], [3, 7, 3, 1, 2, 6]], [[0, 4, 2, 0, 2, 3], [2, 3, 3, 7, 0, 4], [0, 4, 3, 7, 1, 5]], [[3, 7, 1, 5, 1, 0], [3, 2, 3, 7, 1, 0]], [[0, 4, 1, 3, 0, 1], [0, 4, 3, 7, 1, 3], [0, 4, 2, 3, 3, 7], [0, 2, 2, 3, 0, 4]], [[3, 7, 3, 1, 3, 2]], [[2, 6, 3, 2, 3, 1], [3, 1, 1, 5, 2, 6], [2, 6, 1, 5, 0, 4]], [[1, 5, 3, 2, 1, 3], [1, 5, 2, 6, 3, 2], [1, 5, 0, 2, 2, 6], [1, 0, 0, 2, 1, 5]], [[2, 3, 0, 1, 0, 4], [2, 6, 2, 3, 0, 4]], [[2, 3, 2, 0, 2, 6]], [[1, 5, 0, 4, 0, 2], [1, 3, 1, 5, 0, 2]], [[1, 5, 1, 0, 1, 3]], [[0, 2, 0, 1, 0, 4]], [], ] ================================================ FILE: shap_e/rendering/blender/__init__.py ================================================ from .render import render_mesh, render_model from .view_data import BlenderViewData __all__ = ["BlenderViewData", "render_model"] ================================================ FILE: shap_e/rendering/blender/blender_script.py ================================================ """ Script to run within blender. Provide arguments after `--`. For example: `blender -b -P blender_script.py -- --help` """ import argparse import json import math import os import random import sys import bpy from mathutils import Vector from mathutils.noise import random_unit_vector MAX_DEPTH = 5.0 FORMAT_VERSION = 6 # Set by main(), these constants are passed to the script to avoid # duplicating them across multiple files. UNIFORM_LIGHT_DIRECTION = None BASIC_AMBIENT_COLOR = None BASIC_DIFFUSE_COLOR = None def clear_scene(): bpy.ops.object.select_all(action="SELECT") bpy.ops.object.delete() def clear_lights(): bpy.ops.object.select_all(action="DESELECT") for obj in bpy.context.scene.objects.values(): if isinstance(obj.data, bpy.types.Light): obj.select_set(True) bpy.ops.object.delete() def import_model(path): clear_scene() _, ext = os.path.splitext(path) ext = ext.lower() if ext == ".obj": bpy.ops.import_scene.obj(filepath=path) elif ext in [".glb", ".gltf"]: bpy.ops.import_scene.gltf(filepath=path) elif ext == ".stl": bpy.ops.import_mesh.stl(filepath=path) elif ext == ".fbx": bpy.ops.import_scene.fbx(filepath=path) elif ext == ".dae": bpy.ops.wm.collada_import(filepath=path) elif ext == ".ply": bpy.ops.import_mesh.ply(filepath=path) else: raise RuntimeError(f"unexpected extension: {ext}") def scene_root_objects(): for obj in bpy.context.scene.objects.values(): if not obj.parent: yield obj def scene_bbox(single_obj=None, ignore_matrix=False): bbox_min = (math.inf,) * 3 bbox_max = (-math.inf,) * 3 found = False for obj in scene_meshes() if single_obj is None else [single_obj]: found = True for coord in obj.bound_box: coord = Vector(coord) if not ignore_matrix: coord = obj.matrix_world @ coord bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord)) bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord)) if not found: raise RuntimeError("no objects in scene to compute bounding box for") return Vector(bbox_min), Vector(bbox_max) def scene_meshes(): for obj in bpy.context.scene.objects.values(): if isinstance(obj.data, (bpy.types.Mesh)): yield obj def normalize_scene(): if len(list(scene_root_objects())) > 1: # Create an empty object to be used as a parent for all root objects parent_empty = bpy.data.objects.new("ParentEmpty", None) bpy.context.scene.collection.objects.link(parent_empty) # Parent all root objects to the empty object for obj in scene_root_objects(): if obj != parent_empty: obj.parent = parent_empty bbox_min, bbox_max = scene_bbox() scale = 1 / max(bbox_max - bbox_min) for obj in scene_root_objects(): obj.scale = obj.scale * scale # Apply scale to matrix_world. bpy.context.view_layer.update() bbox_min, bbox_max = scene_bbox() offset = -(bbox_min + bbox_max) / 2 for obj in scene_root_objects(): obj.matrix_world.translation += offset bpy.ops.object.select_all(action="DESELECT") def create_camera(): # https://b3d.interplanety.org/en/how-to-create-camera-through-the-blender-python-api/ camera_data = bpy.data.cameras.new(name="Camera") camera_object = bpy.data.objects.new("Camera", camera_data) bpy.context.scene.collection.objects.link(camera_object) bpy.context.scene.camera = camera_object def set_camera(direction, camera_dist=2.0): camera_pos = -camera_dist * direction bpy.context.scene.camera.location = camera_pos # https://blender.stackexchange.com/questions/5210/pointing-the-camera-in-a-particular-direction-programmatically rot_quat = direction.to_track_quat("-Z", "Y") bpy.context.scene.camera.rotation_euler = rot_quat.to_euler() bpy.context.view_layer.update() def randomize_camera(camera_dist=2.0): direction = random_unit_vector() set_camera(direction, camera_dist=camera_dist) def pan_camera(time, axis="Z", camera_dist=2.0, elevation=0.1): angle = time * math.pi * 2 direction = [-math.cos(angle), -math.sin(angle), elevation] assert axis in ["X", "Y", "Z"] if axis == "X": direction = [direction[2], *direction[:2]] elif axis == "Y": direction = [direction[0], elevation, direction[1]] direction = Vector(direction).normalized() set_camera(direction, camera_dist=camera_dist) def place_camera(time, camera_pose_mode="random", camera_dist_min=2.0, camera_dist_max=2.0): camera_dist = random.uniform(camera_dist_min, camera_dist_max) if camera_pose_mode == "random": randomize_camera(camera_dist=camera_dist) elif camera_pose_mode == "z-circular": pan_camera(time, axis="Z", camera_dist=camera_dist) elif camera_pose_mode == "z-circular-elevated": pan_camera(time, axis="Z", camera_dist=camera_dist, elevation=-0.2617993878) else: raise ValueError(f"Unknown camera pose mode: {camera_pose_mode}") def create_light(location, energy=1.0, angle=0.5 * math.pi / 180): # https://blender.stackexchange.com/questions/215624/how-to-create-a-light-with-the-python-api-in-blender-2-92 light_data = bpy.data.lights.new(name="Light", type="SUN") light_data.energy = energy light_data.angle = angle light_object = bpy.data.objects.new(name="Light", object_data=light_data) direction = -location rot_quat = direction.to_track_quat("-Z", "Y") light_object.rotation_euler = rot_quat.to_euler() bpy.context.view_layer.update() bpy.context.collection.objects.link(light_object) light_object.location = location def create_random_lights(count=4, distance=2.0, energy=1.5): clear_lights() for _ in range(count): create_light(random_unit_vector() * distance, energy=energy) def create_camera_light(): clear_lights() create_light(bpy.context.scene.camera.location, energy=5.0) def create_uniform_light(backend): clear_lights() # Random direction to decorrelate axis-aligned sides. pos = Vector(UNIFORM_LIGHT_DIRECTION) angle = 0.0092 if backend == "CYCLES" else math.pi create_light(pos, energy=5.0, angle=angle) create_light(-pos, energy=5.0, angle=angle) def create_vertex_color_shaders(): # By default, Blender will ignore vertex colors in both the # Eevee and Cycles backends, since these colors aren't # associated with a material. # # What we do here is create a simple material shader and link # the vertex color to the material color. for obj in bpy.context.scene.objects.values(): if not isinstance(obj.data, (bpy.types.Mesh)): continue if len(obj.data.materials): # We don't want to override any existing materials. continue color_keys = (obj.data.vertex_colors or {}).keys() if not len(color_keys): # Many objects will have no materials *or* vertex colors. continue mat = bpy.data.materials.new(name="VertexColored") mat.use_nodes = True # There should be a Principled BSDF by default. bsdf_node = None for node in mat.node_tree.nodes: if node.type == "BSDF_PRINCIPLED": bsdf_node = node assert bsdf_node is not None, "material has no Principled BSDF node to modify" socket_map = {} for input in bsdf_node.inputs: socket_map[input.name] = input # Make sure nothing lights the object except for the diffuse color. socket_map["Specular"].default_value = 0.0 socket_map["Roughness"].default_value = 1.0 v_color = mat.node_tree.nodes.new("ShaderNodeVertexColor") v_color.layer_name = color_keys[0] mat.node_tree.links.new(v_color.outputs[0], socket_map["Base Color"]) obj.data.materials.append(mat) def create_default_materials(): for obj in bpy.context.scene.objects.values(): if isinstance(obj.data, (bpy.types.Mesh)): if not len(obj.data.materials): mat = bpy.data.materials.new(name="DefaultMaterial") mat.use_nodes = True obj.data.materials.append(mat) def find_materials(): all_materials = set() for obj in bpy.context.scene.objects.values(): if not isinstance(obj.data, bpy.types.Mesh): continue for mat in obj.data.materials: all_materials.add(mat) return all_materials def delete_all_materials(): for obj in bpy.context.scene.objects.values(): if isinstance(obj.data, bpy.types.Mesh): # https://blender.stackexchange.com/questions/146714/removing-all-material-slots-in-one-go obj.data.materials.clear() def setup_material_extraction_shaders(capturing_material_alpha: bool): """ Change every material to emit texture colors (or alpha) rather than having an actual reflective color. Returns a function to undo the changes to the materials. """ # Objects can share materials, so we first find all of the # materials in the project, and then modify them each once. undo_fns = [] for mat in find_materials(): undo_fn = setup_material_extraction_shader_for_material(mat, capturing_material_alpha) if undo_fn is not None: undo_fns.append(undo_fn) return lambda: [undo_fn() for undo_fn in undo_fns] def setup_material_extraction_shader_for_material(mat, capturing_material_alpha: bool): mat.use_nodes = True # By default, most imported models should use the regular # "Principled BSDF" material, so we should always find this. # If not, this shader manipulation logic won't work. bsdf_node = None for node in mat.node_tree.nodes: if node.type == "BSDF_PRINCIPLED": bsdf_node = node assert bsdf_node is not None, "material has no Principled BSDF node to modify" socket_map = {} for input in bsdf_node.inputs: socket_map[input.name] = input for name in ["Base Color", "Emission", "Emission Strength", "Alpha", "Specular"]: assert name in socket_map.keys(), f"{name} not in {list(socket_map.keys())}" old_base_color = get_socket_value(mat.node_tree, socket_map["Base Color"]) old_alpha = get_socket_value(mat.node_tree, socket_map["Alpha"]) old_emission = get_socket_value(mat.node_tree, socket_map["Emission"]) old_emission_strength = get_socket_value(mat.node_tree, socket_map["Emission Strength"]) old_specular = get_socket_value(mat.node_tree, socket_map["Specular"]) # Make sure the base color of all objects is black and the opacity # is 1, so that we are effectively just telling the shader what color # to make the pixels. clear_socket_input(mat.node_tree, socket_map["Base Color"]) socket_map["Base Color"].default_value = [0, 0, 0, 1] clear_socket_input(mat.node_tree, socket_map["Alpha"]) socket_map["Alpha"].default_value = 1 clear_socket_input(mat.node_tree, socket_map["Specular"]) socket_map["Specular"].default_value = 0.0 old_blend_method = mat.blend_method mat.blend_method = "OPAQUE" if capturing_material_alpha: set_socket_value(mat.node_tree, socket_map["Emission"], old_alpha) else: set_socket_value(mat.node_tree, socket_map["Emission"], old_base_color) clear_socket_input(mat.node_tree, socket_map["Emission Strength"]) socket_map["Emission Strength"].default_value = 1.0 def undo_fn(): mat.blend_method = old_blend_method set_socket_value(mat.node_tree, socket_map["Base Color"], old_base_color) set_socket_value(mat.node_tree, socket_map["Alpha"], old_alpha) set_socket_value(mat.node_tree, socket_map["Emission"], old_emission) set_socket_value(mat.node_tree, socket_map["Emission Strength"], old_emission_strength) set_socket_value(mat.node_tree, socket_map["Specular"], old_specular) return undo_fn def get_socket_value(tree, socket): default = socket.default_value if not isinstance(default, float): default = list(default) for link in tree.links: if link.to_socket == socket: return (link.from_socket, default) return (None, default) def clear_socket_input(tree, socket): for link in list(tree.links): if link.to_socket == socket: tree.links.remove(link) def set_socket_value(tree, socket, socket_and_default): clear_socket_input(tree, socket) old_source_socket, default = socket_and_default if isinstance(default, float) and not isinstance(socket.default_value, float): # Codepath for setting Emission to a previous alpha value. socket.default_value = [default] * 3 + [1.0] else: socket.default_value = default if old_source_socket is not None: tree.links.new(old_source_socket, socket) def setup_nodes(output_path, capturing_material_alpha: bool = False, basic_lighting: bool = False): tree = bpy.context.scene.node_tree links = tree.links for node in tree.nodes: tree.nodes.remove(node) # Helpers to perform math on links and constants. def node_op(op: str, *args, clamp=False): node = tree.nodes.new(type="CompositorNodeMath") node.operation = op if clamp: node.use_clamp = True for i, arg in enumerate(args): if isinstance(arg, (int, float)): node.inputs[i].default_value = arg else: links.new(arg, node.inputs[i]) return node.outputs[0] def node_clamp(x, maximum=1.0): return node_op("MINIMUM", x, maximum) def node_mul(x, y, **kwargs): return node_op("MULTIPLY", x, y, **kwargs) def node_add(x, y, **kwargs): return node_op("ADD", x, y, **kwargs) def node_abs(x, **kwargs): return node_op("ABSOLUTE", x, **kwargs) input_node = tree.nodes.new(type="CompositorNodeRLayers") input_node.scene = bpy.context.scene input_sockets = {} for output in input_node.outputs: input_sockets[output.name] = output if capturing_material_alpha: color_socket = input_sockets["Image"] else: raw_color_socket = input_sockets["Image"] if basic_lighting: # Compute diffuse lighting normal_xyz = tree.nodes.new(type="CompositorNodeSeparateXYZ") tree.links.new(input_sockets["Normal"], normal_xyz.inputs[0]) normal_x, normal_y, normal_z = [normal_xyz.outputs[i] for i in range(3)] dot = node_add( node_mul(UNIFORM_LIGHT_DIRECTION[0], normal_x), node_add( node_mul(UNIFORM_LIGHT_DIRECTION[1], normal_y), node_mul(UNIFORM_LIGHT_DIRECTION[2], normal_z), ), ) diffuse = node_abs(dot) # Compute ambient + diffuse lighting brightness = node_add(BASIC_AMBIENT_COLOR, node_mul(BASIC_DIFFUSE_COLOR, diffuse)) # Modulate the RGB channels using the total brightness. rgba_node = tree.nodes.new(type="CompositorNodeSepRGBA") tree.links.new(raw_color_socket, rgba_node.inputs[0]) combine_node = tree.nodes.new(type="CompositorNodeCombRGBA") for i in range(3): tree.links.new(node_mul(rgba_node.outputs[i], brightness), combine_node.inputs[i]) tree.links.new(rgba_node.outputs[3], combine_node.inputs[3]) raw_color_socket = combine_node.outputs[0] # We apply sRGB here so that our fixed-point depth map and material # alpha values are not sRGB, and so that we perform ambient+diffuse # lighting in linear RGB space. color_node = tree.nodes.new(type="CompositorNodeConvertColorSpace") color_node.from_color_space = "Linear" color_node.to_color_space = "sRGB" tree.links.new(raw_color_socket, color_node.inputs[0]) color_socket = color_node.outputs[0] split_node = tree.nodes.new(type="CompositorNodeSepRGBA") tree.links.new(color_socket, split_node.inputs[0]) # Create separate file output nodes for every channel we care about. # The process calling this script must decide how to recombine these # channels, possibly into a single image. for i, channel in enumerate("rgba") if not capturing_material_alpha else [(0, "MatAlpha")]: output_node = tree.nodes.new(type="CompositorNodeOutputFile") output_node.base_path = f"{output_path}_{channel}" links.new(split_node.outputs[i], output_node.inputs[0]) if capturing_material_alpha: # No need to re-write depth here. return depth_out = node_clamp(node_mul(input_sockets["Depth"], 1 / MAX_DEPTH)) output_node = tree.nodes.new(type="CompositorNodeOutputFile") output_node.base_path = f"{output_path}_depth" links.new(depth_out, output_node.inputs[0]) def render_scene(output_path, fast_mode: bool, extract_material: bool, basic_lighting: bool): use_workbench = bpy.context.scene.render.engine == "BLENDER_WORKBENCH" if use_workbench: # We must use a different engine to compute depth maps. bpy.context.scene.render.engine = "BLENDER_EEVEE" bpy.context.scene.eevee.taa_render_samples = 1 # faster, since we discard image. if fast_mode: if bpy.context.scene.render.engine == "BLENDER_EEVEE": bpy.context.scene.eevee.taa_render_samples = 1 elif bpy.context.scene.render.engine == "CYCLES": bpy.context.scene.cycles.samples = 256 else: if bpy.context.scene.render.engine == "CYCLES": # We should still impose a per-frame time limit # so that we don't timeout completely. bpy.context.scene.cycles.time_limit = 40 bpy.context.view_layer.update() bpy.context.scene.use_nodes = True bpy.context.scene.view_layers["ViewLayer"].use_pass_z = True if basic_lighting: bpy.context.scene.view_layers["ViewLayer"].use_pass_normal = True bpy.context.scene.view_settings.view_transform = "Raw" # sRGB done in graph nodes bpy.context.scene.render.film_transparent = True bpy.context.scene.render.resolution_x = 512 bpy.context.scene.render.resolution_y = 512 bpy.context.scene.render.image_settings.file_format = "PNG" bpy.context.scene.render.image_settings.color_mode = "BW" bpy.context.scene.render.image_settings.color_depth = "16" bpy.context.scene.render.filepath = output_path if extract_material: for do_alpha in [False, True]: undo_fn = setup_material_extraction_shaders(capturing_material_alpha=do_alpha) setup_nodes(output_path, capturing_material_alpha=do_alpha) bpy.ops.render.render(write_still=True) undo_fn() else: setup_nodes(output_path, basic_lighting=basic_lighting) bpy.ops.render.render(write_still=True) # The output images must be moved from their own sub-directories, or # discarded if we are using workbench for the color. for channel_name in ["r", "g", "b", "a", "depth", *(["MatAlpha"] if extract_material else [])]: sub_dir = f"{output_path}_{channel_name}" image_path = os.path.join(sub_dir, os.listdir(sub_dir)[0]) name, ext = os.path.splitext(output_path) if channel_name == "depth" or not use_workbench: os.rename(image_path, f"{name}_{channel_name}{ext}") else: os.remove(image_path) os.removedirs(sub_dir) if use_workbench: # Re-render RGBA using workbench with texture mode, since this seems # to show the most reasonable colors when lighting is broken. bpy.context.scene.use_nodes = False bpy.context.scene.render.engine = "BLENDER_WORKBENCH" bpy.context.scene.render.image_settings.color_mode = "RGBA" bpy.context.scene.render.image_settings.color_depth = "8" bpy.context.scene.display.shading.color_type = "TEXTURE" bpy.context.scene.display.shading.light = "FLAT" if fast_mode: # Single pass anti-aliasing. bpy.context.scene.display.render_aa = "FXAA" os.remove(output_path) bpy.ops.render.render(write_still=True) bpy.context.scene.render.image_settings.color_mode = "BW" bpy.context.scene.render.image_settings.color_depth = "16" def scene_fov(): x_fov = bpy.context.scene.camera.data.angle_x y_fov = bpy.context.scene.camera.data.angle_y width = bpy.context.scene.render.resolution_x height = bpy.context.scene.render.resolution_y if bpy.context.scene.camera.data.angle == x_fov: y_fov = 2 * math.atan(math.tan(x_fov / 2) * height / width) else: x_fov = 2 * math.atan(math.tan(y_fov / 2) * width / height) return x_fov, y_fov def write_camera_metadata(path): x_fov, y_fov = scene_fov() bbox_min, bbox_max = scene_bbox() matrix = bpy.context.scene.camera.matrix_world with open(path, "w") as f: json.dump( dict( format_version=FORMAT_VERSION, max_depth=MAX_DEPTH, bbox=[list(bbox_min), list(bbox_max)], origin=list(matrix.col[3])[:3], x_fov=x_fov, y_fov=y_fov, x=list(matrix.col[0])[:3], y=list(-matrix.col[1])[:3], z=list(-matrix.col[2])[:3], ), f, ) def save_rendering_dataset( input_path: str, output_path: str, num_images: int, backend: str, light_mode: str, camera_pose: str, camera_dist_min: float, camera_dist_max: float, fast_mode: bool, extract_material: bool, delete_material: bool, ): assert light_mode in ["random", "uniform", "camera", "basic"] assert camera_pose in ["random", "z-circular", "z-circular-elevated"] basic_lighting = light_mode == "basic" assert not (basic_lighting and extract_material), "cannot extract material with basic lighting" assert not (delete_material and extract_material), "cannot extract material and delete it" import_model(input_path) bpy.context.scene.render.engine = backend normalize_scene() if light_mode == "random": create_random_lights() elif light_mode == "uniform": create_uniform_light(backend) create_camera() create_vertex_color_shaders() if delete_material: delete_all_materials() if extract_material or basic_lighting: create_default_materials() if basic_lighting: # Make sure materials are uniformly lit, so that we can light # them in the output shader. setup_material_extraction_shaders(capturing_material_alpha=False) for i in range(num_images): t = i / max(num_images - 1, 1) # same as np.linspace(0, 1, num_images) place_camera( t, camera_pose_mode=camera_pose, camera_dist_min=camera_dist_min, camera_dist_max=camera_dist_max, ) if light_mode == "camera": create_camera_light() render_scene( os.path.join(output_path, f"{i:05}.png"), fast_mode=fast_mode, extract_material=extract_material, basic_lighting=basic_lighting, ) write_camera_metadata(os.path.join(output_path, f"{i:05}.json")) with open(os.path.join(output_path, "info.json"), "w") as f: info = dict( backend=backend, light_mode=light_mode, fast_mode=fast_mode, extract_material=extract_material, format_version=FORMAT_VERSION, channels=["R", "G", "B", "A", "D", *(["MatAlpha"] if extract_material else [])], scale=0.5, # The scene is bounded by [-scale, scale]. ) json.dump(info, f) def main(): global UNIFORM_LIGHT_DIRECTION, BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR try: dash_index = sys.argv.index("--") except ValueError as exc: raise ValueError("arguments must be preceded by '--'") from exc raw_args = sys.argv[dash_index + 1 :] parser = argparse.ArgumentParser() parser.add_argument("--input_path", required=True, type=str) parser.add_argument("--output_path", required=True, type=str) parser.add_argument("--num_images", required=True, type=int) parser.add_argument("--backend", type=str, default="BLENDER_EEVEE") parser.add_argument("--light_mode", type=str, default="random") parser.add_argument("--camera_pose", type=str, default="random") parser.add_argument("--camera_dist_min", type=float, default=2.0) parser.add_argument("--camera_dist_max", type=float, default=2.0) parser.add_argument("--fast_mode", action="store_true") parser.add_argument("--extract_material", action="store_true") parser.add_argument("--delete_material", action="store_true") # Prevent constants from being repeated. parser.add_argument("--uniform_light_direction", required=True, type=float, nargs="+") parser.add_argument("--basic_ambient", required=True, type=float) parser.add_argument("--basic_diffuse", required=True, type=float) args = parser.parse_args(raw_args) UNIFORM_LIGHT_DIRECTION = args.uniform_light_direction BASIC_AMBIENT_COLOR = args.basic_ambient BASIC_DIFFUSE_COLOR = args.basic_diffuse save_rendering_dataset( input_path=args.input_path, output_path=args.output_path, num_images=args.num_images, backend=args.backend, light_mode=args.light_mode, camera_pose=args.camera_pose, camera_dist_min=args.camera_dist_min, camera_dist_max=args.camera_dist_max, fast_mode=args.fast_mode, extract_material=args.extract_material, delete_material=args.delete_material, ) main() ================================================ FILE: shap_e/rendering/blender/constants.py ================================================ UNIFORM_LIGHT_DIRECTION = [0.09387503, -0.63953443, -0.7630093] BASIC_AMBIENT_COLOR = 0.3 BASIC_DIFFUSE_COLOR = 0.7 ================================================ FILE: shap_e/rendering/blender/render.py ================================================ import os import platform import subprocess import tempfile import zipfile import blobfile as bf import numpy as np from PIL import Image from shap_e.rendering.mesh import TriMesh from .constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR, UNIFORM_LIGHT_DIRECTION SCRIPT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "blender_script.py") def render_model( model_path: str, output_path: str, num_images: int, backend: str = "BLENDER_EEVEE", light_mode: str = "random", camera_pose: str = "random", camera_dist_min: float = 2.0, camera_dist_max: float = 2.0, fast_mode: bool = False, extract_material: bool = False, delete_material: bool = False, verbose: bool = False, timeout: float = 15 * 60, ): with tempfile.TemporaryDirectory() as tmp_dir: tmp_in = model_path tmp_out = os.path.join(tmp_dir, "out") zip_out = tmp_out + ".zip" os.mkdir(tmp_out) args = [] if platform.system() == "Linux": # Needed to enable Eevee backend on headless linux. args = ["xvfb-run", "-a"] args.extend( [ _blender_binary_path(), "-b", "-P", SCRIPT_PATH, "--", "--input_path", tmp_in, "--output_path", tmp_out, "--num_images", str(num_images), "--backend", backend, "--light_mode", light_mode, "--camera_pose", camera_pose, "--camera_dist_min", str(camera_dist_min), "--camera_dist_max", str(camera_dist_max), "--uniform_light_direction", *[str(x) for x in UNIFORM_LIGHT_DIRECTION], "--basic_ambient", str(BASIC_AMBIENT_COLOR), "--basic_diffuse", str(BASIC_DIFFUSE_COLOR), ] ) if fast_mode: args.append("--fast_mode") if extract_material: args.append("--extract_material") if delete_material: args.append("--delete_material") if verbose: subprocess.check_call(args) else: try: output = subprocess.check_output(args, stderr=subprocess.STDOUT, timeout=timeout) except subprocess.CalledProcessError as exc: raise RuntimeError(f"{exc}: {exc.output}") from exc if not os.path.exists(os.path.join(tmp_out, "info.json")): if verbose: # There is no output available, since it was # logged directly to stdout/stderr. raise RuntimeError(f"render failed: output file missing") else: raise RuntimeError(f"render failed: output file missing. Output: {output}") _combine_rgba(tmp_out) with zipfile.ZipFile(zip_out, mode="w") as zf: for name in os.listdir(tmp_out): zf.write(os.path.join(tmp_out, name), name) bf.copy(zip_out, output_path, overwrite=True) def render_mesh( mesh: TriMesh, output_path: str, num_images: int, backend: str = "BLENDER_EEVEE", **kwargs, ): if mesh.has_vertex_colors() and backend not in ["BLENDER_EEVEE", "CYCLES"]: raise ValueError(f"backend does not support vertex colors: {backend}") with tempfile.TemporaryDirectory() as tmp_dir: ply_path = os.path.join(tmp_dir, "out.ply") with open(ply_path, "wb") as f: mesh.write_ply(f) render_model( ply_path, output_path=output_path, num_images=num_images, backend=backend, **kwargs ) def _combine_rgba(out_dir: str): i = 0 while True: paths = [os.path.join(out_dir, f"{i:05}_{ch}.png") for ch in "rgba"] if not os.path.exists(paths[0]): break joined = np.stack( [(np.array(Image.open(path)) >> 8).astype(np.uint8) for path in paths], axis=-1 ) Image.fromarray(joined).save(os.path.join(out_dir, f"{i:05}.png")) for path in paths: os.remove(path) i += 1 def _blender_binary_path() -> str: path = os.getenv("BLENDER_PATH", None) if path is not None: return path if os.path.exists("/Applications/Blender.app/Contents/MacOS/Blender"): return "/Applications/Blender.app/Contents/MacOS/Blender" raise EnvironmentError( "To render 3D models, install Blender version 3.3.1 or higher and " "set the environment variable `BLENDER_PATH` to the path of the Blender executable." ) ================================================ FILE: shap_e/rendering/blender/view_data.py ================================================ import itertools import json import zipfile from typing import BinaryIO, List, Tuple import numpy as np from PIL import Image from shap_e.rendering.view_data import Camera, ProjectiveCamera, ViewData class BlenderViewData(ViewData): """ Interact with a dataset zipfile exported by view_data.py. """ def __init__(self, f_obj: BinaryIO): self.zipfile = zipfile.ZipFile(f_obj, mode="r") self.infos = [] with self.zipfile.open("info.json", "r") as f: self.info = json.load(f) self.channels = list(self.info.get("channels", "RGBAD")) assert set("RGBA").issubset( set(self.channels) ), "The blender output should at least have RGBA images." names = set(x.filename for x in self.zipfile.infolist()) for i in itertools.count(): name = f"{i:05}.json" if name not in names: break with self.zipfile.open(name, "r") as f: self.infos.append(json.load(f)) @property def num_views(self) -> int: return len(self.infos) @property def channel_names(self) -> List[str]: return list(self.channels) def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]: for ch in channels: if ch not in self.channel_names: raise ValueError(f"unsupported channel: {ch}") # Gather (a superset of) the requested channels. channel_map = {} if any(x in channels for x in "RGBA"): with self.zipfile.open(f"{index:05}.png", "r") as f: rgba = np.array(Image.open(f)).astype(np.float32) / 255.0 channel_map.update(zip("RGBA", rgba.transpose([2, 0, 1]))) if "D" in channels: with self.zipfile.open(f"{index:05}_depth.png", "r") as f: # Decode a 16-bit fixed-point number. fp = np.array(Image.open(f)) inf_dist = fp == 0xFFFF channel_map["D"] = np.where( inf_dist, np.inf, self.infos[index]["max_depth"] * (fp.astype(np.float32) / 65536), ) if "MatAlpha" in channels: with self.zipfile.open(f"{index:05}_MatAlpha.png", "r") as f: channel_map["MatAlpha"] = np.array(Image.open(f)).astype(np.float32) / 65536 # The order of channels is user-specified. combined = np.stack([channel_map[k] for k in channels], axis=-1) h, w, _ = combined.shape return self.camera(index, w, h), combined def camera(self, index: int, width: int, height: int) -> ProjectiveCamera: info = self.infos[index] return ProjectiveCamera( origin=np.array(info["origin"], dtype=np.float32), x=np.array(info["x"], dtype=np.float32), y=np.array(info["y"], dtype=np.float32), z=np.array(info["z"], dtype=np.float32), width=width, height=height, x_fov=info["x_fov"], y_fov=info["y_fov"], ) ================================================ FILE: shap_e/rendering/mc.py ================================================ from dataclasses import dataclass from functools import lru_cache from typing import Tuple import torch from ._mc_table import MC_TABLE from .torch_mesh import TorchMesh def marching_cubes( field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor, ) -> TorchMesh: """ For a signed distance field, produce a mesh using marching cubes. :param field: a 3D tensor of field values, where negative values correspond to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively. :param min_point: a tensor of shape [3] containing the point corresponding to (0, 0, 0) in the field. :param size: a tensor of shape [3] containing the per-axis distance from the (0, 0, 0) field corner and the (-1, -1, -1) field corner. """ assert len(field.shape) == 3, "input must be a 3D scalar field" dev = field.device grid_size = field.shape grid_size_tensor = torch.tensor(grid_size).to(size) lut = _lookup_table(dev) # Create bitmasks between 0 and 255 (inclusive) indicating the state # of the eight corners of each cube. bitmasks = (field > 0).to(torch.uint8) bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1) bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2) bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4) # Compute corner coordinates across the entire grid. corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype) corner_coords[range(grid_size[0]), :, :, 0] = torch.arange( grid_size[0], device=dev, dtype=field.dtype )[:, None, None] corner_coords[:, range(grid_size[1]), :, 1] = torch.arange( grid_size[1], device=dev, dtype=field.dtype )[:, None] corner_coords[:, :, range(grid_size[2]), 2] = torch.arange( grid_size[2], device=dev, dtype=field.dtype ) # Compute all vertices across all edges in the grid, even though we will # throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices. # These are all midpoints, and don't account for interpolation (which is # done later based on the used edge midpoints). edge_midpoints = torch.cat( [ ((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3), ((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3), ((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3), ], dim=0, ) # Create a flat array of [X, Y, Z] indices for each cube. cube_indices = torch.zeros( grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long ) cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[ :, None, None ] cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[ :, None ] cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev) flat_cube_indices = cube_indices.reshape(-1, 3) # Create a flat array mapping each cube to 12 global edge indices. edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size) # Apply the LUT to figure out the triangles. flat_bitmasks = bitmasks.reshape( -1 ).long() # must cast to long for indexing to believe this not a mask local_tris = lut.cases[flat_bitmasks] local_masks = lut.masks[flat_bitmasks] # Compute the global edge indices for the triangles. global_tris = torch.gather( edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1) ).reshape(local_tris.shape) # Select the used triangles for each cube. selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)] # Now we have a bunch of indices into the full list of possible vertices, # but we want to reduce this list to only the used vertices. used_vertex_indices = torch.unique(selected_tris.view(-1)) used_edge_midpoints = edge_midpoints[used_vertex_indices] old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long) old_index_to_new_index[used_vertex_indices] = torch.arange( len(used_vertex_indices), device=dev, dtype=torch.long ) # Rewrite the triangles to use the new indices selected_tris = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape( selected_tris.shape ) # Compute the actual interpolated coordinates corresponding to edge midpoints. v1 = torch.floor(used_edge_midpoints).to(torch.long) v2 = torch.ceil(used_edge_midpoints).to(torch.long) s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]] s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]] p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point # The signs of s1 and s2 should be different. We want to find # t such that t*s2 + (1-t)*s1 = 0. t = (s1 / (s1 - s2))[:, None] verts = t * p2 + (1 - t) * p1 return TorchMesh(verts=verts, faces=selected_tris) def _create_flat_edge_indices( flat_cube_indices: torch.Tensor, grid_size: Tuple[int, int, int] ) -> torch.Tensor: num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2] y_offset = num_xs num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2] z_offset = num_xs + num_ys return torch.stack( [ # Edges spanning x-axis. flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2], flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + (flat_cube_indices[:, 1] + 1) * grid_size[2] + flat_cube_indices[:, 2], flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] + 1, flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + (flat_cube_indices[:, 1] + 1) * grid_size[2] + flat_cube_indices[:, 2] + 1, # Edges spanning y-axis. ( y_offset + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] ), ( y_offset + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] ), ( y_offset + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] + 1 ), ( y_offset + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] + 1 ), # Edges spanning z-axis. ( z_offset + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + flat_cube_indices[:, 1] * (grid_size[2] - 1) + flat_cube_indices[:, 2] ), ( z_offset + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + flat_cube_indices[:, 1] * (grid_size[2] - 1) + flat_cube_indices[:, 2] ), ( z_offset + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + flat_cube_indices[:, 2] ), ( z_offset + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + flat_cube_indices[:, 2] ), ], dim=-1, ) @dataclass class McLookupTable: # Coordinates in triangles are represented as edge indices from 0-12 # Here is an MC cell with both corner and edge indices marked. # 6 + ---------- 3 ----------+ 7 # /| /| # 6 | 7 | # / | / | # 4 +--------- 2 ------------+ 5 | # | 10 | | # | | | 11 # | | | | # 8 | 2 9 | 3 # | +--------- 1 --------|---+ # | / | / # | 4 | 5 # |/ |/ # +---------- 0 -----------+ # 0 1 cases: torch.Tensor # [256 x 5 x 3] long tensor masks: torch.Tensor # [256 x 5] bool tensor @lru_cache(maxsize=9) # if there's more than 8 GPUs and a CPU, don't bother caching def _lookup_table(device: torch.device) -> McLookupTable: cases = torch.zeros(256, 5, 3, device=device, dtype=torch.long) masks = torch.zeros(256, 5, device=device, dtype=torch.bool) edge_to_index = { (0, 1): 0, (2, 3): 1, (4, 5): 2, (6, 7): 3, (0, 2): 4, (1, 3): 5, (4, 6): 6, (5, 7): 7, (0, 4): 8, (1, 5): 9, (2, 6): 10, (3, 7): 11, } for i, case in enumerate(MC_TABLE): for j, tri in enumerate(case): for k, (c1, c2) in enumerate(zip(tri[::2], tri[1::2])): cases[i, j, k] = edge_to_index[(c1, c2) if c1 < c2 else (c2, c1)] masks[i, j] = True return McLookupTable(cases=cases, masks=masks) ================================================ FILE: shap_e/rendering/mesh.py ================================================ from dataclasses import dataclass, field from typing import BinaryIO, Dict, Optional, Union import blobfile as bf import numpy as np from .ply_util import write_ply @dataclass class TriMesh: """ A 3D triangle mesh with optional data at the vertices and faces. """ # [N x 3] array of vertex coordinates. verts: np.ndarray # [M x 3] array of triangles, pointing to indices in verts. faces: np.ndarray # [P x 3] array of normal vectors per face. normals: Optional[np.ndarray] = None # Extra data per vertex and face. vertex_channels: Optional[Dict[str, np.ndarray]] = field(default_factory=dict) face_channels: Optional[Dict[str, np.ndarray]] = field(default_factory=dict) @classmethod def load(cls, f: Union[str, BinaryIO]) -> "TriMesh": """ Load the mesh from a .npz file. """ if isinstance(f, str): with bf.BlobFile(f, "rb") as reader: return cls.load(reader) else: obj = np.load(f) keys = list(obj.keys()) verts = obj["verts"] faces = obj["faces"] normals = obj["normals"] if "normals" in keys else None vertex_channels = {} face_channels = {} for key in keys: if key.startswith("v_"): vertex_channels[key[2:]] = obj[key] elif key.startswith("f_"): face_channels[key[2:]] = obj[key] return cls( verts=verts, faces=faces, normals=normals, vertex_channels=vertex_channels, face_channels=face_channels, ) def save(self, f: Union[str, BinaryIO]): """ Save the mesh to a .npz file. """ if isinstance(f, str): with bf.BlobFile(f, "wb") as writer: self.save(writer) else: obj_dict = dict(verts=self.verts, faces=self.faces) if self.normals is not None: obj_dict["normals"] = self.normals for k, v in self.vertex_channels.items(): obj_dict[f"v_{k}"] = v for k, v in self.face_channels.items(): obj_dict[f"f_{k}"] = v np.savez(f, **obj_dict) def has_vertex_colors(self) -> bool: return self.vertex_channels is not None and all(x in self.vertex_channels for x in "RGB") def write_ply(self, raw_f: BinaryIO): write_ply( raw_f, coords=self.verts, rgb=( np.stack([self.vertex_channels[x] for x in "RGB"], axis=1) if self.has_vertex_colors() else None ), faces=self.faces, ) def write_obj(self, raw_f: BinaryIO): if self.has_vertex_colors(): vertex_colors = np.stack([self.vertex_channels[x] for x in "RGB"], axis=1) vertices = [ "{} {} {} {} {} {}".format(*coord, *color) for coord, color in zip(self.verts.tolist(), vertex_colors.tolist()) ] else: vertices = ["{} {} {}".format(*coord) for coord in self.verts.tolist()] faces = [ "f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1)) for tri in self.faces.tolist() ] combined_data = ["v " + vertex for vertex in vertices] + faces raw_f.writelines("\n".join(combined_data)) ================================================ FILE: shap_e/rendering/ply_util.py ================================================ import struct from typing import BinaryIO, Optional import numpy as np from shap_e.util.io import buffered_writer def write_ply( raw_f: BinaryIO, coords: np.ndarray, rgb: Optional[np.ndarray] = None, faces: Optional[np.ndarray] = None, ): """ Write a PLY file for a mesh or a point cloud. :param coords: an [N x 3] array of floating point coordinates. :param rgb: an [N x 3] array of vertex colors, in the range [0.0, 1.0]. :param faces: an [N x 3] array of triangles encoded as integer indices. """ with buffered_writer(raw_f) as f: f.write(b"ply\n") f.write(b"format binary_little_endian 1.0\n") f.write(bytes(f"element vertex {len(coords)}\n", "ascii")) f.write(b"property float x\n") f.write(b"property float y\n") f.write(b"property float z\n") if rgb is not None: f.write(b"property uchar red\n") f.write(b"property uchar green\n") f.write(b"property uchar blue\n") if faces is not None: f.write(bytes(f"element face {len(faces)}\n", "ascii")) f.write(b"property list uchar int vertex_index\n") f.write(b"end_header\n") if rgb is not None: rgb = (rgb * 255.499).round().astype(int) vertices = [ (*coord, *rgb) for coord, rgb in zip( coords.tolist(), rgb.tolist(), ) ] format = struct.Struct("<3f3B") for item in vertices: f.write(format.pack(*item)) else: format = struct.Struct("<3f") for vertex in coords.tolist(): f.write(format.pack(*vertex)) if faces is not None: format = struct.Struct(" "PointCloud": """ Construct a point cloud from the given view data. The data must have a depth channel. All other channels will be stored in the `channels` attribute of the result. Pixels in the rendered views are not converted into points in the cloud if they have infinite depth or less than 1.0 alpha. """ channel_names = vd.channel_names if "D" not in channel_names: raise ValueError(f"view data must have depth channel") depth_index = channel_names.index("D") all_coords = [] all_channels = defaultdict(list) if num_views is None: num_views = vd.num_views for i in range(num_views): camera, channel_values = vd.load_view(i, channel_names) flat_values = channel_values.reshape([-1, len(channel_names)]) # Create an array of integer (x, y) image coordinates for Camera methods. image_coords = camera.image_coords() # Select subset of pixels that have meaningful depth/color. image_mask = np.isfinite(flat_values[:, depth_index]) if "A" in channel_names: image_mask = image_mask & (flat_values[:, channel_names.index("A")] >= 1 - 1e-5) image_coords = image_coords[image_mask] flat_values = flat_values[image_mask] # Use the depth and camera information to compute the coordinates # corresponding to every visible pixel. camera_rays = camera.camera_rays(image_coords) camera_origins = camera_rays[:, 0] camera_directions = camera_rays[:, 1] depth_dirs = camera.depth_directions(image_coords) ray_scales = flat_values[:, depth_index] / np.sum( camera_directions * depth_dirs, axis=-1 ) coords = camera_origins + camera_directions * ray_scales[:, None] all_coords.append(coords) for j, name in enumerate(channel_names): if name != "D": all_channels[name].append(flat_values[:, j]) if len(all_coords) == 0: return cls(coords=np.zeros([0, 3], dtype=np.float32), channels={}) return cls( coords=np.concatenate(all_coords, axis=0), channels={k: np.concatenate(v, axis=0) for k, v in all_channels.items()}, ) @classmethod def load(cls, f: Union[str, BinaryIO]) -> "PointCloud": """ Load the point cloud from a .npz file. """ if isinstance(f, str): with bf.BlobFile(f, "rb") as reader: return cls.load(reader) else: obj = np.load(f) keys = list(obj.keys()) return PointCloud( coords=obj["coords"], channels={k: obj[k] for k in keys if k != "coords"}, ) def save(self, f: Union[str, BinaryIO]): """ Save the point cloud to a .npz file. """ if isinstance(f, str): with bf.BlobFile(f, "wb") as writer: self.save(writer) else: np.savez(f, coords=self.coords, **self.channels) def write_ply(self, raw_f: BinaryIO): write_ply( raw_f, coords=self.coords, rgb=( np.stack([self.channels[x] for x in "RGB"], axis=1) if all(x in self.channels for x in "RGB") else None ), ) def random_sample(self, num_points: int, **subsample_kwargs) -> "PointCloud": """ Sample a random subset of this PointCloud. :param num_points: maximum number of points to sample. :param subsample_kwargs: arguments to self.subsample(). :return: a reduced PointCloud, or self if num_points is not less than the current number of points. """ if len(self.coords) <= num_points: return self indices = np.random.choice(len(self.coords), size=(num_points,), replace=False) return self.subsample(indices, **subsample_kwargs) def farthest_point_sample( self, num_points: int, init_idx: Optional[int] = None, **subsample_kwargs ) -> "PointCloud": """ Sample a subset of the point cloud that is evenly distributed in space. First, a random point is selected. Then each successive point is chosen such that it is furthest from the currently selected points. The time complexity of this operation is O(NM), where N is the original number of points and M is the reduced number. Therefore, performance can be improved by randomly subsampling points with random_sample() before running farthest_point_sample(). :param num_points: maximum number of points to sample. :param init_idx: if specified, the first point to sample. :param subsample_kwargs: arguments to self.subsample(). :return: a reduced PointCloud, or self if num_points is not less than the current number of points. """ if len(self.coords) <= num_points: return self init_idx = random.randrange(len(self.coords)) if init_idx is None else init_idx indices = np.zeros([num_points], dtype=np.int64) indices[0] = init_idx sq_norms = np.sum(self.coords**2, axis=-1) def compute_dists(idx: int): # Utilize equality: ||A-B||^2 = ||A||^2 + ||B||^2 - 2*(A @ B). return sq_norms + sq_norms[idx] - 2 * (self.coords @ self.coords[idx]) cur_dists = compute_dists(init_idx) for i in range(1, num_points): idx = np.argmax(cur_dists) indices[i] = idx # Without this line, we may duplicate an index more than once if # there are duplicate points, due to rounding errors. cur_dists[idx] = -1 cur_dists = np.minimum(cur_dists, compute_dists(idx)) return self.subsample(indices, **subsample_kwargs) def subsample(self, indices: np.ndarray, average_neighbors: bool = False) -> "PointCloud": if not average_neighbors: return PointCloud( coords=self.coords[indices], channels={k: v[indices] for k, v in self.channels.items()}, ) new_coords = self.coords[indices] neighbor_indices = PointCloud(coords=new_coords, channels={}).nearest_points(self.coords) # Make sure every point points to itself, which might not # be the case if points are duplicated or there is rounding # error. neighbor_indices[indices] = np.arange(len(indices)) new_channels = {} for k, v in self.channels.items(): v_sum = np.zeros_like(v[: len(indices)]) v_count = np.zeros_like(v[: len(indices)]) np.add.at(v_sum, neighbor_indices, v) np.add.at(v_count, neighbor_indices, 1) new_channels[k] = v_sum / v_count return PointCloud(coords=new_coords, channels=new_channels) def select_channels(self, channel_names: List[str]) -> np.ndarray: data = np.stack([preprocess(self.channels[name], name) for name in channel_names], axis=-1) return data def nearest_points(self, points: np.ndarray, batch_size: int = 16384) -> np.ndarray: """ For each point in another set of points, compute the point in this pointcloud which is closest. :param points: an [N x 3] array of points. :param batch_size: the number of neighbor distances to compute at once. Smaller values save memory, while larger values may make the computation faster. :return: an [N] array of indices into self.coords. """ norms = np.sum(self.coords**2, axis=-1) all_indices = [] for i in range(0, len(points), batch_size): batch = points[i : i + batch_size] dists = norms + np.sum(batch**2, axis=-1)[:, None] - 2 * (batch @ self.coords.T) all_indices.append(np.argmin(dists, axis=-1)) return np.concatenate(all_indices, axis=0) def combine(self, other: "PointCloud") -> "PointCloud": assert self.channels.keys() == other.channels.keys() return PointCloud( coords=np.concatenate([self.coords, other.coords], axis=0), channels={ k: np.concatenate([v, other.channels[k]], axis=0) for k, v in self.channels.items() }, ) ================================================ FILE: shap_e/rendering/pytorch3d_util.py ================================================ import copy import inspect from typing import Any, Callable, List, Sequence, Tuple, Union import numpy as np import torch from pytorch3d.renderer import ( BlendParams, DirectionalLights, FoVPerspectiveCameras, MeshRasterizer, MeshRenderer, RasterizationSettings, SoftPhongShader, TexturesVertex, ) from pytorch3d.renderer.utils import TensorProperties from pytorch3d.structures import Meshes from shap_e.models.nn.checkpoint import checkpoint from .blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR, UNIFORM_LIGHT_DIRECTION from .torch_mesh import TorchMesh from .view_data import ProjectiveCamera # Using a lower value like 1e-4 seems to result in weird issues # for our high-poly meshes. DEFAULT_RENDER_SIGMA = 1e-5 DEFAULT_RENDER_GAMMA = 1e-4 def render_images( image_size: int, meshes: Meshes, cameras: Any, lights: Any, sigma: float = DEFAULT_RENDER_SIGMA, gamma: float = DEFAULT_RENDER_GAMMA, max_faces_per_bin=100000, faces_per_pixel=50, bin_size=None, use_checkpoint: bool = False, ) -> torch.Tensor: if use_checkpoint: # Decompose all of our arguments into a bunch of tensor lists # so that autograd can keep track of what the op depends on. verts_list = meshes.verts_list() faces_list = meshes.faces_list() assert isinstance(meshes.textures, TexturesVertex) assert isinstance(lights, BidirectionalLights) textures = meshes.textures.verts_features_padded() light_vecs, light_fn = _deconstruct_tensor_props(lights) camera_vecs, camera_fn = _deconstruct_tensor_props(cameras) def ckpt_fn( *args: torch.Tensor, num_verts=len(verts_list), num_light_vecs=len(light_vecs), num_camera_vecs=len(camera_vecs), light_fn=light_fn, camera_fn=camera_fn, faces_list=faces_list ): args = list(args) verts_list = args[:num_verts] del args[:num_verts] light_vecs = args[:num_light_vecs] del args[:num_light_vecs] camera_vecs = args[:num_camera_vecs] del args[:num_camera_vecs] textures = args.pop(0) meshes = Meshes(verts=verts_list, faces=faces_list, textures=TexturesVertex(textures)) lights = light_fn(light_vecs) cameras = camera_fn(camera_vecs) return render_images( image_size=image_size, meshes=meshes, cameras=cameras, lights=lights, sigma=sigma, gamma=gamma, max_faces_per_bin=max_faces_per_bin, faces_per_pixel=faces_per_pixel, bin_size=bin_size, use_checkpoint=False, ) result = checkpoint(ckpt_fn, (*verts_list, *light_vecs, *camera_vecs, textures), (), True) else: raster_settings_soft = RasterizationSettings( image_size=image_size, blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma, faces_per_pixel=faces_per_pixel, max_faces_per_bin=max_faces_per_bin, bin_size=bin_size, perspective_correct=False, ) renderer = MeshRenderer( rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings_soft), shader=SoftPhongShader( device=meshes.device, cameras=cameras, lights=lights, blend_params=BlendParams(sigma=sigma, gamma=gamma, background_color=(0, 0, 0)), ), ) result = renderer(meshes) return result def _deconstruct_tensor_props( props: TensorProperties, ) -> Tuple[List[torch.Tensor], Callable[[List[torch.Tensor]], TensorProperties]]: vecs = [] names = [] other_props = {} for k in dir(props): if k.startswith("__"): continue v = getattr(props, k) if inspect.ismethod(v): continue if torch.is_tensor(v): vecs.append(v) names.append(k) else: other_props[k] = v def recreate_fn(vecs_arg): other = type(props)(device=props.device) for k, v in other_props.items(): setattr(other, k, copy.deepcopy(v)) for name, vec in zip(names, vecs_arg): setattr(other, name, vec) return other return vecs, recreate_fn def convert_meshes(raw_meshes: Sequence[TorchMesh], default_brightness=0.8) -> Meshes: meshes = Meshes( verts=[mesh.verts for mesh in raw_meshes], faces=[mesh.faces for mesh in raw_meshes] ) rgbs = [] for mesh in raw_meshes: if mesh.vertex_channels and all(k in mesh.vertex_channels for k in "RGB"): rgbs.append(torch.stack([mesh.vertex_channels[k] for k in "RGB"], axis=-1)) else: rgbs.append( torch.ones( len(mesh.verts) * default_brightness, 3, device=mesh.verts.device, dtype=mesh.verts.dtype, ) ) meshes.textures = TexturesVertex(verts_features=rgbs) return meshes def convert_cameras( cameras: Sequence[ProjectiveCamera], device: torch.device ) -> FoVPerspectiveCameras: Rs = [] Ts = [] for camera in cameras: assert ( camera.width == camera.height and camera.x_fov == camera.y_fov ), "viewports must be square" assert camera.x_fov == cameras[0].x_fov, "all cameras must have same field-of-view" R = np.stack([-camera.x, -camera.y, camera.z], axis=0).T T = -R.T @ camera.origin Rs.append(R) Ts.append(T) return FoVPerspectiveCameras( R=np.stack(Rs, axis=0), T=np.stack(Ts, axis=0), fov=cameras[0].x_fov, degrees=False, device=device, ) def convert_cameras_torch( origins: torch.Tensor, xs: torch.Tensor, ys: torch.Tensor, zs: torch.Tensor, fov: float ) -> FoVPerspectiveCameras: Rs = [] Ts = [] for origin, x, y, z in zip(origins, xs, ys, zs): R = torch.stack([-x, -y, z], axis=0).T T = -R.T @ origin Rs.append(R) Ts.append(T) return FoVPerspectiveCameras( R=torch.stack(Rs, dim=0), T=torch.stack(Ts, dim=0), fov=fov, degrees=False, device=origins.device, ) def blender_uniform_lights( batch_size: int, device: torch.device, ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR, diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR, specular_color: Union[float, Tuple[float]] = 0.0, ) -> "BidirectionalLights": """ Create a light that attempts to match the light used by the Blender renderer when run with `--light_mode basic`. """ if isinstance(ambient_color, float): ambient_color = (ambient_color,) * 3 if isinstance(diffuse_color, float): diffuse_color = (diffuse_color,) * 3 if isinstance(specular_color, float): specular_color = (specular_color,) * 3 return BidirectionalLights( ambient_color=(ambient_color,) * batch_size, diffuse_color=(diffuse_color,) * batch_size, specular_color=(specular_color,) * batch_size, direction=(UNIFORM_LIGHT_DIRECTION,) * batch_size, device=device, ) class BidirectionalLights(DirectionalLights): """ Adapted from here, but effectively shines the light in both positive and negative directions: https://github.com/facebookresearch/pytorch3d/blob/efea540bbcab56fccde6f4bc729d640a403dac56/pytorch3d/renderer/lighting.py#L159 """ def diffuse(self, normals, points=None) -> torch.Tensor: return torch.maximum( super().diffuse(normals, points=points), super().diffuse(-normals, points=points) ) def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: return torch.maximum( super().specular(normals, points, camera_position, shininess), super().specular(-normals, points, camera_position, shininess), ) ================================================ FILE: shap_e/rendering/raycast/__init__.py ================================================ ================================================ FILE: shap_e/rendering/raycast/_utils.py ================================================ import torch def normalize(v: torch.Tensor) -> torch.Tensor: return v / torch.linalg.norm(v, dim=-1, keepdim=True) def cross_product(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor: return torch.stack( [ v1[..., 1] * v2[..., 2] - v2[..., 1] * v1[..., 2], -(v1[..., 0] * v2[..., 2] - v2[..., 0] * v1[..., 2]), v1[..., 0] * v2[..., 1] - v2[..., 0] * v1[..., 1], ], dim=-1, ) ================================================ FILE: shap_e/rendering/raycast/cast.py ================================================ from typing import Iterator, Optional, Tuple import numpy as np import torch from shap_e.rendering.view_data import ProjectiveCamera from ._utils import cross_product from .types import RayCollisions, Rays, TriMesh def cast_camera( camera: ProjectiveCamera, mesh: TriMesh, ray_batch_size: Optional[int] = None, checkpoint: Optional[bool] = None, ) -> Iterator[RayCollisions]: pixel_indices = np.arange(camera.width * camera.height) image_coords = np.stack([pixel_indices % camera.width, pixel_indices // camera.width], axis=1) rays = camera.camera_rays(image_coords) batch_size = ray_batch_size or len(rays) checkpoint = checkpoint if checkpoint is not None else batch_size < len(rays) for i in range(0, len(rays), batch_size): sub_rays = rays[i : i + batch_size] origins = torch.from_numpy(sub_rays[:, 0]).to(mesh.vertices) directions = torch.from_numpy(sub_rays[:, 1]).to(mesh.vertices) yield cast_rays(Rays(origins=origins, directions=directions), mesh, checkpoint=checkpoint) def cast_rays(rays: Rays, mesh: TriMesh, checkpoint: bool = False) -> RayCollisions: """ Cast a batch of rays onto a mesh. """ if checkpoint: collides, ray_dists, tri_indices, barycentric, normals = RayCollisionFunction.apply( rays.origins, rays.directions, mesh.faces, mesh.vertices ) return RayCollisions( collides=collides, ray_dists=ray_dists, tri_indices=tri_indices, barycentric=barycentric, normals=normals, ) # https://github.com/unixpickle/vae-textures/blob/2968549ddd4a3487f9437d4db00793324453cd59/vae_textures/render.py#L98 normals = mesh.normals() # [N x 3] directions = rays.directions # [M x 3] collides = (directions @ normals.T).abs() > 1e-8 # [N x M] tris = mesh.vertices[mesh.faces] # [N x 3 x 3] v1 = tris[:, 1] - tris[:, 0] v2 = tris[:, 2] - tris[:, 0] cross1 = cross_product(directions[:, None], v2[None]) # [N x M x 3] det = torch.sum(cross1 * v1[None], dim=-1) # [N x M] collides = torch.logical_and(collides, det.abs() > 1e-8) invDet = 1 / det # [N x M] o = rays.origins[:, None] - tris[None, :, 0] # [N x M x 3] bary1 = invDet * torch.sum(o * cross1, dim=-1) # [N x M] collides = torch.logical_and(collides, torch.logical_and(bary1 >= 0, bary1 <= 1)) cross2 = cross_product(o, v1[None]) # [N x M x 3] bary2 = invDet * torch.sum(directions[:, None] * cross2, dim=-1) # [N x M] collides = torch.logical_and(collides, torch.logical_and(bary2 >= 0, bary2 <= 1)) bary0 = 1 - (bary1 + bary2) # Make sure this is in the positive part of the ray. scale = invDet * torch.sum(v2 * cross2, dim=-1) collides = torch.logical_and(collides, scale > 0) # Select the nearest collision ray_dists, tri_indices = torch.min( torch.where(collides, scale, torch.tensor(torch.inf).to(scale)), dim=-1 ) # [N] nearest_bary = torch.stack( [ bary0[range(len(tri_indices)), tri_indices], bary1[range(len(tri_indices)), tri_indices], bary2[range(len(tri_indices)), tri_indices], ], dim=-1, ) return RayCollisions( collides=torch.any(collides, dim=-1), ray_dists=ray_dists, tri_indices=tri_indices, barycentric=nearest_bary, normals=normals[tri_indices], ) class RayCollisionFunction(torch.autograd.Function): @staticmethod def forward( ctx, origins, directions, faces, vertices ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ctx.save_for_backward(origins, directions, faces, vertices) with torch.no_grad(): res = cast_rays( Rays(origins=origins, directions=directions), TriMesh(faces=faces, vertices=vertices), checkpoint=False, ) return (res.collides, res.ray_dists, res.tri_indices, res.barycentric, res.normals) @staticmethod def backward( ctx, _collides_grad, ray_dists_grad, _tri_indices_grad, barycentric_grad, normals_grad ): origins, directions, faces, vertices = ctx.input_tensors origins = origins.detach().requires_grad_(True) directions = directions.detach().requires_grad_(True) vertices = vertices.detach().requires_grad_(True) with torch.enable_grad(): outputs = cast_rays( Rays(origins=origins, directions=directions), TriMesh(faces=faces, vertices=vertices), checkpoint=False, ) origins_grad, directions_grad, vertices_grad = torch.autograd.grad( (outputs.ray_dists, outputs.barycentric, outputs.normals), (origins, directions, vertices), (ray_dists_grad, barycentric_grad, normals_grad), ) return (origins_grad, directions_grad, None, vertices_grad) ================================================ FILE: shap_e/rendering/raycast/render.py ================================================ from typing import Optional, Sequence import torch from shap_e.rendering.blender.constants import ( BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR, UNIFORM_LIGHT_DIRECTION, ) from shap_e.rendering.view_data import ProjectiveCamera from .cast import cast_camera from .types import RayCollisions, TriMesh def render_diffuse_mesh( camera: ProjectiveCamera, mesh: TriMesh, light_direction: Sequence[float] = tuple(UNIFORM_LIGHT_DIRECTION), diffuse: float = BASIC_DIFFUSE_COLOR, ambient: float = BASIC_AMBIENT_COLOR, ray_batch_size: Optional[int] = None, checkpoint: Optional[bool] = None, ) -> torch.Tensor: """ Return an [H x W x 4] RGBA tensor of the rendered image. The pixels are floating points, with alpha in the range [0, 1] and the other colors matching the scale used by the mesh's vertex colors. """ light_direction = torch.tensor( light_direction, device=mesh.vertices.device, dtype=mesh.vertices.dtype ) all_collisions = RayCollisions.collect( cast_camera( camera=camera, mesh=mesh, ray_batch_size=ray_batch_size, checkpoint=checkpoint, ) ) num_rays = len(all_collisions.normals) if mesh.vertex_colors is None: vertex_colors = torch.tensor([[0.8, 0.8, 0.8]]).to(mesh.vertices).repeat(num_rays, 1) else: vertex_colors = mesh.vertex_colors light_coeffs = ambient + ( diffuse * torch.sum(all_collisions.normals * light_direction, dim=-1).abs() ) vertex_colors = mesh.vertex_colors[mesh.faces[all_collisions.tri_indices]] bary_products = torch.sum(vertex_colors * all_collisions.barycentric[..., None], axis=-2) out_colors = bary_products * light_coeffs[..., None] res = torch.where(all_collisions.collides[:, None], out_colors, torch.zeros_like(out_colors)) return torch.cat([res, all_collisions.collides[:, None].float()], dim=-1).view( camera.height, camera.width, 4 ) ================================================ FILE: shap_e/rendering/raycast/types.py ================================================ from dataclasses import dataclass from typing import Iterable, Optional import numpy as np import torch import shap_e.rendering.mesh from ._utils import cross_product, normalize @dataclass class Rays: """ A ray in ray casting. """ origins: torch.Tensor # [N x 3] float tensor directions: torch.Tensor # [N x 3] float tensor def normalized_directions(self) -> torch.Tensor: return normalize(self.directions) @dataclass class RayCollisions: """ The result of casting N rays onto a mesh. """ collides: torch.Tensor # [N] boolean tensor ray_dists: torch.Tensor # [N] float tensor tri_indices: torch.Tensor # [N] long tensor barycentric: torch.Tensor # [N x 3] float tensor normals: torch.Tensor # [N x 3] float tensor @classmethod def collect(cls, it: Iterable["RayCollisions"]) -> "RayCollisions": res = None for x in it: if res is None: res = x else: res = cls( collides=torch.cat([res.collides, x.collides]), ray_dists=torch.cat([res.ray_dists, x.ray_dists]), tri_indices=torch.cat([res.tri_indices, x.tri_indices]), barycentric=torch.cat([res.barycentric, x.barycentric]), normals=torch.cat([res.normals, x.normals]), ) if res is None: raise ValueError("cannot collect an empty iterable of RayCollisions") return res @dataclass class TriMesh: faces: torch.Tensor # [N x 3] long tensor vertices: torch.Tensor # [N x 3] float tensor vertex_colors: Optional[torch.Tensor] = None def normals(self) -> torch.Tensor: """ Returns an [N x 3] batch of normal vectors per triangle assuming the right-hand rule. """ tris = self.vertices[self.faces] v1 = tris[:, 1] - tris[:, 0] v2 = tris[:, 2] - tris[:, 0] return normalize(cross_product(v1, v2)) @classmethod def from_numpy(cls, x: shap_e.rendering.mesh.TriMesh) -> "TriMesh": vertex_colors = None if all(ch in x.vertex_channels for ch in "RGB"): vertex_colors = torch.from_numpy( np.stack([x.vertex_channels[ch] for ch in "RGB"], axis=-1) ) return cls( faces=torch.from_numpy(x.faces), vertices=torch.from_numpy(x.verts), vertex_colors=vertex_colors, ) def to(self, *args, **kwargs) -> "TriMesh": return TriMesh( faces=self.faces.to(*args, **kwargs), vertices=self.vertices.to(*args, **kwargs), vertex_colors=None if self.vertex_colors is None else self.vertex_colors.to(*args, **kwargs), ) ================================================ FILE: shap_e/rendering/torch_mesh.py ================================================ from dataclasses import dataclass, field from typing import Dict, Optional import torch from .mesh import TriMesh @dataclass class TorchMesh: """ A 3D triangle mesh with optional data at the vertices and faces. """ # [N x 3] array of vertex coordinates. verts: torch.Tensor # [M x 3] array of triangles, pointing to indices in verts. faces: torch.Tensor # Extra data per vertex and face. vertex_channels: Optional[Dict[str, torch.Tensor]] = field(default_factory=dict) face_channels: Optional[Dict[str, torch.Tensor]] = field(default_factory=dict) def tri_mesh(self) -> TriMesh: """ Create a CPU version of the mesh. """ return TriMesh( verts=self.verts.detach().cpu().numpy(), faces=self.faces.cpu().numpy(), vertex_channels=( {k: v.detach().cpu().numpy() for k, v in self.vertex_channels.items()} if self.vertex_channels is not None else None ), face_channels=( {k: v.detach().cpu().numpy() for k, v in self.face_channels.items()} if self.face_channels is not None else None ), ) ================================================ FILE: shap_e/rendering/view_data.py ================================================ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Dict, List, Tuple import numpy as np @dataclass class Camera(ABC): """ An object describing how a camera corresponds to pixels in an image. """ @abstractmethod def image_coords(self) -> np.ndarray: """ :return: ([self.height, self.width, 2]).reshape(self.height * self.width, 2) image coordinates """ @abstractmethod def camera_rays(self, coords: np.ndarray) -> np.ndarray: """ For every (x, y) coordinate in a rendered image, compute the ray of the corresponding pixel. :param coords: an [N x 2] integer array of 2D image coordinates. :return: an [N x 2 x 3] array of [2 x 3] (origin, direction) tuples. The direction should always be unit length. """ def depth_directions(self, coords: np.ndarray) -> np.ndarray: """ For every (x, y) coordinate in a rendered image, get the direction that corresponds to "depth" in an RGBD rendering. This may raise an exception if there is no "D" channel in the corresponding ViewData. :param coords: an [N x 2] integer array of 2D image coordinates. :return: an [N x 3] array of normalized depth directions. """ _ = coords raise NotImplementedError @abstractmethod def center_crop(self) -> "Camera": """ Creates a new camera with the same intrinsics and direction as this one, but with a center crop to a square of the smaller dimension. """ @abstractmethod def resize_image(self, width: int, height: int) -> "Camera": """ Creates a new camera with the same intrinsics and direction as this one, but with resized image dimensions. """ @abstractmethod def scale_scene(self, factor: float) -> "Camera": """ Creates a new camera with the same intrinsics and direction as this one, but with the scene rescaled by the given factor. """ @dataclass class ProjectiveCamera(Camera): """ A Camera implementation for a standard pinhole camera. The camera rays shoot away from the origin in the z direction, with the x and y directions corresponding to the positive horizontal and vertical axes in image space. """ origin: np.ndarray x: np.ndarray y: np.ndarray z: np.ndarray width: int height: int x_fov: float y_fov: float def image_coords(self) -> np.ndarray: ind = np.arange(self.width * self.height) coords = np.stack([ind % self.width, ind // self.width], axis=1).astype(np.float32) return coords def camera_rays(self, coords: np.ndarray) -> np.ndarray: fracs = (coords / (np.array([self.width, self.height], dtype=np.float32) - 1)) * 2 - 1 fracs = fracs * np.tan(np.array([self.x_fov, self.y_fov]) / 2) directions = self.z + self.x * fracs[:, :1] + self.y * fracs[:, 1:] directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True) return np.stack([np.broadcast_to(self.origin, directions.shape), directions], axis=1) def depth_directions(self, coords: np.ndarray) -> np.ndarray: return np.tile((self.z / np.linalg.norm(self.z))[None], [len(coords), 1]) def resize_image(self, width: int, height: int) -> "ProjectiveCamera": """ Creates a new camera for the resized view assuming the aspect ratio does not change. """ assert width * self.height == height * self.width, "The aspect ratio should not change." return ProjectiveCamera( origin=self.origin, x=self.x, y=self.y, z=self.z, width=width, height=height, x_fov=self.x_fov, y_fov=self.y_fov, ) def center_crop(self) -> "ProjectiveCamera": """ Creates a new camera for the center-cropped view """ size = min(self.width, self.height) fov = min(self.x_fov, self.y_fov) return ProjectiveCamera( origin=self.origin, x=self.x, y=self.y, z=self.z, width=size, height=size, x_fov=fov, y_fov=fov, ) def scale_scene(self, factor: float) -> "ProjectiveCamera": """ Creates a new camera with the same intrinsics and direction as this one, but with the camera frame rescaled by the given factor. """ return ProjectiveCamera( origin=self.origin * factor, x=self.x, y=self.y, z=self.z, width=self.width, height=self.height, x_fov=self.x_fov, y_fov=self.y_fov, ) class ViewData(ABC): """ A collection of rendered camera views of a scene or object. This is a generalization of a NeRF dataset, since NeRF datasets only encode RGB or RGBA data, whereas this dataset supports arbitrary channels. """ @property @abstractmethod def num_views(self) -> int: """ The number of rendered views. """ @property @abstractmethod def channel_names(self) -> List[str]: """ Get all of the supported channels available for the views. This can be arbitrary, but there are some standard names: "R", "G", "B", "A" (alpha), and "D" (depth). """ @abstractmethod def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]: """ Load the given channels from the view at the given index. :return: a tuple (camera_view, data), where data is a float array of shape [height x width x num_channels]. """ class MemoryViewData(ViewData): """ A ViewData that is implemented in memory. """ def __init__(self, channels: Dict[str, np.ndarray], cameras: List[Camera]): assert all(v.shape[0] == len(cameras) for v in channels.values()) self.channels = channels self.cameras = cameras @property def num_views(self) -> int: return len(self.cameras) @property def channel_names(self) -> List[str]: return list(self.channels.keys()) def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]: outputs = [self.channels[channel][index] for channel in channels] return self.cameras[index], np.stack(outputs, axis=-1) ================================================ FILE: shap_e/util/__init__.py ================================================ ================================================ FILE: shap_e/util/collections.py ================================================ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional from typing import OrderedDict, Generic, TypeVar K = TypeVar('K') V = TypeVar('V') class AttrDict(OrderedDict[K, V], Generic[K, V]): """ An attribute dictionary that automatically handles nested keys joined by "/". Originally copied from: https://stackoverflow.com/questions/3031219/recursively-access-dict-via-attributes-as-well-as-index-access """ MARKER = object() # pylint: disable=super-init-not-called def __init__(self, *args, **kwargs): if len(args) == 0: for key, value in kwargs.items(): self.__setitem__(key, value) else: assert len(args) == 1 assert isinstance(args[0], (dict, AttrDict)) for key, value in args[0].items(): self.__setitem__(key, value) def __contains__(self, key): if "/" in key: keys = key.split("/") key, next_key = keys[0], "/".join(keys[1:]) return key in self and next_key in self[key] return super(AttrDict, self).__contains__(key) def __setitem__(self, key, value): if "/" in key: keys = key.split("/") key, next_key = keys[0], "/".join(keys[1:]) if key not in self: self[key] = AttrDict() self[key].__setitem__(next_key, value) return if isinstance(value, dict) and not isinstance(value, AttrDict): value = AttrDict(**value) if isinstance(value, list): value = [AttrDict(val) if isinstance(val, dict) else val for val in value] super(AttrDict, self).__setitem__(key, value) def __getitem__(self, key): if "/" in key: keys = key.split("/") key, next_key = keys[0], "/".join(keys[1:]) val = self[key] if not isinstance(val, AttrDict): raise ValueError return val.__getitem__(next_key) return self.get(key, None) def all_keys( self, leaves_only: bool = False, parent: Optional[str] = None, ) -> List[str]: keys = [] for key in self.keys(): cur = key if parent is None else f"{parent}/{key}" if not leaves_only or not isinstance(self[key], dict): keys.append(cur) if isinstance(self[key], dict): keys.extend(self[key].all_keys(leaves_only=leaves_only, parent=cur)) return keys def dumpable(self, strip=True): """ Casts into OrderedDict and removes internal attributes """ def _dump(val): if isinstance(val, AttrDict): return val.dumpable() elif isinstance(val, list): return [_dump(v) for v in val] return val if strip: return {k: _dump(v) for k, v in self.items() if not k.startswith("_")} return {k: _dump(v if not k.startswith("_") else repr(v)) for k, v in self.items()} def map( self, map_fn: Callable[[Any, Any], Any], should_map: Optional[Callable[[Any, Any], bool]] = None, ) -> "AttrDict": """ Creates a copy of self where some or all values are transformed by map_fn. :param should_map: If provided, only those values that evaluate to true are converted; otherwise, all values are mapped. """ def _apply(key, val): if isinstance(val, AttrDict): return val.map(map_fn, should_map) elif should_map is None or should_map(key, val): return map_fn(key, val) return val return AttrDict({k: _apply(k, v) for k, v in self.items()}) def __eq__(self, other): return self.keys() == other.keys() and all(self[k] == other[k] for k in self.keys()) def combine( self, other: Dict[str, Any], combine_fn: Callable[[Optional[Any], Optional[Any]], Any], ) -> "AttrDict": """ Some values may be missing, but the dictionary structures must be the same. :param combine_fn: a (possibly non-commutative) function to combine the values """ def _apply(val, other_val): if val is not None and isinstance(val, AttrDict): assert isinstance(other_val, AttrDict) return val.combine(other_val, combine_fn) return combine_fn(val, other_val) # TODO nit: this changes the ordering.. keys = self.keys() | other.keys() return AttrDict({k: _apply(self[k], other[k]) for k in keys}) __setattr__, __getattr__ = __setitem__, __getitem__ ================================================ FILE: shap_e/util/data_util.py ================================================ import tempfile from contextlib import contextmanager from typing import Iterator, Optional, Union import blobfile as bf import numpy as np import torch from PIL import Image from shap_e.rendering.blender.render import render_mesh, render_model from shap_e.rendering.blender.view_data import BlenderViewData from shap_e.rendering.mesh import TriMesh from shap_e.rendering.point_cloud import PointCloud from shap_e.rendering.view_data import ViewData from shap_e.util.collections import AttrDict from shap_e.util.image_util import center_crop, get_alpha, remove_alpha, resize def load_or_create_multimodal_batch( device: torch.device, *, mesh_path: Optional[str] = None, model_path: Optional[str] = None, cache_dir: Optional[str] = None, point_count: int = 2**14, random_sample_count: int = 2**19, pc_num_views: int = 40, mv_light_mode: Optional[str] = None, mv_num_views: int = 20, mv_image_size: int = 512, mv_alpha_removal: str = "black", verbose: bool = False, ) -> AttrDict: if verbose: print("creating point cloud...") pc = load_or_create_pc( mesh_path=mesh_path, model_path=model_path, cache_dir=cache_dir, random_sample_count=random_sample_count, point_count=point_count, num_views=pc_num_views, verbose=verbose, ) raw_pc = np.concatenate([pc.coords, pc.select_channels(["R", "G", "B"])], axis=-1) encode_me = torch.from_numpy(raw_pc).float().to(device) batch = AttrDict(points=encode_me.t()[None]) if mv_light_mode: if verbose: print("creating multiview...") with load_or_create_multiview( mesh_path=mesh_path, model_path=model_path, cache_dir=cache_dir, num_views=mv_num_views, extract_material=False, light_mode=mv_light_mode, verbose=verbose, ) as mv: cameras, views, view_alphas, depths = [], [], [], [] for view_idx in range(mv.num_views): camera, view = mv.load_view( view_idx, ["R", "G", "B", "A"] if "A" in mv.channel_names else ["R", "G", "B"], ) depth = None if "D" in mv.channel_names: _, depth = mv.load_view(view_idx, ["D"]) depth = process_depth(depth, mv_image_size) view, alpha = process_image( np.round(view * 255.0).astype(np.uint8), mv_alpha_removal, mv_image_size ) camera = camera.center_crop().resize_image(mv_image_size, mv_image_size) cameras.append(camera) views.append(view) view_alphas.append(alpha) depths.append(depth) batch.depths = [depths] batch.views = [views] batch.view_alphas = [view_alphas] batch.cameras = [cameras] return normalize_input_batch(batch, pc_scale=2.0, color_scale=1.0 / 255.0) def load_or_create_pc( *, mesh_path: Optional[str], model_path: Optional[str], cache_dir: Optional[str], random_sample_count: int, point_count: int, num_views: int, verbose: bool = False, ) -> PointCloud: assert (model_path is not None) ^ ( mesh_path is not None ), "must specify exactly one of model_path or mesh_path" path = model_path if model_path is not None else mesh_path if cache_dir is not None: cache_path = bf.join( cache_dir, f"pc_{bf.basename(path)}_mat_{num_views}_{random_sample_count}_{point_count}.npz", ) if bf.exists(cache_path): return PointCloud.load(cache_path) else: cache_path = None with load_or_create_multiview( mesh_path=mesh_path, model_path=model_path, cache_dir=cache_dir, num_views=num_views, verbose=verbose, ) as mv: if verbose: print("extracting point cloud from multiview...") pc = mv_to_pc( multiview=mv, random_sample_count=random_sample_count, point_count=point_count ) if cache_path is not None: pc.save(cache_path) return pc @contextmanager def load_or_create_multiview( *, mesh_path: Optional[str], model_path: Optional[str], cache_dir: Optional[str], num_views: int = 20, extract_material: bool = True, light_mode: Optional[str] = None, verbose: bool = False, ) -> Iterator[BlenderViewData]: assert (model_path is not None) ^ ( mesh_path is not None ), "must specify exactly one of model_path or mesh_path" path = model_path if model_path is not None else mesh_path if extract_material: assert light_mode is None, "light_mode is ignored when extract_material=True" else: assert light_mode is not None, "must specify light_mode when extract_material=False" if cache_dir is not None: if extract_material: cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_mat_{num_views}.zip") else: cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_{light_mode}_{num_views}.zip") if bf.exists(cache_path): with bf.BlobFile(cache_path, "rb") as f: yield BlenderViewData(f) return else: cache_path = None common_kwargs = dict( fast_mode=True, extract_material=extract_material, camera_pose="random", light_mode=light_mode or "uniform", verbose=verbose, ) with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = bf.join(tmp_dir, "out.zip") if mesh_path is not None: mesh = TriMesh.load(mesh_path) render_mesh( mesh=mesh, output_path=tmp_path, num_images=num_views, backend="BLENDER_EEVEE", **common_kwargs, ) elif model_path is not None: render_model( model_path, output_path=tmp_path, num_images=num_views, backend="BLENDER_EEVEE", **common_kwargs, ) if cache_path is not None: bf.copy(tmp_path, cache_path) with bf.BlobFile(tmp_path, "rb") as f: yield BlenderViewData(f) def mv_to_pc(multiview: ViewData, random_sample_count: int, point_count: int) -> PointCloud: pc = PointCloud.from_rgbd(multiview) # Handle empty samples. if len(pc.coords) == 0: pc = PointCloud( coords=np.zeros([1, 3]), channels=dict(zip("RGB", np.zeros([3, 1]))), ) while len(pc.coords) < point_count: pc = pc.combine(pc) # Prevent duplicate points; some models may not like it. pc.coords += np.random.normal(size=pc.coords.shape) * 1e-4 pc = pc.random_sample(random_sample_count) pc = pc.farthest_point_sample(point_count, average_neighbors=True) return pc def normalize_input_batch(batch: AttrDict, *, pc_scale: float, color_scale: float) -> AttrDict: res = batch.copy() scale_vec = torch.tensor([*([pc_scale] * 3), *([color_scale] * 3)], device=batch.points.device) res.points = res.points * scale_vec[:, None] if "cameras" in res: res.cameras = [[cam.scale_scene(pc_scale) for cam in cams] for cams in res.cameras] if "depths" in res: res.depths = [[depth * pc_scale for depth in depths] for depths in res.depths] return res def process_depth(depth_img: np.ndarray, image_size: int) -> np.ndarray: depth_img = center_crop(depth_img) depth_img = resize(depth_img, width=image_size, height=image_size) return np.squeeze(depth_img) def process_image( img_or_img_arr: Union[Image.Image, np.ndarray], alpha_removal: str, image_size: int ): if isinstance(img_or_img_arr, np.ndarray): img = Image.fromarray(img_or_img_arr) img_arr = img_or_img_arr else: img = img_or_img_arr img_arr = np.array(img) if len(img_arr.shape) == 2: # Grayscale rgb = Image.new("RGB", img.size) rgb.paste(img) img = rgb img_arr = np.array(img) img = center_crop(img) alpha = get_alpha(img) img = remove_alpha(img, mode=alpha_removal) alpha = alpha.resize((image_size,) * 2, resample=Image.BILINEAR) img = img.resize((image_size,) * 2, resample=Image.BILINEAR) return img, alpha ================================================ FILE: shap_e/util/image_util.py ================================================ import random from typing import Any, List, Optional, Union import blobfile as bf import numpy as np import torch import torch.nn.functional as F from PIL import Image def center_crop( img: Union[Image.Image, torch.Tensor, np.ndarray] ) -> Union[Image.Image, torch.Tensor, np.ndarray]: """ Center crops an image. """ if isinstance(img, (np.ndarray, torch.Tensor)): height, width = img.shape[:2] else: width, height = img.size size = min(width, height) left, top = (width - size) // 2, (height - size) // 2 right, bottom = left + size, top + size if isinstance(img, (np.ndarray, torch.Tensor)): img = img[top:bottom, left:right] else: img = img.crop((left, top, right, bottom)) return img def resize( img: Union[Image.Image, torch.Tensor, np.ndarray], *, height: int, width: int, min_value: Optional[Any] = None, max_value: Optional[Any] = None, ) -> Union[Image.Image, torch.Tensor, np.ndarray]: """ :param: img: image in HWC order :return: currently written for downsampling """ orig, cls = img, type(img) if isinstance(img, Image.Image): img = np.array(img) dtype = img.dtype if isinstance(img, np.ndarray): img = torch.from_numpy(img) ndim = img.ndim if img.ndim == 2: img = img.unsqueeze(-1) if min_value is None and max_value is None: # .clamp throws an error when both are None min_value = -np.inf img = img.permute(2, 0, 1) size = (height, width) img = ( F.interpolate(img[None].float(), size=size, mode="area")[0] .clamp(min_value, max_value) .to(img.dtype) .permute(1, 2, 0) ) if ndim < img.ndim: img = img.squeeze(-1) if not isinstance(orig, torch.Tensor): img = img.numpy() img = img.astype(dtype) if isinstance(orig, Image.Image): img = Image.fromarray(img) return img def get_alpha(img: Image.Image) -> Image.Image: """ :return: the alpha channel separated out as a grayscale image """ img_arr = np.asarray(img) if img_arr.shape[2] == 4: alpha = img_arr[:, :, 3] else: alpha = np.full(img_arr.shape[:2], 255, dtype=np.uint8) alpha = Image.fromarray(alpha) return alpha def remove_alpha(img: Image.Image, mode: str = "random") -> Image.Image: """ No op if the image doesn't have an alpha channel. :param: mode: Defaults to "random" but has an option to use a "black" or "white" background :return: image with alpha removed """ img_arr = np.asarray(img) if img_arr.shape[2] == 4: # Add bg to get rid of alpha channel if mode == "random": height, width = img_arr.shape[:2] bg = Image.fromarray( random.choice([_black_bg, _gray_bg, _checker_bg, _noise_bg])(height, width) ) bg.paste(img, mask=img) img = bg elif mode == "black" or mode == "white": img_arr = img_arr.astype(float) rgb, alpha = img_arr[:, :, :3], img_arr[:, :, -1:] / 255 background = np.zeros((1, 1, 3)) if mode == "black" else np.full((1, 1, 3), 255) rgb = rgb * alpha + background * (1 - alpha) img = Image.fromarray(np.round(rgb).astype(np.uint8)) return img def _black_bg(h: int, w: int) -> np.ndarray: return np.zeros([h, w, 3], dtype=np.uint8) def _gray_bg(h: int, w: int) -> np.ndarray: return (np.zeros([h, w, 3]) + np.random.randint(low=0, high=256)).astype(np.uint8) def _checker_bg(h: int, w: int) -> np.ndarray: checker_size = np.ceil(np.exp(np.random.uniform() * np.log(min(h, w)))) c1 = np.random.randint(low=0, high=256) c2 = np.random.randint(low=0, high=256) xs = np.arange(w)[None, :, None] + np.random.randint(low=0, high=checker_size + 1) ys = np.arange(h)[:, None, None] + np.random.randint(low=0, high=checker_size + 1) fields = np.logical_xor((xs // checker_size) % 2 == 0, (ys // checker_size) % 2 == 0) return np.where(fields, np.array([c1] * 3), np.array([c2] * 3)).astype(np.uint8) def _noise_bg(h: int, w: int) -> np.ndarray: return np.random.randint(low=0, high=256, size=[h, w, 3]).astype(np.uint8) def load_image(image_path: str) -> Image.Image: with bf.BlobFile(image_path, "rb") as thefile: img = Image.open(thefile) img.load() return img def make_tile(images: List[Union[np.ndarray, Image.Image]], columns=8) -> Image.Image: """ to test, run >>> display(make_tile([(np.zeros((128, 128, 3)) + c).astype(np.uint8) for c in np.linspace(0, 255, 15)])) """ images = list(map(np.array, images)) size = images[0].shape[0] n = round_up(len(images), columns) n_blanks = n - len(images) images.extend([np.zeros((size, size, 3), dtype=np.uint8)] * n_blanks) images = ( np.array(images) .reshape(n // columns, columns, size, size, 3) .transpose([0, 2, 1, 3, 4]) .reshape(n // columns * size, columns * size, 3) ) return Image.fromarray(images) def round_up(n: int, b: int) -> int: return (n + b - 1) // b * b ================================================ FILE: shap_e/util/io.py ================================================ import io from contextlib import contextmanager from typing import Any, BinaryIO, Iterator, Union import blobfile as bf import yaml from shap_e.util.collections import AttrDict def read_config(path_or_file: Union[str, io.IOBase]) -> Any: if isinstance(path_or_file, io.IOBase): obj = yaml.load(path_or_file, Loader=yaml.SafeLoader) else: with bf.BlobFile(path_or_file, "rb") as f: try: obj = yaml.load(f, Loader=yaml.SafeLoader) except Exception as exc: with bf.BlobFile(path_or_file, "rb") as f: print(f.read()) raise exc if isinstance(obj, dict): return AttrDict(obj) return obj @contextmanager def buffered_writer(raw_f: BinaryIO) -> Iterator[io.BufferedIOBase]: if isinstance(raw_f, io.BufferedIOBase): yield raw_f else: f = io.BufferedWriter(raw_f) yield f f.flush() ================================================ FILE: shap_e/util/notebooks.py ================================================ import base64 import io from typing import Union import ipywidgets as widgets import numpy as np import torch from PIL import Image from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera from shap_e.models.transmitter.base import Transmitter, VectorDecoder from shap_e.rendering.torch_mesh import TorchMesh from shap_e.util.collections import AttrDict def create_pan_cameras(size: int, device: torch.device) -> DifferentiableCameraBatch: origins = [] xs = [] ys = [] zs = [] for theta in np.linspace(0, 2 * np.pi, num=20): z = np.array([np.sin(theta), np.cos(theta), -0.5]) z /= np.sqrt(np.sum(z**2)) origin = -z * 4 x = np.array([np.cos(theta), -np.sin(theta), 0.0]) y = np.cross(z, x) origins.append(origin) xs.append(x) ys.append(y) zs.append(z) return DifferentiableCameraBatch( shape=(1, len(xs)), flat_camera=DifferentiableProjectiveCamera( origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device), x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device), y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device), z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device), width=size, height=size, x_fov=0.7, y_fov=0.7, ), ) @torch.no_grad() def decode_latent_images( xm: Union[Transmitter, VectorDecoder], latent: torch.Tensor, cameras: DifferentiableCameraBatch, rendering_mode: str = "stf", ): decoded = xm.renderer.render_views( AttrDict(cameras=cameras), params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params( latent[None] ), options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False), ) arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy() return [Image.fromarray(x) for x in arr] @torch.no_grad() def decode_latent_mesh( xm: Union[Transmitter, VectorDecoder], latent: torch.Tensor, ) -> TorchMesh: decoded = xm.renderer.render_views( AttrDict(cameras=create_pan_cameras(2, latent.device)), # lowest resolution possible params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params( latent[None] ), options=AttrDict(rendering_mode="stf", render_with_direction=False), ) return decoded.raw_meshes[0] def gif_widget(images): writer = io.BytesIO() images[0].save( writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0 ) writer.seek(0) data = base64.b64encode(writer.read()).decode("ascii") return widgets.HTML(f'')