Full Code of openai/shap-e for AI

main 50131012ee11 cached
79 files
17.7 MB
141.3k tokens
656 symbols
1 requests
Download .txt
Showing preview only (531K chars total). Download the full file or copy to clipboard to get everything.
Repository: openai/shap-e
Branch: main
Commit: 50131012ee11
Files: 79
Total size: 17.7 MB

Directory structure:
gitextract__yazl4ub/

├── .gitignore
├── LICENSE
├── README.md
├── model-card.md
├── samples.md
├── setup.py
└── shap_e/
    ├── __init__.py
    ├── diffusion/
    │   ├── __init__.py
    │   ├── gaussian_diffusion.py
    │   ├── k_diffusion.py
    │   └── sample.py
    ├── examples/
    │   ├── encode_model.ipynb
    │   ├── example_data/
    │   │   └── cactus/
    │   │       ├── material.mtl
    │   │       └── object.obj
    │   ├── sample_image_to_3d.ipynb
    │   └── sample_text_to_3d.ipynb
    ├── models/
    │   ├── __init__.py
    │   ├── configs.py
    │   ├── download.py
    │   ├── generation/
    │   │   ├── __init__.py
    │   │   ├── latent_diffusion.py
    │   │   ├── perceiver.py
    │   │   ├── pooled_mlp.py
    │   │   ├── pretrained_clip.py
    │   │   ├── transformer.py
    │   │   └── util.py
    │   ├── nerf/
    │   │   ├── __init__.py
    │   │   ├── model.py
    │   │   ├── ray.py
    │   │   └── renderer.py
    │   ├── nerstf/
    │   │   ├── mlp.py
    │   │   └── renderer.py
    │   ├── nn/
    │   │   ├── __init__.py
    │   │   ├── camera.py
    │   │   ├── checkpoint.py
    │   │   ├── encoding.py
    │   │   ├── meta.py
    │   │   ├── ops.py
    │   │   ├── pointnet2_utils.py
    │   │   └── utils.py
    │   ├── query.py
    │   ├── renderer.py
    │   ├── stf/
    │   │   ├── __init__.py
    │   │   ├── base.py
    │   │   ├── mlp.py
    │   │   └── renderer.py
    │   ├── transmitter/
    │   │   ├── __init__.py
    │   │   ├── base.py
    │   │   ├── bottleneck.py
    │   │   ├── channels_encoder.py
    │   │   ├── multiview_encoder.py
    │   │   ├── params_proj.py
    │   │   └── pc_encoder.py
    │   └── volume.py
    ├── rendering/
    │   ├── __init__.py
    │   ├── _mc_table.py
    │   ├── blender/
    │   │   ├── __init__.py
    │   │   ├── blender_script.py
    │   │   ├── constants.py
    │   │   ├── render.py
    │   │   └── view_data.py
    │   ├── mc.py
    │   ├── mesh.py
    │   ├── ply_util.py
    │   ├── point_cloud.py
    │   ├── pytorch3d_util.py
    │   ├── raycast/
    │   │   ├── __init__.py
    │   │   ├── _utils.py
    │   │   ├── cast.py
    │   │   ├── render.py
    │   │   └── types.py
    │   ├── torch_mesh.py
    │   └── view_data.py
    └── util/
        ├── __init__.py
        ├── collections.py
        ├── data_util.py
        ├── image_util.py
        ├── io.py
        └── notebooks.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
__pycache__/
.DS_Store
*.egg-info/


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2023 OpenAI

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

================================================
FILE: README.md
================================================
# Shap-E

This is the official code and model release for [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463).

 * See [Usage](#usage) for guidance on how to use this repository.
 * See [Samples](#samples) for examples of what our text-conditional model can generate.

# Samples

Here are some highlighted samples from our text-conditional model. For random samples on selected prompts, see [samples.md](samples.md).

<table>
    <tbody>
        <tr>
            <td align="center">
                <img src="samples/a_chair_that_looks_like_an_avocado/2.gif" alt="A chair that looks like an avocado">
            </td>
            <td align="center">
                <img src="samples/an_airplane_that_looks_like_a_banana/3.gif" alt="An airplane that looks like a banana">
            </td align="center">
            <td align="center">
                <img src="samples/a_spaceship/0.gif" alt="A spaceship">
            </td>
        </tr>
        <tr>
            <td align="center">A chair that looks<br>like an avocado</td>
            <td align="center">An airplane that looks<br>like a banana</td>
            <td align="center">A spaceship</td>
        </tr>
        <tr>
            <td align="center">
                <img src="samples/a_birthday_cupcake/3.gif" alt="A birthday cupcake">
            </td>
            <td align="center">
                <img src="samples/a_chair_that_looks_like_a_tree/2.gif" alt="A chair that looks like a tree">
            </td>
            <td align="center">
                <img src="samples/a_green_boot/3.gif" alt="A green boot">
            </td>
        </tr>
        <tr>
            <td align="center">A birthday cupcake</td>
            <td align="center">A chair that looks<br>like a tree</td>
            <td align="center">A green boot</td>
        </tr>
        <tr>
            <td align="center">
                <img src="samples/a_penguin/1.gif" alt="A penguin">
            </td>
            <td align="center">
                <img src="samples/ube_ice_cream_cone/3.gif" alt="Ube ice cream cone">
            </td>
            <td align="center">
                <img src="samples/a_bowl_of_vegetables/2.gif" alt="A bowl of vegetables">
            </td>
        </tr>
        <tr>
            <td align="center">A penguin</td>
            <td align="center">Ube ice cream cone</td>
            <td align="center">A bowl of vegetables</td>
        </tr>
    </tbody>
<table>

# Usage

Install with `pip install -e .`.

To get started with examples, see the following notebooks:

* [sample_text_to_3d.ipynb](shap_e/examples/sample_text_to_3d.ipynb) - sample a 3D model, conditioned on a text prompt.
* [sample_image_to_3d.ipynb](shap_e/examples/sample_image_to_3d.ipynb) - sample a 3D model, conditioned on a synthetic view image. To get the best result, you should remove background from the input image.
* [encode_model.ipynb](shap_e/examples/encode_model.ipynb) - loads a 3D model or a trimesh, creates a batch of multiview renders and a point cloud, encodes them into a latent, and renders it back. For this to work, install Blender version 3.3.1 or higher, and set the environment variable `BLENDER_PATH` to the path of the Blender executable.


================================================
FILE: model-card.md
================================================
# Model Card: Shap-E

This is the official codebase for running the latent diffusion models described in [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463). These models were trained and released by OpenAI. Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about how the models were trained and evaluated.

# Model Details

Shap-E includes two kinds of models: an encoder and a latent diffusion model.

 1. **The encoder** converts 3D assets into the parameters of small neural networks which represent the 3D shape and texture as an implicit function. The resulting implicit function can be rendered from arbitrary viewpoints or imported into downstream applications as a mesh.
 2. **The latent diffusion model** generates novel implicit functions conditioned on either images or text descriptions. As above, these samples can be rendered or exported as a mesh. Specifically, these models produce latents which must be linearly projected to get the final implicit function parameters. The final projection layer of the encoder is used for this purpose.

Like [Point-E](https://github.com/openai/point-e/blob/main/model-card.md), Shap-E can often generate coherent 3D objects when conditioned on a rendering from a single viewpoint. When conditioned on text prompts directly, Shap-E is also often capable of producing recognizable objects, although it sometimes struggles to combine multiple objects or concepts.

Samples from Shap-E are typically lower fidelity than professional 3D assets and often have rough edges, holes, or blurry surface textures.

# Model Date

April 2023

# Model Versions

The following model checkpoints are available in this repository:

 * `transmitter` - the encoder and corresponding projection layers for converting encoder outputs into implicit neural representations.
 * `decoder` - just the final projection layer component of `transmitter`. This is a smaller checkpoint than `transmitter` since it does not include parameters for encoding 3D assets. This is the minimum required model to convert diffusion outputs into implicit neural representations.
 * `text300M` - the text-conditional latent diffusion model.
 * `image300M` - the image-conditional latent diffusion model.

# Paper & Samples

[Paper link](https://arxiv.org/abs/2305.02463) / [Samples](samples.md)

# Training data

The encoder and image-conditional diffusion models are trained on the [same dataset as Point-E](https://github.com/openai/point-e/blob/main/model-card.md#training-data). However, a few changes to the post-processing were made:

 * We rendered 60 views (instead of 20) of each model when computing point clouds, to avoid small cracks.
 * We produced 16K points in each point cloud instead of 4K.
 * We simplified the lighting and material setup to only include diffuse materials.

For our text-conditional diffusion model, we expanded our dataset with roughly a million more 3D assets. Additionally, we collected 120K captions from human annotators for a high-quality subset of our 3D assets.

# Evaluated Use

We release these models with the intention of furthering progress in the field of generative modeling. However, we acknowledge that our models have certain constraints and biases, which is why we advise against employing them for commercial purposes at this time. We are aware that the utilization of our models could extend to areas beyond our expectations, and defining specific criteria for what is considered suitable for "research" purposes presents a challenge. Specifically, we advise caution when using these models in contexts that demand high accuracy, where minor imperfections in the generated 3D assets could have adverse consequences.

Specifically, these models have been evaluated on the following tasks for research purposes:

 * Generating 3D renderings or meshes conditioned on single, synthetic images
 * Generating 3D renderings or meshes conditioned on text descriptions

# Performance & Limitations

Our image-conditional model has only been evaluated on a highly specific distribution of synthetic renderings. Even in these cases, the model still sometimes fails to infer the correct occluded parts of an object or produces geometry that is inconsistent with the given rendered images. These failure modes are similar to those of Point-E. The resulting 3D assets often have rough edges, holes, or blurry surface textures.

Our text-conditional model can also produce a somewhat large and diverse vocabulary of objects. This model is often capable of producing objects with requested colors and textures, and sometimes even combining multiple objects. However, it often fails for more complex prompts that require placing multiple objects in a scene or binding attributes to objects. It also typically fails to produce a desired number of objects when a certain quantity is requested.

We find that our text-conditional model can sometimes produce samples which reflect gender biases. For example, samples for "a nurse" typically have a different body shape than samples for "a doctor". When probing for potential misuses, we also found that our text-conditional model is capable of producing 3D assets related to violence, such as guns or tanks. However, the resulting quality of these samples is poor enough that they look unrealistic and toy like.

As with Point-E, our dataset consists of many simple, cartoonish 3D assets, and our generative models are prone to imitating this style.

We believe our models will have many potential use cases. For example, our text-conditional model could enable users to quickly produce many 3D assets, allowing for rapid prototyping for computer graphics applications or 3D printing.

The use of 3D printing in concert with our models could potentially be harmful, for example if used to create dangerous objects or fabricate tools or parts that are deployed without external validation.

Generative 3D models share many challenges and constraints with image generation models. This includes the tendency to generate content that may be biased or detrimental, as well as the potential for dual-use applications. As the capabilities of these models evolve, further investigation is required to gain a clearer understanding of how these risks manifest.


================================================
FILE: samples.md
================================================
# Samples

Here is a collection of prompts and four random text-conditional samples for each prompt. Samples are rendered at 128x128 resolution with NeRF.

<table><tbody><tr><th align="center">Prompt</th><th></th><th></th><th></th><th></th><tr><td align="center">a penguin</td><td align="center"><img src="samples/a_penguin/0.gif" alt="a penguin"></td><td align="center"><img src="samples/a_penguin/1.gif" alt="a penguin"></td><td align="center"><img src="samples/a_penguin/2.gif" alt="a penguin"></td><td align="center"><img src="samples/a_penguin/3.gif" alt="a penguin"></td></tr><tr><td align="center">a campfire</td><td align="center"><img src="samples/a_campfire/0.gif" alt="a campfire"></td><td align="center"><img src="samples/a_campfire/1.gif" alt="a campfire"></td><td align="center"><img src="samples/a_campfire/2.gif" alt="a campfire"></td><td align="center"><img src="samples/a_campfire/3.gif" alt="a campfire"></td></tr><tr><td align="center">an elephant</td><td align="center"><img src="samples/an_elephant/0.gif" alt="an elephant"></td><td align="center"><img src="samples/an_elephant/1.gif" alt="an elephant"></td><td align="center"><img src="samples/an_elephant/2.gif" alt="an elephant"></td><td align="center"><img src="samples/an_elephant/3.gif" alt="an elephant"></td></tr><tr><td align="center">a donut with pink icing</td><td align="center"><img src="samples/a_donut_with_pink_icing/0.gif" alt="a donut with pink icing"></td><td align="center"><img src="samples/a_donut_with_pink_icing/1.gif" alt="a donut with pink icing"></td><td align="center"><img src="samples/a_donut_with_pink_icing/2.gif" alt="a donut with pink icing"></td><td align="center"><img src="samples/a_donut_with_pink_icing/3.gif" alt="a donut with pink icing"></td></tr><tr><td align="center">a voxelized dog</td><td align="center"><img src="samples/a_voxelized_dog/0.gif" alt="a voxelized dog"></td><td align="center"><img src="samples/a_voxelized_dog/1.gif" alt="a voxelized dog"></td><td align="center"><img src="samples/a_voxelized_dog/2.gif" alt="a voxelized dog"></td><td align="center"><img src="samples/a_voxelized_dog/3.gif" alt="a voxelized dog"></td></tr><tr><td align="center">ube ice cream cone</td><td align="center"><img src="samples/ube_ice_cream_cone/0.gif" alt="ube ice cream cone"></td><td align="center"><img src="samples/ube_ice_cream_cone/1.gif" alt="ube ice cream cone"></td><td align="center"><img src="samples/ube_ice_cream_cone/2.gif" alt="ube ice cream cone"></td><td align="center"><img src="samples/ube_ice_cream_cone/3.gif" alt="ube ice cream cone"></td></tr><tr><td align="center">a birthday cupcake</td><td align="center"><img src="samples/a_birthday_cupcake/0.gif" alt="a birthday cupcake"></td><td align="center"><img src="samples/a_birthday_cupcake/1.gif" alt="a birthday cupcake"></td><td align="center"><img src="samples/a_birthday_cupcake/2.gif" alt="a birthday cupcake"></td><td align="center"><img src="samples/a_birthday_cupcake/3.gif" alt="a birthday cupcake"></td></tr><tr><td align="center">shepherds pie</td><td align="center"><img src="samples/shepherds_pie/0.gif" alt="shepherds pie"></td><td align="center"><img src="samples/shepherds_pie/1.gif" alt="shepherds pie"></td><td align="center"><img src="samples/shepherds_pie/2.gif" alt="shepherds pie"></td><td align="center"><img src="samples/shepherds_pie/3.gif" alt="shepherds pie"></td></tr><tr><td align="center">a bowl of vegetables</td><td align="center"><img src="samples/a_bowl_of_vegetables/0.gif" alt="a bowl of vegetables"></td><td align="center"><img src="samples/a_bowl_of_vegetables/1.gif" alt="a bowl of vegetables"></td><td align="center"><img src="samples/a_bowl_of_vegetables/2.gif" alt="a bowl of vegetables"></td><td align="center"><img src="samples/a_bowl_of_vegetables/3.gif" alt="a bowl of vegetables"></td></tr><tr><td align="center">a cheeseburger</td><td align="center"><img src="samples/a_cheeseburger/0.gif" alt="a cheeseburger"></td><td align="center"><img src="samples/a_cheeseburger/1.gif" alt="a cheeseburger"></td><td align="center"><img src="samples/a_cheeseburger/2.gif" alt="a cheeseburger"></td><td align="center"><img src="samples/a_cheeseburger/3.gif" alt="a cheeseburger"></td></tr><tr><td align="center">a plate of mushy green peas</td><td align="center"><img src="samples/a_plate_of_mushy_green_peas/0.gif" alt="a plate of mushy green peas"></td><td align="center"><img src="samples/a_plate_of_mushy_green_peas/1.gif" alt="a plate of mushy green peas"></td><td align="center"><img src="samples/a_plate_of_mushy_green_peas/2.gif" alt="a plate of mushy green peas"></td><td align="center"><img src="samples/a_plate_of_mushy_green_peas/3.gif" alt="a plate of mushy green peas"></td></tr><tr><td align="center">a traffic cone</td><td align="center"><img src="samples/a_traffic_cone/0.gif" alt="a traffic cone"></td><td align="center"><img src="samples/a_traffic_cone/1.gif" alt="a traffic cone"></td><td align="center"><img src="samples/a_traffic_cone/2.gif" alt="a traffic cone"></td><td align="center"><img src="samples/a_traffic_cone/3.gif" alt="a traffic cone"></td></tr><tr><td align="center">a car that looks like an avocado</td><td align="center"><img src="samples/a_car_that_looks_like_an_avocado/0.gif" alt="a car that looks like an avocado"></td><td align="center"><img src="samples/a_car_that_looks_like_an_avocado/1.gif" alt="a car that looks like an avocado"></td><td align="center"><img src="samples/a_car_that_looks_like_an_avocado/2.gif" alt="a car that looks like an avocado"></td><td align="center"><img src="samples/a_car_that_looks_like_an_avocado/3.gif" alt="a car that looks like an avocado"></td></tr><tr><td align="center">an airplane that looks like a banana</td><td align="center"><img src="samples/an_airplane_that_looks_like_a_banana/0.gif" alt="an airplane that looks like a banana"></td><td align="center"><img src="samples/an_airplane_that_looks_like_a_banana/1.gif" alt="an airplane that looks like a banana"></td><td align="center"><img src="samples/an_airplane_that_looks_like_a_banana/2.gif" alt="an airplane that looks like a banana"></td><td align="center"><img src="samples/an_airplane_that_looks_like_a_banana/3.gif" alt="an airplane that looks like a banana"></td></tr><tr><td align="center">a stop sign</td><td align="center"><img src="samples/a_stop_sign/0.gif" alt="a stop sign"></td><td align="center"><img src="samples/a_stop_sign/1.gif" alt="a stop sign"></td><td align="center"><img src="samples/a_stop_sign/2.gif" alt="a stop sign"></td><td align="center"><img src="samples/a_stop_sign/3.gif" alt="a stop sign"></td></tr><tr><td align="center">a spaceship</td><td align="center"><img src="samples/a_spaceship/0.gif" alt="a spaceship"></td><td align="center"><img src="samples/a_spaceship/1.gif" alt="a spaceship"></td><td align="center"><img src="samples/a_spaceship/2.gif" alt="a spaceship"></td><td align="center"><img src="samples/a_spaceship/3.gif" alt="a spaceship"></td></tr><tr><td align="center">a race car</td><td align="center"><img src="samples/a_race_car/0.gif" alt="a race car"></td><td align="center"><img src="samples/a_race_car/1.gif" alt="a race car"></td><td align="center"><img src="samples/a_race_car/2.gif" alt="a race car"></td><td align="center"><img src="samples/a_race_car/3.gif" alt="a race car"></td></tr><tr><td align="center">a schoolbus</td><td align="center"><img src="samples/a_schoolbus/0.gif" alt="a schoolbus"></td><td align="center"><img src="samples/a_schoolbus/1.gif" alt="a schoolbus"></td><td align="center"><img src="samples/a_schoolbus/2.gif" alt="a schoolbus"></td><td align="center"><img src="samples/a_schoolbus/3.gif" alt="a schoolbus"></td></tr><tr><td align="center">a firetruck</td><td align="center"><img src="samples/a_firetruck/0.gif" alt="a firetruck"></td><td align="center"><img src="samples/a_firetruck/1.gif" alt="a firetruck"></td><td align="center"><img src="samples/a_firetruck/2.gif" alt="a firetruck"></td><td align="center"><img src="samples/a_firetruck/3.gif" alt="a firetruck"></td></tr><tr><td align="center">a rusty old car</td><td align="center"><img src="samples/a_rusty_old_car/0.gif" alt="a rusty old car"></td><td align="center"><img src="samples/a_rusty_old_car/1.gif" alt="a rusty old car"></td><td align="center"><img src="samples/a_rusty_old_car/2.gif" alt="a rusty old car"></td><td align="center"><img src="samples/a_rusty_old_car/3.gif" alt="a rusty old car"></td></tr><tr><td align="center">a fast car</td><td align="center"><img src="samples/a_fast_car/0.gif" alt="a fast car"></td><td align="center"><img src="samples/a_fast_car/1.gif" alt="a fast car"></td><td align="center"><img src="samples/a_fast_car/2.gif" alt="a fast car"></td><td align="center"><img src="samples/a_fast_car/3.gif" alt="a fast car"></td></tr><tr><td align="center">a chair that looks like an avocado</td><td align="center"><img src="samples/a_chair_that_looks_like_an_avocado/0.gif" alt="a chair that looks like an avocado"></td><td align="center"><img src="samples/a_chair_that_looks_like_an_avocado/1.gif" alt="a chair that looks like an avocado"></td><td align="center"><img src="samples/a_chair_that_looks_like_an_avocado/2.gif" alt="a chair that looks like an avocado"></td><td align="center"><img src="samples/a_chair_that_looks_like_an_avocado/3.gif" alt="a chair that looks like an avocado"></td></tr><tr><td align="center">a chair that looks like fruit</td><td align="center"><img src="samples/a_chair_that_looks_like_fruit/0.gif" alt="a chair that looks like fruit"></td><td align="center"><img src="samples/a_chair_that_looks_like_fruit/1.gif" alt="a chair that looks like fruit"></td><td align="center"><img src="samples/a_chair_that_looks_like_fruit/2.gif" alt="a chair that looks like fruit"></td><td align="center"><img src="samples/a_chair_that_looks_like_fruit/3.gif" alt="a chair that looks like fruit"></td></tr><tr><td align="center">a chair that looks like a tree</td><td align="center"><img src="samples/a_chair_that_looks_like_a_tree/0.gif" alt="a chair that looks like a tree"></td><td align="center"><img src="samples/a_chair_that_looks_like_a_tree/1.gif" alt="a chair that looks like a tree"></td><td align="center"><img src="samples/a_chair_that_looks_like_a_tree/2.gif" alt="a chair that looks like a tree"></td><td align="center"><img src="samples/a_chair_that_looks_like_a_tree/3.gif" alt="a chair that looks like a tree"></td></tr><tr><td align="center">a chair that looks like a zebra</td><td align="center"><img src="samples/a_chair_that_looks_like_a_zebra/0.gif" alt="a chair that looks like a zebra"></td><td align="center"><img src="samples/a_chair_that_looks_like_a_zebra/1.gif" alt="a chair that looks like a zebra"></td><td align="center"><img src="samples/a_chair_that_looks_like_a_zebra/2.gif" alt="a chair that looks like a zebra"></td><td align="center"><img src="samples/a_chair_that_looks_like_a_zebra/3.gif" alt="a chair that looks like a zebra"></td></tr><tr><td align="center">a chair that looks like a swimming pool</td><td align="center"><img src="samples/a_chair_that_looks_like_a_swimming_pool/0.gif" alt="a chair that looks like a swimming pool"></td><td align="center"><img src="samples/a_chair_that_looks_like_a_swimming_pool/1.gif" alt="a chair that looks like a swimming pool"></td><td align="center"><img src="samples/a_chair_that_looks_like_a_swimming_pool/2.gif" alt="a chair that looks like a swimming pool"></td><td align="center"><img src="samples/a_chair_that_looks_like_a_swimming_pool/3.gif" alt="a chair that looks like a swimming pool"></td></tr><tr><td align="center">the person is running</td><td align="center"><img src="samples/the_person_is_running/0.gif" alt="the person is running"></td><td align="center"><img src="samples/the_person_is_running/1.gif" alt="the person is running"></td><td align="center"><img src="samples/the_person_is_running/2.gif" alt="the person is running"></td><td align="center"><img src="samples/the_person_is_running/3.gif" alt="the person is running"></td></tr><tr><td align="center">the person is sitting</td><td align="center"><img src="samples/the_person_is_sitting/0.gif" alt="the person is sitting"></td><td align="center"><img src="samples/the_person_is_sitting/1.gif" alt="the person is sitting"></td><td align="center"><img src="samples/the_person_is_sitting/2.gif" alt="the person is sitting"></td><td align="center"><img src="samples/the_person_is_sitting/3.gif" alt="the person is sitting"></td></tr><tr><td align="center">the person is lying down</td><td align="center"><img src="samples/the_person_is_lying_down/0.gif" alt="the person is lying down"></td><td align="center"><img src="samples/the_person_is_lying_down/1.gif" alt="the person is lying down"></td><td align="center"><img src="samples/the_person_is_lying_down/2.gif" alt="the person is lying down"></td><td align="center"><img src="samples/the_person_is_lying_down/3.gif" alt="the person is lying down"></td></tr><tr><td align="center">a person that looks like a zebra</td><td align="center"><img src="samples/a_person_that_looks_like_a_zebra/0.gif" alt="a person that looks like a zebra"></td><td align="center"><img src="samples/a_person_that_looks_like_a_zebra/1.gif" alt="a person that looks like a zebra"></td><td align="center"><img src="samples/a_person_that_looks_like_a_zebra/2.gif" alt="a person that looks like a zebra"></td><td align="center"><img src="samples/a_person_that_looks_like_a_zebra/3.gif" alt="a person that looks like a zebra"></td></tr><tr><td align="center">a person that looks like a leopard</td><td align="center"><img src="samples/a_person_that_looks_like_a_leopard/0.gif" alt="a person that looks like a leopard"></td><td align="center"><img src="samples/a_person_that_looks_like_a_leopard/1.gif" alt="a person that looks like a leopard"></td><td align="center"><img src="samples/a_person_that_looks_like_a_leopard/2.gif" alt="a person that looks like a leopard"></td><td align="center"><img src="samples/a_person_that_looks_like_a_leopard/3.gif" alt="a person that looks like a leopard"></td></tr><tr><td align="center">a pair of shorts</td><td align="center"><img src="samples/a_pair_of_shorts/0.gif" alt="a pair of shorts"></td><td align="center"><img src="samples/a_pair_of_shorts/1.gif" alt="a pair of shorts"></td><td align="center"><img src="samples/a_pair_of_shorts/2.gif" alt="a pair of shorts"></td><td align="center"><img src="samples/a_pair_of_shorts/3.gif" alt="a pair of shorts"></td></tr><tr><td align="center">a designer dress</td><td align="center"><img src="samples/a_designer_dress/0.gif" alt="a designer dress"></td><td align="center"><img src="samples/a_designer_dress/1.gif" alt="a designer dress"></td><td align="center"><img src="samples/a_designer_dress/2.gif" alt="a designer dress"></td><td align="center"><img src="samples/a_designer_dress/3.gif" alt="a designer dress"></td></tr><tr><td align="center">banana shoes</td><td align="center"><img src="samples/banana_shoes/0.gif" alt="banana shoes"></td><td align="center"><img src="samples/banana_shoes/1.gif" alt="banana shoes"></td><td align="center"><img src="samples/banana_shoes/2.gif" alt="banana shoes"></td><td align="center"><img src="samples/banana_shoes/3.gif" alt="banana shoes"></td></tr><tr><td align="center">a green boot</td><td align="center"><img src="samples/a_green_boot/0.gif" alt="a green boot"></td><td align="center"><img src="samples/a_green_boot/1.gif" alt="a green boot"></td><td align="center"><img src="samples/a_green_boot/2.gif" alt="a green boot"></td><td align="center"><img src="samples/a_green_boot/3.gif" alt="a green boot"></td></tr><tr><td align="center">a pair of sunglasses</td><td align="center"><img src="samples/a_pair_of_sunglasses/0.gif" alt="a pair of sunglasses"></td><td align="center"><img src="samples/a_pair_of_sunglasses/1.gif" alt="a pair of sunglasses"></td><td align="center"><img src="samples/a_pair_of_sunglasses/2.gif" alt="a pair of sunglasses"></td><td align="center"><img src="samples/a_pair_of_sunglasses/3.gif" alt="a pair of sunglasses"></td></tr></tbody></table>



================================================
FILE: setup.py
================================================
from setuptools import setup

setup(
    name="shap-e",
    packages=[
        "shap_e",
        "shap_e.diffusion",
        "shap_e.models",
        "shap_e.models.generation",
        "shap_e.models.nerf",
        "shap_e.models.nerstf",
        "shap_e.models.nn",
        "shap_e.models.stf",
        "shap_e.models.transmitter",
        "shap_e.rendering",
        "shap_e.rendering.blender",
        "shap_e.rendering.raycast",
        "shap_e.util",
    ],
    install_requires=[
        "filelock",
        "Pillow",
        "torch",
        "fire",
        "humanize",
        "requests",
        "tqdm",
        "matplotlib",
        "scikit-image",
        "scipy",
        "numpy",
        "blobfile",
        "clip @ git+https://github.com/openai/CLIP.git",
    ],
    author="OpenAI",
)


================================================
FILE: shap_e/__init__.py
================================================


================================================
FILE: shap_e/diffusion/__init__.py
================================================


================================================
FILE: shap_e/diffusion/gaussian_diffusion.py
================================================
"""
Based on https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
"""

import math
from typing import Any, Dict, Iterable, Optional, Sequence, Union

import blobfile as bf
import numpy as np
import torch as th
import yaml


def diffusion_from_config(config: Union[str, Dict[str, Any]]) -> "GaussianDiffusion":
    if isinstance(config, str):
        with bf.BlobFile(config, "rb") as f:
            obj = yaml.load(f, Loader=yaml.SafeLoader)
        return diffusion_from_config(obj)

    schedule = config["schedule"]
    steps = config["timesteps"]
    respace = config.get("respacing", None)
    mean_type = config.get("mean_type", "epsilon")
    betas = get_named_beta_schedule(schedule, steps, **config.get("schedule_args", {}))
    channel_scales = config.get("channel_scales", None)
    channel_biases = config.get("channel_biases", None)
    if channel_scales is not None:
        channel_scales = np.array(channel_scales)
    if channel_biases is not None:
        channel_biases = np.array(channel_biases)
    kwargs = dict(
        betas=betas,
        model_mean_type=mean_type,
        model_var_type="learned_range",
        loss_type="mse",
        channel_scales=channel_scales,
        channel_biases=channel_biases,
    )
    if respace is None:
        return GaussianDiffusion(**kwargs)
    else:
        return SpacedDiffusion(use_timesteps=space_timesteps(steps, respace), **kwargs)


def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    """
    This is the deprecated API for creating beta schedules.

    See get_named_beta_schedule() for the new library of schedules.
    """
    if beta_schedule == "linear":
        betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas


def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, **extra_args: float):
    """
    Get a pre-defined beta schedule for the given name.

    The beta schedule library consists of beta schedules which remain similar
    in the limit of num_diffusion_timesteps.
    Beta schedules may be added, but should not be removed or changed once
    they are committed to maintain backwards compatibility.
    """
    if schedule_name == "linear":
        # Linear schedule from Ho et al, extended to work for any number of
        # diffusion steps.
        scale = 1000 / num_diffusion_timesteps
        return get_beta_schedule(
            "linear",
            beta_start=scale * 0.0001,
            beta_end=scale * 0.02,
            num_diffusion_timesteps=num_diffusion_timesteps,
        )
    elif schedule_name == "cosine":
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )
    elif schedule_name == "inv_parabola":
        exponent = extra_args.get("power", 2.0)
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: 1 - t**exponent,
        )
    elif schedule_name == "translated_parabola":
        exponent = extra_args.get("power", 2.0)
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: (1 - t) ** exponent,
        )
    elif schedule_name == "exp":
        coefficient = extra_args.get("coefficient", -12.0)
        return betas_for_alpha_bar(num_diffusion_timesteps, lambda t: math.exp(t * coefficient))
    else:
        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")


def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].

    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)


def space_timesteps(num_timesteps, section_counts):
    """
    Create a list of timesteps to use from an original diffusion process,
    given the number of timesteps we want to take from equally-sized portions
    of the original process.
    For example, if there's 300 timesteps and the section counts are [10,15,20]
    then the first 100 timesteps are strided to be 10 timesteps, the second 100
    are strided to be 15 timesteps, and the final 100 are strided to be 20.
    :param num_timesteps: the number of diffusion steps in the original
                          process to divide up.
    :param section_counts: either a list of numbers, or a string containing
                           comma-separated numbers, indicating the step count
                           per section. As a special case, use "ddimN" where N
                           is a number of steps to use the striding from the
                           DDIM paper.
    :return: a set of diffusion steps from the original process to use.
    """
    if isinstance(section_counts, str):
        if section_counts.startswith("ddim"):
            desired_count = int(section_counts[len("ddim") :])
            for i in range(1, num_timesteps):
                if len(range(0, num_timesteps, i)) == desired_count:
                    return set(range(0, num_timesteps, i))
            raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
        elif section_counts.startswith("exact"):
            res = set(int(x) for x in section_counts[len("exact") :].split(","))
            for x in res:
                if x < 0 or x >= num_timesteps:
                    raise ValueError(f"timestep out of bounds: {x}")
            return res
        section_counts = [int(x) for x in section_counts.split(",")]
    size_per = num_timesteps // len(section_counts)
    extra = num_timesteps % len(section_counts)
    start_idx = 0
    all_steps = []
    for i, section_count in enumerate(section_counts):
        size = size_per + (1 if i < extra else 0)
        if size < section_count:
            raise ValueError(f"cannot divide section of {size} steps into {section_count}")
        if section_count <= 1:
            frac_stride = 1
        else:
            frac_stride = (size - 1) / (section_count - 1)
        cur_idx = 0.0
        taken_steps = []
        for _ in range(section_count):
            taken_steps.append(start_idx + round(cur_idx))
            cur_idx += frac_stride
        all_steps += taken_steps
        start_idx += size
    return set(all_steps)


class GaussianDiffusion:
    """
    Utilities for training and sampling diffusion models.

    Ported directly from here:
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42

    :param betas: a 1-D array of betas for each diffusion timestep from T to 1.
    :param model_mean_type: a string determining what the model outputs.
    :param model_var_type: a string determining how variance is output.
    :param loss_type: a string determining the loss function to use.
    :param discretized_t0: if True, use discrete gaussian loss for t=0. Only
                           makes sense for images.
    :param channel_scales: a multiplier to apply to x_start in training_losses
                           and sampling functions.
    """

    def __init__(
        self,
        *,
        betas: Sequence[float],
        model_mean_type: str,
        model_var_type: str,
        loss_type: str,
        discretized_t0: bool = False,
        channel_scales: Optional[np.ndarray] = None,
        channel_biases: Optional[np.ndarray] = None,
    ):
        self.model_mean_type = model_mean_type
        self.model_var_type = model_var_type
        self.loss_type = loss_type
        self.discretized_t0 = discretized_t0
        self.channel_scales = channel_scales
        self.channel_biases = channel_biases

        # Use float64 for accuracy.
        betas = np.array(betas, dtype=np.float64)
        self.betas = betas
        assert len(betas.shape) == 1, "betas must be 1-D"
        assert (betas > 0).all() and (betas <= 1).all()

        self.num_timesteps = int(betas.shape[0])

        alphas = 1.0 - betas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.posterior_log_variance_clipped = np.log(
            np.append(self.posterior_variance[1], self.posterior_variance[1:])
        )
        self.posterior_mean_coef1 = (
            betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
        )

    def get_sigmas(self, t):
        return _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, t.shape)

    def q_mean_variance(self, x_start, t):
        """
        Get the distribution q(x_t | x_0).

        :param x_start: the [N x C x ...] tensor of noiseless inputs.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
        """
        mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
        variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
        log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
        return mean, variance, log_variance

    def q_sample(self, x_start, t, noise=None):
        """
        Diffuse the data for a given number of diffusion steps.

        In other words, sample from q(x_t | x_0).

        :param x_start: the initial data batch.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :param noise: if specified, the split-out normal noise.
        :return: A noisy version of x_start.
        """
        if noise is None:
            noise = th.randn_like(x_start)
        assert noise.shape == x_start.shape
        return (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

    def q_posterior_mean_variance(self, x_start, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior:

            q(x_{t-1} | x_t, x_0)

        """
        assert x_start.shape == x_t.shape
        posterior_mean = (
            _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = _extract_into_tensor(
            self.posterior_log_variance_clipped, t, x_t.shape
        )
        assert (
            posterior_mean.shape[0]
            == posterior_variance.shape[0]
            == posterior_log_variance_clipped.shape[0]
            == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(
        self, model, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None
    ):
        """
        Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
        the initial x, x_0.

        :param model: the model, which takes a signal and a batch of timesteps
                      as input.
        :param x: the [N x C x ...] tensor at time t.
        :param t: a 1-D Tensor of timesteps.
        :param clip_denoised: if True, clip the denoised signal into [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample. Applies before
            clip_denoised.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict with the following keys:
                 - 'mean': the model mean output.
                 - 'variance': the model variance output.
                 - 'log_variance': the log of 'variance'.
                 - 'pred_xstart': the prediction for x_0.
        """
        if model_kwargs is None:
            model_kwargs = {}

        B, C = x.shape[:2]
        assert t.shape == (B,)
        model_output = model(x, t, **model_kwargs)
        if isinstance(model_output, tuple):
            model_output, extra = model_output
        else:
            extra = None

        if self.model_var_type in ["learned", "learned_range"]:
            assert model_output.shape == (B, C * 2, *x.shape[2:])
            model_output, model_var_values = th.split(model_output, C, dim=1)
            if self.model_var_type == "learned":
                model_log_variance = model_var_values
                model_variance = th.exp(model_log_variance)
            else:
                min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
                max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
                # The model_var_values is [-1, 1] for [min_var, max_var].
                frac = (model_var_values + 1) / 2
                model_log_variance = frac * max_log + (1 - frac) * min_log
                model_variance = th.exp(model_log_variance)
        else:
            model_variance, model_log_variance = {
                # for fixedlarge, we set the initial (log-)variance like so
                # to get a better decoder log likelihood.
                "fixed_large": (
                    np.append(self.posterior_variance[1], self.betas[1:]),
                    np.log(np.append(self.posterior_variance[1], self.betas[1:])),
                ),
                "fixed_small": (
                    self.posterior_variance,
                    self.posterior_log_variance_clipped,
                ),
            }[self.model_var_type]
            model_variance = _extract_into_tensor(model_variance, t, x.shape)
            model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)

        def process_xstart(x):
            if denoised_fn is not None:
                x = denoised_fn(x)
            if clip_denoised:
                return x.clamp(-1, 1)
            return x

        if self.model_mean_type == "x_prev":
            pred_xstart = process_xstart(
                self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
            )
            model_mean = model_output
        elif self.model_mean_type in ["x_start", "epsilon"]:
            if self.model_mean_type == "x_start":
                pred_xstart = process_xstart(model_output)
            else:
                pred_xstart = process_xstart(
                    self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
                )
            model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
        else:
            raise NotImplementedError(self.model_mean_type)

        assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
        return {
            "mean": model_mean,
            "variance": model_variance,
            "log_variance": model_log_variance,
            "pred_xstart": pred_xstart,
            "extra": extra,
        }

    def _predict_xstart_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
        )

    def _predict_xstart_from_xprev(self, x_t, t, xprev):
        assert x_t.shape == xprev.shape
        return (  # (xprev - coef2*x_t) / coef1
            _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
            - _extract_into_tensor(
                self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
            )
            * x_t
        )

    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
        return (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

    def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute the mean for the previous step, given a function cond_fn that
        computes the gradient of a conditional log probability with respect to
        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
        condition on y.

        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
        """
        gradient = cond_fn(x, t, **(model_kwargs or {}))
        new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
        return new_mean

    def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute what the p_mean_variance output would have been, should the
        model's score function be conditioned by cond_fn.

        See condition_mean() for details on cond_fn.

        Unlike condition_mean(), this instead uses the conditioning strategy
        from Song et al (2020).
        """
        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)

        eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
        eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **(model_kwargs or {}))

        out = p_mean_var.copy()
        out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
        out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
        return out

    def p_sample(
        self,
        model,
        x,
        t,
        clip_denoised=False,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
    ):
        """
        Sample x_{t-1} from the model at the given timestep.

        :param model: the model to sample from.
        :param x: the current tensor at x_{t-1}.
        :param t: the value of t, starting at 0 for the first diffusion step.
        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param cond_fn: if not None, this is a gradient function that acts
                        similarly to the model.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict containing the following keys:
                 - 'sample': a random sample from the model.
                 - 'pred_xstart': a prediction of x_0.
        """
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        noise = th.randn_like(x)
        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        )  # no noise when t == 0
        if cond_fn is not None:
            out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
        sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
        return {"sample": sample, "pred_xstart": out["pred_xstart"]}

    def p_sample_loop(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=False,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        temp=1.0,
    ):
        """
        Generate samples from the model.

        :param model: the model module.
        :param shape: the shape of the samples, (N, C, H, W).
        :param noise: if specified, the noise from the encoder to sample.
                      Should be of the same shape as `shape`.
        :param clip_denoised: if True, clip x_start predictions to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param cond_fn: if not None, this is a gradient function that acts
                        similarly to the model.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :param device: if specified, the device to create the samples on.
                       If not specified, use a model parameter's device.
        :param progress: if True, show a tqdm progress bar.
        :return: a non-differentiable batch of samples.
        """
        final = None
        for sample in self.p_sample_loop_progressive(
            model,
            shape,
            noise=noise,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            cond_fn=cond_fn,
            model_kwargs=model_kwargs,
            device=device,
            progress=progress,
            temp=temp,
        ):
            final = sample
        return final["sample"]

    def p_sample_loop_progressive(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=False,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        temp=1.0,
    ):
        """
        Generate samples from the model and yield intermediate samples from
        each timestep of diffusion.

        Arguments are the same as p_sample_loop().
        Returns a generator over dicts, where each dict is the return value of
        p_sample().
        """
        if device is None:
            device = next(model.parameters()).device
        assert isinstance(shape, (tuple, list))
        if noise is not None:
            img = noise
        else:
            img = th.randn(*shape, device=device) * temp
        indices = list(range(self.num_timesteps))[::-1]

        if progress:
            # Lazy import so that we don't depend on tqdm.
            from tqdm.auto import tqdm

            indices = tqdm(indices)

        for i in indices:
            t = th.tensor([i] * shape[0], device=device)
            with th.no_grad():
                out = self.p_sample(
                    model,
                    img,
                    t,
                    clip_denoised=clip_denoised,
                    denoised_fn=denoised_fn,
                    cond_fn=cond_fn,
                    model_kwargs=model_kwargs,
                )
                yield self.unscale_out_dict(out)
                img = out["sample"]

    def ddim_sample(
        self,
        model,
        x,
        t,
        clip_denoised=False,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        eta=0.0,
    ):
        """
        Sample x_{t-1} from the model using DDIM.

        Same usage as p_sample().
        """
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        if cond_fn is not None:
            out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)

        # Usually our model outputs epsilon, but we re-derive it
        # in case we used x_start or x_prev prediction.
        eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])

        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
        alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
        sigma = (
            eta
            * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
            * th.sqrt(1 - alpha_bar / alpha_bar_prev)
        )
        # Equation 12.
        noise = th.randn_like(x)
        mean_pred = (
            out["pred_xstart"] * th.sqrt(alpha_bar_prev)
            + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
        )
        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        )  # no noise when t == 0
        sample = mean_pred + nonzero_mask * sigma * noise
        return {"sample": sample, "pred_xstart": out["pred_xstart"]}

    def ddim_reverse_sample(
        self,
        model,
        x,
        t,
        clip_denoised=False,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        eta=0.0,
    ):
        """
        Sample x_{t+1} from the model using DDIM reverse ODE.
        """
        assert eta == 0.0, "Reverse ODE only for deterministic path"
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        if cond_fn is not None:
            out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
        # Usually our model outputs epsilon, but we re-derive it
        # in case we used x_start or x_prev prediction.
        eps = (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
            - out["pred_xstart"]
        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
        alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)

        # Equation 12. reversed
        mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps

        return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}

    def ddim_sample_loop(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=False,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        eta=0.0,
        temp=1.0,
    ):
        """
        Generate samples from the model using DDIM.

        Same usage as p_sample_loop().
        """
        final = None
        for sample in self.ddim_sample_loop_progressive(
            model,
            shape,
            noise=noise,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            cond_fn=cond_fn,
            model_kwargs=model_kwargs,
            device=device,
            progress=progress,
            eta=eta,
            temp=temp,
        ):
            final = sample
        return final["sample"]

    def ddim_sample_loop_progressive(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=False,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        eta=0.0,
        temp=1.0,
    ):
        """
        Use DDIM to sample from the model and yield intermediate samples from
        each timestep of DDIM.

        Same usage as p_sample_loop_progressive().
        """
        if device is None:
            device = next(model.parameters()).device
        assert isinstance(shape, (tuple, list))
        if noise is not None:
            img = noise
        else:
            img = th.randn(*shape, device=device) * temp
        indices = list(range(self.num_timesteps))[::-1]

        if progress:
            # Lazy import so that we don't depend on tqdm.
            from tqdm.auto import tqdm

            indices = tqdm(indices)

        for i in indices:
            t = th.tensor([i] * shape[0], device=device)
            with th.no_grad():
                out = self.ddim_sample(
                    model,
                    img,
                    t,
                    clip_denoised=clip_denoised,
                    denoised_fn=denoised_fn,
                    cond_fn=cond_fn,
                    model_kwargs=model_kwargs,
                    eta=eta,
                )
                yield self.unscale_out_dict(out)
                img = out["sample"]

    def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=False, model_kwargs=None):
        """
        Get a term for the variational lower-bound.

        The resulting units are bits (rather than nats, as one might expect).
        This allows for comparison to other papers.

        :return: a dict with the following keys:
                 - 'output': a shape [N] tensor of NLLs or KLs.
                 - 'pred_xstart': the x_0 predictions.
        """
        true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
            x_start=x_start, x_t=x_t, t=t
        )
        out = self.p_mean_variance(
            model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
        )
        kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
        kl = mean_flat(kl) / np.log(2.0)

        decoder_nll = -discretized_gaussian_log_likelihood(
            x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
        )
        if not self.discretized_t0:
            decoder_nll = th.zeros_like(decoder_nll)
        assert decoder_nll.shape == x_start.shape
        decoder_nll = mean_flat(decoder_nll) / np.log(2.0)

        # At the first timestep return the decoder NLL,
        # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
        output = th.where((t == 0), decoder_nll, kl)
        return {
            "output": output,
            "pred_xstart": out["pred_xstart"],
            "extra": out["extra"],
        }

    def training_losses(
        self, model, x_start, t, model_kwargs=None, noise=None
    ) -> Dict[str, th.Tensor]:
        """
        Compute training losses for a single timestep.

        :param model: the model to evaluate loss on.
        :param x_start: the [N x C x ...] tensor of inputs.
        :param t: a batch of timestep indices.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :param noise: if specified, the specific Gaussian noise to try to remove.
        :return: a dict with the key "loss" containing a tensor of shape [N].
                 Some mean or variance settings may also have other keys.
        """
        x_start = self.scale_channels(x_start)
        if model_kwargs is None:
            model_kwargs = {}
        if noise is None:
            noise = th.randn_like(x_start)
        x_t = self.q_sample(x_start, t, noise=noise)

        terms = {}

        if self.loss_type == "kl" or self.loss_type == "rescaled_kl":
            vb_terms = self._vb_terms_bpd(
                model=model,
                x_start=x_start,
                x_t=x_t,
                t=t,
                clip_denoised=False,
                model_kwargs=model_kwargs,
            )
            terms["loss"] = vb_terms["output"]
            if self.loss_type == "rescaled_kl":
                terms["loss"] *= self.num_timesteps
            extra = vb_terms["extra"]
        elif self.loss_type == "mse" or self.loss_type == "rescaled_mse":
            model_output = model(x_t, t, **model_kwargs)
            if isinstance(model_output, tuple):
                model_output, extra = model_output
            else:
                extra = {}

            if self.model_var_type in [
                "learned",
                "learned_range",
            ]:
                B, C = x_t.shape[:2]
                assert model_output.shape == (
                    B,
                    C * 2,
                    *x_t.shape[2:],
                ), f"{model_output.shape} != {(B, C * 2, *x_t.shape[2:])}"
                model_output, model_var_values = th.split(model_output, C, dim=1)
                # Learn the variance using the variational bound, but don't let
                # it affect our mean prediction.
                frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
                terms["vb"] = self._vb_terms_bpd(
                    model=lambda *args, r=frozen_out: r,
                    x_start=x_start,
                    x_t=x_t,
                    t=t,
                    clip_denoised=False,
                )["output"]
                if self.loss_type == "rescaled_mse":
                    # Divide by 1000 for equivalence with initial implementation.
                    # Without a factor of 1/1000, the VB term hurts the MSE term.
                    terms["vb"] *= self.num_timesteps / 1000.0

            target = {
                "x_prev": self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
                "x_start": x_start,
                "epsilon": noise,
            }[self.model_mean_type]
            assert model_output.shape == target.shape == x_start.shape
            terms["mse"] = mean_flat((target - model_output) ** 2)
            if "vb" in terms:
                terms["loss"] = terms["mse"] + terms["vb"]
            else:
                terms["loss"] = terms["mse"]
        else:
            raise NotImplementedError(self.loss_type)

        if "losses" in extra:
            terms.update({k: loss for k, (loss, _scale) in extra["losses"].items()})
            for loss, scale in extra["losses"].values():
                terms["loss"] = terms["loss"] + loss * scale

        return terms

    def _prior_bpd(self, x_start):
        """
        Get the prior KL term for the variational lower-bound, measured in
        bits-per-dim.

        This term can't be optimized, as it only depends on the encoder.

        :param x_start: the [N x C x ...] tensor of inputs.
        :return: a batch of [N] KL values (in bits), one per batch element.
        """
        batch_size = x_start.shape[0]
        t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
        kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
        return mean_flat(kl_prior) / np.log(2.0)

    def calc_bpd_loop(self, model, x_start, clip_denoised=False, model_kwargs=None):
        """
        Compute the entire variational lower-bound, measured in bits-per-dim,
        as well as other related quantities.

        :param model: the model to evaluate loss on.
        :param x_start: the [N x C x ...] tensor of inputs.
        :param clip_denoised: if True, clip denoised samples.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.

        :return: a dict containing the following keys:
                 - total_bpd: the total variational lower-bound, per batch element.
                 - prior_bpd: the prior term in the lower-bound.
                 - vb: an [N x T] tensor of terms in the lower-bound.
                 - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
                 - mse: an [N x T] tensor of epsilon MSEs for each timestep.
        """
        device = x_start.device
        batch_size = x_start.shape[0]

        vb = []
        xstart_mse = []
        mse = []
        for t in list(range(self.num_timesteps))[::-1]:
            t_batch = th.tensor([t] * batch_size, device=device)
            noise = th.randn_like(x_start)
            x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
            # Calculate VLB term at the current timestep
            with th.no_grad():
                out = self._vb_terms_bpd(
                    model,
                    x_start=x_start,
                    x_t=x_t,
                    t=t_batch,
                    clip_denoised=clip_denoised,
                    model_kwargs=model_kwargs,
                )
            vb.append(out["output"])
            xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
            eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
            mse.append(mean_flat((eps - noise) ** 2))

        vb = th.stack(vb, dim=1)
        xstart_mse = th.stack(xstart_mse, dim=1)
        mse = th.stack(mse, dim=1)

        prior_bpd = self._prior_bpd(x_start)
        total_bpd = vb.sum(dim=1) + prior_bpd
        return {
            "total_bpd": total_bpd,
            "prior_bpd": prior_bpd,
            "vb": vb,
            "xstart_mse": xstart_mse,
            "mse": mse,
        }

    def scale_channels(self, x: th.Tensor) -> th.Tensor:
        if self.channel_scales is not None:
            x = x * th.from_numpy(self.channel_scales).to(x).reshape(
                [1, -1, *([1] * (len(x.shape) - 2))]
            )
        if self.channel_biases is not None:
            x = x + th.from_numpy(self.channel_biases).to(x).reshape(
                [1, -1, *([1] * (len(x.shape) - 2))]
            )
        return x

    def unscale_channels(self, x: th.Tensor) -> th.Tensor:
        if self.channel_biases is not None:
            x = x - th.from_numpy(self.channel_biases).to(x).reshape(
                [1, -1, *([1] * (len(x.shape) - 2))]
            )
        if self.channel_scales is not None:
            x = x / th.from_numpy(self.channel_scales).to(x).reshape(
                [1, -1, *([1] * (len(x.shape) - 2))]
            )
        return x

    def unscale_out_dict(
        self, out: Dict[str, Union[th.Tensor, Any]]
    ) -> Dict[str, Union[th.Tensor, Any]]:
        return {
            k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v) for k, v in out.items()
        }


class SpacedDiffusion(GaussianDiffusion):
    """
    A diffusion process which can skip steps in a base diffusion process.
    :param use_timesteps: (unordered) timesteps from the original diffusion
                          process to retain.
    :param kwargs: the kwargs to create the base diffusion process.
    """

    def __init__(self, use_timesteps: Iterable[int], **kwargs):
        self.use_timesteps = set(use_timesteps)
        self.timestep_map = []
        self.original_num_steps = len(kwargs["betas"])

        base_diffusion = GaussianDiffusion(**kwargs)  # pylint: disable=missing-kwoa
        last_alpha_cumprod = 1.0
        new_betas = []
        for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
            if i in self.use_timesteps:
                new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
                last_alpha_cumprod = alpha_cumprod
                self.timestep_map.append(i)
        kwargs["betas"] = np.array(new_betas)
        super().__init__(**kwargs)

    def p_mean_variance(self, model, *args, **kwargs):
        return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)

    def training_losses(self, model, *args, **kwargs):
        return super().training_losses(self._wrap_model(model), *args, **kwargs)

    def condition_mean(self, cond_fn, *args, **kwargs):
        return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)

    def condition_score(self, cond_fn, *args, **kwargs):
        return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)

    def _wrap_model(self, model):
        if isinstance(model, _WrappedModel):
            return model
        return _WrappedModel(model, self.timestep_map, self.original_num_steps)


class _WrappedModel:
    def __init__(self, model, timestep_map, original_num_steps):
        self.model = model
        self.timestep_map = timestep_map
        self.original_num_steps = original_num_steps

    def __call__(self, x, ts, **kwargs):
        map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
        new_ts = map_tensor[ts]
        return self.model(x, new_ts, **kwargs)


def _extract_into_tensor(arr, timesteps, broadcast_shape):
    """
    Extract values from a 1-D numpy array for a batch of indices.

    :param arr: the 1-D numpy array.
    :param timesteps: a tensor of indices into the array to extract.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res + th.zeros(broadcast_shape, device=timesteps.device)


def normal_kl(mean1, logvar1, mean2, logvar2):
    """
    Compute the KL divergence between two gaussians.
    Shapes are automatically broadcasted, so batches can be compared to
    scalars, among other use cases.
    """
    tensor = None
    for obj in (mean1, logvar1, mean2, logvar2):
        if isinstance(obj, th.Tensor):
            tensor = obj
            break
    assert tensor is not None, "at least one argument must be a Tensor"

    # Force variances to be Tensors. Broadcasting helps convert scalars to
    # Tensors, but it does not work for th.exp().
    logvar1, logvar2 = [
        x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)
    ]

    return 0.5 * (
        -1.0
        + logvar2
        - logvar1
        + th.exp(logvar1 - logvar2)
        + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
    )


def approx_standard_normal_cdf(x):
    """
    A fast approximation of the cumulative distribution function of the
    standard normal.
    """
    return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))


def discretized_gaussian_log_likelihood(x, *, means, log_scales):
    """
    Compute the log-likelihood of a Gaussian distribution discretizing to a
    given image.
    :param x: the target images. It is assumed that this was uint8 values,
              rescaled to the range [-1, 1].
    :param means: the Gaussian mean Tensor.
    :param log_scales: the Gaussian log stddev Tensor.
    :return: a tensor like x of log probabilities (in nats).
    """
    assert x.shape == means.shape == log_scales.shape
    centered_x = x - means
    inv_stdv = th.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
    cdf_plus = approx_standard_normal_cdf(plus_in)
    min_in = inv_stdv * (centered_x - 1.0 / 255.0)
    cdf_min = approx_standard_normal_cdf(min_in)
    log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
    log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
    cdf_delta = cdf_plus - cdf_min
    log_probs = th.where(
        x < -0.999,
        log_cdf_plus,
        th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
    )
    assert log_probs.shape == x.shape
    return log_probs


def mean_flat(tensor):
    """
    Take the mean over all non-batch dimensions.
    """
    return tensor.flatten(1).mean(1)


================================================
FILE: shap_e/diffusion/k_diffusion.py
================================================
"""
Based on: https://github.com/crowsonkb/k-diffusion

Copyright (c) 2022 Katherine Crowson

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import numpy as np
import torch as th

from .gaussian_diffusion import GaussianDiffusion, mean_flat


class KarrasDenoiser:
    def __init__(self, sigma_data: float = 0.5):
        self.sigma_data = sigma_data

    def get_snr(self, sigmas):
        return sigmas**-2

    def get_sigmas(self, sigmas):
        return sigmas

    def get_scalings(self, sigma):
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
        c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
        return c_skip, c_out, c_in

    def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None):
        if model_kwargs is None:
            model_kwargs = {}
        if noise is None:
            noise = th.randn_like(x_start)

        terms = {}

        dims = x_start.ndim
        x_t = x_start + noise * append_dims(sigmas, dims)
        c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)]
        model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs)
        target = (x_start - c_skip * x_t) / c_out

        terms["mse"] = mean_flat((model_output - target) ** 2)
        terms["xs_mse"] = mean_flat((denoised - x_start) ** 2)

        if "vb" in terms:
            terms["loss"] = terms["mse"] + terms["vb"]
        else:
            terms["loss"] = terms["mse"]

        return terms

    def denoise(self, model, x_t, sigmas, **model_kwargs):
        c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)]
        rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
        model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
        denoised = c_out * model_output + c_skip * x_t
        return model_output, denoised


class GaussianToKarrasDenoiser:
    def __init__(self, model, diffusion):
        from scipy import interpolate

        self.model = model
        self.diffusion = diffusion
        self.alpha_cumprod_to_t = interpolate.interp1d(
            diffusion.alphas_cumprod, np.arange(0, diffusion.num_timesteps)
        )

    def sigma_to_t(self, sigma):
        alpha_cumprod = 1.0 / (sigma**2 + 1)
        if alpha_cumprod > self.diffusion.alphas_cumprod[0]:
            return 0
        elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]:
            return self.diffusion.num_timesteps - 1
        else:
            return float(self.alpha_cumprod_to_t(alpha_cumprod))

    def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None):
        t = th.tensor(
            [self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()],
            dtype=th.long,
            device=sigmas.device,
        )
        c_in = append_dims(1.0 / (sigmas**2 + 1) ** 0.5, x_t.ndim)
        out = self.diffusion.p_mean_variance(
            self.model, x_t * c_in, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
        )
        return None, out["pred_xstart"]


def karras_sample(*args, **kwargs):
    last = None
    for x in karras_sample_progressive(*args, **kwargs):
        last = x["x"]
    return last


def karras_sample_progressive(
    diffusion,
    model,
    shape,
    steps,
    clip_denoised=True,
    progress=False,
    model_kwargs=None,
    device=None,
    sigma_min=0.002,
    sigma_max=80,  # higher for highres?
    rho=7.0,
    sampler="heun",
    s_churn=0.0,
    s_tmin=0.0,
    s_tmax=float("inf"),
    s_noise=1.0,
    guidance_scale=0.0,
):
    sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)
    x_T = th.randn(*shape, device=device) * sigma_max
    sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[
        sampler
    ]

    if sampler != "ancestral":
        sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise)
    else:
        sampler_args = {}

    if isinstance(diffusion, KarrasDenoiser):

        def denoiser(x_t, sigma):
            _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs)
            if clip_denoised:
                denoised = denoised.clamp(-1, 1)
            return denoised

    elif isinstance(diffusion, GaussianDiffusion):
        model = GaussianToKarrasDenoiser(model, diffusion)

        def denoiser(x_t, sigma):
            _, denoised = model.denoise(
                x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs
            )
            return denoised

    else:
        raise NotImplementedError

    if guidance_scale != 0 and guidance_scale != 1:

        def guided_denoiser(x_t, sigma):
            x_t = th.cat([x_t, x_t], dim=0)
            sigma = th.cat([sigma, sigma], dim=0)
            x_0 = denoiser(x_t, sigma)
            cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0)
            x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0)
            return x_0

    else:
        guided_denoiser = denoiser

    for obj in sample_fn(
        guided_denoiser,
        x_T,
        sigmas,
        progress=progress,
        **sampler_args,
    ):
        if isinstance(diffusion, GaussianDiffusion):
            yield diffusion.unscale_out_dict(obj)
        else:
            yield obj


def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
    """Constructs the noise schedule of Karras et al. (2022)."""
    ramp = th.linspace(0, 1, n)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)


def to_d(x, sigma, denoised):
    """Converts a denoiser output to a Karras ODE derivative."""
    return (x - denoised) / append_dims(sigma, x.ndim)


def get_ancestral_step(sigma_from, sigma_to):
    """Calculates the noise level (sigma_down) to step down to and the amount
    of noise to add (sigma_up) when doing an ancestral sampling step."""
    sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
    sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
    return sigma_down, sigma_up


@th.no_grad()
def sample_euler_ancestral(model, x, sigmas, progress=False):
    """Ancestral sampling with Euler method steps."""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        denoised = model(x, sigmas[i] * s_in)
        sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
        yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "pred_xstart": denoised}
        d = to_d(x, sigmas[i], denoised)
        # Euler method
        dt = sigma_down - sigmas[i]
        x = x + d * dt
        x = x + th.randn_like(x) * sigma_up
    yield {"x": x, "pred_xstart": x}


@th.no_grad()
def sample_heun(
    denoiser,
    x,
    sigmas,
    progress=False,
    s_churn=0.0,
    s_tmin=0.0,
    s_tmax=float("inf"),
    s_noise=1.0,
):
    """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        gamma = (
            min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
        )
        eps = th.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
        denoised = denoiser(x, sigma_hat * s_in)
        d = to_d(x, sigma_hat, denoised)
        yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "pred_xstart": denoised}
        dt = sigmas[i + 1] - sigma_hat
        if sigmas[i + 1] == 0:
            # Euler method
            x = x + d * dt
        else:
            # Heun's method
            x_2 = x + d * dt
            denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in)
            d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
            d_prime = (d + d_2) / 2
            x = x + d_prime * dt
    yield {"x": x, "pred_xstart": denoised}


@th.no_grad()
def sample_dpm(
    denoiser,
    x,
    sigmas,
    progress=False,
    s_churn=0.0,
    s_tmin=0.0,
    s_tmax=float("inf"),
    s_noise=1.0,
):
    """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        gamma = (
            min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
        )
        eps = th.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
        denoised = denoiser(x, sigma_hat * s_in)
        d = to_d(x, sigma_hat, denoised)
        yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised}
        # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
        sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
        dt_1 = sigma_mid - sigma_hat
        dt_2 = sigmas[i + 1] - sigma_hat
        x_2 = x + d * dt_1
        denoised_2 = denoiser(x_2, sigma_mid * s_in)
        d_2 = to_d(x_2, sigma_mid, denoised_2)
        x = x + d_2 * dt_2
    yield {"x": x, "pred_xstart": denoised}


def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
    return x[(...,) + (None,) * dims_to_append]


def append_zero(x):
    return th.cat([x, x.new_zeros([1])])


================================================
FILE: shap_e/diffusion/sample.py
================================================
from typing import Any, Callable, Dict, Optional

import torch
import torch.nn as nn

from .gaussian_diffusion import GaussianDiffusion
from .k_diffusion import karras_sample

DEFAULT_KARRAS_STEPS = 64
DEFAULT_KARRAS_SIGMA_MIN = 1e-3
DEFAULT_KARRAS_SIGMA_MAX = 160
DEFAULT_KARRAS_S_CHURN = 0.0


def uncond_guide_model(
    model: Callable[..., torch.Tensor], scale: float
) -> Callable[..., torch.Tensor]:
    def model_fn(x_t, ts, **kwargs):
        half = x_t[: len(x_t) // 2]
        combined = torch.cat([half, half], dim=0)
        model_out = model(combined, ts, **kwargs)
        eps, rest = model_out[:, :3], model_out[:, 3:]
        cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
        half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)
        return torch.cat([eps, rest], dim=1)

    return model_fn


def sample_latents(
    *,
    batch_size: int,
    model: nn.Module,
    diffusion: GaussianDiffusion,
    model_kwargs: Dict[str, Any],
    guidance_scale: float,
    clip_denoised: bool,
    use_fp16: bool,
    use_karras: bool,
    karras_steps: int,
    sigma_min: float,
    sigma_max: float,
    s_churn: float,
    device: Optional[torch.device] = None,
    progress: bool = False,
) -> torch.Tensor:
    sample_shape = (batch_size, model.d_latent)

    if device is None:
        device = next(model.parameters()).device

    if hasattr(model, "cached_model_kwargs"):
        model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
    if guidance_scale != 1.0 and guidance_scale != 0.0:
        for k, v in model_kwargs.copy().items():
            model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)

    sample_shape = (batch_size, model.d_latent)
    with torch.autocast(device_type=device.type, enabled=use_fp16):
        if use_karras:
            samples = karras_sample(
                diffusion=diffusion,
                model=model,
                shape=sample_shape,
                steps=karras_steps,
                clip_denoised=clip_denoised,
                model_kwargs=model_kwargs,
                device=device,
                sigma_min=sigma_min,
                sigma_max=sigma_max,
                s_churn=s_churn,
                guidance_scale=guidance_scale,
                progress=progress,
            )
        else:
            internal_batch_size = batch_size
            if guidance_scale != 1.0:
                model = uncond_guide_model(model, guidance_scale)
                internal_batch_size *= 2
            samples = diffusion.p_sample_loop(
                model,
                shape=(internal_batch_size, *sample_shape[1:]),
                model_kwargs=model_kwargs,
                device=device,
                clip_denoised=clip_denoised,
                progress=progress,
            )

    return samples


================================================
FILE: shap_e/examples/encode_model.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from shap_e.models.download import load_model\n",
    "from shap_e.util.data_util import load_or_create_multimodal_batch\n",
    "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xm = load_model('transmitter', device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"example_data/cactus/object.obj\"\n",
    "\n",
    "# This may take a few minutes, since it requires rendering the model twice\n",
    "# in two different modes.\n",
    "batch = load_or_create_multimodal_batch(\n",
    "    device,\n",
    "    model_path=model_path,\n",
    "    mv_light_mode=\"basic\",\n",
    "    mv_image_size=256,\n",
    "    cache_dir=\"example_data/cactus/cached\",\n",
    "    verbose=True, # this will show Blender output during renders\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    latent = xm.encoder.encode_to_bottleneck(batch)\n",
    "\n",
    "    render_mode = 'stf' # you can change this to 'nerf'\n",
    "    size = 128 # recommended that you lower resolution when using nerf\n",
    "\n",
    "    cameras = create_pan_cameras(size, device)\n",
    "    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
    "    display(gif_widget(images))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: shap_e/examples/example_data/cactus/material.mtl
================================================
newmtl mat0
Ka 0.0000 0.7000 0.0000
Kd 0.0000 0.7000 0.0000
Ks 0.0000 0.0000 0.0000
newmtl mat1
Ka 0.6600 0.4400 0.2000
Kd 0.6600 0.4400 0.2000
Ks 0.0000 0.0000 0.0000
newmtl mat2
Ka 0.3000 0.3000 0.3000
Kd 0.3000 0.3000 0.3000
Ks 0.0000 0.0000 0.0000
newmtl mat3
Ka 0.0000 0.5000 0.0000
Kd 0.0000 0.5000 0.0000
Ks 0.0000 0.0000 0.0000
newmtl mat4
Ka 0.0000 0.5667 0.0000
Kd 0.0000 0.5667 0.0000
Ks 0.0000 0.0000 0.0000
newmtl mat5
Ka 0.5400 0.3933 0.2333
Kd 0.5400 0.3933 0.2333
Ks 0.0000 0.0000 0.0000
newmtl mat6
Ka 0.0000 0.6333 0.0000
Kd 0.0000 0.6333 0.0000
Ks 0.0000 0.0000 0.0000
newmtl mat7
Ka 0.2000 0.3667 0.2000
Kd 0.2000 0.3667 0.2000
Ks 0.0000 0.0000 0.0000
newmtl mat8
Ka 0.4200 0.3467 0.2667
Kd 0.4200 0.3467 0.2667
Ks 0.0000 0.0000 0.0000
newmtl mat9
Ka 0.1000 0.4333 0.1000
Kd 0.1000 0.4333 0.1000
Ks 0.0000 0.0000 0.0000
newmtl mat10
Ka 0.1000 0.5667 0.1000
Kd 0.1000 0.5667 0.1000
Ks 0.0000 0.0000 0.0000
newmtl mat11
Ka 0.2000 0.4333 0.2000
Kd 0.2000 0.4333 0.2000
Ks 0.0000 0.0000 0.0000
newmtl mat12
Ka 0.1000 0.5000 0.1000
Kd 0.1000 0.5000 0.1000
Ks 0.0000 0.0000 0.0000


================================================
FILE: shap_e/examples/example_data/cactus/object.obj
================================================
[File too large to display: 17.2 MB]

================================================
FILE: shap_e/examples/sample_image_to_3d.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "964ccced",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from shap_e.diffusion.sample import sample_latents\n",
    "from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n",
    "from shap_e.models.download import load_model, load_config\n",
    "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget\n",
    "from shap_e.util.image_util import load_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eed3a76",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d922637",
   "metadata": {},
   "outputs": [],
   "source": [
    "xm = load_model('transmitter', device=device)\n",
    "model = load_model('image300M', device=device)\n",
    "diffusion = diffusion_from_config(load_config('diffusion'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53d329d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 4\n",
    "guidance_scale = 3.0\n",
    "\n",
    "# To get the best result, you should remove the background and show only the object of interest to the model.\n",
    "image = load_image(\"example_data/corgi.png\")\n",
    "\n",
    "latents = sample_latents(\n",
    "    batch_size=batch_size,\n",
    "    model=model,\n",
    "    diffusion=diffusion,\n",
    "    guidance_scale=guidance_scale,\n",
    "    model_kwargs=dict(images=[image] * batch_size),\n",
    "    progress=True,\n",
    "    clip_denoised=True,\n",
    "    use_fp16=True,\n",
    "    use_karras=True,\n",
    "    karras_steps=64,\n",
    "    sigma_min=1e-3,\n",
    "    sigma_max=160,\n",
    "    s_churn=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "633da2ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "render_mode = 'nerf' # you can change this to 'stf' for mesh rendering\n",
    "size = 64 # this is the size of the renders; higher values take longer to render.\n",
    "\n",
    "cameras = create_pan_cameras(size, device)\n",
    "for i, latent in enumerate(latents):\n",
    "    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
    "    display(gif_widget(images))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: shap_e/examples/sample_text_to_3d.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "964ccced",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from shap_e.diffusion.sample import sample_latents\n",
    "from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n",
    "from shap_e.models.download import load_model, load_config\n",
    "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eed3a76",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d922637",
   "metadata": {},
   "outputs": [],
   "source": [
    "xm = load_model('transmitter', device=device)\n",
    "model = load_model('text300M', device=device)\n",
    "diffusion = diffusion_from_config(load_config('diffusion'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53d329d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 4\n",
    "guidance_scale = 15.0\n",
    "prompt = \"a shark\"\n",
    "\n",
    "latents = sample_latents(\n",
    "    batch_size=batch_size,\n",
    "    model=model,\n",
    "    diffusion=diffusion,\n",
    "    guidance_scale=guidance_scale,\n",
    "    model_kwargs=dict(texts=[prompt] * batch_size),\n",
    "    progress=True,\n",
    "    clip_denoised=True,\n",
    "    use_fp16=True,\n",
    "    use_karras=True,\n",
    "    karras_steps=64,\n",
    "    sigma_min=1e-3,\n",
    "    sigma_max=160,\n",
    "    s_churn=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "633da2ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "render_mode = 'nerf' # you can change this to 'stf'\n",
    "size = 64 # this is the size of the renders; higher values take longer to render.\n",
    "\n",
    "cameras = create_pan_cameras(size, device)\n",
    "for i, latent in enumerate(latents):\n",
    "    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
    "    display(gif_widget(images))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85a4dce4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example of saving the latents as meshes.\n",
    "from shap_e.util.notebooks import decode_latent_mesh\n",
    "\n",
    "for i, latent in enumerate(latents):\n",
    "    t = decode_latent_mesh(xm, latent).tri_mesh()\n",
    "    with open(f'example_mesh_{i}.ply', 'wb') as f:\n",
    "        t.write_ply(f)\n",
    "    with open(f'example_mesh_{i}.obj', 'w') as f:\n",
    "        t.write_obj(f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: shap_e/models/__init__.py
================================================


================================================
FILE: shap_e/models/configs.py
================================================
from typing import Any, Dict, Union

import blobfile as bf
import torch
import torch.nn as nn
import yaml

from shap_e.models.generation.latent_diffusion import SplitVectorDiffusion
from shap_e.models.generation.perceiver import PointDiffusionPerceiver
from shap_e.models.generation.pooled_mlp import PooledMLP
from shap_e.models.generation.transformer import (
    CLIPImageGridPointDiffusionTransformer,
    CLIPImageGridUpsamplePointDiffusionTransformer,
    CLIPImagePointDiffusionTransformer,
    PointDiffusionTransformer,
    UpsamplePointDiffusionTransformer,
)
from shap_e.models.nerf.model import MLPNeRFModel, VoidNeRFModel
from shap_e.models.nerf.renderer import OneStepNeRFRenderer, TwoStepNeRFRenderer
from shap_e.models.nerstf.mlp import MLPDensitySDFModel, MLPNeRSTFModel
from shap_e.models.nerstf.renderer import NeRSTFRenderer
from shap_e.models.nn.meta import batch_meta_state_dict
from shap_e.models.stf.mlp import MLPSDFModel, MLPTextureFieldModel
from shap_e.models.stf.renderer import STFRenderer
from shap_e.models.transmitter.base import ChannelsDecoder, Transmitter, VectorDecoder
from shap_e.models.transmitter.channels_encoder import (
    PointCloudPerceiverChannelsEncoder,
    PointCloudTransformerChannelsEncoder,
)
from shap_e.models.transmitter.multiview_encoder import MultiviewTransformerEncoder
from shap_e.models.transmitter.pc_encoder import (
    PointCloudPerceiverEncoder,
    PointCloudTransformerEncoder,
)
from shap_e.models.volume import BoundingBoxVolume, SphericalVolume, UnboundedVolume


def model_from_config(config: Union[str, Dict[str, Any]], device: torch.device) -> nn.Module:
    if isinstance(config, str):
        with bf.BlobFile(config, "rb") as f:
            obj = yaml.load(f, Loader=yaml.SafeLoader)
        return model_from_config(obj, device=device)

    config = config.copy()
    name = config.pop("name")

    if name == "PointCloudTransformerEncoder":
        return PointCloudTransformerEncoder(device=device, dtype=torch.float32, **config)
    elif name == "PointCloudPerceiverEncoder":
        return PointCloudPerceiverEncoder(device=device, dtype=torch.float32, **config)
    elif name == "PointCloudTransformerChannelsEncoder":
        return PointCloudTransformerChannelsEncoder(device=device, dtype=torch.float32, **config)
    elif name == "PointCloudPerceiverChannelsEncoder":
        return PointCloudPerceiverChannelsEncoder(device=device, dtype=torch.float32, **config)
    elif name == "MultiviewTransformerEncoder":
        return MultiviewTransformerEncoder(device=device, dtype=torch.float32, **config)
    elif name == "Transmitter":
        renderer = model_from_config(config.pop("renderer"), device=device)
        param_shapes = {
            k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()
        }
        encoder_config = config.pop("encoder").copy()
        encoder_config["param_shapes"] = param_shapes
        encoder = model_from_config(encoder_config, device=device)
        return Transmitter(encoder=encoder, renderer=renderer, **config)
    elif name == "VectorDecoder":
        renderer = model_from_config(config.pop("renderer"), device=device)
        param_shapes = {
            k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()
        }
        return VectorDecoder(param_shapes=param_shapes, renderer=renderer, device=device, **config)
    elif name == "ChannelsDecoder":
        renderer = model_from_config(config.pop("renderer"), device=device)
        param_shapes = {
            k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()
        }
        return ChannelsDecoder(
            param_shapes=param_shapes, renderer=renderer, device=device, **config
        )
    elif name == "OneStepNeRFRenderer":
        config = config.copy()
        for field in [
            # Required
            "void_model",
            "foreground_model",
            "volume",
            # Optional to use NeRF++
            "background_model",
            "outer_volume",
        ]:
            if field in config:
                config[field] = model_from_config(config.pop(field).copy(), device)
        return OneStepNeRFRenderer(device=device, **config)
    elif name == "TwoStepNeRFRenderer":
        config = config.copy()
        for field in [
            # Required
            "void_model",
            "coarse_model",
            "fine_model",
            "volume",
            # Optional to use NeRF++
            "coarse_background_model",
            "fine_background_model",
            "outer_volume",
        ]:
            if field in config:
                config[field] = model_from_config(config.pop(field).copy(), device)
        return TwoStepNeRFRenderer(device=device, **config)
    elif name == "PooledMLP":
        return PooledMLP(device, **config)
    elif name == "PointDiffusionTransformer":
        return PointDiffusionTransformer(device=device, dtype=torch.float32, **config)
    elif name == "PointDiffusionPerceiver":
        return PointDiffusionPerceiver(device=device, dtype=torch.float32, **config)
    elif name == "CLIPImagePointDiffusionTransformer":
        return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
    elif name == "CLIPImageGridPointDiffusionTransformer":
        return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config)
    elif name == "UpsamplePointDiffusionTransformer":
        return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
    elif name == "CLIPImageGridUpsamplePointDiffusionTransformer":
        return CLIPImageGridUpsamplePointDiffusionTransformer(
            device=device, dtype=torch.float32, **config
        )
    elif name == "SplitVectorDiffusion":
        inner_config = config.pop("inner")
        d_latent = config.pop("d_latent")
        latent_ctx = config.pop("latent_ctx", 1)
        inner_config["input_channels"] = d_latent // latent_ctx
        inner_config["n_ctx"] = latent_ctx
        inner_config["output_channels"] = d_latent // latent_ctx * 2
        inner_model = model_from_config(inner_config, device)
        return SplitVectorDiffusion(
            device=device, wrapped=inner_model, n_ctx=latent_ctx, d_latent=d_latent
        )
    elif name == "STFRenderer":
        config = config.copy()
        for field in ["sdf", "tf", "volume"]:
            config[field] = model_from_config(config.pop(field), device)
        return STFRenderer(device=device, **config)
    elif name == "NeRSTFRenderer":
        config = config.copy()
        for field in ["sdf", "tf", "nerstf", "void", "volume"]:
            if field not in config:
                continue
            config[field] = model_from_config(config.pop(field), device)
        config.setdefault("sdf", None)
        config.setdefault("tf", None)
        config.setdefault("nerstf", None)
        return NeRSTFRenderer(device=device, **config)

    model_cls = {
        "MLPSDFModel": MLPSDFModel,
        "MLPTextureFieldModel": MLPTextureFieldModel,
        "MLPNeRFModel": MLPNeRFModel,
        "MLPDensitySDFModel": MLPDensitySDFModel,
        "MLPNeRSTFModel": MLPNeRSTFModel,
        "VoidNeRFModel": VoidNeRFModel,
        "BoundingBoxVolume": BoundingBoxVolume,
        "SphericalVolume": SphericalVolume,
        "UnboundedVolume": UnboundedVolume,
    }[name]
    return model_cls(device=device, **config)


================================================
FILE: shap_e/models/download.py
================================================
"""
Adapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py
"""

import hashlib
import os
from functools import lru_cache
from typing import Dict, Optional

import requests
import torch
import yaml
from filelock import FileLock
from tqdm.auto import tqdm

MODEL_PATHS = {
    "transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt",
    "decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt",
    "text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt",
    "image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt",
}

CONFIG_PATHS = {
    "transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml",
    "decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml",
    "text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml",
    "image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml",
    "diffusion": "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml",
}

URL_HASHES = {
    "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt": "af02a0b85a8abdfb3919584b63c540ba175f6ad4790f574a7fef4617e5acdc3b",
    "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt": "d7e7ebbfe3780499ae89b2da5e7c1354012dba5a6abfe295bed42f25c3be1b98",
    "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt": "e6b4fa599a7b3c3b16c222d5f5fe56f9db9289ff0b6575fbe5c11bc97106aad4",
    "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt": "cb8072c64bbbcf6910488814d212227de5db291780d4ea99c6152f9346cf12aa",
    "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml": "ffe1bcb405104a37d9408391182ab118a4ef313c391e07689684f1f62071605e",
    "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml": "e6d373649f8e24d85925f4674b9ac41c57aba5f60e42cde6d10f87381326365c",
    "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml": "f290beeea3d3e9ff15db01bde5382b6e549e463060c0744f89c049505be246c1",
    "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml": "4e0745605a533c543c72add803a78d233e2a6401e0abfa0cad58afb4d74ad0b0",
    "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml": "efcb2cd7ee545b2d27223979d41857802448143990572a42645cd09c2942ed57",
}


@lru_cache()
def default_cache_dir() -> str:
    return os.path.join(os.path.abspath(os.getcwd()), "shap_e_model_cache")


def fetch_file_cached(
    url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
) -> str:
    """
    Download the file at the given URL into a local file and return the path.
    If cache_dir is specified, it will be used to download the files.
    Otherwise, default_cache_dir() is used.
    """
    expected_hash = URL_HASHES[url]

    if cache_dir is None:
        cache_dir = default_cache_dir()
    os.makedirs(cache_dir, exist_ok=True)
    local_path = os.path.join(cache_dir, url.split("/")[-1])
    if os.path.exists(local_path):
        check_hash(local_path, expected_hash)
        return local_path

    response = requests.get(url, stream=True)
    size = int(response.headers.get("content-length", "0"))
    with FileLock(local_path + ".lock"):
        if progress:
            pbar = tqdm(total=size, unit="iB", unit_scale=True)
        tmp_path = local_path + ".tmp"
        with open(tmp_path, "wb") as f:
            for chunk in response.iter_content(chunk_size):
                if progress:
                    pbar.update(len(chunk))
                f.write(chunk)
        os.rename(tmp_path, local_path)
        if progress:
            pbar.close()
        check_hash(local_path, expected_hash)
        return local_path


def check_hash(path: str, expected_hash: str):
    actual_hash = hash_file(path)
    if actual_hash != expected_hash:
        raise RuntimeError(
            f"The file {path} should have hash {expected_hash} but has {actual_hash}. "
            "Try deleting it and running this call again."
        )


def hash_file(path: str) -> str:
    sha256_hash = hashlib.sha256()
    with open(path, "rb") as file:
        while True:
            data = file.read(4096)
            if not len(data):
                break
            sha256_hash.update(data)
    return sha256_hash.hexdigest()


def load_config(
    config_name: str,
    progress: bool = False,
    cache_dir: Optional[str] = None,
    chunk_size: int = 4096,
):
    if config_name not in CONFIG_PATHS:
        raise ValueError(
            f"Unknown config name {config_name}. Known names are: {CONFIG_PATHS.keys()}."
        )
    path = fetch_file_cached(
        CONFIG_PATHS[config_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
    )
    with open(path, "r") as f:
        return yaml.safe_load(f)


def load_checkpoint(
    checkpoint_name: str,
    device: torch.device,
    progress: bool = True,
    cache_dir: Optional[str] = None,
    chunk_size: int = 4096,
) -> Dict[str, torch.Tensor]:
    if checkpoint_name not in MODEL_PATHS:
        raise ValueError(
            f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
        )
    path = fetch_file_cached(
        MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
    )
    return torch.load(path, map_location=device)


def load_model(
    model_name: str,
    device: torch.device,
    **kwargs,
) -> Dict[str, torch.Tensor]:
    from .configs import model_from_config

    model = model_from_config(load_config(model_name, **kwargs), device=device)
    model.load_state_dict(load_checkpoint(model_name, device=device, **kwargs))
    model.eval()
    return model


================================================
FILE: shap_e/models/generation/__init__.py
================================================


================================================
FILE: shap_e/models/generation/latent_diffusion.py
================================================
from typing import Any, Dict

import torch
import torch.nn as nn


class SplitVectorDiffusion(nn.Module):
    def __init__(self, *, device: torch.device, wrapped: nn.Module, n_ctx: int, d_latent: int):
        super().__init__()
        self.device = device
        self.n_ctx = n_ctx
        self.d_latent = d_latent
        self.wrapped = wrapped

        if hasattr(self.wrapped, "cached_model_kwargs"):
            self.cached_model_kwargs = self.wrapped.cached_model_kwargs

    def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs):
        h = x.reshape(x.shape[0], self.n_ctx, -1).permute(0, 2, 1)
        pre_channels = h.shape[1]
        h = self.wrapped(h, t, **kwargs)
        assert (
            h.shape[1] == pre_channels * 2
        ), "expected twice as many outputs for variance prediction"
        eps, var = torch.chunk(h, 2, dim=1)
        return torch.cat(
            [
                eps.permute(0, 2, 1).flatten(1),
                var.permute(0, 2, 1).flatten(1),
            ],
            dim=1,
        )


================================================
FILE: shap_e/models/generation/perceiver.py
================================================
import math
from typing import Optional

import torch
import torch.nn as nn

from shap_e.models.nn.checkpoint import checkpoint

from .transformer import MLP, Transformer, init_linear
from .util import timestep_embedding


class MultiheadCrossAttention(nn.Module):
    def __init__(
        self,
        *,
        device: torch.device,
        dtype: torch.dtype,
        n_ctx: int,
        n_data: int,
        width: int,
        heads: int,
        init_scale: float,
        data_width: Optional[int] = None,
    ):
        super().__init__()
        self.n_ctx = n_ctx
        self.n_data = n_data
        self.width = width
        self.heads = heads
        self.data_width = width if data_width is None else data_width
        self.c_q = nn.Linear(width, width, device=device, dtype=dtype)
        self.c_kv = nn.Linear(self.data_width, width * 2, device=device, dtype=dtype)
        self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
        self.attention = QKVMultiheadCrossAttention(
            device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, n_data=n_data
        )
        init_linear(self.c_q, init_scale)
        init_linear(self.c_kv, init_scale)
        init_linear(self.c_proj, init_scale)

    def forward(self, x, data):
        x = self.c_q(x)
        data = self.c_kv(data)
        x = checkpoint(self.attention, (x, data), (), True)
        x = self.c_proj(x)
        return x


class QKVMultiheadCrossAttention(nn.Module):
    def __init__(
        self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, n_data: int
    ):
        super().__init__()
        self.device = device
        self.dtype = dtype
        self.heads = heads
        self.n_ctx = n_ctx
        self.n_data = n_data

    def forward(self, q, kv):
        _, n_ctx, _ = q.shape
        bs, n_data, width = kv.shape
        attn_ch = width // self.heads // 2
        scale = 1 / math.sqrt(math.sqrt(attn_ch))
        q = q.view(bs, n_ctx, self.heads, -1)
        kv = kv.view(bs, n_data, self.heads, -1)
        k, v = torch.split(kv, attn_ch, dim=-1)
        weight = torch.einsum(
            "bthc,bshc->bhts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        wdtype = weight.dtype
        weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
        return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)


class ResidualCrossAttentionBlock(nn.Module):
    def __init__(
        self,
        *,
        device: torch.device,
        dtype: torch.dtype,
        n_ctx: int,
        n_data: int,
        width: int,
        heads: int,
        data_width: Optional[int] = None,
        init_scale: float = 1.0,
    ):
        super().__init__()

        if data_width is None:
            data_width = width

        self.attn = MultiheadCrossAttention(
            device=device,
            dtype=dtype,
            n_ctx=n_ctx,
            n_data=n_data,
            width=width,
            heads=heads,
            data_width=data_width,
            init_scale=init_scale,
        )
        self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
        self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
        self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
        self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)

    def forward(self, x: torch.Tensor, data: torch.Tensor):
        x = x + self.attn(self.ln_1(x), self.ln_2(data))
        x = x + self.mlp(self.ln_3(x))
        return x


class SimplePerceiver(nn.Module):
    """
    Only does cross attention
    """

    def __init__(
        self,
        *,
        device: torch.device,
        dtype: torch.dtype,
        n_ctx: int,
        n_data: int,
        width: int,
        layers: int,
        heads: int,
        init_scale: float = 0.25,
        data_width: Optional[int] = None,
    ):
        super().__init__()
        self.n_ctx = n_ctx
        self.width = width
        self.layers = layers
        init_scale = init_scale * math.sqrt(1.0 / width)
        self.resblocks = nn.ModuleList(
            [
                ResidualCrossAttentionBlock(
                    device=device,
                    dtype=dtype,
                    n_ctx=n_ctx,
                    n_data=n_data,
                    width=width,
                    heads=heads,
                    init_scale=init_scale,
                    data_width=data_width,
                )
                for _ in range(layers)
            ]
        )

    def forward(self, x: torch.Tensor, data: torch.Tensor):
        for block in self.resblocks:
            x = block(x, data)
        return x


class PointDiffusionPerceiver(nn.Module):
    def __init__(
        self,
        *,
        device: torch.device,
        dtype: torch.dtype,
        input_channels: int = 3,
        output_channels: int = 3,
        n_ctx: int = 1024,
        n_latent: int = 128,
        width: int = 512,
        encoder_layers: int = 12,
        latent_layers: int = 12,
        decoder_layers: int = 12,
        heads: int = 8,
        init_scale: float = 0.25,
    ):
        super().__init__()
        self.time_embed = MLP(
            device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
        )
        self.latent_embed = MLP(
            device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
        )
        self.n_latent = n_latent

        self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
        self.encoder = SimplePerceiver(
            device=device,
            dtype=dtype,
            n_ctx=n_latent,
            n_data=n_ctx,
            width=width,
            layers=encoder_layers,
            heads=heads,
            init_scale=init_scale,
        )
        self.processor = Transformer(
            device=device,
            dtype=dtype,
            n_ctx=n_latent,
            width=width,
            layers=latent_layers,
            heads=heads,
            init_scale=init_scale,
        )
        self.decoder = SimplePerceiver(
            device=device,
            dtype=dtype,
            n_ctx=n_ctx,
            n_data=n_latent,
            width=width,
            layers=decoder_layers,
            heads=heads,
            init_scale=init_scale,
        )
        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
        self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
        self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
        with torch.no_grad():
            self.output_proj.weight.zero_()
            self.output_proj.bias.zero_()

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        :param x: an [N x C x T] tensor.
        :param t: an [N] tensor.
        :return: an [N x C' x T] tensor.
        """
        assert x.shape[-1] == self.decoder.n_ctx
        t_embed = self.time_embed(timestep_embedding(t, self.encoder.width))
        data = self.input_proj(x.permute(0, 2, 1)) + t_embed[:, None]
        data = self.ln_pre(data)

        l = torch.arange(self.n_latent).to(x.device)
        h = self.latent_embed(timestep_embedding(l, self.decoder.width))
        h = h.unsqueeze(0).repeat(x.shape[0], 1, 1)

        h = self.encoder(h, data)
        h = self.processor(h)
        h = self.decoder(data, h)
        h = self.ln_post(h)
        h = self.output_proj(h)
        return h.permute(0, 2, 1)


================================================
FILE: shap_e/models/generation/pooled_mlp.py
================================================
import torch
import torch.nn as nn

from .util import timestep_embedding


class PooledMLP(nn.Module):
    def __init__(
        self,
        device: torch.device,
        *,
        input_channels: int = 3,
        output_channels: int = 6,
        hidden_size: int = 256,
        resblocks: int = 4,
        pool_op: str = "max",
    ):
        super().__init__()
        self.input_embed = nn.Conv1d(input_channels, hidden_size, kernel_size=1, device=device)
        self.time_embed = nn.Linear(hidden_size, hidden_size, device=device)

        blocks = []
        for _ in range(resblocks):
            blocks.append(ResBlock(hidden_size, pool_op, device=device))
        self.sequence = nn.Sequential(*blocks)

        self.out = nn.Conv1d(hidden_size, output_channels, kernel_size=1, device=device)
        with torch.no_grad():
            self.out.bias.zero_()
            self.out.weight.zero_()

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        in_embed = self.input_embed(x)
        t_embed = self.time_embed(timestep_embedding(t, in_embed.shape[1]))
        h = in_embed + t_embed[..., None]
        h = self.sequence(h)
        h = self.out(h)
        return h


class ResBlock(nn.Module):
    def __init__(self, hidden_size: int, pool_op: str, device: torch.device):
        super().__init__()
        assert pool_op in ["mean", "max"]
        self.pool_op = pool_op
        self.body = nn.Sequential(
            nn.SiLU(),
            nn.LayerNorm((hidden_size,), device=device),
            nn.Linear(hidden_size, hidden_size, device=device),
            nn.SiLU(),
            nn.LayerNorm((hidden_size,), device=device),
            nn.Linear(hidden_size, hidden_size, device=device),
        )
        self.gate = nn.Sequential(
            nn.Linear(hidden_size, hidden_size, device=device),
            nn.Tanh(),
        )

    def forward(self, x: torch.Tensor):
        N, C, T = x.shape
        out = self.body(x.permute(0, 2, 1).reshape(N * T, C)).reshape([N, T, C]).permute(0, 2, 1)
        pooled = pool(self.pool_op, x)
        gate = self.gate(pooled)
        return x + out * gate[..., None]


def pool(op_name: str, x: torch.Tensor) -> torch.Tensor:
    if op_name == "max":
        pooled, _ = torch.max(x, dim=-1)
    elif op_name == "mean":
        pooled, _ = torch.mean(x, dim=-1)
    else:
        raise ValueError(f"unknown pool op: {op_name}")
    return pooled


================================================
FILE: shap_e/models/generation/pretrained_clip.py
================================================
from typing import Iterable, List, Optional, Union

import numpy as np
import torch
import torch.nn as nn
from PIL import Image

from shap_e.models.download import default_cache_dir

ImageType = Union[np.ndarray, torch.Tensor, Image.Image]


class ImageCLIP(nn.Module):
    """
    A wrapper around a pre-trained CLIP model that automatically handles
    batches of texts, images, and embeddings.
    """

    def __init__(
        self,
        device: torch.device,
        dtype: Optional[torch.dtype] = torch.float32,
        ensure_used_params: bool = True,
        clip_name: str = "ViT-L/14",
        cache_dir: Optional[str] = None,
    ):
        super().__init__()

        assert clip_name in ["ViT-L/14", "ViT-B/32"]

        self.device = device
        self.ensure_used_params = ensure_used_params

        # Lazy import because of torchvision.
        import clip

        self.clip_model, self.preprocess = clip.load(
            clip_name, device=device, download_root=cache_dir or default_cache_dir()
        )
        self.clip_name = clip_name

        if dtype is not None:
            self.clip_model.to(dtype)
        self._tokenize = clip.tokenize

    @property
    def feature_dim(self) -> int:
        if self.clip_name == "ViT-L/14":
            return 768
        else:
            return 512

    @property
    def grid_size(self) -> int:
        if self.clip_name == "ViT-L/14":
            return 16
        else:
            return 7

    @property
    def grid_feature_dim(self) -> int:
        if self.clip_name == "ViT-L/14":
            return 1024
        else:
            return 768

    def forward(
        self,
        batch_size: int,
        images: Optional[Iterable[Optional[ImageType]]] = None,
        texts: Optional[Iterable[Optional[str]]] = None,
        embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
    ) -> torch.Tensor:
        """
        Generate a batch of embeddings from a mixture of images, texts,
        precomputed embeddings, and possibly empty values.

        For each batch element, at most one of images, texts, and embeddings
        should have a non-None value. Embeddings from multiple modalities
        cannot be mixed for a single batch element. If no modality is provided,
        a zero embedding will be used for the batch element.
        """
        image_seq = [None] * batch_size if images is None else list(images)
        text_seq = [None] * batch_size if texts is None else list(texts)
        embedding_seq = [None] * batch_size if embeddings is None else list(embeddings)
        assert len(image_seq) == batch_size, "number of images should match batch size"
        assert len(text_seq) == batch_size, "number of texts should match batch size"
        assert len(embedding_seq) == batch_size, "number of embeddings should match batch size"

        if self.ensure_used_params:
            return self._static_multimodal_embed(
                images=image_seq, texts=text_seq, embeddings=embedding_seq
            )

        result = torch.zeros((batch_size, self.feature_dim), device=self.device)
        index_images = []
        index_texts = []
        for i, (image, text, emb) in enumerate(zip(image_seq, text_seq, embedding_seq)):
            assert (
                sum([int(image is not None), int(text is not None), int(emb is not None)]) < 2
            ), "only one modality may be non-None per batch element"
            if image is not None:
                index_images.append((i, image))
            elif text is not None:
                index_texts.append((i, text))
            elif emb is not None:
                result[i] = emb.to(result)

        if len(index_images):
            embs = self.embed_images((img for _, img in index_images))
            for (i, _), emb in zip(index_images, embs):
                result[i] = emb.to(result)
        if len(index_texts):
            embs = self.embed_text((text for _, text in index_texts))
            for (i, _), emb in zip(index_texts, embs):
                result[i] = emb.to(result)

        return result

    def _static_multimodal_embed(
        self,
        images: List[Optional[ImageType]] = None,
        texts: List[Optional[str]] = None,
        embeddings: List[Optional[torch.Tensor]] = None,
    ) -> torch.Tensor:
        """
        Like forward(), but always runs all encoders to ensure that
        the forward graph looks the same on every rank.
        """
        image_emb = self.embed_images(images)
        text_emb = self.embed_text(t if t else "" for t in texts)
        joined_embs = torch.stack(
            [
                emb.to(device=self.device, dtype=torch.float32)
                if emb is not None
                else torch.zeros(self.feature_dim, device=self.device)
                for emb in embeddings
            ],
            dim=0,
        )

        image_flag = torch.tensor([x is not None for x in images], device=self.device)[
            :, None
        ].expand_as(image_emb)
        text_flag = torch.tensor([x is not None for x in texts], device=self.device)[
            :, None
        ].expand_as(image_emb)
        emb_flag = torch.tensor([x is not None for x in embeddings], device=self.device)[
            :, None
        ].expand_as(image_emb)

        return (
            image_flag.float() * image_emb
            + text_flag.float() * text_emb
            + emb_flag.float() * joined_embs
            + self.clip_model.logit_scale * 0  # avoid unused parameters
        )

    def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
        """
        :param xs: N images, stored as numpy arrays, tensors, or PIL images.
        :return: an [N x D] tensor of features.
        """
        clip_inputs = self.images_to_tensor(xs)
        results = self.clip_model.encode_image(clip_inputs).float()
        return results / torch.linalg.norm(results, dim=-1, keepdim=True)

    def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
        """
        Embed text prompts as an [N x D] tensor.
        """
        enc = self.clip_model.encode_text(
            self._tokenize(list(prompts), truncate=True).to(self.device)
        ).float()
        return enc / torch.linalg.norm(enc, dim=-1, keepdim=True)

    def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
        """
        Embed images into latent grids.

        :param xs: an iterable of images to embed.
        :return: a tensor of shape [N x C x L], where L = self.grid_size**2.
        """
        if self.ensure_used_params:
            extra_value = 0.0
            for p in self.parameters():
                extra_value = extra_value + p.mean() * 0.0
        else:
            extra_value = 0.0

        x = self.images_to_tensor(xs).to(self.clip_model.dtype)

        # https://github.com/openai/CLIP/blob/4d120f3ec35b30bd0f992f5d8af2d793aad98d2a/clip/model.py#L225
        vt = self.clip_model.visual
        x = vt.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat(
            [
                vt.class_embedding.to(x.dtype)
                + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
                x,
            ],
            dim=1,
        )  # shape = [*, grid ** 2 + 1, width]
        x = x + vt.positional_embedding.to(x.dtype)
        x = vt.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = vt.transformer(x)
        x = x.permute(1, 2, 0)  # LND -> NDL

        return x[..., 1:].contiguous().float() + extra_value

    def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
        return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0).to(self.device)


class FrozenImageCLIP:
    def __init__(self, device: torch.device, **kwargs):
        self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **kwargs)
        for parameter in self.model.parameters():
            parameter.requires_grad_(False)

    @property
    def feature_dim(self) -> int:
        return self.model.feature_dim

    @property
    def grid_size(self) -> int:
        return self.model.grid_size

    @property
    def grid_feature_dim(self) -> int:
        return self.model.grid_feature_dim

    def __call__(
        self,
        batch_size: int,
        images: Optional[Iterable[Optional[ImageType]]] = None,
        texts: Optional[Iterable[Optional[str]]] = None,
        embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
    ) -> torch.Tensor:
        # We don't do a no_grad() here so that gradients could still
        # flow to the input embeddings argument.
        # This behavior is currently not used, but it could be.
        return self.model(batch_size=batch_size, images=images, texts=texts, embeddings=embeddings)

    def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
        with torch.no_grad():
            return self.model.embed_images(xs)

    def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
        with torch.no_grad():
            return self.model.embed_text(prompts)

    def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
        with torch.no_grad():
            return self.model.embed_images_grid(xs)


def _image_to_pil(obj: Optional[ImageType]) -> Image.Image:
    if obj is None:
        return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8))
    if isinstance(obj, np.ndarray):
        return Image.fromarray(obj.astype(np.uint8))
    elif isinstance(obj, torch.Tensor):
        return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8))
    else:
        return obj


================================================
FILE: shap_e/models/generation/transformer.py
================================================
import math
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import torch
import torch.nn as nn

from shap_e.models.nn.checkpoint import checkpoint

from .pretrained_clip import FrozenImageCLIP, ImageCLIP, ImageType
from .util import timestep_embedding


def init_linear(l, stddev):
    nn.init.normal_(l.weight, std=stddev)
    if l.bias is not None:
        nn.init.constant_(l.bias, 0.0)


class MultiheadAttention(nn.Module):
    def __init__(
        self,
        *,
        device: torch.device,
        dtype: torch.dtype,
        n_ctx: int,
        width: int,
        heads: int,
        init_scale: float,
    ):
        super().__init__()
        self.n_ctx = n_ctx
        self.width = width
        self.heads = heads
        self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype)
        self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
        self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
        init_linear(self.c_qkv, init_scale)
        init_linear(self.c_proj, init_scale)

    def forward(self, x):
        x = self.c_qkv(x)
        x = checkpoint(self.attention, (x,), (), True)
        x = self.c_proj(x)
        return x


class MLP(nn.Module):
    def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float):
        super().__init__()
        self.width = width
        self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
        self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
        self.gelu = nn.GELU()
        init_linear(self.c_fc, init_scale)
        init_linear(self.c_proj, init_scale)

    def forward(self, x):
        return self.c_proj(self.gelu(self.c_fc(x)))


class QKVMultiheadAttention(nn.Module):
    def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
        super().__init__()
        self.device = device
        self.dtype = dtype
        self.heads = heads
        self.n_ctx = n_ctx

    def forward(self, qkv):
        bs, n_ctx, width = qkv.shape
        attn_ch = width // self.heads // 3
        scale = 1 / math.sqrt(math.sqrt(attn_ch))
        qkv = qkv.view(bs, n_ctx, self.heads, -1)
        q, k, v = torch.split(qkv, attn_ch, dim=-1)
        weight = torch.einsum(
            "bthc,bshc->bhts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        wdtype = weight.dtype
        weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
        return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)


class ResidualAttentionBlock(nn.Module):
    def __init__(
        self,
        *,
        device: torch.device,
        dtype: torch.dtype,
        n_ctx: int,
        width: int,
        heads: int,
        init_scale: float = 1.0,
    ):
        super().__init__()

        self.attn = MultiheadAttention(
            device=device,
            dtype=dtype,
            n_ctx=n_ctx,
            width=width,
            heads=heads,
            init_scale=init_scale,
        )
        self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
        self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
        self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)

    def forward(self, x: torch.Tensor):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class Transformer(nn.Module):
    def __init__(
        self,
        *,
        device: torch.device,
        dtype: torch.dtype,
        n_ctx: int,
        width: int,
        layers: int,
        heads: int,
        init_scale: float = 0.25,
    ):
        super().__init__()
        self.n_ctx = n_ctx
        self.width = width
        self.layers = layers
        init_scale = init_scale * math.sqrt(1.0 / width)
        self.resblocks = nn.ModuleList(
            [
                ResidualAttentionBlock(
                    device=device,
                    dtype=dtype,
                    n_ctx=n_ctx,
                    width=width,
                    heads=heads,
                    init_scale=init_scale,
                )
                for _ in range(layers)
            ]
        )

    def forward(self, x: torch.Tensor):
        for block in self.resblocks:
            x = block(x)
        return x


class PointDiffusionTransformer(nn.Module):
    def __init__(
        self,
        *,
        device: torch.device,
        dtype: torch.dtype,
        input_channels: int = 3,
        output_channels: int = 3,
        n_ctx: int = 1024,
        width: int = 512,
        layers: int = 12,
        heads: int = 8,
        init_scale: float = 0.25,
        time_token_cond: bool = False,
        use_pos_emb: bool = False,
        pos_emb_init_scale: float = 1.0,
        pos_emb_n_ctx: Optional[int] = None,
    ):
        super().__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.n_ctx = n_ctx
        self.time_token_cond = time_token_cond
        self.use_pos_emb = use_pos_emb
        self.time_embed = MLP(
            device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
        )
        self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
        self.backbone = Transformer(
            device=device,
            dtype=dtype,
            n_ctx=n_ctx + int(time_token_cond),
            width=width,
            layers=layers,
            heads=heads,
            init_scale=init_scale,
        )
        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
        self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
        self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
        with torch.no_grad():
            self.output_proj.weight.zero_()
            self.output_proj.bias.zero_()
        if self.use_pos_emb:
            self.register_parameter(
                "pos_emb",
                nn.Parameter(
                    pos_emb_init_scale
                    * torch.randn(pos_emb_n_ctx or self.n_ctx, width, device=device, dtype=dtype)
                ),
            )

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        :param x: an [N x C x T] tensor.
        :param t: an [N] tensor.
        :return: an [N x C' x T] tensor.
        """
        assert x.shape[-1] == self.n_ctx
        t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
        return self._forward_with_cond(x, [(t_embed, self.time_token_cond)])

    def _forward_with_cond(
        self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]]
    ) -> torch.Tensor:
        h = self.input_proj(x.permute(0, 2, 1))  # NCL -> NLC
        for emb, as_token in cond_as_token:
            if not as_token:
                h = h + emb[:, None]
        if self.use_pos_emb:
            h = h + self.pos_emb
        extra_tokens = [
            (emb[:, None] if len(emb.shape) == 2 else emb)
            for emb, as_token in cond_as_token
            if as_token
        ]
        if len(extra_tokens):
            h = torch.cat(extra_tokens + [h], dim=1)

        h = self.ln_pre(h)
        h = self.backbone(h)
        h = self.ln_post(h)
        if len(extra_tokens):
            h = h[:, sum(h.shape[1] for h in extra_tokens) :]
        h = self.output_proj(h)
        return h.permute(0, 2, 1)


class CLIPImagePointDiffusionTransformer(PointDiffusionTransformer):
    def __init__(
        self,
        *,
        device: torch.device,
        dtype: torch.dtype,
        n_ctx: int = 1024,
        token_cond: bool = False,
        cond_drop_prob: float = 0.0,
        frozen_clip: bool = True,
        **kwargs,
    ):
        super().__init__(
            device=device, dtype=dtype, n_ctx=n_ctx + int(token_cond), pos_emb_n_ctx=n_ctx, **kwargs
        )
        self.n_ctx = n_ctx
        self.token_cond = token_cond
        self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
        self.clip_embed = nn.Linear(
            self.clip.feature_dim, self.backbone.width, device=device, dtype=dtype
        )
        self.cond_drop_prob = cond_drop_prob

    def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
        with torch.no_grad():
            return dict(embeddings=self.clip(batch_size, **model_kwargs))

    def forward(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        images: Optional[Iterable[Optional[ImageType]]] = None,
        texts: Optional[Iterable[Optional[str]]] = None,
        embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
    ):
        """
        :param x: an [N x C x T] tensor.
        :param t: an [N] tensor.
        :param images: a batch of images to condition on.
        :param texts: a batch of texts to condition on.
        :param embeddings: a batch of CLIP embeddings to condition on.
        :return: an [N x C' x T] tensor.
        """
        assert x.shape[-1] == self.n_ctx

        t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
        clip_out = self.clip(batch_size=len(x), images=images, texts=texts, embeddings=embeddings)
        assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0]

        if self.training:
            mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
            clip_out = clip_out * mask[:, None].to(clip_out)

        # Rescale the features to have unit variance
        clip_out = math.sqrt(clip_out.shape[1]) * clip_out

        clip_embed = self.clip_embed(clip_out)

        cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)]
        return self._forward_with_cond(x, cond)


class CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer):
    def __init__(
        self,
        *,
        device: torch.device,
        dtype: torch.dtype,
        n_ctx: int = 1024,
        cond_drop_prob: float = 0.0,
        frozen_clip: bool = True,
        **kwargs,
    ):
        clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
        super().__init__(
            device=device,
            dtype=dtype,
            n_ctx=n_ctx + clip.grid_size**2,
            pos_emb_n_ctx=n_ctx,
            **kwargs,
        )
        self.n_ctx = n_ctx
        self.clip = clip
        self.clip_embed = nn.Sequential(
            nn.LayerNorm(
                normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype
            ),
            nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),
        )
        self.cond_drop_prob = cond_drop_prob

    def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
        _ = batch_size
        with torch.no_grad():
            return dict(embeddings=self.clip.embed_images_grid(model_kwargs["images"]))

    def forward(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        images: Optional[Iterable[ImageType]] = None,
        embeddings: Optional[Iterable[torch.Tensor]] = None,
    ):
        """
        :param x: an [N x C x T] tensor.
        :param t: an [N] tensor.
        :param images: a batch of images to condition on.
        :param embeddings: a batch of CLIP latent grids to condition on.
        :return: an [N x C' x T] tensor.
        """
        assert images is not None or embeddings is not None, "must specify images or embeddings"
        assert images is None or embeddings is None, "cannot specify both images and embeddings"
        assert x.shape[-1] == self.n_ctx

        t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))

        if images is not None:
            clip_out = self.clip.embed_images_grid(images)
        else:
            clip_out = embeddings

        if self.training:
            mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
            clip_out = clip_out * mask[:, None, None].to(clip_out)

        clip_out = clip_out.permute(0, 2, 1)  # NCL -> NLC
        clip_embed = self.clip_embed(clip_out)

        cond = [(t_embed, self.time_token_cond), (clip_embed, True)]
        return self._forward_with_cond(x, cond)


class UpsamplePointDiffusionTransformer(PointDiffusionTransformer):
    def __init__(
        self,
        *,
        device: torch.device,
        dtype: torch.dtype,
        cond_input_channels: Optional[int] = None,
        cond_ctx: int = 1024,
        n_ctx: int = 4096 - 1024,
        channel_scales: Optional[Sequence[float]] = None,
        channel_biases: Optional[Sequence[float]] = None,
        **kwargs,
    ):
        super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **kwargs)
        self.n_ctx = n_ctx
        self.cond_input_channels = cond_input_channels or self.input_channels
        self.cond_point_proj = nn.Linear(
            self.cond_input_channels, self.backbone.width, device=device, dtype=dtype
        )

        self.register_buffer(
            "channel_scales",
            torch.tensor(channel_scales, dtype=dtype, device=device)
            if channel_scales is not None
            else None,
        )
        self.register_buffer(
            "channel_biases",
            torch.tensor(channel_biases, dtype=dtype, device=device)
            if channel_biases is not None
            else None,
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor):
        """
        :param x: an [N x C1 x T] tensor.
        :param t: an [N] tensor.
        :param low_res: an [N x C2 x T'] tensor of conditioning points.
        :return: an [N x C3 x T] tensor.
        """
        assert x.shape[-1] == self.n_ctx
        t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
        low_res_embed = self._embed_low_res(low_res)
        cond = [(t_embed, self.time_token_cond), (low_res_embed, True)]
        return self._forward_with_cond(x, cond)

    def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor:
        if self.channel_scales is not None:
            x = x * self.channel_scales[None, :, None]
        if self.channel_biases is not None:
            x = x + self.channel_biases[None, :, None]
        return self.cond_point_proj(x.permute(0, 2, 1))


class CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffusionTransformer):
    def __init__(
        self,
        *,
        device: torch.device,
        dtype: torch.dtype,
        n_ctx: int = 4096 - 1024,
        cond_drop_prob: float = 0.0,
        frozen_clip: bool = True,
        **kwargs,
    ):
        clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
        super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs)
        self.n_ctx = n_ctx

        self.clip = clip
        self.clip_embed = nn.Sequential(
            nn.LayerNorm(
                normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype
            ),
            nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),
        )
        self.cond_drop_prob = cond_drop_prob

    def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
        _ = batch_size
        with torch.no_grad():
            return dict(
                embeddings=self.clip.embed_images_grid(model_kwargs["images"]),
                low_res=model_kwargs["low_res"],
            )

    def forward(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        *,
        low_res: torch.Tensor,
        images: Optional[Iterable[ImageType]] = None,
        embeddings: Optional[Iterable[torch.Tensor]] = None,
    ):
        """
        :param x: an [N x C1 x T] tensor.
        :param t: an [N] tensor.
        :param low_res: an [N x C2 x T'] tensor of conditioning points.
        :param images: a batch of images to condition on.
        :param embeddings: a batch of CLIP latent grids to condition on.
        :return: an [N x C3 x T] tensor.
        """
        assert x.shape[-1] == self.n_ctx
        t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
        low_res_embed = self._embed_low_res(low_res)

        if images is not None:
            clip_out = self.clip.embed_images_grid(images)
        else:
            clip_out = embeddings

        if self.training:
            mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
            clip_out = clip_out * mask[:, None, None].to(clip_out)

        clip_out = clip_out.permute(0, 2, 1)  # NCL -> NLC
        clip_embed = self.clip_embed(clip_out)

        cond = [(t_embed, self.time_token_cond), (clip_embed, True), (low_res_embed, True)]
        return self._forward_with_cond(x, cond)


================================================
FILE: shap_e/models/generation/util.py
================================================
import math

import torch


def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


================================================
FILE: shap_e/models/nerf/__init__.py
================================================


================================================
FILE: shap_e/models/nerf/model.py
================================================
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn

from shap_e.models.nn.checkpoint import checkpoint
from shap_e.models.nn.encoding import encode_position, spherical_harmonics_basis
from shap_e.models.nn.meta import MetaModule, subdict
from shap_e.models.nn.ops import MLP, MetaMLP, get_act, mlp_init, zero_init
from shap_e.models.nn.utils import ArrayType
from shap_e.models.query import Query
from shap_e.util.collections import AttrDict


class NeRFModel(ABC):
    """
    Parametric scene representation whose outputs are integrated by NeRFRenderer
    """

    @abstractmethod
    def forward(
        self,
        query: Query,
        params: Optional[Dict[str, torch.Tensor]] = None,
        options: Optional[Dict[str, Any]] = None,
    ) -> AttrDict:
        """
        :param query: the points in the field to query.
        :param params: Meta parameters
        :param options: Optional hyperparameters
        :return: An AttrDict containing at least
            - density: [batch_size x ... x 1]
            - channels: [batch_size x ... x n_channels]
            - aux_losses: [batch_size x ... x 1]
        """


class VoidNeRFModel(MetaModule, NeRFModel):
    """
    Implements the default empty space model where all queries are rendered as
    background.
    """

    def __init__(
        self,
        background: ArrayType,
        trainable: bool = False,
        channel_scale: float = 255.0,
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__()
        background = nn.Parameter(
            torch.from_numpy(np.array(background)).to(dtype=torch.float32, device=device)
            / channel_scale
        )
        if trainable:
            self.register_parameter("background", background)
        else:
            self.register_buffer("background", background)

    def forward(
        self,
        query: Query,
        params: Optional[Dict[str, torch.Tensor]] = None,
        options: Optional[Dict[str, Any]] = None,
    ) -> AttrDict:
        _ = params
        default_bg = self.background[None]
        background = options.get("background", default_bg) if options is not None else default_bg

        shape = query.position.shape[:-1]
        ones = [1] * (len(shape) - 1)
        n_channels = background.shape[-1]
        background = torch.broadcast_to(
            background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]
        )
        return background


class MLPNeRFModel(MetaModule, NeRFModel):
    def __init__(
        self,
        # Positional encoding parameters
        n_levels: int = 10,
        # MLP parameters
        d_hidden: int = 256,
        n_density_layers: int = 4,
        n_channel_layers: int = 1,
        n_channels: int = 3,
        sh_degree: int = 4,
        activation: str = "relu",
        density_activation: str = "exp",
        init: Optional[str] = None,
        init_scale: float = 1.0,
        output_activation: str = "sigmoid",
        meta_parameters: bool = False,
        trainable_meta: bool = False,
        zero_out: bool = True,
        register_freqs: bool = True,
        posenc_version: str = "v1",
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__()

        # Positional encoding
        if register_freqs:
            # not used anymore
            self.register_buffer(
                "freqs",
                2.0 ** torch.arange(n_levels, device=device, dtype=torch.float).view(1, n_levels),
            )

        self.posenc_version = posenc_version
        dummy = torch.eye(1, 3)
        d_input = encode_position(posenc_version, position=dummy).shape[-1]

        self.n_levels = n_levels

        self.sh_degree = sh_degree
        d_sh_coeffs = sh_degree**2

        self.meta_parameters = meta_parameters

        mlp_cls = (
            partial(
                MetaMLP,
                meta_scale=False,
                meta_shift=False,
                meta_proj=True,
                meta_bias=True,
                trainable_meta=trainable_meta,
            )
            if meta_parameters
            else MLP
        )

        self.density_mlp = mlp_cls(
            d_input=d_input,
            d_hidden=[d_hidden] * (n_density_layers - 1),
            d_output=d_hidden,
            act_name=activation,
            init_scale=init_scale,
        )

        self.channel_mlp = mlp_cls(
            d_input=d_hidden + d_sh_coeffs,
            d_hidden=[d_hidden] * n_channel_layers,
            d_output=n_channels,
            act_name=activation,
            init_scale=init_scale,
        )

        self.act = get_act(output_activation)
        self.density_act = get_act(density_activation)

        mlp_init(
            list(self.density_mlp.affines) + list(self.channel_mlp.affines),
            init=init,
            init_scale=init_scale,
        )

        if zero_out:
            zero_init(self.channel_mlp.affines[-1])

        self.to(device)

    def encode_position(self, query: Query):
        h = encode_position(self.posenc_version, position=query.position)
        return h

    def forward(
        self,
        query: Query,
        params: Optional[Dict[str, torch.Tensor]] = None,
        options: Optional[Dict[str, Any]] = None,
    ) -> AttrDict:
        params = self.update(params)

        options = AttrDict() if options is None else AttrDict(options)

        query = query.copy()

        h_position = self.encode_position(query)

        if self.meta_parameters:
            density_params = subdict(params, "density_mlp")
            density_mlp = partial(
                self.density_mlp, params=density_params, options=options, log_prefix="density_"
            )
            density_mlp_parameters = list(density_params.values())
        else:
            density_mlp = partial(self.density_mlp, options=options, log_prefix="density_")
            density_mlp_parameters = self.density_mlp.parameters()
        h_density = checkpoint(
            density_mlp,
            (h_position,),
            density_mlp_parameters,
            options.checkpoint_nerf_mlp,
        )
        h_direction = maybe_get_spherical_harmonics_basis(
            sh_degree=self.sh_degree,
            coords_shape=query.position.shape,
            coords=query.direction,
            device=query.position.device,
        )

        if self.meta_parameters:
            channel_params = subdict(params, "channel_mlp")
            channel_mlp = partial(
                self.channel_mlp, params=channel_params, options=options, log_prefix="channel_"
            )
            channel_mlp_parameters = list(channel_params.values())
        else:
            channel_mlp = partial(self.channel_mlp, options=options, log_prefix="channel_")
            channel_mlp_parameters = self.channel_mlp.parameters()
        h_channel = checkpoint(
            channel_mlp,
            (torch.cat([h_density, h_direction], dim=-1),),
            channel_mlp_parameters,
            options.checkpoint_nerf_mlp,
        )

        density_logit = h_density[..., :1]

        res = AttrDict(
            density_logit=density_logit,
            density=self.density_act(density_logit),
            channels=self.act(h_channel),
            aux_losses=AttrDict(),
            no_weight_grad_aux_losses=AttrDict(),
        )
        if options.return_h_density:
            res.h_density = h_density

        return res


def maybe_get_spherical_harmonics_basis(
    sh_degree: int,
    coords_shape: Tuple[int],
    coords: Optional[torch.Tensor] = None,
    device: torch.device = torch.device("cuda"),
) -> torch.Tensor:
    """
    :param sh_degree: Spherical harmonics degree
    :param coords_shape: [*shape, 3]
    :param coords: optional coordinate tensor of coords_shape
    """
    if coords is None:
        return torch.zeros(*coords_shape[:-1], sh_degree**2).to(device)

    return spherical_harmonics_basis(coords, sh_degree)


================================================
FILE: shap_e/models/nerf/ray.py
================================================
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Tuple

import torch

from shap_e.models.nn.utils import sample_pmf
from shap_e.models.volume import Volume, VolumeRange
from shap_e.util.collections import AttrDict

from .model import NeRFModel, Query


def render_rays(
    rays: torch.Tensor,
    parts: List["RayVolumeIntegral"],
    void_model: NeRFModel,
    shared: bool = False,
    prev_raw_outputs: Optional[List[AttrDict]] = None,
    render_with_direction: bool = True,
    importance_sampling_options: Optional[Dict[str, Any]] = None,
) -> Tuple["RayVolumeIntegralResults", List["RaySampler"], List[AttrDict]]:
    """
    Perform volumetric rendering over a partition of possible t's in the union
    of rendering volumes (written below with some abuse of notations)

        C(r) := sum(
            transmittance(t[i]) *
            integrate(
                lambda t: density(t) * channels(t) * transmittance(t),
                [t[i], t[i + 1]],
            )
            for i in range(len(parts))
        ) + transmittance(t[-1]) * void_model(t[-1]).channels

    where

    1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the
       probability of light passing through the volume specified by [t[0], s].
       (transmittance of 1 means light can pass freely)
    2) density and channels are obtained by evaluating the appropriate
       part.model at time t.
    3) [t[i], t[i + 1]] is defined as the range of t where the ray intersects
       (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface
       of the shell (if bounded). If the ray does not intersect, the integral over
       this segment is evaluated as 0 and transmittance(t[i + 1]) :=
       transmittance(t[i]).
    4) The last term is integration to infinity (e.g. [t[-1], math.inf]) that
       is evaluated by the void_model (i.e. we consider this space to be empty).

    :param rays: [batch_size x ... x 2 x 3] origin and direction.
    :param parts: disjoint volume integrals.
    :param void_model: use this model to integrate over the empty space
    :param shared: All RayVolumeIntegrals are calculated with the same model.
    :param prev_raw_outputs: Raw outputs from the previous rendering step

    :return: A tuple of
        - AttrDict containing the rendered `channels`, `distances`, and the `aux_losses`
        - A list of importance samplers for additional fine-grained rendering
        - A list of raw output for each interval
    """
    if importance_sampling_options is None:
        importance_sampling_options = {}

    origin, direc = rays[..., 0, :], rays[..., 1, :]

    if prev_raw_outputs is None:
        prev_raw_outputs = [None] * len(parts)

    samplers = []
    raw_outputs = []
    t0 = None
    results = None

    for part_i, prev_raw_i in zip(parts, prev_raw_outputs):

        # Integrate over [t[i], t[i + 1]]
        results_i = part_i.render_rays(
            origin,
            direc,
            t0=t0,
            prev_raw=prev_raw_i,
            shared=shared,
            render_with_direction=render_with_direction,
        )

        # Create an importance sampler for (optional) fine rendering
        samplers.append(
            ImportanceRaySampler(
                results_i.volume_range, results_i.raw, **importance_sampling_options
            )
        )
        raw_outputs.append(results_i.raw)

        # Pass t[i + 1] as the start of integration for the next interval.
        t0 = results_i.volume_range.next_t0()

        # Combine the results from [t[0], t[i]] and [t[i], t[i+1]]
        results = results_i if results is None else results.combine(results_i)

    # While integrating out [t[-1], math.inf] is the correct thing to do, this
    # erases a lot of useful information. Also, void_model is meant to predict
    # the channels at t=math.inf.

    # # Add the void background over [t[-1], math.inf] to complete integration.
    # results = results.combine(
    #     RayVolumeIntegralResults(
    #         output=AttrDict(
    #             channels=void_model(origin, direc),
    #             distances=torch.zeros_like(t0),
    #             aux_losses=AttrDict(),
    #         ),
    #         volume_range=VolumeRange(
    #             t0=t0,
    #             t1=torch.full_like(t0, math.inf),
    #             intersected=torch.full_like(results.volume_range.intersected, True),
    #         ),
    #         # Void space extends to infinity. It is assumed that no light
    #         # passes beyond the void.
    #         transmittance=torch.zeros_like(results_i.transmittance),
    #     )
    # )

    results.output.channels = results.output.channels + results.transmittance * void_model(
        Query(origin, direc)
    )

    return results, samplers, raw_outputs


@dataclass
class RayVolumeIntegralResults:
    """
    Stores the relevant state and results of

        integrate(
            lambda t: density(t) * channels(t) * transmittance(t),
            [t0, t1],
        )
    """

    # Rendered output and auxiliary losses
    # output.channels has shape [batch_size, *inner_shape, n_channels]
    output: AttrDict

    """
    Optional values
    """

    # Raw values contain the sampled `ts`, `density`, `channels`, etc.
    raw: Optional[AttrDict] = None

    # Integration
    volume_range: Optional[VolumeRange] = None

    # If a ray intersects, the transmittance from t0 to t1 (e.g. the
    # probability that the ray passes through this volume).
    # has shape [batch_size, *inner_shape, 1]
    transmittance: Optional[torch.Tensor] = None

    def combine(self, cur: "RayVolumeIntegralResults") -> "RayVolumeIntegralResults":
        """
        Combines the integration results of `self` over [t0, t1] and
        `cur` over [t1, t2] to produce a new set of results over [t0, t2] by
        using a similar equation to (4) in NeRF++:

            integrate(
                lambda t: density(t) * channels(t) * transmittance(t),
                [t0, t2]
            )

          = integrate(
                lambda t: density(t) * channels(t) * transmittance(t),
                [t0, t1]
            ) + transmittance(t1) * integrate(
                lambda t: density(t) * channels(t) * transmittance(t),
                [t1, t2]
            )
        """
        assert torch.allclose(self.volume_range.next_t0(), cur.volume_range.t0)

        def _combine_fn(
            prev_val: Optional[torch.Tensor],
            cur_val: Optional[torch.Tensor],
            *,
            prev_transmittance: torch.Tensor,
        ):
            assert prev_val is not None
            if cur_val is None:
                # cur_output.aux_losses are empty for the void_model.
                return prev_val
            return prev_val + prev_transmittance * cur_val

        output = self.output.combine(
            cur.output, combine_fn=partial(_combine_fn, prev_transmittance=self.transmittance)
        )

        combined = RayVolumeIntegralResults(
            output=output,
            volume_range=self.volume_range.extend(cur.volume_range),
            transmittance=self.transmittance * cur.transmittance,
        )
        return combined


@dataclass
class RayVolumeIntegral:
    model: NeRFModel
    volume: Volume
    sampler: "RaySampler"
    n_samples: int

    def render_rays(
        self,
        origin: torch.Tensor,
        direction: torch.Tensor,
        t0: Optional[torch.Tensor] = None,
        prev_raw: Optional[AttrDict] = None,
        shared: bool = False,
        render_with_direction: bool = True,
    ) -> "RayVolumeIntegralResults":
        """
        Perform volumetric rendering over the given volume.

        :param position: [batch_size, *shape, 3]
        :param direction: [batch_size, *shape, 3]
        :param t0: Optional [batch_size, *shape, 1]
        :param prev_raw: the raw outputs when using multiple levels with this model.
        :param shared: means the same model is used for all RayVolumeIntegral's
        :param render_with_direction: use the incoming ray direction when querying the model.

        :return: RayVolumeIntegralResults
        """
        # 1. Intersect the rays with the current volume and sample ts to
        # integrate along.
        vrange = self.volume.intersect(origin, direction, t0_lower=t0)
        ts = self.sampler.sample(vrange.t0, vrange.t1, self.n_samples)

        if prev_raw is not None and not shared:
            # Append the previous ts now before fprop because previous
            # rendering used a different model and we can't reuse the output.
            ts = torch.sort(torch.cat([ts, prev_raw.ts], dim=-2), dim=-2).values

        # Shape sanity checks
        batch_size, *_shape, _t0_dim = vrange.t0.shape
        _, *ts_shape, _ts_dim = ts.shape

        # 2. Get the points along the ray and query the model
        directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3])
        positions = origin.unsqueeze(-2) + ts * directions

        optional_directions = directions if render_with_direction else None
        mids = (ts[..., 1:, :] + ts[..., :-1, :]) / 2
        raw = self.model(
            Query(
                position=positions,
                direction=optional_directions,
                t_min=torch.cat([vrange.t0[..., None, :], mids], dim=-2),
                t_max=torch.cat([mids, vrange.t1[..., None, :]], dim=-2),
            )
        )
        raw.ts = ts

        if prev_raw is not None and shared:
            # We can append the additional queries to previous raw outputs
            # before integration
            copy = prev_raw.copy()
            result = torch.sort(torch.cat([raw.pop("ts"), copy.pop("ts")], dim=-2), dim=-2)
            merge_results = partial(self._merge_results, dim=-2, indices=result.indices)
            raw = raw.combine(copy, merge_results)
            raw.ts = result.values

        # 3. Integrate the raw results
        output, transmittance = self.integrate_samples(vrange, raw)

        # 4. Clean up results that do not intersect with the volume.
        transmittance = torch.where(
            vrange.intersected, transmittance, torch.ones_like(transmittance)
        )

        def _mask_fn(_key: str, tensor: torch.Tensor):
            return torch.where(vrange.intersected, tensor, torch.zeros_like(tensor))

        def _is_tensor(_key: str, value: Any):
            return isinstance(value, torch.Tensor)

        output = output.map(map_fn=_mask_fn, should_map=_is_tensor)

        return RayVolumeIntegralResults(
            output=output,
            raw=raw,
            volume_range=vrange,
            transmittance=transmittance,
        )

    def integrate_samples(
        self,
        volume_range: VolumeRange,
        raw: AttrDict,
    ) -> Tuple[AttrDict, torch.Tensor]:
        """
        Integrate the raw.channels along with other aux_losses and values to
        produce the final output dictionary containing rendered `channels`,
        estimated `distances` and `aux_losses`.

        :param volume_range: Specifies the integral range [t0, t1]
        :param raw: Contains a dict of function evaluations at ts. Should have

            density: torch.Tensor [batch_size, *shape, n_samples, 1]
            channels: torch.Tensor [batch_size, *shape, n_samples, n_channels]
            aux_losses: {key: torch.Tensor [batch_size, *shape, n_samples, 1] for each key}
            no_weight_grad_aux_losses: an optional set of losses for which the weights
                                       should be detached before integration.

            after the call, integrate_samples populates some intermediate calculations
            for later use like

            weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density *
                transmittance)[i] weight for each rgb output at [..., i, :].
        :returns: a tuple of (
            a dictionary of rendered outputs and aux_losses,
            transmittance of this volume,
        )
        """

        # 1. Calculate the weights
        _, _, dt = volume_range.partition(raw.ts)
        ddensity = raw.density * dt

        mass = torch.cumsum(ddensity, dim=-2)
        transmittance = torch.exp(-mass[..., -1, :])

        alphas = 1.0 - torch.exp(-ddensity)
        Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2))
        # This is the probability of light hitting and reflecting off of
        # something at depth [..., i, :].
        weights = alphas * Ts

        # 2. Integrate all results
        def _integrate(key: str, samples: torch.Tensor, weights: torch.Tensor):
            if key == "density":
                # Omit integrating the density, because we don't need it
                return None
            return torch.sum(samples * weights, dim=-2)

        def _is_tensor(_key: str, value: Any):
            return isinstance(value, torch.Tensor)

        if raw.no_weight_grad_aux_losses:
            extra_aux_losses = raw.no_weight_grad_aux_losses.map(
                partial(_integrate, weights=weights.detach()), should_map=_is_tensor
            )
        else:
            extra_aux_losses = {}
        output = raw.map(partial(_integrate, weights=weights), should_map=_is_tensor)
        if "no_weight_grad_aux_losses" in output:
            del output["no_weight_grad_aux_losses"]
        output.aux_losses.update(extra_aux_losses)

        # Integrating the ts yields the distance away from the origin; rename the variable.
        output.distances = output.ts
        del output["ts"]
        del output["density"]

        assert output.distances.shape == (*output.channels.shape[:-1], 1)
        assert output.channels.shape[:-1] == raw.channels.shape[:-2]
        assert output.channels.shape[-1] == raw.channels.shape[-1]

        # 3. Reduce loss
        def _reduce_loss(_key: str, loss: torch.Tensor):
            return loss.view(loss.shape[0], -1).sum(dim=-1)

        # 4. Store other useful calculations
        raw.weights = weights

        output.aux_losses = output.aux_losses.map(_reduce_loss)

        return output, transmittance

    def _merge_results(
        self, a: Optional[torch.Tensor], b: torch.Tensor, dim: int, indices: torch.Tensor
    ):
        """
        :param a: [..., n_a, ...]. The other dictionary containing the b's may
            contain extra tensors from earlier calculations, so a can be None.
        :param b: [..., n_b, ...]
        :param dim: dimension to merge
        :param indices: how the merged results should be sorted at the end
        :return: a concatted and sorted tensor of size [..., n_a + n_b, ...]
        """
        if a is None:
            return None

        merged = torch.cat([a, b], dim=dim)
        return torch.gather(merged, dim=dim, index=torch.broadcast_to(indices, merged.shape))


class RaySampler(ABC):
    @abstractmethod
    def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
        """
        :param t0: start time has shape [batch_size, *shape, 1]
        :param t1: finish time has shape [batch_size, *shape, 1]
        :param n_samples: number of ts to sample
        :return: sampled ts of shape [batch_size, *shape, n_samples, 1]
        """


class StratifiedRaySampler(RaySampler):
    """
    Instead of fixed intervals, a sample is drawn uniformly at random from each
    interval.
    """

    def __init__(self, depth_mode: str = "linear"):
        """
        :param depth_mode: linear samples ts linearly in depth. harmonic ensures
            closer points are sampled more densely.
        """
        self.depth_mode = depth_mode
        assert self.depth_mode in ("linear", "geometric", "harmonic")

    def sample(
        self,
        t0: torch.Tensor,
        t1: torch.Tensor,
        n_samples: int,
        epsilon: float = 1e-3,
    ) -> torch.Tensor:
        """
        :param t0: start time has shape [batch_size, *shape, 1]
        :param t1: finish time has shape [batch_size, *shape, 1]
        :param n_samples: number of ts to sample
        :return: sampled ts of shape [batch_size, *shape, n_samples, 1]
        """
        ones = [1] * (len(t0.shape) - 1)
        ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device)

        if self.depth_mode == "linear":
            ts = t0 * (1.0 - ts) + t1 * ts
        elif self.depth_mode == "geometric":
            ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp()
        elif self.depth_mode == "harmonic":
            # The original NeRF recommends this interpolation scheme for
            # spherical scenes, but there could be some weird edge cases when
            # the observer crosses from the inner to outer volume.
            ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts)

        mids = 0.5 * (ts[..., 1:] + ts[..., :-1])
        upper = torch.cat([mids, t1], dim=-1)
        lower = torch.cat([t0, mids], dim=-1)
        t_rand = torch.rand_like(ts)

        ts = lower + (upper - lower) * t_rand
        return ts.unsqueeze(-1)


class ImportanceRaySampler(RaySampler):
    """
    Given the initial estimate of densities, this samples more from
    regions/bins expected to have objects.
    """

    def __init__(
        self, volume_range: VolumeRange, raw: AttrDict, blur_pool: bool = False, alpha: float = 1e-5
    ):
        """
        :param volume_range: the range in which a ray intersects the given volume.
        :param raw: dictionary of raw outputs from the NeRF models of shape
            [batch_size, *shape, n_coarse_samples, 1]. Should at least contain

            :param ts: earlier samples from the coarse rendering step
            :param weights: discretized version of density * transmittance
        :param blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF.
        :param alpha: small value to add to weights.
        """
        self.volume_range = volume_range
        self.ts = raw.ts.clone().detach()
        self.weights = raw.weights.clone().detach()
        self.blur_pool = blur_pool
        self.alpha = alpha

    @torch.no_grad()
    def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
        """
        :param t0: start time has shape [batch_size, *shape, 1]
        :param t1: finish time has shape [batch_size, *shape, 1]
        :param n_samples: number of ts to sample
        :return: sampled ts of shape [batch_size, *shape, n_samples, 1]
        """
        lower, upper, _ = self.volume_range.partition(self.ts)

        batch_size, *shape, n_coarse_samples, _ = self.ts.shape

        weights = self.weights
        if self.blur_pool:
            padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2)
            maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :])
            weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :])
        weights = weights + self.alpha
        pmf = weights / weights.sum(dim=-2, keepdim=True)
        inds = sample_pmf(pmf, n_samples)
        assert inds.shape == (batch_size, *shape, n_samples, 1)
        assert (inds >= 0).all() and (inds < n_coarse_samples).all()

        t_rand = torch.rand(inds.shape, device=inds.device)
        lower_ = torch.gather(lower, -2, inds)
        upper_ = torch.gather(upper, -2, inds)

        ts = lower_ + (upper_ - lower_) * t_rand
        ts = torch.sort(ts, dim=-2).values
        return ts


================================================
FILE: shap_e/models/nerf/renderer.py
================================================
from functools import partial
from typing import Any, Dict, Optional

import torch

from shap_e.models.nn.meta import subdict
from shap_e.models.renderer import RayRenderer
from shap_e.models.volume import Volume
from shap_e.util.collections import AttrDict

from .model import NeRFModel
from .ray import RayVolumeIntegral, StratifiedRaySampler, render_rays


class TwoStepNeRFRenderer(RayRenderer):
    """
    Coarse and fine-grained rendering as proposed by NeRF. This class
    additionally supports background rendering like NeRF++.
    """

    def __init__(
        self,
        n_coarse_samples: int,
        n_fine_samples: int,
        void_model: NeRFModel,
        fine_model: NeRFModel,
        volume: Volume,
        coarse_model: Optional[NeRFModel] = None,
        coarse_background_model: Optional[NeRFModel] = None,
        fine_background_model: Optional[NeRFModel] = None,
        outer_volume: Optional[Volume] = None,
        foreground_stratified_depth_sampling_mode: str = "linear",
        background_stratified_depth_sampling_mode: str = "linear",
        importance_sampling_options: Optional[Dict[str, Any]] = None,
        channel_scale: float = 255,
        device: torch.device = torch.device("cuda"),
        **kwargs,
    ):
        """
        :param outer_volume: is where distant objects are encoded.
        """
        super().__init__(**kwargs)

        if coarse_model is None:
            assert (
                fine_background_model is None or coarse_background_model is None
            ), "models should be shared for both fg and bg"

        self.n_coarse_samples = n_coarse_samples
        self.n_fine_samples = n_fine_samples
        self.void_model = void_model
        self.coarse_model = coarse_model
        self.fine_model = fine_model
        self.volume = volume
        self.coarse_background_model = coarse_background_model
        self.fine_background_model = fine_background_model
        self.outer_volume = outer_volume
        self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode
        self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode
        self.importance_sampling_options = AttrDict(importance_sampling_options or {})
        self.channel_scale = channel_scale
        self.device = device
        self.to(device)

        if self.coarse_background_model is not None:
            assert self.fine_background_model is not None
            assert self.outer_volume is not None

    def render_rays(
        self,
        batch: Dict,
        params: Optional[Dict] = None,
        options: Optional[Dict] = None,
    ) -> AttrDict:
        params = self.update(params)

        batch = AttrDict(batch)
        if options is None:
            options = AttrDict()
        options.setdefault("render_background", True)
        options.setdefault("render_with_direction", True)
        options.setdefault("n_coarse_samples", self.n_coarse_samples)
        options.setdefault("n_fine_samples", self.n_fine_samples)
        options.setdefault(
            "foreground_stratified_depth_sampling_mode",
            self.foreground_stratified_depth_sampling_mode,
        )
        options.setdefault(
            "background_stratified_depth_sampling_mode",
            self.background_stratified_depth_sampling_mode,
        )

        shared = self.coarse_model is None

        # First, render rays using the coarse models with stratified ray samples.
        coarse_model, coarse_key = (
            (self.fine_model, "fine_model") if shared else (self.coarse_model, "coarse_model")
        )
        coarse_model = partial(
            coarse_model,
            params=subdict(params, coarse_key),
            options=options,
        )
        parts = [
            RayVolumeIntegral(
                model=coarse_model,
                volume=self.volume,
                sampler=StratifiedRaySampler(
                    depth_mode=options.foreground_stratified_depth_sampling_mode,
                ),
                n_samples=options.n_coarse_samples,
            ),
        ]
        if options.render_background and self.outer_volume is not None:
            coarse_background_model, coarse_background_key = (
                (self.fine_background_model, "fine_background_model")
                if shared
                else (self.coarse_background_model, "coarse_background_model")
            )
            coarse_background_model = partial(
                coarse_background_model,
                params=subdict(params, coarse_background_key),
                options=options,
            )
            parts.append(
                RayVolumeIntegral(
                    model=coarse_background_model,
                    volume=self.outer_volume,
                    sampler=StratifiedRaySampler(
                        depth_mode=options.background_stratified_depth_sampling_mode,
                    ),
                    n_samples=options.n_coarse_samples,
                )
            )
        coarse_results, samplers, coarse_raw_outputs = render_rays(
            batch.rays,
            parts,
            partial(self.void_model, options=options),
            shared=shared,
            render_with_direction=options.render_with_direction,
            importance_sampling_options=AttrDict(self.importance_sampling_options),
        )

        # Then, render rays using the fine models with importance-weighted ray samples.
        fine_model = partial(
            self.fine_model,
            params=subdict(params, "fine_model"),
            options=options,
        )
        parts = [
            RayVolumeIntegral(
                model=fine_model,
                volume=self.volume,
                sampler=samplers[0],
                n_samples=options.n_fine_samples,
            ),
        ]
        if options.render_background and self.outer_volume is not None:
            fine_background_model = partial(
                self.fine_background_model,
                params=subdict(params, "fine_background_model"),
                options=options,
            )
            parts.append(
                RayVolumeIntegral(
                    model=fine_background_model,
                    volume=self.outer_volume,
                    sampler=samplers[1],
                    n_samples=options.n_fine_samples,
                )
            )
        fine_results, *_ = render_rays(
            batch.rays,
            parts,
            partial(self.void_model, options=options),
            shared=shared,
            prev_raw_outputs=coarse_raw_outputs,
            render_with_direction=options.render_with_direction,
        )

        # Combine results
        aux_losses = fine_results.output.aux_losses.copy()
        for key, val in coarse_results.output.aux_losses.items():
            aux_losses[key + "_coarse"] = val

        return AttrDict(
            channels=fine_results.output.channels * self.channel_scale,
            channels_coarse=coarse_results.output.channels * self.channel_scale,
            distances=fine_results.output.distances,
            transmittance=fine_results.transmittance,
            transmittance_coarse=coarse_results.transmittance,
            t0=fine_results.volume_range.t0,
            t1=fine_results.volume_range.t1,
            intersected=fine_results.volume_range.intersected,
            aux_losses=aux_losses,
        )


class OneStepNeRFRenderer(RayRenderer):
    """
    Renders rays using stratified sampling only unlike vanilla NeRF.
    The same setup as NeRF++.
    """

    def __init__(
        self,
        n_samples: int,
        void_model: NeRFModel,
        foreground_model: NeRFModel,
        volume: Volume,
        background_model: Optional[NeRFModel] = None,
        outer_volume: Optional[Volume] = None,
        foreground_stratified_depth_sampling_mode: str = "linear",
        background_stratified_depth_sampling_mode: str = "linear",
        channel_scale: float = 255,
        device: torch.device = torch.device("cuda"),
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.n_samples = n_samples
        self.void_model = void_model
        self.foreground_model = foreground_model
        self.volume = volume
        self.background_model = background_model
        self.outer_volume = outer_volume
        self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode
        self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode
        self.channel_scale = channel_scale
        self.device = device
        self.to(device)

    def render_rays(
        self,
        batch: Dict,
        params: Optional[Dict] = None,
        options: Optional[Dict] = None,
    ) -> AttrDict:
        params = self.update(params)

        batch = AttrDict(batch)
        if options is None:
            options = AttrDict()
        options.setdefault("render_background", True)
        options.setdefault("render_with_direction", True)
        options.setdefault("n_samples", self.n_samples)
        options.setdefault(
            "foreground_stratified_depth_sampling_mode",
            self.foreground_stratified_depth_sampling_mode,
        )
        options.setdefault(
            "background_stratified_depth_sampling_mode",
            self.background_stratified_depth_sampling_mode,
        )

        foreground_model = partial(
            self.foreground_model,
            params=subdict(params, "foreground_model"),
            options=options,
        )
        parts = [
            RayVolumeIntegral(
                model=foreground_model,
                volume=self.volume,
                sampler=StratifiedRaySampler(
                    depth_mode=options.foreground_stratified_depth_sampling_mode
                ),
                n_samples=options.n_samples,
            ),
        ]
        if options.render_background and self.outer_volume is not None:
            background_model = partial(
                self.background_model,
                params=subdict(params, "background_model"),
                options=options,
            )
            parts.append(
                RayVolumeIntegral(
                    model=background_model,
                    volume=self.outer_volume,
                    sampler=StratifiedRaySampler(
                        depth_mode=options.background_stratified_depth_sampling_mode
                    ),
                    n_samples=options.n_samples,
                )
            )
        results, *_ = render_rays(
            batch.rays,
            parts,
            self.void_model,
            render_with_direction=options.render_with_direction,
        )

        return AttrDict(
            channels=results.output.channels * self.channel_scale,
            distances=results.output.distances,
            transmittance=results.transmittance,
            t0=results.volume_range.t0,
            t1=results.volume_range.t1,
            intersected=results.volume_range.intersected,
            aux_losses=results.output.aux_losses,
        )


================================================
FILE: shap_e/models/nerstf/mlp.py
================================================
from typing import Any, Dict, Optional, Tuple

import torch

from shap_e.models.nn.ops import get_act
from shap_e.models.query import Query
from shap_e.models.stf.mlp import MLPModel
from shap_e.util.collections import AttrDict


class MLPDensitySDFModel(MLPModel):
    def __init__(
        self,
        initial_bias: float = -0.1,
        sdf_activation="tanh",
        density_activation="exp",
        **kwargs,
    ):
        super().__init__(
            n_output=2,
            output_activation="identity",
            **kwargs,
        )
        self.mlp[-1].bias[0].data.fill_(initial_bias)
        self.sdf_activation = get_act(sdf_activation)
        self.density_activation = get_act(density_activation)

    def forward(
        self,
        query: Query,
        params: Optional[Dict[str, torch.Tensor]] = None,
        options: Optional[Dict[str, Any]] = None,
    ) -> AttrDict[str, Any]:
        # query.direction is None typically for SDF models and training
        h, _h_directionless = self._mlp(
            query.position, query.direction, params=params, options=options
        )
        h_sdf, h_density = h.split(1, dim=-1)
        return AttrDict(
            density=self.density_activation(h_density),
            signed_distance=self.sdf_activation(h_sdf),
        )


class MLPNeRSTFModel(MLPModel):
    def __init__(
        self,
        sdf_activation="tanh",
        density_activation="exp",
        channel_activation="sigmoid",
        direction_dependent_shape: bool = True,  # To be able to load old models. Set this to be False in future models.
        separate_nerf_channels: bool = False,
        separate_coarse_channels: bool = False,
        initial_density_bias: float = 0.0,
        initial_sdf_bias: float = -0.1,
        **kwargs,
    ):
        h_map, h_directionless_map = indices_for_output_mode(
            direction_dependent_shape=direction_dependent_shape,
            separate_nerf_channels=separate_nerf_channels,
            separate_coarse_channels=separate_coarse_channels,
        )
        n_output = index_mapping_max(h_map)
        super().__init__(
            n_output=n_output,
            output_activation="identity",
            **kwargs,
        )
        self.direction_dependent_shape = direction_dependent_shape
        self.separate_nerf_channels = separate_nerf_channels
        self.separate_coarse_channels = separate_coarse_channels
        self.sdf_activation = get_act(sdf_activation)
        self.density_activation = get_act(density_activation)
        self.channel_activation = get_act(channel_activation)
        self.h_map = h_map
        self.h_directionless_map = h_directionless_map
        self.mlp[-1].bias.data.zero_()
        layer = -1 if self.direction_dependent_shape else self.insert_direction_at
        self.mlp[layer].bias[0].data.fill_(initial_sdf_bias)
        self.mlp[layer].bias[1].data.fill_(initial_density_bias)

    def forward(
        self,
        query: Query,
        params: Optional[Dict[str, torch.Tensor]] = None,
        options: Optional[Dict[str, Any]] = None,
    ) -> AttrDict[str, Any]:
        options = AttrDict() if options is None else AttrDict(options)
        h, h_directionless = self._mlp(
            query.position, query.direction, params=params, options=options
        )
        activations = map_indices_to_keys(self.h_map, h)
        activations.update(map_indices_to_keys(self.h_directionless_map, h_directionless))

        if options.nerf_level == "coarse":
            h_density = activations.density_coarse
        else:
            h_density = activations.density_fine

        if options.get("rendering_mode", "stf") == "nerf":
            if options.nerf_level == "coarse":
                h_channels = activations.nerf_coarse
            else:
                h_channels = activations.nerf_fine
        else:
            h_channels = activations.stf
        return AttrDict(
            density=self.density_activation(h_density),
            signed_distance=self.sdf_activation(activations.sdf),
            channels=self.channel_activation(h_channels),
        )


IndexMapping = AttrDict[str, Tuple[int, int]]


def indices_for_output_mode(
    direction_dependent_shape: bool,
    separate_nerf_channels: bool,
    separate_coarse_channels: bool,
) -> Tuple[IndexMapping, IndexMapping]:
    """
    Get output mappings for (h, h_directionless).
    """
    h_map = AttrDict()
    h_directionless_map = AttrDict()
    if direction_dependent_shape:
        h_map.sdf = (0, 1)
        if separate_coarse_channels:
            assert separate_nerf_channels
            h_map.density_coarse = (1, 2)
            h_map.density_fine = (2, 3)
            h_map.stf = (3, 6)
            h_map.nerf_coarse = (6, 9)
            h_map.nerf_fine = (9, 12)
        else:
            h_map.density_coarse = (1, 2)
            h_map.density_fine = (1, 2)
            if separate_nerf_channels:
                h_map.stf = (2, 5)
                h_map.nerf_coarse = (5, 8)
                h_map.nerf_fine = (5, 8)
            else:
                h_map.stf = (2, 5)
                h_map.nerf_coarse = (2, 5)
                h_map.nerf_fine = (2, 5)
    else:
        h_directionless_map.sdf = (0, 1)
        h_directionless_map.density_coarse = (1, 2)
        if separate_coarse_channels:
            h_directionless_map.density_fine = (2, 3)
        else:
            h_directionless_map.density_fine = h_directionless_map.density_coarse
        h_map.stf = (0, 3)
        if separate_coarse_channels:
            assert separate_nerf_channels
            h_map.nerf_coarse = (3, 6)
            h_map.nerf_fine = (6, 9)
        else:
            if separate_nerf_channels:
                h_map.nerf_coarse = (3, 6)
            else:
                h_map.nerf_coarse = (0, 3)
            h_map.nerf_fine = h_map.nerf_coarse
    return h_map, h_directionless_map


def map_indices_to_keys(mapping: IndexMapping, data: torch.Tensor) -> AttrDict[str, torch.Tensor]:
    return AttrDict({k: data[..., start:end] for k, (start, end) in mapping.items()})


def index_mapping_max(mapping: IndexMapping) -> int:
    return max(end for _, (_, end) in mapping.items())


================================================
FILE: shap_e/models/nerstf/renderer.py
================================================
from functools import partial
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import torch

from shap_e.models.nerf.model import NeRFModel
from shap_e.models.nerf.ray import RayVolumeIntegral, StratifiedRaySampler, render_rays
from shap_e.models.nn.meta import subdict
from shap_e.models.nn.utils import to_torch
from shap_e.models.query import Query
from shap_e.models.renderer import RayRenderer, render_views_from_rays
from shap_e.models.stf.base import Model
from shap_e.models.stf.renderer import STFRendererBase, render_views_from_stf
from shap_e.models.volume import BoundingBoxVolume, Volume
from shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR
from shap_e.util.collections import AttrDict


class NeRSTFRenderer(RayRenderer, STFRendererBase):
    def __init__(
        self,
        sdf: Optional[Model],
        tf: Optional[Model],
        nerstf: Optional[Model],
        void: NeRFModel,
        volume: Volume,
        grid_size: int,
        n_coarse_samples: int,
        n_fine_samples: int,
        importance_sampling_options: Optional[Dict[str, Any]] = None,
        separate_shared_samples: bool = False,
        texture_channels: Sequence[str] = ("R", "G", "B"),
        channel_scale: Sequence[float] = (255.0, 255.0, 255.0),
        ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR,
        diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR,
        specular_color: Union[float, Tuple[float]] = 0.0,
        output_srgb: bool = True,
        device: torch.device = torch.device("cuda"),
        **kwargs,
    ):
        super().__init__(**kwargs)
        assert isinstance(volume, BoundingBoxVolume), "cannot sample points in unknown volume"
        assert (nerstf is not None) ^ (sdf is not None and tf is not None)
        self.sdf = sdf
        self.tf = tf
        self.nerstf = nerstf
        self.void = void
        self.volume = volume
        self.grid_size = grid_size
        self.n_coarse_samples = n_coarse_samples
        self.n_fine_samples = n_fine_samples
        self.importance_sampling_options = AttrDict(importance_sampling_options or {})
        self.separate_shared_samples = separate_shared_samples
        self.texture_channels = texture_channels
        self.channel_scale = to_torch(channel_scale).to(device)
        self.ambient_color = ambient_color
        self.diffuse_color = diffuse_color
        self.specular_color = specular_color
        self.output_srgb = output_srgb
        self.device = device
        self.to(device)

    def _query(
        self,
        query: Query,
        params: AttrDict[str, torch.Tensor],
        options: AttrDict[str, Any],
    ) -> AttrDict:
        no_dir_query = query.copy()
        no_dir_query.direction = None

        if options.get("rendering_mode", "stf") == "stf":
            assert query.direction is None

        if self.nerstf is not None:
      
Download .txt
gitextract__yazl4ub/

├── .gitignore
├── LICENSE
├── README.md
├── model-card.md
├── samples.md
├── setup.py
└── shap_e/
    ├── __init__.py
    ├── diffusion/
    │   ├── __init__.py
    │   ├── gaussian_diffusion.py
    │   ├── k_diffusion.py
    │   └── sample.py
    ├── examples/
    │   ├── encode_model.ipynb
    │   ├── example_data/
    │   │   └── cactus/
    │   │       ├── material.mtl
    │   │       └── object.obj
    │   ├── sample_image_to_3d.ipynb
    │   └── sample_text_to_3d.ipynb
    ├── models/
    │   ├── __init__.py
    │   ├── configs.py
    │   ├── download.py
    │   ├── generation/
    │   │   ├── __init__.py
    │   │   ├── latent_diffusion.py
    │   │   ├── perceiver.py
    │   │   ├── pooled_mlp.py
    │   │   ├── pretrained_clip.py
    │   │   ├── transformer.py
    │   │   └── util.py
    │   ├── nerf/
    │   │   ├── __init__.py
    │   │   ├── model.py
    │   │   ├── ray.py
    │   │   └── renderer.py
    │   ├── nerstf/
    │   │   ├── mlp.py
    │   │   └── renderer.py
    │   ├── nn/
    │   │   ├── __init__.py
    │   │   ├── camera.py
    │   │   ├── checkpoint.py
    │   │   ├── encoding.py
    │   │   ├── meta.py
    │   │   ├── ops.py
    │   │   ├── pointnet2_utils.py
    │   │   └── utils.py
    │   ├── query.py
    │   ├── renderer.py
    │   ├── stf/
    │   │   ├── __init__.py
    │   │   ├── base.py
    │   │   ├── mlp.py
    │   │   └── renderer.py
    │   ├── transmitter/
    │   │   ├── __init__.py
    │   │   ├── base.py
    │   │   ├── bottleneck.py
    │   │   ├── channels_encoder.py
    │   │   ├── multiview_encoder.py
    │   │   ├── params_proj.py
    │   │   └── pc_encoder.py
    │   └── volume.py
    ├── rendering/
    │   ├── __init__.py
    │   ├── _mc_table.py
    │   ├── blender/
    │   │   ├── __init__.py
    │   │   ├── blender_script.py
    │   │   ├── constants.py
    │   │   ├── render.py
    │   │   └── view_data.py
    │   ├── mc.py
    │   ├── mesh.py
    │   ├── ply_util.py
    │   ├── point_cloud.py
    │   ├── pytorch3d_util.py
    │   ├── raycast/
    │   │   ├── __init__.py
    │   │   ├── _utils.py
    │   │   ├── cast.py
    │   │   ├── render.py
    │   │   └── types.py
    │   ├── torch_mesh.py
    │   └── view_data.py
    └── util/
        ├── __init__.py
        ├── collections.py
        ├── data_util.py
        ├── image_util.py
        ├── io.py
        └── notebooks.py
Download .txt
SYMBOL INDEX (656 symbols across 54 files)

FILE: shap_e/diffusion/gaussian_diffusion.py
  function diffusion_from_config (line 14) | def diffusion_from_config(config: Union[str, Dict[str, Any]]) -> "Gaussi...
  function get_beta_schedule (line 45) | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffus...
  function get_named_beta_schedule (line 59) | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, **ex...
  function betas_for_alpha_bar (line 102) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
  function space_timesteps (line 122) | def space_timesteps(num_timesteps, section_counts):
  class GaussianDiffusion (line 175) | class GaussianDiffusion:
    method __init__ (line 192) | def __init__(
    method get_sigmas (line 246) | def get_sigmas(self, t):
    method q_mean_variance (line 249) | def q_mean_variance(self, x_start, t):
    method q_sample (line 262) | def q_sample(self, x_start, t, noise=None):
    method q_posterior_mean_variance (line 281) | def q_posterior_mean_variance(self, x_start, x_t, t):
    method p_mean_variance (line 305) | def p_mean_variance(
    method _predict_xstart_from_eps (line 400) | def _predict_xstart_from_eps(self, x_t, t, eps):
    method _predict_xstart_from_xprev (line 407) | def _predict_xstart_from_xprev(self, x_t, t, xprev):
    method _predict_eps_from_xstart (line 417) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
    method condition_mean (line 422) | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
    method condition_score (line 435) | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
    method p_sample (line 455) | def p_sample(
    method p_sample_loop (line 499) | def p_sample_loop(
    method p_sample_loop_progressive (line 547) | def p_sample_loop_progressive(
    method ddim_sample (line 598) | def ddim_sample(
    method ddim_reverse_sample (line 648) | def ddim_reverse_sample(
    method ddim_sample_loop (line 686) | def ddim_sample_loop(
    method ddim_sample_loop_progressive (line 722) | def ddim_sample_loop_progressive(
    method _vb_terms_bpd (line 773) | def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=False, m...
    method training_losses (line 810) | def training_losses(
    method _prior_bpd (line 901) | def _prior_bpd(self, x_start):
    method calc_bpd_loop (line 917) | def calc_bpd_loop(self, model, x_start, clip_denoised=False, model_kwa...
    method scale_channels (line 974) | def scale_channels(self, x: th.Tensor) -> th.Tensor:
    method unscale_channels (line 985) | def unscale_channels(self, x: th.Tensor) -> th.Tensor:
    method unscale_out_dict (line 996) | def unscale_out_dict(
  class SpacedDiffusion (line 1004) | class SpacedDiffusion(GaussianDiffusion):
    method __init__ (line 1012) | def __init__(self, use_timesteps: Iterable[int], **kwargs):
    method p_mean_variance (line 1028) | def p_mean_variance(self, model, *args, **kwargs):
    method training_losses (line 1031) | def training_losses(self, model, *args, **kwargs):
    method condition_mean (line 1034) | def condition_mean(self, cond_fn, *args, **kwargs):
    method condition_score (line 1037) | def condition_score(self, cond_fn, *args, **kwargs):
    method _wrap_model (line 1040) | def _wrap_model(self, model):
  class _WrappedModel (line 1046) | class _WrappedModel:
    method __init__ (line 1047) | def __init__(self, model, timestep_map, original_num_steps):
    method __call__ (line 1052) | def __call__(self, x, ts, **kwargs):
  function _extract_into_tensor (line 1058) | def _extract_into_tensor(arr, timesteps, broadcast_shape):
  function normal_kl (line 1074) | def normal_kl(mean1, logvar1, mean2, logvar2):
  function approx_standard_normal_cdf (line 1102) | def approx_standard_normal_cdf(x):
  function discretized_gaussian_log_likelihood (line 1110) | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
  function mean_flat (line 1139) | def mean_flat(tensor):

FILE: shap_e/diffusion/k_diffusion.py
  class KarrasDenoiser (line 31) | class KarrasDenoiser:
    method __init__ (line 32) | def __init__(self, sigma_data: float = 0.5):
    method get_snr (line 35) | def get_snr(self, sigmas):
    method get_sigmas (line 38) | def get_sigmas(self, sigmas):
    method get_scalings (line 41) | def get_scalings(self, sigma):
    method training_losses (line 47) | def training_losses(self, model, x_start, sigmas, model_kwargs=None, n...
    method denoise (line 71) | def denoise(self, model, x_t, sigmas, **model_kwargs):
  class GaussianToKarrasDenoiser (line 79) | class GaussianToKarrasDenoiser:
    method __init__ (line 80) | def __init__(self, model, diffusion):
    method sigma_to_t (line 89) | def sigma_to_t(self, sigma):
    method denoise (line 98) | def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None):
  function karras_sample (line 111) | def karras_sample(*args, **kwargs):
  function karras_sample_progressive (line 118) | def karras_sample_progressive(
  function get_sigmas_karras (line 194) | def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
  function to_d (line 203) | def to_d(x, sigma, denoised):
  function get_ancestral_step (line 208) | def get_ancestral_step(sigma_from, sigma_to):
  function sample_euler_ancestral (line 217) | def sample_euler_ancestral(model, x, sigmas, progress=False):
  function sample_heun (line 239) | def sample_heun(
  function sample_dpm (line 283) | def sample_dpm(
  function append_dims (line 323) | def append_dims(x, target_dims):
  function append_zero (line 331) | def append_zero(x):

FILE: shap_e/diffusion/sample.py
  function uncond_guide_model (line 15) | def uncond_guide_model(
  function sample_latents (line 31) | def sample_latents(

FILE: shap_e/models/configs.py
  function model_from_config (line 38) | def model_from_config(config: Union[str, Dict[str, Any]], device: torch....

FILE: shap_e/models/download.py
  function default_cache_dir (line 45) | def default_cache_dir() -> str:
  function fetch_file_cached (line 49) | def fetch_file_cached(
  function check_hash (line 85) | def check_hash(path: str, expected_hash: str):
  function hash_file (line 94) | def hash_file(path: str) -> str:
  function load_config (line 105) | def load_config(
  function load_checkpoint (line 122) | def load_checkpoint(
  function load_model (line 139) | def load_model(

FILE: shap_e/models/generation/latent_diffusion.py
  class SplitVectorDiffusion (line 7) | class SplitVectorDiffusion(nn.Module):
    method __init__ (line 8) | def __init__(self, *, device: torch.device, wrapped: nn.Module, n_ctx:...
    method forward (line 18) | def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs):

FILE: shap_e/models/generation/perceiver.py
  class MultiheadCrossAttention (line 13) | class MultiheadCrossAttention(nn.Module):
    method __init__ (line 14) | def __init__(
    method forward (line 42) | def forward(self, x, data):
  class QKVMultiheadCrossAttention (line 50) | class QKVMultiheadCrossAttention(nn.Module):
    method __init__ (line 51) | def __init__(
    method forward (line 61) | def forward(self, q, kv):
  class ResidualCrossAttentionBlock (line 77) | class ResidualCrossAttentionBlock(nn.Module):
    method __init__ (line 78) | def __init__(
    method forward (line 110) | def forward(self, x: torch.Tensor, data: torch.Tensor):
  class SimplePerceiver (line 116) | class SimplePerceiver(nn.Module):
    method __init__ (line 121) | def __init__(
    method forward (line 155) | def forward(self, x: torch.Tensor, data: torch.Tensor):
  class PointDiffusionPerceiver (line 161) | class PointDiffusionPerceiver(nn.Module):
    method __init__ (line 162) | def __init__(
    method forward (line 224) | def forward(self, x: torch.Tensor, t: torch.Tensor):

FILE: shap_e/models/generation/pooled_mlp.py
  class PooledMLP (line 7) | class PooledMLP(nn.Module):
    method __init__ (line 8) | def __init__(
    method forward (line 32) | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
  class ResBlock (line 41) | class ResBlock(nn.Module):
    method __init__ (line 42) | def __init__(self, hidden_size: int, pool_op: str, device: torch.device):
    method forward (line 59) | def forward(self, x: torch.Tensor):
  function pool (line 67) | def pool(op_name: str, x: torch.Tensor) -> torch.Tensor:

FILE: shap_e/models/generation/pretrained_clip.py
  class ImageCLIP (line 13) | class ImageCLIP(nn.Module):
    method __init__ (line 19) | def __init__(
    method feature_dim (line 47) | def feature_dim(self) -> int:
    method grid_size (line 54) | def grid_size(self) -> int:
    method grid_feature_dim (line 61) | def grid_feature_dim(self) -> int:
    method forward (line 67) | def forward(
    method _static_multimodal_embed (line 120) | def _static_multimodal_embed(
    method embed_images (line 159) | def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Ten...
    method embed_text (line 168) | def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
    method embed_images_grid (line 177) | def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torc...
    method images_to_tensor (line 215) | def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch...
  class FrozenImageCLIP (line 219) | class FrozenImageCLIP:
    method __init__ (line 220) | def __init__(self, device: torch.device, **kwargs):
    method feature_dim (line 226) | def feature_dim(self) -> int:
    method grid_size (line 230) | def grid_size(self) -> int:
    method grid_feature_dim (line 234) | def grid_feature_dim(self) -> int:
    method __call__ (line 237) | def __call__(
    method embed_images (line 249) | def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Ten...
    method embed_text (line 253) | def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
    method embed_images_grid (line 257) | def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torc...
  function _image_to_pil (line 262) | def _image_to_pil(obj: Optional[ImageType]) -> Image.Image:

FILE: shap_e/models/generation/transformer.py
  function init_linear (line 13) | def init_linear(l, stddev):
  class MultiheadAttention (line 19) | class MultiheadAttention(nn.Module):
    method __init__ (line 20) | def __init__(
    method forward (line 40) | def forward(self, x):
  class MLP (line 47) | class MLP(nn.Module):
    method __init__ (line 48) | def __init__(self, *, device: torch.device, dtype: torch.dtype, width:...
    method forward (line 57) | def forward(self, x):
  class QKVMultiheadAttention (line 61) | class QKVMultiheadAttention(nn.Module):
    method __init__ (line 62) | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads:...
    method forward (line 69) | def forward(self, qkv):
  class ResidualAttentionBlock (line 83) | class ResidualAttentionBlock(nn.Module):
    method __init__ (line 84) | def __init__(
    method forward (line 108) | def forward(self, x: torch.Tensor):
  class Transformer (line 114) | class Transformer(nn.Module):
    method __init__ (line 115) | def __init__(
    method forward (line 145) | def forward(self, x: torch.Tensor):
  class PointDiffusionTransformer (line 151) | class PointDiffusionTransformer(nn.Module):
    method __init__ (line 152) | def __init__(
    method forward (line 203) | def forward(self, x: torch.Tensor, t: torch.Tensor):
    method _forward_with_cond (line 213) | def _forward_with_cond(
  class CLIPImagePointDiffusionTransformer (line 239) | class CLIPImagePointDiffusionTransformer(PointDiffusionTransformer):
    method __init__ (line 240) | def __init__(
    method cached_model_kwargs (line 262) | def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str,...
    method forward (line 266) | def forward(
  class CLIPImageGridPointDiffusionTransformer (line 301) | class CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer):
    method __init__ (line 302) | def __init__(
    method cached_model_kwargs (line 330) | def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str,...
    method forward (line 335) | def forward(
  class UpsamplePointDiffusionTransformer (line 371) | class UpsamplePointDiffusionTransformer(PointDiffusionTransformer):
    method __init__ (line 372) | def __init__(
    method forward (line 404) | def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch....
    method _embed_low_res (line 417) | def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor:
  class CLIPImageGridUpsamplePointDiffusionTransformer (line 425) | class CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffus...
    method __init__ (line 426) | def __init__(
    method cached_model_kwargs (line 449) | def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str,...
    method forward (line 457) | def forward(

FILE: shap_e/models/generation/util.py
  function timestep_embedding (line 6) | def timestep_embedding(timesteps, dim, max_period=10000):

FILE: shap_e/models/nerf/model.py
  class NeRFModel (line 18) | class NeRFModel(ABC):
    method forward (line 24) | def forward(
  class VoidNeRFModel (line 41) | class VoidNeRFModel(MetaModule, NeRFModel):
    method __init__ (line 47) | def __init__(
    method forward (line 64) | def forward(
  class MLPNeRFModel (line 83) | class MLPNeRFModel(MetaModule, NeRFModel):
    method __init__ (line 84) | def __init__(
    method encode_position (line 170) | def encode_position(self, query: Query):
    method forward (line 174) | def forward(
  function maybe_get_spherical_harmonics_basis (line 241) | def maybe_get_spherical_harmonics_basis(

FILE: shap_e/models/nerf/ray.py
  function render_rays (line 15) | def render_rays(
  class RayVolumeIntegralResults (line 133) | class RayVolumeIntegralResults:
    method combine (line 162) | def combine(self, cur: "RayVolumeIntegralResults") -> "RayVolumeIntegr...
  class RayVolumeIntegral (line 208) | class RayVolumeIntegral:
    method render_rays (line 214) | def render_rays(
    method integrate_samples (line 297) | def integrate_samples(
    method _merge_results (line 381) | def _merge_results(
  class RaySampler (line 399) | class RaySampler(ABC):
    method sample (line 401) | def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -...
  class StratifiedRaySampler (line 410) | class StratifiedRaySampler(RaySampler):
    method __init__ (line 416) | def __init__(self, depth_mode: str = "linear"):
    method sample (line 424) | def sample(
  class ImportanceRaySampler (line 459) | class ImportanceRaySampler(RaySampler):
    method __init__ (line 465) | def __init__(
    method sample (line 485) | def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -...

FILE: shap_e/models/nerf/renderer.py
  class TwoStepNeRFRenderer (line 15) | class TwoStepNeRFRenderer(RayRenderer):
    method __init__ (line 21) | def __init__(
    method render_rays (line 69) | def render_rays(
  class OneStepNeRFRenderer (line 199) | class OneStepNeRFRenderer(RayRenderer):
    method __init__ (line 205) | def __init__(
    method render_rays (line 232) | def render_rays(

FILE: shap_e/models/nerstf/mlp.py
  class MLPDensitySDFModel (line 11) | class MLPDensitySDFModel(MLPModel):
    method __init__ (line 12) | def __init__(
    method forward (line 28) | def forward(
  class MLPNeRSTFModel (line 45) | class MLPNeRSTFModel(MLPModel):
    method __init__ (line 46) | def __init__(
    method forward (line 82) | def forward(
  function indices_for_output_mode (line 117) | def indices_for_output_mode(
  function map_indices_to_keys (line 168) | def map_indices_to_keys(mapping: IndexMapping, data: torch.Tensor) -> At...
  function index_mapping_max (line 172) | def index_mapping_max(mapping: IndexMapping) -> int:

FILE: shap_e/models/nerstf/renderer.py
  class NeRSTFRenderer (line 19) | class NeRSTFRenderer(RayRenderer, STFRendererBase):
    method __init__ (line 20) | def __init__(
    method _query (line 63) | def _query(
    method render_rays (line 92) | def render_rays(
    method render_views (line 185) | def render_views(
    method get_signed_distance (line 269) | def get_signed_distance(
    method get_texture (line 280) | def get_texture(

FILE: shap_e/models/nn/camera.py
  class DifferentiableCamera (line 12) | class DifferentiableCamera(ABC):
    method camera_rays (line 18) | def camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
    method resize_image (line 29) | def resize_image(self, width: int, height: int) -> "DifferentiableCame...
  class DifferentiableProjectiveCamera (line 37) | class DifferentiableProjectiveCamera(DifferentiableCamera):
    method __post_init__ (line 51) | def __post_init__(self):
    method resolution (line 62) | def resolution(self):
    method fov (line 65) | def fov(self):
    method image_coords (line 68) | def image_coords(self) -> torch.Tensor:
    method camera_rays (line 82) | def camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
    method resize_image (line 112) | def resize_image(self, width: int, height: int) -> "DifferentiableProj...
  class DifferentiableCameraBatch (line 130) | class DifferentiableCameraBatch(ABC):
  function normalize (line 139) | def normalize(vec: torch.Tensor) -> torch.Tensor:
  function project_out (line 143) | def project_out(vec1: torch.Tensor, vec2: torch.Tensor) -> torch.Tensor:
  function camera_orientation (line 152) | def camera_orientation(toward: torch.Tensor, up: Optional[torch.Tensor] ...
  function projective_camera_frame (line 175) | def projective_camera_frame(
  function get_image_coords (line 202) | def get_image_coords(width, height) -> torch.Tensor:

FILE: shap_e/models/nn/checkpoint.py
  function checkpoint (line 7) | def checkpoint(
  class CheckpointFunction (line 29) | class CheckpointFunction(torch.autograd.Function):
    method forward (line 32) | def forward(ctx, run_function, length, *args):
    method backward (line 44) | def backward(ctx, *output_grads):
  class CheckpointFunctionGradFunction (line 59) | class CheckpointFunctionGradFunction(torch.autograd.Function):
    method forward (line 62) | def forward(ctx, run_function, length_1, length_2, *args):
    method backward (line 87) | def backward(ctx, *all_output_grads):

FILE: shap_e/models/nn/encoding.py
  function encode_position (line 9) | def encode_position(version: str, *, position: torch.Tensor):
  function encode_channels (line 20) | def encode_channels(version: str, *, channels: torch.Tensor):
  function position_encoding_channels (line 31) | def position_encoding_channels(version: Optional[str] = None) -> int:
  function channel_encoding_channels (line 37) | def channel_encoding_channels(version: Optional[str] = None) -> int:
  class PosEmbLinear (line 43) | class PosEmbLinear(nn.Linear):
    method __init__ (line 44) | def __init__(
    method forward (line 54) | def forward(self, x: torch.Tensor):
  class MultiviewPoseEmbedding (line 60) | class MultiviewPoseEmbedding(nn.Conv2d):
    method __init__ (line 61) | def __init__(
    method forward (line 84) | def forward(
  class MultiviewPointCloudEmbedding (line 116) | class MultiviewPointCloudEmbedding(nn.Conv2d):
    method __init__ (line 117) | def __init__(
    method forward (line 144) | def forward(
  function maybe_encode_direction (line 180) | def maybe_encode_direction(
  function posenc_nerf (line 200) | def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) ->...
  function get_scales (line 217) | def get_scales(
  function spherical_harmonics_basis (line 226) | def spherical_harmonics_basis(

FILE: shap_e/models/nn/meta.py
  function subdict (line 46) | def subdict(dictionary, key=None):
  function superdict (line 61) | def superdict(dictionary, key=None):
  function leveldict (line 69) | def leveldict(dictionary, depth=0):
  function leveliter (line 73) | def leveliter(dictionary, depth=0):
  class MetaModule (line 82) | class MetaModule(nn.Module):
    method __init__ (line 103) | def __init__(self, *args, **kwargs):
    method register_meta_buffer (line 108) | def register_meta_buffer(self, name: str, param: nn.Parameter):
    method register_meta_parameter (line 116) | def register_meta_parameter(self, name: str, parameter: nn.Parameter):
    method register_meta (line 125) | def register_meta(self, name: str, parameter: nn.Parameter, trainable:...
    method register (line 131) | def register(self, name: str, parameter: nn.Parameter, meta: bool, tra...
    method named_meta_parameters (line 143) | def named_meta_parameters(self, prefix="", recurse=True):
    method named_nonmeta_parameters (line 162) | def named_nonmeta_parameters(self, prefix="", recurse=True):
    method nonmeta_parameters (line 177) | def nonmeta_parameters(self, prefix="", recurse=True):
    method meta_state_dict (line 181) | def meta_state_dict(self, prefix="", recurse=True):
    method update (line 203) | def update(self, params=None):
  function batch_meta_parameters (line 221) | def batch_meta_parameters(net, batch_size):
  function batch_meta_state_dict (line 228) | def batch_meta_state_dict(net, batch_size):

FILE: shap_e/models/nn/ops.py
  function gelu (line 15) | def gelu(x):
  function swish (line 19) | def swish(x):
  function quick_gelu (line 23) | def quick_gelu(x):
  function torch_gelu (line 27) | def torch_gelu(x):
  function geglu (line 31) | def geglu(x):
  class SirenSin (line 36) | class SirenSin:
    method __init__ (line 37) | def __init__(self, w0=30.0):
    method __call__ (line 40) | def __call__(self, x):
  function get_act (line 44) | def get_act(name):
  function zero_init (line 64) | def zero_init(affine):
  function siren_init_first_layer (line 70) | def siren_init_first_layer(affine, init_scale: float = 1.0):
  function siren_init (line 78) | def siren_init(affine, coeff=1.0, init_scale: float = 1.0):
  function siren_init_30 (line 86) | def siren_init_30(affine, init_scale: float = 1.0):
  function std_init (line 90) | def std_init(affine, init_scale: float = 1.0):
  function mlp_init (line 98) | def mlp_init(affines, init: Optional[str] = None, init_scale: float = 1.0):
  class MetaLinear (line 114) | class MetaLinear(MetaModule):
    method __init__ (line 115) | def __init__(
    method reset_parameters (line 148) | def reset_parameters(self) -> None:
    method _bcast (line 161) | def _bcast(self, op, left, right):
    method forward (line 167) | def forward(self, x, params=None):
  function Conv (line 191) | def Conv(n_dim, d_in, d_out, kernel, stride=1, padding=0, dilation=1, **...
  function flatten (line 200) | def flatten(x):
  function unflatten (line 208) | def unflatten(x, info):
  function torchify (line 213) | def torchify(x):
  function untorchify (line 218) | def untorchify(x):
  class MLP (line 223) | class MLP(nn.Module):
    method __init__ (line 224) | def __init__(
    method forward (line 251) | def forward(self, h, options: Optional[AttrDict] = None, log_prefix: s...
  class MetaMLP (line 260) | class MetaMLP(MetaModule):
    method __init__ (line 261) | def __init__(
    method forward (line 300) | def forward(self, h, params=None, options: Optional[AttrDict] = None, ...
  class LayerNorm (line 311) | class LayerNorm(nn.LayerNorm):
    method __init__ (line 312) | def __init__(
    method forward (line 319) | def forward(self, input):
  class PointSetEmbedding (line 328) | class PointSetEmbedding(nn.Module):
    method __init__ (line 329) | def __init__(
    method forward (line 370) | def forward(self, xyz, points):
    method apply_conv (line 403) | def apply_conv(self, points: torch.Tensor, conv: nn.Module):

FILE: shap_e/models/nn/pointnet2_utils.py
  function timeit (line 35) | def timeit(tag, t):
  function pc_normalize (line 40) | def pc_normalize(pc):
  function square_distance (line 49) | def square_distance(src, dst):
  function index_points (line 73) | def index_points(points, idx):
  function farthest_point_sample (line 95) | def farthest_point_sample(xyz, npoint, deterministic=False):
  function query_ball_point (line 122) | def query_ball_point(radius, nsample, xyz, new_xyz):
  function sample_and_group (line 145) | def sample_and_group(
  function sample_and_group_all (line 192) | def sample_and_group_all(xyz, points):
  class PointNetSetAbstraction (line 212) | class PointNetSetAbstraction(nn.Module):
    method __init__ (line 213) | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
    method forward (line 227) | def forward(self, xyz, points):
  class PointNetSetAbstractionMsg (line 258) | class PointNetSetAbstractionMsg(nn.Module):
    method __init__ (line 259) | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_...
    method forward (line 277) | def forward(self, xyz, points):
  class PointNetFeaturePropagation (line 318) | class PointNetFeaturePropagation(nn.Module):
    method __init__ (line 319) | def __init__(self, in_channel, mlp):
    method forward (line 329) | def forward(self, xyz1, xyz2, points1, points2):

FILE: shap_e/models/nn/utils.py
  function to_torch (line 9) | def to_torch(arr: ArrayType, dtype=torch.float):
  function sample_pmf (line 15) | def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor:
  function safe_divide (line 36) | def safe_divide(a, b, epsilon=1e-6):

FILE: shap_e/models/query.py
  class Query (line 8) | class Query:
    method copy (line 16) | def copy(self) -> "Query":
    method map_tensors (line 24) | def map_tensors(self, f: Callable[[torch.Tensor], torch.Tensor]) -> "Q...

FILE: shap_e/models/renderer.py
  class Renderer (line 17) | class Renderer(MetaModule):
    method render_views (line 25) | def render_views(
  class RayRenderer (line 55) | class RayRenderer(Renderer):
    method render_rays (line 63) | def render_rays(
    method render_views (line 76) | def render_views(
    method forward (line 92) | def forward(
  function get_camera_from_batch (line 146) | def get_camera_from_batch(batch: AttrDict) -> Tuple[DifferentiableCamera...
  function append_tensor (line 166) | def append_tensor(val_list: Optional[List[torch.Tensor]], output: Option...
  function render_views_from_rays (line 172) | def render_views_from_rays(

FILE: shap_e/models/stf/base.py
  class Model (line 11) | class Model(ABC):
    method forward (line 13) | def forward(
    method forward_batched (line 23) | def forward_batched(

FILE: shap_e/models/stf/mlp.py
  class MLPModel (line 17) | class MLPModel(MetaModule, Model):
    method __init__ (line 18) | def __init__(
    method forward (line 107) | def forward(
    method _run_mlp (line 125) | def _run_mlp(
    method _mlp (line 152) | def _mlp(
  class MLPSDFModel (line 183) | class MLPSDFModel(MLPModel):
    method __init__ (line 184) | def __init__(self, initial_bias: float = -0.1, **kwargs):
    method forward (line 188) | def forward(
  class MLPTextureFieldModel (line 198) | class MLPTextureFieldModel(MLPModel):
    method __init__ (line 199) | def __init__(
    method forward (line 206) | def forward(

FILE: shap_e/models/stf/renderer.py
  class STFRendererBase (line 25) | class STFRendererBase(ABC):
    method get_signed_distance (line 27) | def get_signed_distance(
    method get_texture (line 36) | def get_texture(
  class STFRenderer (line 45) | class STFRenderer(Renderer, STFRendererBase):
    method __init__ (line 46) | def __init__(
    method render_views (line 76) | def render_views(
    method get_signed_distance (line 106) | def get_signed_distance(
    method get_texture (line 118) | def get_texture(
  function render_views_from_stf (line 131) | def render_views_from_stf(
  function _render_with_pytorch3d (line 315) | def _render_with_pytorch3d(
  function _render_with_raycast (line 387) | def _render_with_raycast(
  function _convert_srgb_to_linear (line 458) | def _convert_srgb_to_linear(u: torch.Tensor) -> torch.Tensor:
  function _convert_linear_to_srgb (line 462) | def _convert_linear_to_srgb(u: torch.Tensor) -> torch.Tensor:
  function cross_entropy_sdf_loss (line 466) | def cross_entropy_sdf_loss(fields: torch.Tensor):
  function slice_fields (line 484) | def slice_fields(fields: torch.Tensor, dim: int, start: int, end: int):
  function volume_query_points (line 495) | def volume_query_points(

FILE: shap_e/models/transmitter/base.py
  class Encoder (line 14) | class Encoder(nn.Module, ABC):
    method __init__ (line 15) | def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tu...
    method forward (line 26) | def forward(self, batch: AttrDict, options: Optional[AttrDict] = None)...
  class VectorEncoder (line 32) | class VectorEncoder(Encoder):
    method __init__ (line 33) | def __init__(
    method forward (line 57) | def forward(self, batch: AttrDict, options: Optional[AttrDict] = None)...
    method encode_to_bottleneck (line 61) | def encode_to_bottleneck(
    method encode_to_vector (line 70) | def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict...
    method bottleneck_to_params (line 75) | def bottleneck_to_params(
  class ChannelsEncoder (line 82) | class ChannelsEncoder(VectorEncoder):
    method __init__ (line 83) | def __init__(
    method encode_to_channels (line 105) | def encode_to_channels(
    method encode_to_vector (line 113) | def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict...
    method bottleneck_to_channels (line 116) | def bottleneck_to_channels(
    method bottleneck_to_params (line 122) | def bottleneck_to_params(
  class Transmitter (line 131) | class Transmitter(nn.Module):
    method __init__ (line 132) | def __init__(self, encoder: Encoder, renderer: Renderer):
    method forward (line 137) | def forward(self, batch: AttrDict, options: Optional[AttrDict] = None)...
  class VectorDecoder (line 145) | class VectorDecoder(nn.Module):
    method __init__ (line 146) | def __init__(
    method bottleneck_to_params (line 169) | def bottleneck_to_params(
  class ChannelsDecoder (line 176) | class ChannelsDecoder(VectorDecoder):
    method __init__ (line 177) | def __init__(
    method bottleneck_to_channels (line 186) | def bottleneck_to_channels(
    method bottleneck_to_params (line 192) | def bottleneck_to_params(

FILE: shap_e/models/transmitter/bottleneck.py
  class LatentBottleneck (line 12) | class LatentBottleneck(nn.Module, ABC):
    method __init__ (line 13) | def __init__(self, *, device: torch.device, d_latent: int):
    method forward (line 19) | def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None)...
  class LatentWarp (line 23) | class LatentWarp(nn.Module, ABC):
    method __init__ (line 24) | def __init__(self, *, device: torch.device):
    method warp (line 29) | def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) ->...
    method unwarp (line 33) | def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) ...
  class IdentityLatentWarp (line 37) | class IdentityLatentWarp(LatentWarp):
    method warp (line 38) | def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) ->...
    method unwarp (line 42) | def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) ...
  class Tan2LatentWarp (line 47) | class Tan2LatentWarp(LatentWarp):
    method __init__ (line 48) | def __init__(self, *, coeff1: float = 1.0, device: torch.device):
    method warp (line 53) | def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) ->...
    method unwarp (line 57) | def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) ...
  class IdentityLatentBottleneck (line 62) | class IdentityLatentBottleneck(LatentBottleneck):
    method forward (line 63) | def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None)...
  class ClampNoiseBottleneck (line 68) | class ClampNoiseBottleneck(LatentBottleneck):
    method __init__ (line 69) | def __init__(self, *, device: torch.device, d_latent: int, noise_scale...
    method forward (line 73) | def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None)...
  class ClampDiffusionNoiseBottleneck (line 81) | class ClampDiffusionNoiseBottleneck(LatentBottleneck):
    method __init__ (line 82) | def __init__(
    method forward (line 94) | def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None)...
  function latent_bottleneck_from_config (line 106) | def latent_bottleneck_from_config(config: Dict[str, Any], device: torch....
  function latent_warp_from_config (line 118) | def latent_warp_from_config(config: Dict[str, Any], device: torch.device):

FILE: shap_e/models/transmitter/channels_encoder.py
  class TransformerChannelsEncoder (line 29) | class TransformerChannelsEncoder(ChannelsEncoder, ABC):
    method __init__ (line 35) | def __init__(
    method encode_input (line 85) | def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = ...
    method encode_to_channels (line 88) | def encode_to_channels(
  class PerceiverChannelsEncoder (line 101) | class PerceiverChannelsEncoder(ChannelsEncoder, ABC):
    method __init__ (line 107) | def __init__(
    method get_h_and_iterator (line 180) | def get_h_and_iterator(
    method encode_to_channels (line 190) | def encode_to_channels(
    method get_n_unrolls (line 208) | def get_n_unrolls(self):
  class DatasetIterator (line 221) | class DatasetIterator:
    method __iter__ (line 226) | def __iter__(self):
    method __next__ (line 230) | def __next__(self):
    method _reset (line 243) | def _reset(self):
    method _shuffle (line 247) | def _shuffle(self):
  class PointCloudTransformerChannelsEncoder (line 261) | class PointCloudTransformerChannelsEncoder(TransformerChannelsEncoder):
    method __init__ (line 267) | def __init__(
    method encode_input (line 279) | def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = ...
  class PointCloudPerceiverChannelsEncoder (line 286) | class PointCloudPerceiverChannelsEncoder(PerceiverChannelsEncoder):
    method __init__ (line 292) | def __init__(
    method get_h_and_iterator (line 464) | def get_h_and_iterator(
    method sample_pcl_fps (line 501) | def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor:
    method get_pcl_dataset (line 504) | def get_pcl_dataset(
    method get_multiview_dataset (line 518) | def get_multiview_dataset(
    method get_dense_pose_multiview_dataset (line 545) | def get_dense_pose_multiview_dataset(
    method get_pcl_and_multiview_pcl_dataset (line 572) | def get_pcl_and_multiview_pcl_dataset(
    method get_multiview_pcl_dataset (line 606) | def get_multiview_pcl_dataset(
    method encode_views (line 638) | def encode_views(self, batch: AttrDict) -> torch.Tensor:
    method encode_dense_pose_views (line 669) | def encode_dense_pose_views(self, batch: AttrDict) -> torch.Tensor:
    method encode_multiview_pcl (line 696) | def encode_multiview_pcl(self, batch: AttrDict, use_distance: bool = T...
    method views_to_tensor (line 728) | def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.I...
    method depths_to_tensor (line 750) | def depths_to_tensor(
    method view_alphas_to_tensor (line 776) | def view_alphas_to_tensor(
    method raw_depths_to_tensor (line 804) | def raw_depths_to_tensor(
    method cameras_to_tensor (line 829) | def cameras_to_tensor(
    method dense_pose_cameras_to_tensor (line 855) | def dense_pose_cameras_to_tensor(
  function sample_pcl_fps (line 916) | def sample_pcl_fps(points: torch.Tensor, data_ctx: int, method: str = "f...
  function sample_fps (line 939) | def sample_fps(example: torch.Tensor, n_samples: int) -> torch.Tensor:

FILE: shap_e/models/transmitter/multiview_encoder.py
  class MultiviewTransformerEncoder (line 16) | class MultiviewTransformerEncoder(VectorEncoder):
    method __init__ (line 22) | def __init__(
    method encode_to_vector (line 99) | def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict...
    method views_to_tensor (line 131) | def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.I...
    method depths_to_tensor (line 152) | def depths_to_tensor(
    method cameras_to_tensor (line 177) | def cameras_to_tensor(

FILE: shap_e/models/transmitter/params_proj.py
  function flatten_param_shapes (line 13) | def flatten_param_shapes(param_shapes: Dict[str, Tuple[int]]):
  class ParamsProj (line 21) | class ParamsProj(nn.Module, ABC):
    method __init__ (line 22) | def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tu...
    method forward (line 29) | def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None)...
  class LinearParamsProj (line 33) | class LinearParamsProj(ParamsProj):
    method __init__ (line 34) | def __init__(
    method forward (line 55) | def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None)...
  class MLPParamsProj (line 63) | class MLPParamsProj(ParamsProj):
    method __init__ (line 64) | def __init__(
    method forward (line 84) | def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None)...
  class ChannelsProj (line 92) | class ChannelsProj(nn.Module):
    method __init__ (line 93) | def __init__(
    method forward (line 125) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class ChannelsParamsProj (line 138) | class ChannelsParamsProj(ParamsProj):
    method __init__ (line 139) | def __init__(
    method forward (line 166) | def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None)...
  function params_proj_from_config (line 178) | def params_proj_from_config(
  function _sanitize_name (line 196) | def _sanitize_name(x: str) -> str:

FILE: shap_e/models/transmitter/pc_encoder.py
  class PointCloudTransformerEncoder (line 21) | class PointCloudTransformerEncoder(VectorEncoder):
    method __init__ (line 27) | def __init__(
    method encode_to_vector (line 77) | def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict...
  class PerceiverEncoder (line 90) | class PerceiverEncoder(VectorEncoder):
    method __init__ (line 96) | def __init__(
    method get_h_and_iterator (line 162) | def get_h_and_iterator(
    method encode_to_vector (line 172) | def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict...
    method get_n_unrolls (line 184) | def get_n_unrolls(self):
  class PointCloudPerceiverEncoder (line 196) | class PointCloudPerceiverEncoder(PerceiverEncoder):
    method __init__ (line 202) | def __init__(
    method get_h_and_iterator (line 262) | def get_h_and_iterator(
    method sample_pcl_fps (line 291) | def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor:
    method get_pcl_dataset (line 294) | def get_pcl_dataset(
    method get_multiview_dataset (line 302) | def get_multiview_dataset(
    method encode_views (line 323) | def encode_views(self, batch: AttrDict) -> torch.Tensor:
    method views_to_tensor (line 354) | def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.I...
    method depths_to_tensor (line 376) | def depths_to_tensor(
    method cameras_to_tensor (line 402) | def cameras_to_tensor(

FILE: shap_e/models/volume.py
  class VolumeRange (line 12) | class VolumeRange:
    method __post_init__ (line 17) | def __post_init__(self):
    method next_t0 (line 20) | def next_t0(self):
    method extend (line 28) | def extend(self, another: "VolumeRange") -> "VolumeRange":
    method partition (line 39) | def partition(self, ts) -> Tuple[torch.Tensor, torch.Tensor, torch.Ten...
  class Volume (line 61) | class Volume(ABC):
    method intersect (line 67) | def intersect(
  class BoundingBoxVolume (line 89) | class BoundingBoxVolume(MetaModule, Volume):
    method __init__ (line 94) | def __init__(
    method intersect (line 120) | def intersect(
  class UnboundedVolume (line 169) | class UnboundedVolume(MetaModule, Volume):
    method __init__ (line 176) | def __init__(
    method intersect (line 192) | def intersect(
  class SphericalVolume (line 221) | class SphericalVolume(MetaModule, Volume):
    method __init__ (line 227) | def __init__(
    method intersect (line 246) | def intersect(

FILE: shap_e/rendering/blender/blender_script.py
  function clear_scene (line 29) | def clear_scene():
  function clear_lights (line 34) | def clear_lights():
  function import_model (line 42) | def import_model(path):
  function scene_root_objects (line 62) | def scene_root_objects():
  function scene_bbox (line 68) | def scene_bbox(single_obj=None, ignore_matrix=False):
  function scene_meshes (line 85) | def scene_meshes():
  function normalize_scene (line 91) | def normalize_scene():
  function create_camera (line 119) | def create_camera():
  function set_camera (line 127) | def set_camera(direction, camera_dist=2.0):
  function randomize_camera (line 138) | def randomize_camera(camera_dist=2.0):
  function pan_camera (line 143) | def pan_camera(time, axis="Z", camera_dist=2.0, elevation=0.1):
  function place_camera (line 155) | def place_camera(time, camera_pose_mode="random", camera_dist_min=2.0, c...
  function create_light (line 167) | def create_light(location, energy=1.0, angle=0.5 * math.pi / 180):
  function create_random_lights (line 183) | def create_random_lights(count=4, distance=2.0, energy=1.5):
  function create_camera_light (line 189) | def create_camera_light():
  function create_uniform_light (line 194) | def create_uniform_light(backend):
  function create_vertex_color_shaders (line 203) | def create_vertex_color_shaders():
  function create_default_materials (line 249) | def create_default_materials():
  function find_materials (line 258) | def find_materials():
  function delete_all_materials (line 268) | def delete_all_materials():
  function setup_material_extraction_shaders (line 275) | def setup_material_extraction_shaders(capturing_material_alpha: bool):
  function setup_material_extraction_shader_for_material (line 291) | def setup_material_extraction_shader_for_material(mat, capturing_materia...
  function get_socket_value (line 346) | def get_socket_value(tree, socket):
  function clear_socket_input (line 356) | def clear_socket_input(tree, socket):
  function set_socket_value (line 362) | def set_socket_value(tree, socket, socket_and_default):
  function setup_nodes (line 374) | def setup_nodes(output_path, capturing_material_alpha: bool = False, bas...
  function render_scene (line 469) | def render_scene(output_path, fast_mode: bool, extract_material: bool, b...
  function scene_fov (line 538) | def scene_fov():
  function write_camera_metadata (line 550) | def write_camera_metadata(path):
  function save_rendering_dataset (line 571) | def save_rendering_dataset(
  function main (line 638) | def main():

FILE: shap_e/rendering/blender/render.py
  function render_model (line 18) | def render_model(
  function render_mesh (line 100) | def render_mesh(
  function _combine_rgba (line 119) | def _combine_rgba(out_dir: str):
  function _blender_binary_path (line 134) | def _blender_binary_path() -> str:

FILE: shap_e/rendering/blender/view_data.py
  class BlenderViewData (line 12) | class BlenderViewData(ViewData):
    method __init__ (line 17) | def __init__(self, f_obj: BinaryIO):
    method num_views (line 35) | def num_views(self) -> int:
    method channel_names (line 39) | def channel_names(self) -> List[str]:
    method load_view (line 42) | def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, ...
    method camera (line 73) | def camera(self, index: int, width: int, height: int) -> ProjectiveCam...

FILE: shap_e/rendering/mc.py
  function marching_cubes (line 11) | def marching_cubes(
  function _create_flat_edge_indices (line 124) | def _create_flat_edge_indices(
  class McLookupTable (line 206) | class McLookupTable:
  function _lookup_table (line 229) | def _lookup_table(device: torch.device) -> McLookupTable:

FILE: shap_e/rendering/mesh.py
  class TriMesh (line 11) | class TriMesh:
    method load (line 30) | def load(cls, f: Union[str, BinaryIO]) -> "TriMesh":
    method save (line 58) | def save(self, f: Union[str, BinaryIO]):
    method has_vertex_colors (line 75) | def has_vertex_colors(self) -> bool:
    method write_ply (line 78) | def write_ply(self, raw_f: BinaryIO):
    method write_obj (line 90) | def write_obj(self, raw_f: BinaryIO):

FILE: shap_e/rendering/ply_util.py
  function write_ply (line 9) | def write_ply(

FILE: shap_e/rendering/point_cloud.py
  function preprocess (line 16) | def preprocess(data, channel):
  class PointCloud (line 23) | class PointCloud:
    method from_rgbd (line 36) | def from_rgbd(cls, vd: ViewData, num_views: Optional[int] = None) -> "...
    method load (line 95) | def load(cls, f: Union[str, BinaryIO]) -> "PointCloud":
    method save (line 110) | def save(self, f: Union[str, BinaryIO]):
    method write_ply (line 120) | def write_ply(self, raw_f: BinaryIO):
    method random_sample (line 131) | def random_sample(self, num_points: int, **subsample_kwargs) -> "Point...
    method farthest_point_sample (line 145) | def farthest_point_sample(
    method subsample (line 189) | def subsample(self, indices: np.ndarray, average_neighbors: bool = Fal...
    method select_channels (line 213) | def select_channels(self, channel_names: List[str]) -> np.ndarray:
    method nearest_points (line 217) | def nearest_points(self, points: np.ndarray, batch_size: int = 16384) ...
    method combine (line 236) | def combine(self, other: "PointCloud") -> "PointCloud":

FILE: shap_e/rendering/pytorch3d_util.py
  function render_images (line 33) | def render_images(
  function _deconstruct_tensor_props (line 114) | def _deconstruct_tensor_props(
  function convert_meshes (line 144) | def convert_meshes(raw_meshes: Sequence[TorchMesh], default_brightness=0...
  function convert_cameras (line 165) | def convert_cameras(
  function convert_cameras_torch (line 188) | def convert_cameras_torch(
  function blender_uniform_lights (line 207) | def blender_uniform_lights(
  class BidirectionalLights (line 233) | class BidirectionalLights(DirectionalLights):
    method diffuse (line 239) | def diffuse(self, normals, points=None) -> torch.Tensor:
    method specular (line 244) | def specular(self, normals, points, camera_position, shininess) -> tor...

FILE: shap_e/rendering/raycast/_utils.py
  function normalize (line 4) | def normalize(v: torch.Tensor) -> torch.Tensor:
  function cross_product (line 8) | def cross_product(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:

FILE: shap_e/rendering/raycast/cast.py
  function cast_camera (line 12) | def cast_camera(
  function cast_rays (line 30) | def cast_rays(rays: Rays, mesh: TriMesh, checkpoint: bool = False) -> Ra...
  class RayCollisionFunction (line 96) | class RayCollisionFunction(torch.autograd.Function):
    method forward (line 98) | def forward(
    method backward (line 111) | def backward(

FILE: shap_e/rendering/raycast/render.py
  function render_diffuse_mesh (line 16) | def render_diffuse_mesh(

FILE: shap_e/rendering/raycast/types.py
  class Rays (line 13) | class Rays:
    method normalized_directions (line 21) | def normalized_directions(self) -> torch.Tensor:
  class RayCollisions (line 26) | class RayCollisions:
    method collect (line 38) | def collect(cls, it: Iterable["RayCollisions"]) -> "RayCollisions":
  class TriMesh (line 57) | class TriMesh:
    method normals (line 63) | def normals(self) -> torch.Tensor:
    method from_numpy (line 74) | def from_numpy(cls, x: shap_e.rendering.mesh.TriMesh) -> "TriMesh":
    method to (line 86) | def to(self, *args, **kwargs) -> "TriMesh":

FILE: shap_e/rendering/torch_mesh.py
  class TorchMesh (line 10) | class TorchMesh:
    method tri_mesh (line 25) | def tri_mesh(self) -> TriMesh:

FILE: shap_e/rendering/view_data.py
  class Camera (line 9) | class Camera(ABC):
    method image_coords (line 15) | def image_coords(self) -> np.ndarray:
    method camera_rays (line 21) | def camera_rays(self, coords: np.ndarray) -> np.ndarray:
    method depth_directions (line 31) | def depth_directions(self, coords: np.ndarray) -> np.ndarray:
    method center_crop (line 46) | def center_crop(self) -> "Camera":
    method resize_image (line 53) | def resize_image(self, width: int, height: int) -> "Camera":
    method scale_scene (line 60) | def scale_scene(self, factor: float) -> "Camera":
  class ProjectiveCamera (line 68) | class ProjectiveCamera(Camera):
    method image_coords (line 86) | def image_coords(self) -> np.ndarray:
    method camera_rays (line 91) | def camera_rays(self, coords: np.ndarray) -> np.ndarray:
    method depth_directions (line 98) | def depth_directions(self, coords: np.ndarray) -> np.ndarray:
    method resize_image (line 101) | def resize_image(self, width: int, height: int) -> "ProjectiveCamera":
    method center_crop (line 117) | def center_crop(self) -> "ProjectiveCamera":
    method scale_scene (line 134) | def scale_scene(self, factor: float) -> "ProjectiveCamera":
  class ViewData (line 151) | class ViewData(ABC):
    method num_views (line 161) | def num_views(self) -> int:
    method channel_names (line 168) | def channel_names(self) -> List[str]:
    method load_view (line 177) | def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, ...
  class MemoryViewData (line 186) | class MemoryViewData(ViewData):
    method __init__ (line 191) | def __init__(self, channels: Dict[str, np.ndarray], cameras: List[Came...
    method num_views (line 197) | def num_views(self) -> int:
    method channel_names (line 201) | def channel_names(self) -> List[str]:
    method load_view (line 204) | def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, ...

FILE: shap_e/util/collections.py
  class AttrDict (line 8) | class AttrDict(OrderedDict[K, V], Generic[K, V]):
    method __init__ (line 18) | def __init__(self, *args, **kwargs):
    method __contains__ (line 28) | def __contains__(self, key):
    method __setitem__ (line 35) | def __setitem__(self, key, value):
    method __getitem__ (line 50) | def __getitem__(self, key):
    method all_keys (line 61) | def all_keys(
    method dumpable (line 75) | def dumpable(self, strip=True):
    method map (line 91) | def map(
    method __eq__ (line 113) | def __eq__(self, other):
    method combine (line 116) | def combine(

FILE: shap_e/util/data_util.py
  function load_or_create_multimodal_batch (line 19) | def load_or_create_multimodal_batch(
  function load_or_create_pc (line 85) | def load_or_create_pc(
  function load_or_create_multiview (line 129) | def load_or_create_multiview(
  function mv_to_pc (line 195) | def mv_to_pc(multiview: ViewData, random_sample_count: int, point_count:...
  function normalize_input_batch (line 215) | def normalize_input_batch(batch: AttrDict, *, pc_scale: float, color_sca...
  function process_depth (line 229) | def process_depth(depth_img: np.ndarray, image_size: int) -> np.ndarray:
  function process_image (line 235) | def process_image(

FILE: shap_e/util/image_util.py
  function center_crop (line 11) | def center_crop(
  function resize (line 31) | def resize(
  function get_alpha (line 78) | def get_alpha(img: Image.Image) -> Image.Image:
  function remove_alpha (line 91) | def remove_alpha(img: Image.Image, mode: str = "random") -> Image.Image:
  function _black_bg (line 119) | def _black_bg(h: int, w: int) -> np.ndarray:
  function _gray_bg (line 123) | def _gray_bg(h: int, w: int) -> np.ndarray:
  function _checker_bg (line 127) | def _checker_bg(h: int, w: int) -> np.ndarray:
  function _noise_bg (line 139) | def _noise_bg(h: int, w: int) -> np.ndarray:
  function load_image (line 143) | def load_image(image_path: str) -> Image.Image:
  function make_tile (line 150) | def make_tile(images: List[Union[np.ndarray, Image.Image]], columns=8) -...
  function round_up (line 169) | def round_up(n: int, b: int) -> int:

FILE: shap_e/util/io.py
  function read_config (line 11) | def read_config(path_or_file: Union[str, io.IOBase]) -> Any:
  function buffered_writer (line 28) | def buffered_writer(raw_f: BinaryIO) -> Iterator[io.BufferedIOBase]:

FILE: shap_e/util/notebooks.py
  function create_pan_cameras (line 16) | def create_pan_cameras(size: int, device: torch.device) -> Differentiabl...
  function decode_latent_images (line 47) | def decode_latent_images(
  function decode_latent_mesh (line 65) | def decode_latent_mesh(
  function gif_widget (line 79) | def gif_widget(images):
Condensed preview — 79 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (541K chars).
[
  {
    "path": ".gitignore",
    "chars": 35,
    "preview": "__pycache__/\n.DS_Store\n*.egg-info/\n"
  },
  {
    "path": "LICENSE",
    "chars": 1062,
    "preview": "MIT License\n\nCopyright (c) 2023 OpenAI\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof "
  },
  {
    "path": "README.md",
    "chars": 3252,
    "preview": "# Shap-E\n\nThis is the official code and model release for [Shap-E: Generating Conditional 3D Implicit Functions](https:/"
  },
  {
    "path": "model-card.md",
    "chars": 6363,
    "preview": "# Model Card: Shap-E\n\nThis is the official codebase for running the latent diffusion models described in [Shap-E: Genera"
  },
  {
    "path": "samples.md",
    "chars": 16044,
    "preview": "# Samples\n\nHere is a collection of prompts and four random text-conditional samples for each prompt. Samples are rendere"
  },
  {
    "path": "setup.py",
    "chars": 801,
    "preview": "from setuptools import setup\n\nsetup(\n    name=\"shap-e\",\n    packages=[\n        \"shap_e\",\n        \"shap_e.diffusion\",\n   "
  },
  {
    "path": "shap_e/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "shap_e/diffusion/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "shap_e/diffusion/gaussian_diffusion.py",
    "chars": 44427,
    "preview": "\"\"\"\nBased on https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py\n\"\"\"\n\nimport math\nfrom"
  },
  {
    "path": "shap_e/diffusion/k_diffusion.py",
    "chars": 11095,
    "preview": "\"\"\"\nBased on: https://github.com/crowsonkb/k-diffusion\n\nCopyright (c) 2022 Katherine Crowson\n\nPermission is hereby grant"
  },
  {
    "path": "shap_e/diffusion/sample.py",
    "chars": 2871,
    "preview": "from typing import Any, Callable, Dict, Optional\n\nimport torch\nimport torch.nn as nn\n\nfrom .gaussian_diffusion import Ga"
  },
  {
    "path": "shap_e/examples/encode_model.ipynb",
    "chars": 2360,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n "
  },
  {
    "path": "shap_e/examples/example_data/cactus/material.mtl",
    "chars": 1095,
    "preview": "newmtl mat0\nKa 0.0000 0.7000 0.0000\nKd 0.0000 0.7000 0.0000\nKs 0.0000 0.0000 0.0000\nnewmtl mat1\nKa 0.6600 0.4400 0.2000\n"
  },
  {
    "path": "shap_e/examples/sample_image_to_3d.ipynb",
    "chars": 2920,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"964ccced\",\n   \"metadata\": {},\n   \"output"
  },
  {
    "path": "shap_e/examples/sample_text_to_3d.ipynb",
    "chars": 3240,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"964ccced\",\n   \"metadata\": {},\n   \"output"
  },
  {
    "path": "shap_e/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "shap_e/models/configs.py",
    "chars": 7485,
    "preview": "from typing import Any, Dict, Union\n\nimport blobfile as bf\nimport torch\nimport torch.nn as nn\nimport yaml\n\nfrom shap_e.m"
  },
  {
    "path": "shap_e/models/download.py",
    "chars": 5837,
    "preview": "\"\"\"\nAdapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/do"
  },
  {
    "path": "shap_e/models/generation/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "shap_e/models/generation/latent_diffusion.py",
    "chars": 1043,
    "preview": "from typing import Any, Dict\n\nimport torch\nimport torch.nn as nn\n\n\nclass SplitVectorDiffusion(nn.Module):\n    def __init"
  },
  {
    "path": "shap_e/models/generation/perceiver.py",
    "chars": 7551,
    "preview": "import math\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\n\nfrom shap_e.models.nn.checkpoint import che"
  },
  {
    "path": "shap_e/models/generation/pooled_mlp.py",
    "chars": 2436,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom .util import timestep_embedding\n\n\nclass PooledMLP(nn.Module):\n    def __init__("
  },
  {
    "path": "shap_e/models/generation/pretrained_clip.py",
    "chars": 9845,
    "preview": "from typing import Iterable, List, Optional, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom PIL impor"
  },
  {
    "path": "shap_e/models/generation/transformer.py",
    "chars": 16928,
    "preview": "import math\nfrom typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple\n\nimport torch\nimport torch.nn as nn\n"
  },
  {
    "path": "shap_e/models/generation/util.py",
    "chars": 868,
    "preview": "import math\n\nimport torch\n\n\ndef timestep_embedding(timesteps, dim, max_period=10000):\n    \"\"\"\n    Create sinusoidal time"
  },
  {
    "path": "shap_e/models/nerf/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "shap_e/models/nerf/model.py",
    "chars": 8080,
    "preview": "from abc import ABC, abstractmethod\nfrom functools import partial\nfrom typing import Any, Dict, Optional, Tuple\n\nimport "
  },
  {
    "path": "shap_e/models/nerf/ray.py",
    "chars": 19663,
    "preview": "from abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import A"
  },
  {
    "path": "shap_e/models/nerf/renderer.py",
    "chars": 11180,
    "preview": "from functools import partial\nfrom typing import Any, Dict, Optional\n\nimport torch\n\nfrom shap_e.models.nn.meta import su"
  },
  {
    "path": "shap_e/models/nerstf/mlp.py",
    "chars": 6206,
    "preview": "from typing import Any, Dict, Optional, Tuple\n\nimport torch\n\nfrom shap_e.models.nn.ops import get_act\nfrom shap_e.models"
  },
  {
    "path": "shap_e/models/nerstf/renderer.py",
    "chars": 10225,
    "preview": "from functools import partial\nfrom typing import Any, Dict, Optional, Sequence, Tuple, Union\n\nimport torch\n\nfrom shap_e."
  },
  {
    "path": "shap_e/models/nn/__init__.py",
    "chars": 39,
    "preview": "from .meta import *\nfrom .ops import *\n"
  },
  {
    "path": "shap_e/models/nn/camera.py",
    "chars": 6478,
    "preview": "from abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple, Union\n\nimport "
  },
  {
    "path": "shap_e/models/nn/checkpoint.py",
    "chars": 4241,
    "preview": "from typing import Callable, Iterable, Sequence, Union\n\nimport torch\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n"
  },
  {
    "path": "shap_e/models/nn/encoding.py",
    "chars": 18642,
    "preview": "import math\nfrom functools import lru_cache\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\n\n\ndef encode"
  },
  {
    "path": "shap_e/models/nn/meta.py",
    "chars": 8141,
    "preview": "\"\"\"\nMeta-learning modules based on: https://github.com/tristandeleu/pytorch-meta\n\nMIT License\n\nCopyright (c) 2019-2020 T"
  },
  {
    "path": "shap_e/models/nn/ops.py",
    "chars": 12784,
    "preview": "import math\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimpor"
  },
  {
    "path": "shap_e/models/nn/pointnet2_utils.py",
    "chars": 12931,
    "preview": "\"\"\"\nBased on https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet2_utils.py\n\nMIT License\n\nCo"
  },
  {
    "path": "shap_e/models/nn/utils.py",
    "chars": 1069,
    "preview": "from typing import Iterable, Union\n\nimport numpy as np\nimport torch\n\nArrayType = Union[np.ndarray, Iterable[int], torch."
  },
  {
    "path": "shap_e/models/query.py",
    "chars": 893,
    "preview": "from dataclasses import dataclass\nfrom typing import Callable, Optional\n\nimport torch\n\n\n@dataclass\nclass Query:\n    # Bo"
  },
  {
    "path": "shap_e/models/renderer.py",
    "chars": 8989,
    "preview": "from abc import abstractmethod\nfrom typing import Callable, Dict, List, Optional, Tuple\n\nimport numpy as np\nimport torch"
  },
  {
    "path": "shap_e/models/stf/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "shap_e/models/stf/base.py",
    "chars": 1575,
    "preview": "from abc import ABC, abstractmethod\nfrom typing import Any, Dict, Optional\n\nimport torch\n\nfrom shap_e.models.query impor"
  },
  {
    "path": "shap_e/models/stf/mlp.py",
    "chars": 7451,
    "preview": "from functools import partial\nfrom typing import Any, Dict, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\n\nfrom sh"
  },
  {
    "path": "shap_e/models/stf/renderer.py",
    "chars": 18620,
    "preview": "import warnings\nfrom abc import ABC, abstractmethod\nfrom functools import partial\nfrom typing import Any, Callable, Dict"
  },
  {
    "path": "shap_e/models/transmitter/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "shap_e/models/transmitter/base.py",
    "chars": 6787,
    "preview": "from abc import ABC, abstractmethod\nfrom typing import Any, Dict, Optional, Tuple\n\nimport torch.nn as nn\nfrom torch impo"
  },
  {
    "path": "shap_e/models/transmitter/bottleneck.py",
    "chars": 4250,
    "preview": "from abc import ABC, abstractmethod\nfrom typing import Any, Dict, Optional\n\nimport numpy as np\nimport torch.nn as nn\nfro"
  },
  {
    "path": "shap_e/models/transmitter/channels_encoder.py",
    "chars": 34486,
    "preview": "from abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import A"
  },
  {
    "path": "shap_e/models/transmitter/multiview_encoder.py",
    "chars": 7150,
    "preview": "from typing import Any, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport"
  },
  {
    "path": "shap_e/models/transmitter/params_proj.py",
    "chars": 6877,
    "preview": "import math\nfrom abc import ABC, abstractmethod\nfrom collections import OrderedDict\nfrom typing import Any, Dict, Option"
  },
  {
    "path": "shap_e/models/transmitter/pc_encoder.py",
    "chars": 15280,
    "preview": "from abc import abstractmethod\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple, Union\n\nimport numpy as np\n"
  },
  {
    "path": "shap_e/models/volume.py",
    "chars": 9057,
    "preview": "from abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Dict, Optional, Tuple\n\nimport t"
  },
  {
    "path": "shap_e/rendering/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "shap_e/rendering/_mc_table.py",
    "chars": 20716,
    "preview": "# Treat a cube as a bitmap, and create the index into this array in order of\n# ZYX (note Z is the most significant digit"
  },
  {
    "path": "shap_e/rendering/blender/__init__.py",
    "chars": 132,
    "preview": "from .render import render_mesh, render_model\nfrom .view_data import BlenderViewData\n\n__all__ = [\"BlenderViewData\", \"ren"
  },
  {
    "path": "shap_e/rendering/blender/blender_script.py",
    "chars": 26144,
    "preview": "\"\"\"\nScript to run within blender.\n\nProvide arguments after `--`.\nFor example: `blender -b -P blender_script.py -- --help"
  },
  {
    "path": "shap_e/rendering/blender/constants.py",
    "chars": 116,
    "preview": "UNIFORM_LIGHT_DIRECTION = [0.09387503, -0.63953443, -0.7630093]\nBASIC_AMBIENT_COLOR = 0.3\nBASIC_DIFFUSE_COLOR = 0.7\n"
  },
  {
    "path": "shap_e/rendering/blender/render.py",
    "chars": 4777,
    "preview": "import os\nimport platform\nimport subprocess\nimport tempfile\nimport zipfile\n\nimport blobfile as bf\nimport numpy as np\nfro"
  },
  {
    "path": "shap_e/rendering/blender/view_data.py",
    "chars": 3109,
    "preview": "import itertools\nimport json\nimport zipfile\nfrom typing import BinaryIO, List, Tuple\n\nimport numpy as np\nfrom PIL import"
  },
  {
    "path": "shap_e/rendering/mc.py",
    "chars": 10027,
    "preview": "from dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import Tuple\n\nimport torch\n\nfrom ._mc_tabl"
  },
  {
    "path": "shap_e/rendering/mesh.py",
    "chars": 3517,
    "preview": "from dataclasses import dataclass, field\nfrom typing import BinaryIO, Dict, Optional, Union\n\nimport blobfile as bf\nimpor"
  },
  {
    "path": "shap_e/rendering/ply_util.py",
    "chars": 1921,
    "preview": "import struct\nfrom typing import BinaryIO, Optional\n\nimport numpy as np\n\nfrom shap_e.util.io import buffered_writer\n\n\nde"
  },
  {
    "path": "shap_e/rendering/point_cloud.py",
    "chars": 9474,
    "preview": "import random\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import BinaryIO, Dict, L"
  },
  {
    "path": "shap_e/rendering/pytorch3d_util.py",
    "chars": 8261,
    "preview": "import copy\nimport inspect\nfrom typing import Any, Callable, List, Sequence, Tuple, Union\n\nimport numpy as np\nimport tor"
  },
  {
    "path": "shap_e/rendering/raycast/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "shap_e/rendering/raycast/_utils.py",
    "chars": 453,
    "preview": "import torch\n\n\ndef normalize(v: torch.Tensor) -> torch.Tensor:\n    return v / torch.linalg.norm(v, dim=-1, keepdim=True)"
  },
  {
    "path": "shap_e/rendering/raycast/cast.py",
    "chars": 5032,
    "preview": "from typing import Iterator, Optional, Tuple\n\nimport numpy as np\nimport torch\n\nfrom shap_e.rendering.view_data import Pr"
  },
  {
    "path": "shap_e/rendering/raycast/render.py",
    "chars": 2001,
    "preview": "from typing import Optional, Sequence\n\nimport torch\n\nfrom shap_e.rendering.blender.constants import (\n    BASIC_AMBIENT_"
  },
  {
    "path": "shap_e/rendering/raycast/types.py",
    "chars": 2822,
    "preview": "from dataclasses import dataclass\nfrom typing import Iterable, Optional\n\nimport numpy as np\nimport torch\n\nimport shap_e."
  },
  {
    "path": "shap_e/rendering/torch_mesh.py",
    "chars": 1242,
    "preview": "from dataclasses import dataclass, field\nfrom typing import Dict, Optional\n\nimport torch\n\nfrom .mesh import TriMesh\n\n\n@d"
  },
  {
    "path": "shap_e/rendering/view_data.py",
    "chars": 6597,
    "preview": "from abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Tuple\n\nimport numpy"
  },
  {
    "path": "shap_e/util/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "shap_e/util/collections.py",
    "chars": 4777,
    "preview": "from collections import OrderedDict\nfrom typing import Any, Callable, Dict, List, Optional\nfrom typing import OrderedDic"
  },
  {
    "path": "shap_e/util/data_util.py",
    "chars": 8600,
    "preview": "import tempfile\nfrom contextlib import contextmanager\nfrom typing import Iterator, Optional, Union\n\nimport blobfile as b"
  },
  {
    "path": "shap_e/util/image_util.py",
    "chars": 5218,
    "preview": "import random\nfrom typing import Any, List, Optional, Union\n\nimport blobfile as bf\nimport numpy as np\nimport torch\nimpor"
  },
  {
    "path": "shap_e/util/io.py",
    "chars": 950,
    "preview": "import io\nfrom contextlib import contextmanager\nfrom typing import Any, BinaryIO, Iterator, Union\n\nimport blobfile as bf"
  },
  {
    "path": "shap_e/util/notebooks.py",
    "chars": 2839,
    "preview": "import base64\nimport io\nfrom typing import Union\n\nimport ipywidgets as widgets\nimport numpy as np\nimport torch\nfrom PIL "
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the openai/shap-e GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 79 files (17.7 MB), approximately 141.3k tokens, and a symbol index with 656 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!