[
  {
    "path": ".gitignore",
    "content": "__pycache__/\n.DS_Store\n*.egg-info/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 OpenAI\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "# Shap-E\n\nThis is the official code and model release for [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463).\n\n * See [Usage](#usage) for guidance on how to use this repository.\n * See [Samples](#samples) for examples of what our text-conditional model can generate.\n\n# Samples\n\nHere are some highlighted samples from our text-conditional model. For random samples on selected prompts, see [samples.md](samples.md).\n\n<table>\n    <tbody>\n        <tr>\n            <td align=\"center\">\n                <img src=\"samples/a_chair_that_looks_like_an_avocado/2.gif\" alt=\"A chair that looks like an avocado\">\n            </td>\n            <td align=\"center\">\n                <img src=\"samples/an_airplane_that_looks_like_a_banana/3.gif\" alt=\"An airplane that looks like a banana\">\n            </td align=\"center\">\n            <td align=\"center\">\n                <img src=\"samples/a_spaceship/0.gif\" alt=\"A spaceship\">\n            </td>\n        </tr>\n        <tr>\n            <td align=\"center\">A chair that looks<br>like an avocado</td>\n            <td align=\"center\">An airplane that looks<br>like a banana</td>\n            <td align=\"center\">A spaceship</td>\n        </tr>\n        <tr>\n            <td align=\"center\">\n                <img src=\"samples/a_birthday_cupcake/3.gif\" alt=\"A birthday cupcake\">\n            </td>\n            <td align=\"center\">\n                <img src=\"samples/a_chair_that_looks_like_a_tree/2.gif\" alt=\"A chair that looks like a tree\">\n            </td>\n            <td align=\"center\">\n                <img src=\"samples/a_green_boot/3.gif\" alt=\"A green boot\">\n            </td>\n        </tr>\n        <tr>\n            <td align=\"center\">A birthday cupcake</td>\n            <td align=\"center\">A chair that looks<br>like a tree</td>\n            <td align=\"center\">A green boot</td>\n        </tr>\n        <tr>\n            <td align=\"center\">\n                <img src=\"samples/a_penguin/1.gif\" alt=\"A penguin\">\n            </td>\n            <td align=\"center\">\n                <img src=\"samples/ube_ice_cream_cone/3.gif\" alt=\"Ube ice cream cone\">\n            </td>\n            <td align=\"center\">\n                <img src=\"samples/a_bowl_of_vegetables/2.gif\" alt=\"A bowl of vegetables\">\n            </td>\n        </tr>\n        <tr>\n            <td align=\"center\">A penguin</td>\n            <td align=\"center\">Ube ice cream cone</td>\n            <td align=\"center\">A bowl of vegetables</td>\n        </tr>\n    </tbody>\n<table>\n\n# Usage\n\nInstall with `pip install -e .`.\n\nTo get started with examples, see the following notebooks:\n\n* [sample_text_to_3d.ipynb](shap_e/examples/sample_text_to_3d.ipynb) - sample a 3D model, conditioned on a text prompt.\n* [sample_image_to_3d.ipynb](shap_e/examples/sample_image_to_3d.ipynb) - sample a 3D model, conditioned on a synthetic view image. To get the best result, you should remove background from the input image.\n* [encode_model.ipynb](shap_e/examples/encode_model.ipynb) - loads a 3D model or a trimesh, creates a batch of multiview renders and a point cloud, encodes them into a latent, and renders it back. For this to work, install Blender version 3.3.1 or higher, and set the environment variable `BLENDER_PATH` to the path of the Blender executable.\n"
  },
  {
    "path": "model-card.md",
    "content": "# Model Card: Shap-E\n\nThis is the official codebase for running the latent diffusion models described in [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463). These models were trained and released by OpenAI. Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about how the models were trained and evaluated.\n\n# Model Details\n\nShap-E includes two kinds of models: an encoder and a latent diffusion model.\n\n 1. **The encoder** converts 3D assets into the parameters of small neural networks which represent the 3D shape and texture as an implicit function. The resulting implicit function can be rendered from arbitrary viewpoints or imported into downstream applications as a mesh.\n 2. **The latent diffusion model** generates novel implicit functions conditioned on either images or text descriptions. As above, these samples can be rendered or exported as a mesh. Specifically, these models produce latents which must be linearly projected to get the final implicit function parameters. The final projection layer of the encoder is used for this purpose.\n\nLike [Point-E](https://github.com/openai/point-e/blob/main/model-card.md), Shap-E can often generate coherent 3D objects when conditioned on a rendering from a single viewpoint. When conditioned on text prompts directly, Shap-E is also often capable of producing recognizable objects, although it sometimes struggles to combine multiple objects or concepts.\n\nSamples from Shap-E are typically lower fidelity than professional 3D assets and often have rough edges, holes, or blurry surface textures.\n\n# Model Date\n\nApril 2023\n\n# Model Versions\n\nThe following model checkpoints are available in this repository:\n\n * `transmitter` - the encoder and corresponding projection layers for converting encoder outputs into implicit neural representations.\n * `decoder` - just the final projection layer component of `transmitter`. This is a smaller checkpoint than `transmitter` since it does not include parameters for encoding 3D assets. This is the minimum required model to convert diffusion outputs into implicit neural representations.\n * `text300M` - the text-conditional latent diffusion model.\n * `image300M` - the image-conditional latent diffusion model.\n\n# Paper & Samples\n\n[Paper link](https://arxiv.org/abs/2305.02463) / [Samples](samples.md)\n\n# Training data\n\nThe encoder and image-conditional diffusion models are trained on the [same dataset as Point-E](https://github.com/openai/point-e/blob/main/model-card.md#training-data). However, a few changes to the post-processing were made:\n\n * We rendered 60 views (instead of 20) of each model when computing point clouds, to avoid small cracks.\n * We produced 16K points in each point cloud instead of 4K.\n * We simplified the lighting and material setup to only include diffuse materials.\n\nFor our text-conditional diffusion model, we expanded our dataset with roughly a million more 3D assets. Additionally, we collected 120K captions from human annotators for a high-quality subset of our 3D assets.\n\n# Evaluated Use\n\nWe release these models with the intention of furthering progress in the field of generative modeling. However, we acknowledge that our models have certain constraints and biases, which is why we advise against employing them for commercial purposes at this time. We are aware that the utilization of our models could extend to areas beyond our expectations, and defining specific criteria for what is considered suitable for \"research\" purposes presents a challenge. Specifically, we advise caution when using these models in contexts that demand high accuracy, where minor imperfections in the generated 3D assets could have adverse consequences.\n\nSpecifically, these models have been evaluated on the following tasks for research purposes:\n\n * Generating 3D renderings or meshes conditioned on single, synthetic images\n * Generating 3D renderings or meshes conditioned on text descriptions\n\n# Performance & Limitations\n\nOur image-conditional model has only been evaluated on a highly specific distribution of synthetic renderings. Even in these cases, the model still sometimes fails to infer the correct occluded parts of an object or produces geometry that is inconsistent with the given rendered images. These failure modes are similar to those of Point-E. The resulting 3D assets often have rough edges, holes, or blurry surface textures.\n\nOur text-conditional model can also produce a somewhat large and diverse vocabulary of objects. This model is often capable of producing objects with requested colors and textures, and sometimes even combining multiple objects. However, it often fails for more complex prompts that require placing multiple objects in a scene or binding attributes to objects. It also typically fails to produce a desired number of objects when a certain quantity is requested.\n\nWe find that our text-conditional model can sometimes produce samples which reflect gender biases. For example, samples for \"a nurse\" typically have a different body shape than samples for \"a doctor\". When probing for potential misuses, we also found that our text-conditional model is capable of producing 3D assets related to violence, such as guns or tanks. However, the resulting quality of these samples is poor enough that they look unrealistic and toy like.\n\nAs with Point-E, our dataset consists of many simple, cartoonish 3D assets, and our generative models are prone to imitating this style.\n\nWe believe our models will have many potential use cases. For example, our text-conditional model could enable users to quickly produce many 3D assets, allowing for rapid prototyping for computer graphics applications or 3D printing.\n\nThe use of 3D printing in concert with our models could potentially be harmful, for example if used to create dangerous objects or fabricate tools or parts that are deployed without external validation.\n\nGenerative 3D models share many challenges and constraints with image generation models. This includes the tendency to generate content that may be biased or detrimental, as well as the potential for dual-use applications. As the capabilities of these models evolve, further investigation is required to gain a clearer understanding of how these risks manifest.\n"
  },
  {
    "path": "samples.md",
    "content": "# Samples\n\nHere is a collection of prompts and four random text-conditional samples for each prompt. Samples are rendered at 128x128 resolution with NeRF.\n\n<table><tbody><tr><th align=\"center\">Prompt</th><th></th><th></th><th></th><th></th><tr><td align=\"center\">a penguin</td><td align=\"center\"><img src=\"samples/a_penguin/0.gif\" alt=\"a penguin\"></td><td align=\"center\"><img src=\"samples/a_penguin/1.gif\" alt=\"a penguin\"></td><td align=\"center\"><img src=\"samples/a_penguin/2.gif\" alt=\"a penguin\"></td><td align=\"center\"><img src=\"samples/a_penguin/3.gif\" alt=\"a penguin\"></td></tr><tr><td align=\"center\">a campfire</td><td align=\"center\"><img src=\"samples/a_campfire/0.gif\" alt=\"a campfire\"></td><td align=\"center\"><img src=\"samples/a_campfire/1.gif\" alt=\"a campfire\"></td><td align=\"center\"><img src=\"samples/a_campfire/2.gif\" alt=\"a campfire\"></td><td align=\"center\"><img src=\"samples/a_campfire/3.gif\" alt=\"a campfire\"></td></tr><tr><td align=\"center\">an elephant</td><td align=\"center\"><img src=\"samples/an_elephant/0.gif\" alt=\"an elephant\"></td><td align=\"center\"><img src=\"samples/an_elephant/1.gif\" alt=\"an elephant\"></td><td align=\"center\"><img src=\"samples/an_elephant/2.gif\" alt=\"an elephant\"></td><td align=\"center\"><img src=\"samples/an_elephant/3.gif\" alt=\"an elephant\"></td></tr><tr><td align=\"center\">a donut with pink icing</td><td align=\"center\"><img src=\"samples/a_donut_with_pink_icing/0.gif\" alt=\"a donut with pink icing\"></td><td align=\"center\"><img src=\"samples/a_donut_with_pink_icing/1.gif\" alt=\"a donut with pink icing\"></td><td align=\"center\"><img src=\"samples/a_donut_with_pink_icing/2.gif\" alt=\"a donut with pink icing\"></td><td align=\"center\"><img src=\"samples/a_donut_with_pink_icing/3.gif\" alt=\"a donut with pink icing\"></td></tr><tr><td align=\"center\">a voxelized dog</td><td align=\"center\"><img src=\"samples/a_voxelized_dog/0.gif\" alt=\"a voxelized dog\"></td><td align=\"center\"><img src=\"samples/a_voxelized_dog/1.gif\" alt=\"a voxelized dog\"></td><td align=\"center\"><img src=\"samples/a_voxelized_dog/2.gif\" alt=\"a voxelized dog\"></td><td align=\"center\"><img src=\"samples/a_voxelized_dog/3.gif\" alt=\"a voxelized dog\"></td></tr><tr><td align=\"center\">ube ice cream cone</td><td align=\"center\"><img src=\"samples/ube_ice_cream_cone/0.gif\" alt=\"ube ice cream cone\"></td><td align=\"center\"><img src=\"samples/ube_ice_cream_cone/1.gif\" alt=\"ube ice cream cone\"></td><td align=\"center\"><img src=\"samples/ube_ice_cream_cone/2.gif\" alt=\"ube ice cream cone\"></td><td align=\"center\"><img src=\"samples/ube_ice_cream_cone/3.gif\" alt=\"ube ice cream cone\"></td></tr><tr><td align=\"center\">a birthday cupcake</td><td align=\"center\"><img src=\"samples/a_birthday_cupcake/0.gif\" alt=\"a birthday cupcake\"></td><td align=\"center\"><img src=\"samples/a_birthday_cupcake/1.gif\" alt=\"a birthday cupcake\"></td><td align=\"center\"><img src=\"samples/a_birthday_cupcake/2.gif\" alt=\"a birthday cupcake\"></td><td align=\"center\"><img src=\"samples/a_birthday_cupcake/3.gif\" alt=\"a birthday cupcake\"></td></tr><tr><td align=\"center\">shepherds pie</td><td align=\"center\"><img src=\"samples/shepherds_pie/0.gif\" alt=\"shepherds pie\"></td><td align=\"center\"><img src=\"samples/shepherds_pie/1.gif\" alt=\"shepherds pie\"></td><td align=\"center\"><img src=\"samples/shepherds_pie/2.gif\" alt=\"shepherds pie\"></td><td align=\"center\"><img src=\"samples/shepherds_pie/3.gif\" alt=\"shepherds pie\"></td></tr><tr><td align=\"center\">a bowl of vegetables</td><td align=\"center\"><img src=\"samples/a_bowl_of_vegetables/0.gif\" alt=\"a bowl of vegetables\"></td><td align=\"center\"><img src=\"samples/a_bowl_of_vegetables/1.gif\" alt=\"a bowl of vegetables\"></td><td align=\"center\"><img src=\"samples/a_bowl_of_vegetables/2.gif\" alt=\"a bowl of vegetables\"></td><td align=\"center\"><img src=\"samples/a_bowl_of_vegetables/3.gif\" alt=\"a bowl of vegetables\"></td></tr><tr><td align=\"center\">a cheeseburger</td><td align=\"center\"><img src=\"samples/a_cheeseburger/0.gif\" alt=\"a cheeseburger\"></td><td align=\"center\"><img src=\"samples/a_cheeseburger/1.gif\" alt=\"a cheeseburger\"></td><td align=\"center\"><img src=\"samples/a_cheeseburger/2.gif\" alt=\"a cheeseburger\"></td><td align=\"center\"><img src=\"samples/a_cheeseburger/3.gif\" alt=\"a cheeseburger\"></td></tr><tr><td align=\"center\">a plate of mushy green peas</td><td align=\"center\"><img src=\"samples/a_plate_of_mushy_green_peas/0.gif\" alt=\"a plate of mushy green peas\"></td><td align=\"center\"><img src=\"samples/a_plate_of_mushy_green_peas/1.gif\" alt=\"a plate of mushy green peas\"></td><td align=\"center\"><img src=\"samples/a_plate_of_mushy_green_peas/2.gif\" alt=\"a plate of mushy green peas\"></td><td align=\"center\"><img src=\"samples/a_plate_of_mushy_green_peas/3.gif\" alt=\"a plate of mushy green peas\"></td></tr><tr><td align=\"center\">a traffic cone</td><td align=\"center\"><img src=\"samples/a_traffic_cone/0.gif\" alt=\"a traffic cone\"></td><td align=\"center\"><img src=\"samples/a_traffic_cone/1.gif\" alt=\"a traffic cone\"></td><td align=\"center\"><img src=\"samples/a_traffic_cone/2.gif\" alt=\"a traffic cone\"></td><td align=\"center\"><img src=\"samples/a_traffic_cone/3.gif\" alt=\"a traffic cone\"></td></tr><tr><td align=\"center\">a car that looks like an avocado</td><td align=\"center\"><img src=\"samples/a_car_that_looks_like_an_avocado/0.gif\" alt=\"a car that looks like an avocado\"></td><td align=\"center\"><img src=\"samples/a_car_that_looks_like_an_avocado/1.gif\" alt=\"a car that looks like an avocado\"></td><td align=\"center\"><img src=\"samples/a_car_that_looks_like_an_avocado/2.gif\" alt=\"a car that looks like an avocado\"></td><td align=\"center\"><img src=\"samples/a_car_that_looks_like_an_avocado/3.gif\" alt=\"a car that looks like an avocado\"></td></tr><tr><td align=\"center\">an airplane that looks like a banana</td><td align=\"center\"><img src=\"samples/an_airplane_that_looks_like_a_banana/0.gif\" alt=\"an airplane that looks like a banana\"></td><td align=\"center\"><img src=\"samples/an_airplane_that_looks_like_a_banana/1.gif\" alt=\"an airplane that looks like a banana\"></td><td align=\"center\"><img src=\"samples/an_airplane_that_looks_like_a_banana/2.gif\" alt=\"an airplane that looks like a banana\"></td><td align=\"center\"><img src=\"samples/an_airplane_that_looks_like_a_banana/3.gif\" alt=\"an airplane that looks like a banana\"></td></tr><tr><td align=\"center\">a stop sign</td><td align=\"center\"><img src=\"samples/a_stop_sign/0.gif\" alt=\"a stop sign\"></td><td align=\"center\"><img src=\"samples/a_stop_sign/1.gif\" alt=\"a stop sign\"></td><td align=\"center\"><img src=\"samples/a_stop_sign/2.gif\" alt=\"a stop sign\"></td><td align=\"center\"><img src=\"samples/a_stop_sign/3.gif\" alt=\"a stop sign\"></td></tr><tr><td align=\"center\">a spaceship</td><td align=\"center\"><img src=\"samples/a_spaceship/0.gif\" alt=\"a spaceship\"></td><td align=\"center\"><img src=\"samples/a_spaceship/1.gif\" alt=\"a spaceship\"></td><td align=\"center\"><img src=\"samples/a_spaceship/2.gif\" alt=\"a spaceship\"></td><td align=\"center\"><img src=\"samples/a_spaceship/3.gif\" alt=\"a spaceship\"></td></tr><tr><td align=\"center\">a race car</td><td align=\"center\"><img src=\"samples/a_race_car/0.gif\" alt=\"a race car\"></td><td align=\"center\"><img src=\"samples/a_race_car/1.gif\" alt=\"a race car\"></td><td align=\"center\"><img src=\"samples/a_race_car/2.gif\" alt=\"a race car\"></td><td align=\"center\"><img src=\"samples/a_race_car/3.gif\" alt=\"a race car\"></td></tr><tr><td align=\"center\">a schoolbus</td><td align=\"center\"><img src=\"samples/a_schoolbus/0.gif\" alt=\"a schoolbus\"></td><td align=\"center\"><img src=\"samples/a_schoolbus/1.gif\" alt=\"a schoolbus\"></td><td align=\"center\"><img src=\"samples/a_schoolbus/2.gif\" alt=\"a schoolbus\"></td><td align=\"center\"><img src=\"samples/a_schoolbus/3.gif\" alt=\"a schoolbus\"></td></tr><tr><td align=\"center\">a firetruck</td><td align=\"center\"><img src=\"samples/a_firetruck/0.gif\" alt=\"a firetruck\"></td><td align=\"center\"><img src=\"samples/a_firetruck/1.gif\" alt=\"a firetruck\"></td><td align=\"center\"><img src=\"samples/a_firetruck/2.gif\" alt=\"a firetruck\"></td><td align=\"center\"><img src=\"samples/a_firetruck/3.gif\" alt=\"a firetruck\"></td></tr><tr><td align=\"center\">a rusty old car</td><td align=\"center\"><img src=\"samples/a_rusty_old_car/0.gif\" alt=\"a rusty old car\"></td><td align=\"center\"><img src=\"samples/a_rusty_old_car/1.gif\" alt=\"a rusty old car\"></td><td align=\"center\"><img src=\"samples/a_rusty_old_car/2.gif\" alt=\"a rusty old car\"></td><td align=\"center\"><img src=\"samples/a_rusty_old_car/3.gif\" alt=\"a rusty old car\"></td></tr><tr><td align=\"center\">a fast car</td><td align=\"center\"><img src=\"samples/a_fast_car/0.gif\" alt=\"a fast car\"></td><td align=\"center\"><img src=\"samples/a_fast_car/1.gif\" alt=\"a fast car\"></td><td align=\"center\"><img src=\"samples/a_fast_car/2.gif\" alt=\"a fast car\"></td><td align=\"center\"><img src=\"samples/a_fast_car/3.gif\" alt=\"a fast car\"></td></tr><tr><td align=\"center\">a chair that looks like an avocado</td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_an_avocado/0.gif\" alt=\"a chair that looks like an avocado\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_an_avocado/1.gif\" alt=\"a chair that looks like an avocado\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_an_avocado/2.gif\" alt=\"a chair that looks like an avocado\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_an_avocado/3.gif\" alt=\"a chair that looks like an avocado\"></td></tr><tr><td align=\"center\">a chair that looks like fruit</td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_fruit/0.gif\" alt=\"a chair that looks like fruit\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_fruit/1.gif\" alt=\"a chair that looks like fruit\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_fruit/2.gif\" alt=\"a chair that looks like fruit\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_fruit/3.gif\" alt=\"a chair that looks like fruit\"></td></tr><tr><td align=\"center\">a chair that looks like a tree</td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_a_tree/0.gif\" alt=\"a chair that looks like a tree\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_a_tree/1.gif\" alt=\"a chair that looks like a tree\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_a_tree/2.gif\" alt=\"a chair that looks like a tree\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_a_tree/3.gif\" alt=\"a chair that looks like a tree\"></td></tr><tr><td align=\"center\">a chair that looks like a zebra</td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_a_zebra/0.gif\" alt=\"a chair that looks like a zebra\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_a_zebra/1.gif\" alt=\"a chair that looks like a zebra\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_a_zebra/2.gif\" alt=\"a chair that looks like a zebra\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_a_zebra/3.gif\" alt=\"a chair that looks like a zebra\"></td></tr><tr><td align=\"center\">a chair that looks like a swimming pool</td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_a_swimming_pool/0.gif\" alt=\"a chair that looks like a swimming pool\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_a_swimming_pool/1.gif\" alt=\"a chair that looks like a swimming pool\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_a_swimming_pool/2.gif\" alt=\"a chair that looks like a swimming pool\"></td><td align=\"center\"><img src=\"samples/a_chair_that_looks_like_a_swimming_pool/3.gif\" alt=\"a chair that looks like a swimming pool\"></td></tr><tr><td align=\"center\">the person is running</td><td align=\"center\"><img src=\"samples/the_person_is_running/0.gif\" alt=\"the person is running\"></td><td align=\"center\"><img src=\"samples/the_person_is_running/1.gif\" alt=\"the person is running\"></td><td align=\"center\"><img src=\"samples/the_person_is_running/2.gif\" alt=\"the person is running\"></td><td align=\"center\"><img src=\"samples/the_person_is_running/3.gif\" alt=\"the person is running\"></td></tr><tr><td align=\"center\">the person is sitting</td><td align=\"center\"><img src=\"samples/the_person_is_sitting/0.gif\" alt=\"the person is sitting\"></td><td align=\"center\"><img src=\"samples/the_person_is_sitting/1.gif\" alt=\"the person is sitting\"></td><td align=\"center\"><img src=\"samples/the_person_is_sitting/2.gif\" alt=\"the person is sitting\"></td><td align=\"center\"><img src=\"samples/the_person_is_sitting/3.gif\" alt=\"the person is sitting\"></td></tr><tr><td align=\"center\">the person is lying down</td><td align=\"center\"><img src=\"samples/the_person_is_lying_down/0.gif\" alt=\"the person is lying down\"></td><td align=\"center\"><img src=\"samples/the_person_is_lying_down/1.gif\" alt=\"the person is lying down\"></td><td align=\"center\"><img src=\"samples/the_person_is_lying_down/2.gif\" alt=\"the person is lying down\"></td><td align=\"center\"><img src=\"samples/the_person_is_lying_down/3.gif\" alt=\"the person is lying down\"></td></tr><tr><td align=\"center\">a person that looks like a zebra</td><td align=\"center\"><img src=\"samples/a_person_that_looks_like_a_zebra/0.gif\" alt=\"a person that looks like a zebra\"></td><td align=\"center\"><img src=\"samples/a_person_that_looks_like_a_zebra/1.gif\" alt=\"a person that looks like a zebra\"></td><td align=\"center\"><img src=\"samples/a_person_that_looks_like_a_zebra/2.gif\" alt=\"a person that looks like a zebra\"></td><td align=\"center\"><img src=\"samples/a_person_that_looks_like_a_zebra/3.gif\" alt=\"a person that looks like a zebra\"></td></tr><tr><td align=\"center\">a person that looks like a leopard</td><td align=\"center\"><img src=\"samples/a_person_that_looks_like_a_leopard/0.gif\" alt=\"a person that looks like a leopard\"></td><td align=\"center\"><img src=\"samples/a_person_that_looks_like_a_leopard/1.gif\" alt=\"a person that looks like a leopard\"></td><td align=\"center\"><img src=\"samples/a_person_that_looks_like_a_leopard/2.gif\" alt=\"a person that looks like a leopard\"></td><td align=\"center\"><img src=\"samples/a_person_that_looks_like_a_leopard/3.gif\" alt=\"a person that looks like a leopard\"></td></tr><tr><td align=\"center\">a pair of shorts</td><td align=\"center\"><img src=\"samples/a_pair_of_shorts/0.gif\" alt=\"a pair of shorts\"></td><td align=\"center\"><img src=\"samples/a_pair_of_shorts/1.gif\" alt=\"a pair of shorts\"></td><td align=\"center\"><img src=\"samples/a_pair_of_shorts/2.gif\" alt=\"a pair of shorts\"></td><td align=\"center\"><img src=\"samples/a_pair_of_shorts/3.gif\" alt=\"a pair of shorts\"></td></tr><tr><td align=\"center\">a designer dress</td><td align=\"center\"><img src=\"samples/a_designer_dress/0.gif\" alt=\"a designer dress\"></td><td align=\"center\"><img src=\"samples/a_designer_dress/1.gif\" alt=\"a designer dress\"></td><td align=\"center\"><img src=\"samples/a_designer_dress/2.gif\" alt=\"a designer dress\"></td><td align=\"center\"><img src=\"samples/a_designer_dress/3.gif\" alt=\"a designer dress\"></td></tr><tr><td align=\"center\">banana shoes</td><td align=\"center\"><img src=\"samples/banana_shoes/0.gif\" alt=\"banana shoes\"></td><td align=\"center\"><img src=\"samples/banana_shoes/1.gif\" alt=\"banana shoes\"></td><td align=\"center\"><img src=\"samples/banana_shoes/2.gif\" alt=\"banana shoes\"></td><td align=\"center\"><img src=\"samples/banana_shoes/3.gif\" alt=\"banana shoes\"></td></tr><tr><td align=\"center\">a green boot</td><td align=\"center\"><img src=\"samples/a_green_boot/0.gif\" alt=\"a green boot\"></td><td align=\"center\"><img src=\"samples/a_green_boot/1.gif\" alt=\"a green boot\"></td><td align=\"center\"><img src=\"samples/a_green_boot/2.gif\" alt=\"a green boot\"></td><td align=\"center\"><img src=\"samples/a_green_boot/3.gif\" alt=\"a green boot\"></td></tr><tr><td align=\"center\">a pair of sunglasses</td><td align=\"center\"><img src=\"samples/a_pair_of_sunglasses/0.gif\" alt=\"a pair of sunglasses\"></td><td align=\"center\"><img src=\"samples/a_pair_of_sunglasses/1.gif\" alt=\"a pair of sunglasses\"></td><td align=\"center\"><img src=\"samples/a_pair_of_sunglasses/2.gif\" alt=\"a pair of sunglasses\"></td><td align=\"center\"><img src=\"samples/a_pair_of_sunglasses/3.gif\" alt=\"a pair of sunglasses\"></td></tr></tbody></table>\n\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup\n\nsetup(\n    name=\"shap-e\",\n    packages=[\n        \"shap_e\",\n        \"shap_e.diffusion\",\n        \"shap_e.models\",\n        \"shap_e.models.generation\",\n        \"shap_e.models.nerf\",\n        \"shap_e.models.nerstf\",\n        \"shap_e.models.nn\",\n        \"shap_e.models.stf\",\n        \"shap_e.models.transmitter\",\n        \"shap_e.rendering\",\n        \"shap_e.rendering.blender\",\n        \"shap_e.rendering.raycast\",\n        \"shap_e.util\",\n    ],\n    install_requires=[\n        \"filelock\",\n        \"Pillow\",\n        \"torch\",\n        \"fire\",\n        \"humanize\",\n        \"requests\",\n        \"tqdm\",\n        \"matplotlib\",\n        \"scikit-image\",\n        \"scipy\",\n        \"numpy\",\n        \"blobfile\",\n        \"clip @ git+https://github.com/openai/CLIP.git\",\n    ],\n    author=\"OpenAI\",\n)\n"
  },
  {
    "path": "shap_e/__init__.py",
    "content": ""
  },
  {
    "path": "shap_e/diffusion/__init__.py",
    "content": ""
  },
  {
    "path": "shap_e/diffusion/gaussian_diffusion.py",
    "content": "\"\"\"\nBased on https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py\n\"\"\"\n\nimport math\nfrom typing import Any, Dict, Iterable, Optional, Sequence, Union\n\nimport blobfile as bf\nimport numpy as np\nimport torch as th\nimport yaml\n\n\ndef diffusion_from_config(config: Union[str, Dict[str, Any]]) -> \"GaussianDiffusion\":\n    if isinstance(config, str):\n        with bf.BlobFile(config, \"rb\") as f:\n            obj = yaml.load(f, Loader=yaml.SafeLoader)\n        return diffusion_from_config(obj)\n\n    schedule = config[\"schedule\"]\n    steps = config[\"timesteps\"]\n    respace = config.get(\"respacing\", None)\n    mean_type = config.get(\"mean_type\", \"epsilon\")\n    betas = get_named_beta_schedule(schedule, steps, **config.get(\"schedule_args\", {}))\n    channel_scales = config.get(\"channel_scales\", None)\n    channel_biases = config.get(\"channel_biases\", None)\n    if channel_scales is not None:\n        channel_scales = np.array(channel_scales)\n    if channel_biases is not None:\n        channel_biases = np.array(channel_biases)\n    kwargs = dict(\n        betas=betas,\n        model_mean_type=mean_type,\n        model_var_type=\"learned_range\",\n        loss_type=\"mse\",\n        channel_scales=channel_scales,\n        channel_biases=channel_biases,\n    )\n    if respace is None:\n        return GaussianDiffusion(**kwargs)\n    else:\n        return SpacedDiffusion(use_timesteps=space_timesteps(steps, respace), **kwargs)\n\n\ndef get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):\n    \"\"\"\n    This is the deprecated API for creating beta schedules.\n\n    See get_named_beta_schedule() for the new library of schedules.\n    \"\"\"\n    if beta_schedule == \"linear\":\n        betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)\n    else:\n        raise NotImplementedError(beta_schedule)\n    assert betas.shape == (num_diffusion_timesteps,)\n    return betas\n\n\ndef get_named_beta_schedule(schedule_name, num_diffusion_timesteps, **extra_args: float):\n    \"\"\"\n    Get a pre-defined beta schedule for the given name.\n\n    The beta schedule library consists of beta schedules which remain similar\n    in the limit of num_diffusion_timesteps.\n    Beta schedules may be added, but should not be removed or changed once\n    they are committed to maintain backwards compatibility.\n    \"\"\"\n    if schedule_name == \"linear\":\n        # Linear schedule from Ho et al, extended to work for any number of\n        # diffusion steps.\n        scale = 1000 / num_diffusion_timesteps\n        return get_beta_schedule(\n            \"linear\",\n            beta_start=scale * 0.0001,\n            beta_end=scale * 0.02,\n            num_diffusion_timesteps=num_diffusion_timesteps,\n        )\n    elif schedule_name == \"cosine\":\n        return betas_for_alpha_bar(\n            num_diffusion_timesteps,\n            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,\n        )\n    elif schedule_name == \"inv_parabola\":\n        exponent = extra_args.get(\"power\", 2.0)\n        return betas_for_alpha_bar(\n            num_diffusion_timesteps,\n            lambda t: 1 - t**exponent,\n        )\n    elif schedule_name == \"translated_parabola\":\n        exponent = extra_args.get(\"power\", 2.0)\n        return betas_for_alpha_bar(\n            num_diffusion_timesteps,\n            lambda t: (1 - t) ** exponent,\n        )\n    elif schedule_name == \"exp\":\n        coefficient = extra_args.get(\"coefficient\", -12.0)\n        return betas_for_alpha_bar(num_diffusion_timesteps, lambda t: math.exp(t * coefficient))\n    else:\n        raise NotImplementedError(f\"unknown beta schedule: {schedule_name}\")\n\n\ndef betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):\n    \"\"\"\n    Create a beta schedule that discretizes the given alpha_t_bar function,\n    which defines the cumulative product of (1-beta) over time from t = [0,1].\n\n    :param num_diffusion_timesteps: the number of betas to produce.\n    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and\n                      produces the cumulative product of (1-beta) up to that\n                      part of the diffusion process.\n    :param max_beta: the maximum beta to use; use values lower than 1 to\n                     prevent singularities.\n    \"\"\"\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))\n    return np.array(betas)\n\n\ndef space_timesteps(num_timesteps, section_counts):\n    \"\"\"\n    Create a list of timesteps to use from an original diffusion process,\n    given the number of timesteps we want to take from equally-sized portions\n    of the original process.\n    For example, if there's 300 timesteps and the section counts are [10,15,20]\n    then the first 100 timesteps are strided to be 10 timesteps, the second 100\n    are strided to be 15 timesteps, and the final 100 are strided to be 20.\n    :param num_timesteps: the number of diffusion steps in the original\n                          process to divide up.\n    :param section_counts: either a list of numbers, or a string containing\n                           comma-separated numbers, indicating the step count\n                           per section. As a special case, use \"ddimN\" where N\n                           is a number of steps to use the striding from the\n                           DDIM paper.\n    :return: a set of diffusion steps from the original process to use.\n    \"\"\"\n    if isinstance(section_counts, str):\n        if section_counts.startswith(\"ddim\"):\n            desired_count = int(section_counts[len(\"ddim\") :])\n            for i in range(1, num_timesteps):\n                if len(range(0, num_timesteps, i)) == desired_count:\n                    return set(range(0, num_timesteps, i))\n            raise ValueError(f\"cannot create exactly {num_timesteps} steps with an integer stride\")\n        elif section_counts.startswith(\"exact\"):\n            res = set(int(x) for x in section_counts[len(\"exact\") :].split(\",\"))\n            for x in res:\n                if x < 0 or x >= num_timesteps:\n                    raise ValueError(f\"timestep out of bounds: {x}\")\n            return res\n        section_counts = [int(x) for x in section_counts.split(\",\")]\n    size_per = num_timesteps // len(section_counts)\n    extra = num_timesteps % len(section_counts)\n    start_idx = 0\n    all_steps = []\n    for i, section_count in enumerate(section_counts):\n        size = size_per + (1 if i < extra else 0)\n        if size < section_count:\n            raise ValueError(f\"cannot divide section of {size} steps into {section_count}\")\n        if section_count <= 1:\n            frac_stride = 1\n        else:\n            frac_stride = (size - 1) / (section_count - 1)\n        cur_idx = 0.0\n        taken_steps = []\n        for _ in range(section_count):\n            taken_steps.append(start_idx + round(cur_idx))\n            cur_idx += frac_stride\n        all_steps += taken_steps\n        start_idx += size\n    return set(all_steps)\n\n\nclass GaussianDiffusion:\n    \"\"\"\n    Utilities for training and sampling diffusion models.\n\n    Ported directly from here:\n    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42\n\n    :param betas: a 1-D array of betas for each diffusion timestep from T to 1.\n    :param model_mean_type: a string determining what the model outputs.\n    :param model_var_type: a string determining how variance is output.\n    :param loss_type: a string determining the loss function to use.\n    :param discretized_t0: if True, use discrete gaussian loss for t=0. Only\n                           makes sense for images.\n    :param channel_scales: a multiplier to apply to x_start in training_losses\n                           and sampling functions.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        betas: Sequence[float],\n        model_mean_type: str,\n        model_var_type: str,\n        loss_type: str,\n        discretized_t0: bool = False,\n        channel_scales: Optional[np.ndarray] = None,\n        channel_biases: Optional[np.ndarray] = None,\n    ):\n        self.model_mean_type = model_mean_type\n        self.model_var_type = model_var_type\n        self.loss_type = loss_type\n        self.discretized_t0 = discretized_t0\n        self.channel_scales = channel_scales\n        self.channel_biases = channel_biases\n\n        # Use float64 for accuracy.\n        betas = np.array(betas, dtype=np.float64)\n        self.betas = betas\n        assert len(betas.shape) == 1, \"betas must be 1-D\"\n        assert (betas > 0).all() and (betas <= 1).all()\n\n        self.num_timesteps = int(betas.shape[0])\n\n        alphas = 1.0 - betas\n        self.alphas_cumprod = np.cumprod(alphas, axis=0)\n        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])\n        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)\n        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)\n        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)\n        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)\n        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)\n        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n        self.posterior_variance = (\n            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)\n        )\n        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n        self.posterior_log_variance_clipped = np.log(\n            np.append(self.posterior_variance[1], self.posterior_variance[1:])\n        )\n        self.posterior_mean_coef1 = (\n            betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)\n        )\n        self.posterior_mean_coef2 = (\n            (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)\n        )\n\n    def get_sigmas(self, t):\n        return _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, t.shape)\n\n    def q_mean_variance(self, x_start, t):\n        \"\"\"\n        Get the distribution q(x_t | x_0).\n\n        :param x_start: the [N x C x ...] tensor of noiseless inputs.\n        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.\n        :return: A tuple (mean, variance, log_variance), all of x_start's shape.\n        \"\"\"\n        mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n        variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)\n        log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)\n        return mean, variance, log_variance\n\n    def q_sample(self, x_start, t, noise=None):\n        \"\"\"\n        Diffuse the data for a given number of diffusion steps.\n\n        In other words, sample from q(x_t | x_0).\n\n        :param x_start: the initial data batch.\n        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.\n        :param noise: if specified, the split-out normal noise.\n        :return: A noisy version of x_start.\n        \"\"\"\n        if noise is None:\n            noise = th.randn_like(x_start)\n        assert noise.shape == x_start.shape\n        return (\n            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise\n        )\n\n    def q_posterior_mean_variance(self, x_start, x_t, t):\n        \"\"\"\n        Compute the mean and variance of the diffusion posterior:\n\n            q(x_{t-1} | x_t, x_0)\n\n        \"\"\"\n        assert x_start.shape == x_t.shape\n        posterior_mean = (\n            _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start\n            + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = _extract_into_tensor(\n            self.posterior_log_variance_clipped, t, x_t.shape\n        )\n        assert (\n            posterior_mean.shape[0]\n            == posterior_variance.shape[0]\n            == posterior_log_variance_clipped.shape[0]\n            == x_start.shape[0]\n        )\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n    def p_mean_variance(\n        self, model, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None\n    ):\n        \"\"\"\n        Apply the model to get p(x_{t-1} | x_t), as well as a prediction of\n        the initial x, x_0.\n\n        :param model: the model, which takes a signal and a batch of timesteps\n                      as input.\n        :param x: the [N x C x ...] tensor at time t.\n        :param t: a 1-D Tensor of timesteps.\n        :param clip_denoised: if True, clip the denoised signal into [-1, 1].\n        :param denoised_fn: if not None, a function which applies to the\n            x_start prediction before it is used to sample. Applies before\n            clip_denoised.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :return: a dict with the following keys:\n                 - 'mean': the model mean output.\n                 - 'variance': the model variance output.\n                 - 'log_variance': the log of 'variance'.\n                 - 'pred_xstart': the prediction for x_0.\n        \"\"\"\n        if model_kwargs is None:\n            model_kwargs = {}\n\n        B, C = x.shape[:2]\n        assert t.shape == (B,)\n        model_output = model(x, t, **model_kwargs)\n        if isinstance(model_output, tuple):\n            model_output, extra = model_output\n        else:\n            extra = None\n\n        if self.model_var_type in [\"learned\", \"learned_range\"]:\n            assert model_output.shape == (B, C * 2, *x.shape[2:])\n            model_output, model_var_values = th.split(model_output, C, dim=1)\n            if self.model_var_type == \"learned\":\n                model_log_variance = model_var_values\n                model_variance = th.exp(model_log_variance)\n            else:\n                min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)\n                max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)\n                # The model_var_values is [-1, 1] for [min_var, max_var].\n                frac = (model_var_values + 1) / 2\n                model_log_variance = frac * max_log + (1 - frac) * min_log\n                model_variance = th.exp(model_log_variance)\n        else:\n            model_variance, model_log_variance = {\n                # for fixedlarge, we set the initial (log-)variance like so\n                # to get a better decoder log likelihood.\n                \"fixed_large\": (\n                    np.append(self.posterior_variance[1], self.betas[1:]),\n                    np.log(np.append(self.posterior_variance[1], self.betas[1:])),\n                ),\n                \"fixed_small\": (\n                    self.posterior_variance,\n                    self.posterior_log_variance_clipped,\n                ),\n            }[self.model_var_type]\n            model_variance = _extract_into_tensor(model_variance, t, x.shape)\n            model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)\n\n        def process_xstart(x):\n            if denoised_fn is not None:\n                x = denoised_fn(x)\n            if clip_denoised:\n                return x.clamp(-1, 1)\n            return x\n\n        if self.model_mean_type == \"x_prev\":\n            pred_xstart = process_xstart(\n                self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)\n            )\n            model_mean = model_output\n        elif self.model_mean_type in [\"x_start\", \"epsilon\"]:\n            if self.model_mean_type == \"x_start\":\n                pred_xstart = process_xstart(model_output)\n            else:\n                pred_xstart = process_xstart(\n                    self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)\n                )\n            model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)\n        else:\n            raise NotImplementedError(self.model_mean_type)\n\n        assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape\n        return {\n            \"mean\": model_mean,\n            \"variance\": model_variance,\n            \"log_variance\": model_log_variance,\n            \"pred_xstart\": pred_xstart,\n            \"extra\": extra,\n        }\n\n    def _predict_xstart_from_eps(self, x_t, t, eps):\n        assert x_t.shape == eps.shape\n        return (\n            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t\n            - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps\n        )\n\n    def _predict_xstart_from_xprev(self, x_t, t, xprev):\n        assert x_t.shape == xprev.shape\n        return (  # (xprev - coef2*x_t) / coef1\n            _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev\n            - _extract_into_tensor(\n                self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape\n            )\n            * x_t\n        )\n\n    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):\n        return (\n            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart\n        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n\n    def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):\n        \"\"\"\n        Compute the mean for the previous step, given a function cond_fn that\n        computes the gradient of a conditional log probability with respect to\n        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to\n        condition on y.\n\n        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).\n        \"\"\"\n        gradient = cond_fn(x, t, **(model_kwargs or {}))\n        new_mean = p_mean_var[\"mean\"].float() + p_mean_var[\"variance\"] * gradient.float()\n        return new_mean\n\n    def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):\n        \"\"\"\n        Compute what the p_mean_variance output would have been, should the\n        model's score function be conditioned by cond_fn.\n\n        See condition_mean() for details on cond_fn.\n\n        Unlike condition_mean(), this instead uses the conditioning strategy\n        from Song et al (2020).\n        \"\"\"\n        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)\n\n        eps = self._predict_eps_from_xstart(x, t, p_mean_var[\"pred_xstart\"])\n        eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **(model_kwargs or {}))\n\n        out = p_mean_var.copy()\n        out[\"pred_xstart\"] = self._predict_xstart_from_eps(x, t, eps)\n        out[\"mean\"], _, _ = self.q_posterior_mean_variance(x_start=out[\"pred_xstart\"], x_t=x, t=t)\n        return out\n\n    def p_sample(\n        self,\n        model,\n        x,\n        t,\n        clip_denoised=False,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n    ):\n        \"\"\"\n        Sample x_{t-1} from the model at the given timestep.\n\n        :param model: the model to sample from.\n        :param x: the current tensor at x_{t-1}.\n        :param t: the value of t, starting at 0 for the first diffusion step.\n        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].\n        :param denoised_fn: if not None, a function which applies to the\n            x_start prediction before it is used to sample.\n        :param cond_fn: if not None, this is a gradient function that acts\n                        similarly to the model.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :return: a dict containing the following keys:\n                 - 'sample': a random sample from the model.\n                 - 'pred_xstart': a prediction of x_0.\n        \"\"\"\n        out = self.p_mean_variance(\n            model,\n            x,\n            t,\n            clip_denoised=clip_denoised,\n            denoised_fn=denoised_fn,\n            model_kwargs=model_kwargs,\n        )\n        noise = th.randn_like(x)\n        nonzero_mask = (\n            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))\n        )  # no noise when t == 0\n        if cond_fn is not None:\n            out[\"mean\"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)\n        sample = out[\"mean\"] + nonzero_mask * th.exp(0.5 * out[\"log_variance\"]) * noise\n        return {\"sample\": sample, \"pred_xstart\": out[\"pred_xstart\"]}\n\n    def p_sample_loop(\n        self,\n        model,\n        shape,\n        noise=None,\n        clip_denoised=False,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n        device=None,\n        progress=False,\n        temp=1.0,\n    ):\n        \"\"\"\n        Generate samples from the model.\n\n        :param model: the model module.\n        :param shape: the shape of the samples, (N, C, H, W).\n        :param noise: if specified, the noise from the encoder to sample.\n                      Should be of the same shape as `shape`.\n        :param clip_denoised: if True, clip x_start predictions to [-1, 1].\n        :param denoised_fn: if not None, a function which applies to the\n            x_start prediction before it is used to sample.\n        :param cond_fn: if not None, this is a gradient function that acts\n                        similarly to the model.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :param device: if specified, the device to create the samples on.\n                       If not specified, use a model parameter's device.\n        :param progress: if True, show a tqdm progress bar.\n        :return: a non-differentiable batch of samples.\n        \"\"\"\n        final = None\n        for sample in self.p_sample_loop_progressive(\n            model,\n            shape,\n            noise=noise,\n            clip_denoised=clip_denoised,\n            denoised_fn=denoised_fn,\n            cond_fn=cond_fn,\n            model_kwargs=model_kwargs,\n            device=device,\n            progress=progress,\n            temp=temp,\n        ):\n            final = sample\n        return final[\"sample\"]\n\n    def p_sample_loop_progressive(\n        self,\n        model,\n        shape,\n        noise=None,\n        clip_denoised=False,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n        device=None,\n        progress=False,\n        temp=1.0,\n    ):\n        \"\"\"\n        Generate samples from the model and yield intermediate samples from\n        each timestep of diffusion.\n\n        Arguments are the same as p_sample_loop().\n        Returns a generator over dicts, where each dict is the return value of\n        p_sample().\n        \"\"\"\n        if device is None:\n            device = next(model.parameters()).device\n        assert isinstance(shape, (tuple, list))\n        if noise is not None:\n            img = noise\n        else:\n            img = th.randn(*shape, device=device) * temp\n        indices = list(range(self.num_timesteps))[::-1]\n\n        if progress:\n            # Lazy import so that we don't depend on tqdm.\n            from tqdm.auto import tqdm\n\n            indices = tqdm(indices)\n\n        for i in indices:\n            t = th.tensor([i] * shape[0], device=device)\n            with th.no_grad():\n                out = self.p_sample(\n                    model,\n                    img,\n                    t,\n                    clip_denoised=clip_denoised,\n                    denoised_fn=denoised_fn,\n                    cond_fn=cond_fn,\n                    model_kwargs=model_kwargs,\n                )\n                yield self.unscale_out_dict(out)\n                img = out[\"sample\"]\n\n    def ddim_sample(\n        self,\n        model,\n        x,\n        t,\n        clip_denoised=False,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n        eta=0.0,\n    ):\n        \"\"\"\n        Sample x_{t-1} from the model using DDIM.\n\n        Same usage as p_sample().\n        \"\"\"\n        out = self.p_mean_variance(\n            model,\n            x,\n            t,\n            clip_denoised=clip_denoised,\n            denoised_fn=denoised_fn,\n            model_kwargs=model_kwargs,\n        )\n        if cond_fn is not None:\n            out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)\n\n        # Usually our model outputs epsilon, but we re-derive it\n        # in case we used x_start or x_prev prediction.\n        eps = self._predict_eps_from_xstart(x, t, out[\"pred_xstart\"])\n\n        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)\n        alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)\n        sigma = (\n            eta\n            * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))\n            * th.sqrt(1 - alpha_bar / alpha_bar_prev)\n        )\n        # Equation 12.\n        noise = th.randn_like(x)\n        mean_pred = (\n            out[\"pred_xstart\"] * th.sqrt(alpha_bar_prev)\n            + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps\n        )\n        nonzero_mask = (\n            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))\n        )  # no noise when t == 0\n        sample = mean_pred + nonzero_mask * sigma * noise\n        return {\"sample\": sample, \"pred_xstart\": out[\"pred_xstart\"]}\n\n    def ddim_reverse_sample(\n        self,\n        model,\n        x,\n        t,\n        clip_denoised=False,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n        eta=0.0,\n    ):\n        \"\"\"\n        Sample x_{t+1} from the model using DDIM reverse ODE.\n        \"\"\"\n        assert eta == 0.0, \"Reverse ODE only for deterministic path\"\n        out = self.p_mean_variance(\n            model,\n            x,\n            t,\n            clip_denoised=clip_denoised,\n            denoised_fn=denoised_fn,\n            model_kwargs=model_kwargs,\n        )\n        if cond_fn is not None:\n            out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)\n        # Usually our model outputs epsilon, but we re-derive it\n        # in case we used x_start or x_prev prediction.\n        eps = (\n            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x\n            - out[\"pred_xstart\"]\n        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)\n        alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)\n\n        # Equation 12. reversed\n        mean_pred = out[\"pred_xstart\"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps\n\n        return {\"sample\": mean_pred, \"pred_xstart\": out[\"pred_xstart\"]}\n\n    def ddim_sample_loop(\n        self,\n        model,\n        shape,\n        noise=None,\n        clip_denoised=False,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n        device=None,\n        progress=False,\n        eta=0.0,\n        temp=1.0,\n    ):\n        \"\"\"\n        Generate samples from the model using DDIM.\n\n        Same usage as p_sample_loop().\n        \"\"\"\n        final = None\n        for sample in self.ddim_sample_loop_progressive(\n            model,\n            shape,\n            noise=noise,\n            clip_denoised=clip_denoised,\n            denoised_fn=denoised_fn,\n            cond_fn=cond_fn,\n            model_kwargs=model_kwargs,\n            device=device,\n            progress=progress,\n            eta=eta,\n            temp=temp,\n        ):\n            final = sample\n        return final[\"sample\"]\n\n    def ddim_sample_loop_progressive(\n        self,\n        model,\n        shape,\n        noise=None,\n        clip_denoised=False,\n        denoised_fn=None,\n        cond_fn=None,\n        model_kwargs=None,\n        device=None,\n        progress=False,\n        eta=0.0,\n        temp=1.0,\n    ):\n        \"\"\"\n        Use DDIM to sample from the model and yield intermediate samples from\n        each timestep of DDIM.\n\n        Same usage as p_sample_loop_progressive().\n        \"\"\"\n        if device is None:\n            device = next(model.parameters()).device\n        assert isinstance(shape, (tuple, list))\n        if noise is not None:\n            img = noise\n        else:\n            img = th.randn(*shape, device=device) * temp\n        indices = list(range(self.num_timesteps))[::-1]\n\n        if progress:\n            # Lazy import so that we don't depend on tqdm.\n            from tqdm.auto import tqdm\n\n            indices = tqdm(indices)\n\n        for i in indices:\n            t = th.tensor([i] * shape[0], device=device)\n            with th.no_grad():\n                out = self.ddim_sample(\n                    model,\n                    img,\n                    t,\n                    clip_denoised=clip_denoised,\n                    denoised_fn=denoised_fn,\n                    cond_fn=cond_fn,\n                    model_kwargs=model_kwargs,\n                    eta=eta,\n                )\n                yield self.unscale_out_dict(out)\n                img = out[\"sample\"]\n\n    def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=False, model_kwargs=None):\n        \"\"\"\n        Get a term for the variational lower-bound.\n\n        The resulting units are bits (rather than nats, as one might expect).\n        This allows for comparison to other papers.\n\n        :return: a dict with the following keys:\n                 - 'output': a shape [N] tensor of NLLs or KLs.\n                 - 'pred_xstart': the x_0 predictions.\n        \"\"\"\n        true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(\n            x_start=x_start, x_t=x_t, t=t\n        )\n        out = self.p_mean_variance(\n            model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs\n        )\n        kl = normal_kl(true_mean, true_log_variance_clipped, out[\"mean\"], out[\"log_variance\"])\n        kl = mean_flat(kl) / np.log(2.0)\n\n        decoder_nll = -discretized_gaussian_log_likelihood(\n            x_start, means=out[\"mean\"], log_scales=0.5 * out[\"log_variance\"]\n        )\n        if not self.discretized_t0:\n            decoder_nll = th.zeros_like(decoder_nll)\n        assert decoder_nll.shape == x_start.shape\n        decoder_nll = mean_flat(decoder_nll) / np.log(2.0)\n\n        # At the first timestep return the decoder NLL,\n        # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))\n        output = th.where((t == 0), decoder_nll, kl)\n        return {\n            \"output\": output,\n            \"pred_xstart\": out[\"pred_xstart\"],\n            \"extra\": out[\"extra\"],\n        }\n\n    def training_losses(\n        self, model, x_start, t, model_kwargs=None, noise=None\n    ) -> Dict[str, th.Tensor]:\n        \"\"\"\n        Compute training losses for a single timestep.\n\n        :param model: the model to evaluate loss on.\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :param t: a batch of timestep indices.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n        :param noise: if specified, the specific Gaussian noise to try to remove.\n        :return: a dict with the key \"loss\" containing a tensor of shape [N].\n                 Some mean or variance settings may also have other keys.\n        \"\"\"\n        x_start = self.scale_channels(x_start)\n        if model_kwargs is None:\n            model_kwargs = {}\n        if noise is None:\n            noise = th.randn_like(x_start)\n        x_t = self.q_sample(x_start, t, noise=noise)\n\n        terms = {}\n\n        if self.loss_type == \"kl\" or self.loss_type == \"rescaled_kl\":\n            vb_terms = self._vb_terms_bpd(\n                model=model,\n                x_start=x_start,\n                x_t=x_t,\n                t=t,\n                clip_denoised=False,\n                model_kwargs=model_kwargs,\n            )\n            terms[\"loss\"] = vb_terms[\"output\"]\n            if self.loss_type == \"rescaled_kl\":\n                terms[\"loss\"] *= self.num_timesteps\n            extra = vb_terms[\"extra\"]\n        elif self.loss_type == \"mse\" or self.loss_type == \"rescaled_mse\":\n            model_output = model(x_t, t, **model_kwargs)\n            if isinstance(model_output, tuple):\n                model_output, extra = model_output\n            else:\n                extra = {}\n\n            if self.model_var_type in [\n                \"learned\",\n                \"learned_range\",\n            ]:\n                B, C = x_t.shape[:2]\n                assert model_output.shape == (\n                    B,\n                    C * 2,\n                    *x_t.shape[2:],\n                ), f\"{model_output.shape} != {(B, C * 2, *x_t.shape[2:])}\"\n                model_output, model_var_values = th.split(model_output, C, dim=1)\n                # Learn the variance using the variational bound, but don't let\n                # it affect our mean prediction.\n                frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)\n                terms[\"vb\"] = self._vb_terms_bpd(\n                    model=lambda *args, r=frozen_out: r,\n                    x_start=x_start,\n                    x_t=x_t,\n                    t=t,\n                    clip_denoised=False,\n                )[\"output\"]\n                if self.loss_type == \"rescaled_mse\":\n                    # Divide by 1000 for equivalence with initial implementation.\n                    # Without a factor of 1/1000, the VB term hurts the MSE term.\n                    terms[\"vb\"] *= self.num_timesteps / 1000.0\n\n            target = {\n                \"x_prev\": self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],\n                \"x_start\": x_start,\n                \"epsilon\": noise,\n            }[self.model_mean_type]\n            assert model_output.shape == target.shape == x_start.shape\n            terms[\"mse\"] = mean_flat((target - model_output) ** 2)\n            if \"vb\" in terms:\n                terms[\"loss\"] = terms[\"mse\"] + terms[\"vb\"]\n            else:\n                terms[\"loss\"] = terms[\"mse\"]\n        else:\n            raise NotImplementedError(self.loss_type)\n\n        if \"losses\" in extra:\n            terms.update({k: loss for k, (loss, _scale) in extra[\"losses\"].items()})\n            for loss, scale in extra[\"losses\"].values():\n                terms[\"loss\"] = terms[\"loss\"] + loss * scale\n\n        return terms\n\n    def _prior_bpd(self, x_start):\n        \"\"\"\n        Get the prior KL term for the variational lower-bound, measured in\n        bits-per-dim.\n\n        This term can't be optimized, as it only depends on the encoder.\n\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :return: a batch of [N] KL values (in bits), one per batch element.\n        \"\"\"\n        batch_size = x_start.shape[0]\n        t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)\n        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)\n        kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)\n        return mean_flat(kl_prior) / np.log(2.0)\n\n    def calc_bpd_loop(self, model, x_start, clip_denoised=False, model_kwargs=None):\n        \"\"\"\n        Compute the entire variational lower-bound, measured in bits-per-dim,\n        as well as other related quantities.\n\n        :param model: the model to evaluate loss on.\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :param clip_denoised: if True, clip denoised samples.\n        :param model_kwargs: if not None, a dict of extra keyword arguments to\n            pass to the model. This can be used for conditioning.\n\n        :return: a dict containing the following keys:\n                 - total_bpd: the total variational lower-bound, per batch element.\n                 - prior_bpd: the prior term in the lower-bound.\n                 - vb: an [N x T] tensor of terms in the lower-bound.\n                 - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.\n                 - mse: an [N x T] tensor of epsilon MSEs for each timestep.\n        \"\"\"\n        device = x_start.device\n        batch_size = x_start.shape[0]\n\n        vb = []\n        xstart_mse = []\n        mse = []\n        for t in list(range(self.num_timesteps))[::-1]:\n            t_batch = th.tensor([t] * batch_size, device=device)\n            noise = th.randn_like(x_start)\n            x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)\n            # Calculate VLB term at the current timestep\n            with th.no_grad():\n                out = self._vb_terms_bpd(\n                    model,\n                    x_start=x_start,\n                    x_t=x_t,\n                    t=t_batch,\n                    clip_denoised=clip_denoised,\n                    model_kwargs=model_kwargs,\n                )\n            vb.append(out[\"output\"])\n            xstart_mse.append(mean_flat((out[\"pred_xstart\"] - x_start) ** 2))\n            eps = self._predict_eps_from_xstart(x_t, t_batch, out[\"pred_xstart\"])\n            mse.append(mean_flat((eps - noise) ** 2))\n\n        vb = th.stack(vb, dim=1)\n        xstart_mse = th.stack(xstart_mse, dim=1)\n        mse = th.stack(mse, dim=1)\n\n        prior_bpd = self._prior_bpd(x_start)\n        total_bpd = vb.sum(dim=1) + prior_bpd\n        return {\n            \"total_bpd\": total_bpd,\n            \"prior_bpd\": prior_bpd,\n            \"vb\": vb,\n            \"xstart_mse\": xstart_mse,\n            \"mse\": mse,\n        }\n\n    def scale_channels(self, x: th.Tensor) -> th.Tensor:\n        if self.channel_scales is not None:\n            x = x * th.from_numpy(self.channel_scales).to(x).reshape(\n                [1, -1, *([1] * (len(x.shape) - 2))]\n            )\n        if self.channel_biases is not None:\n            x = x + th.from_numpy(self.channel_biases).to(x).reshape(\n                [1, -1, *([1] * (len(x.shape) - 2))]\n            )\n        return x\n\n    def unscale_channels(self, x: th.Tensor) -> th.Tensor:\n        if self.channel_biases is not None:\n            x = x - th.from_numpy(self.channel_biases).to(x).reshape(\n                [1, -1, *([1] * (len(x.shape) - 2))]\n            )\n        if self.channel_scales is not None:\n            x = x / th.from_numpy(self.channel_scales).to(x).reshape(\n                [1, -1, *([1] * (len(x.shape) - 2))]\n            )\n        return x\n\n    def unscale_out_dict(\n        self, out: Dict[str, Union[th.Tensor, Any]]\n    ) -> Dict[str, Union[th.Tensor, Any]]:\n        return {\n            k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v) for k, v in out.items()\n        }\n\n\nclass SpacedDiffusion(GaussianDiffusion):\n    \"\"\"\n    A diffusion process which can skip steps in a base diffusion process.\n    :param use_timesteps: (unordered) timesteps from the original diffusion\n                          process to retain.\n    :param kwargs: the kwargs to create the base diffusion process.\n    \"\"\"\n\n    def __init__(self, use_timesteps: Iterable[int], **kwargs):\n        self.use_timesteps = set(use_timesteps)\n        self.timestep_map = []\n        self.original_num_steps = len(kwargs[\"betas\"])\n\n        base_diffusion = GaussianDiffusion(**kwargs)  # pylint: disable=missing-kwoa\n        last_alpha_cumprod = 1.0\n        new_betas = []\n        for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):\n            if i in self.use_timesteps:\n                new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)\n                last_alpha_cumprod = alpha_cumprod\n                self.timestep_map.append(i)\n        kwargs[\"betas\"] = np.array(new_betas)\n        super().__init__(**kwargs)\n\n    def p_mean_variance(self, model, *args, **kwargs):\n        return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)\n\n    def training_losses(self, model, *args, **kwargs):\n        return super().training_losses(self._wrap_model(model), *args, **kwargs)\n\n    def condition_mean(self, cond_fn, *args, **kwargs):\n        return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)\n\n    def condition_score(self, cond_fn, *args, **kwargs):\n        return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)\n\n    def _wrap_model(self, model):\n        if isinstance(model, _WrappedModel):\n            return model\n        return _WrappedModel(model, self.timestep_map, self.original_num_steps)\n\n\nclass _WrappedModel:\n    def __init__(self, model, timestep_map, original_num_steps):\n        self.model = model\n        self.timestep_map = timestep_map\n        self.original_num_steps = original_num_steps\n\n    def __call__(self, x, ts, **kwargs):\n        map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)\n        new_ts = map_tensor[ts]\n        return self.model(x, new_ts, **kwargs)\n\n\ndef _extract_into_tensor(arr, timesteps, broadcast_shape):\n    \"\"\"\n    Extract values from a 1-D numpy array for a batch of indices.\n\n    :param arr: the 1-D numpy array.\n    :param timesteps: a tensor of indices into the array to extract.\n    :param broadcast_shape: a larger shape of K dimensions with the batch\n                            dimension equal to the length of timesteps.\n    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.\n    \"\"\"\n    res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()\n    while len(res.shape) < len(broadcast_shape):\n        res = res[..., None]\n    return res + th.zeros(broadcast_shape, device=timesteps.device)\n\n\ndef normal_kl(mean1, logvar1, mean2, logvar2):\n    \"\"\"\n    Compute the KL divergence between two gaussians.\n    Shapes are automatically broadcasted, so batches can be compared to\n    scalars, among other use cases.\n    \"\"\"\n    tensor = None\n    for obj in (mean1, logvar1, mean2, logvar2):\n        if isinstance(obj, th.Tensor):\n            tensor = obj\n            break\n    assert tensor is not None, \"at least one argument must be a Tensor\"\n\n    # Force variances to be Tensors. Broadcasting helps convert scalars to\n    # Tensors, but it does not work for th.exp().\n    logvar1, logvar2 = [\n        x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)\n    ]\n\n    return 0.5 * (\n        -1.0\n        + logvar2\n        - logvar1\n        + th.exp(logvar1 - logvar2)\n        + ((mean1 - mean2) ** 2) * th.exp(-logvar2)\n    )\n\n\ndef approx_standard_normal_cdf(x):\n    \"\"\"\n    A fast approximation of the cumulative distribution function of the\n    standard normal.\n    \"\"\"\n    return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))\n\n\ndef discretized_gaussian_log_likelihood(x, *, means, log_scales):\n    \"\"\"\n    Compute the log-likelihood of a Gaussian distribution discretizing to a\n    given image.\n    :param x: the target images. It is assumed that this was uint8 values,\n              rescaled to the range [-1, 1].\n    :param means: the Gaussian mean Tensor.\n    :param log_scales: the Gaussian log stddev Tensor.\n    :return: a tensor like x of log probabilities (in nats).\n    \"\"\"\n    assert x.shape == means.shape == log_scales.shape\n    centered_x = x - means\n    inv_stdv = th.exp(-log_scales)\n    plus_in = inv_stdv * (centered_x + 1.0 / 255.0)\n    cdf_plus = approx_standard_normal_cdf(plus_in)\n    min_in = inv_stdv * (centered_x - 1.0 / 255.0)\n    cdf_min = approx_standard_normal_cdf(min_in)\n    log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))\n    log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))\n    cdf_delta = cdf_plus - cdf_min\n    log_probs = th.where(\n        x < -0.999,\n        log_cdf_plus,\n        th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),\n    )\n    assert log_probs.shape == x.shape\n    return log_probs\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.flatten(1).mean(1)\n"
  },
  {
    "path": "shap_e/diffusion/k_diffusion.py",
    "content": "\"\"\"\nBased on: https://github.com/crowsonkb/k-diffusion\n\nCopyright (c) 2022 Katherine Crowson\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in\nall copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\nTHE SOFTWARE.\n\"\"\"\n\nimport numpy as np\nimport torch as th\n\nfrom .gaussian_diffusion import GaussianDiffusion, mean_flat\n\n\nclass KarrasDenoiser:\n    def __init__(self, sigma_data: float = 0.5):\n        self.sigma_data = sigma_data\n\n    def get_snr(self, sigmas):\n        return sigmas**-2\n\n    def get_sigmas(self, sigmas):\n        return sigmas\n\n    def get_scalings(self, sigma):\n        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)\n        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5\n        c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5\n        return c_skip, c_out, c_in\n\n    def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None):\n        if model_kwargs is None:\n            model_kwargs = {}\n        if noise is None:\n            noise = th.randn_like(x_start)\n\n        terms = {}\n\n        dims = x_start.ndim\n        x_t = x_start + noise * append_dims(sigmas, dims)\n        c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)]\n        model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs)\n        target = (x_start - c_skip * x_t) / c_out\n\n        terms[\"mse\"] = mean_flat((model_output - target) ** 2)\n        terms[\"xs_mse\"] = mean_flat((denoised - x_start) ** 2)\n\n        if \"vb\" in terms:\n            terms[\"loss\"] = terms[\"mse\"] + terms[\"vb\"]\n        else:\n            terms[\"loss\"] = terms[\"mse\"]\n\n        return terms\n\n    def denoise(self, model, x_t, sigmas, **model_kwargs):\n        c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)]\n        rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)\n        model_output = model(c_in * x_t, rescaled_t, **model_kwargs)\n        denoised = c_out * model_output + c_skip * x_t\n        return model_output, denoised\n\n\nclass GaussianToKarrasDenoiser:\n    def __init__(self, model, diffusion):\n        from scipy import interpolate\n\n        self.model = model\n        self.diffusion = diffusion\n        self.alpha_cumprod_to_t = interpolate.interp1d(\n            diffusion.alphas_cumprod, np.arange(0, diffusion.num_timesteps)\n        )\n\n    def sigma_to_t(self, sigma):\n        alpha_cumprod = 1.0 / (sigma**2 + 1)\n        if alpha_cumprod > self.diffusion.alphas_cumprod[0]:\n            return 0\n        elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]:\n            return self.diffusion.num_timesteps - 1\n        else:\n            return float(self.alpha_cumprod_to_t(alpha_cumprod))\n\n    def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None):\n        t = th.tensor(\n            [self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()],\n            dtype=th.long,\n            device=sigmas.device,\n        )\n        c_in = append_dims(1.0 / (sigmas**2 + 1) ** 0.5, x_t.ndim)\n        out = self.diffusion.p_mean_variance(\n            self.model, x_t * c_in, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs\n        )\n        return None, out[\"pred_xstart\"]\n\n\ndef karras_sample(*args, **kwargs):\n    last = None\n    for x in karras_sample_progressive(*args, **kwargs):\n        last = x[\"x\"]\n    return last\n\n\ndef karras_sample_progressive(\n    diffusion,\n    model,\n    shape,\n    steps,\n    clip_denoised=True,\n    progress=False,\n    model_kwargs=None,\n    device=None,\n    sigma_min=0.002,\n    sigma_max=80,  # higher for highres?\n    rho=7.0,\n    sampler=\"heun\",\n    s_churn=0.0,\n    s_tmin=0.0,\n    s_tmax=float(\"inf\"),\n    s_noise=1.0,\n    guidance_scale=0.0,\n):\n    sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)\n    x_T = th.randn(*shape, device=device) * sigma_max\n    sample_fn = {\"heun\": sample_heun, \"dpm\": sample_dpm, \"ancestral\": sample_euler_ancestral}[\n        sampler\n    ]\n\n    if sampler != \"ancestral\":\n        sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise)\n    else:\n        sampler_args = {}\n\n    if isinstance(diffusion, KarrasDenoiser):\n\n        def denoiser(x_t, sigma):\n            _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs)\n            if clip_denoised:\n                denoised = denoised.clamp(-1, 1)\n            return denoised\n\n    elif isinstance(diffusion, GaussianDiffusion):\n        model = GaussianToKarrasDenoiser(model, diffusion)\n\n        def denoiser(x_t, sigma):\n            _, denoised = model.denoise(\n                x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs\n            )\n            return denoised\n\n    else:\n        raise NotImplementedError\n\n    if guidance_scale != 0 and guidance_scale != 1:\n\n        def guided_denoiser(x_t, sigma):\n            x_t = th.cat([x_t, x_t], dim=0)\n            sigma = th.cat([sigma, sigma], dim=0)\n            x_0 = denoiser(x_t, sigma)\n            cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0)\n            x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0)\n            return x_0\n\n    else:\n        guided_denoiser = denoiser\n\n    for obj in sample_fn(\n        guided_denoiser,\n        x_T,\n        sigmas,\n        progress=progress,\n        **sampler_args,\n    ):\n        if isinstance(diffusion, GaussianDiffusion):\n            yield diffusion.unscale_out_dict(obj)\n        else:\n            yield obj\n\n\ndef get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device=\"cpu\"):\n    \"\"\"Constructs the noise schedule of Karras et al. (2022).\"\"\"\n    ramp = th.linspace(0, 1, n)\n    min_inv_rho = sigma_min ** (1 / rho)\n    max_inv_rho = sigma_max ** (1 / rho)\n    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho\n    return append_zero(sigmas).to(device)\n\n\ndef to_d(x, sigma, denoised):\n    \"\"\"Converts a denoiser output to a Karras ODE derivative.\"\"\"\n    return (x - denoised) / append_dims(sigma, x.ndim)\n\n\ndef get_ancestral_step(sigma_from, sigma_to):\n    \"\"\"Calculates the noise level (sigma_down) to step down to and the amount\n    of noise to add (sigma_up) when doing an ancestral sampling step.\"\"\"\n    sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5\n    sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5\n    return sigma_down, sigma_up\n\n\n@th.no_grad()\ndef sample_euler_ancestral(model, x, sigmas, progress=False):\n    \"\"\"Ancestral sampling with Euler method steps.\"\"\"\n    s_in = x.new_ones([x.shape[0]])\n    indices = range(len(sigmas) - 1)\n    if progress:\n        from tqdm.auto import tqdm\n\n        indices = tqdm(indices)\n\n    for i in indices:\n        denoised = model(x, sigmas[i] * s_in)\n        sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])\n        yield {\"x\": x, \"i\": i, \"sigma\": sigmas[i], \"sigma_hat\": sigmas[i], \"pred_xstart\": denoised}\n        d = to_d(x, sigmas[i], denoised)\n        # Euler method\n        dt = sigma_down - sigmas[i]\n        x = x + d * dt\n        x = x + th.randn_like(x) * sigma_up\n    yield {\"x\": x, \"pred_xstart\": x}\n\n\n@th.no_grad()\ndef sample_heun(\n    denoiser,\n    x,\n    sigmas,\n    progress=False,\n    s_churn=0.0,\n    s_tmin=0.0,\n    s_tmax=float(\"inf\"),\n    s_noise=1.0,\n):\n    \"\"\"Implements Algorithm 2 (Heun steps) from Karras et al. (2022).\"\"\"\n    s_in = x.new_ones([x.shape[0]])\n    indices = range(len(sigmas) - 1)\n    if progress:\n        from tqdm.auto import tqdm\n\n        indices = tqdm(indices)\n\n    for i in indices:\n        gamma = (\n            min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0\n        )\n        eps = th.randn_like(x) * s_noise\n        sigma_hat = sigmas[i] * (gamma + 1)\n        if gamma > 0:\n            x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5\n        denoised = denoiser(x, sigma_hat * s_in)\n        d = to_d(x, sigma_hat, denoised)\n        yield {\"x\": x, \"i\": i, \"sigma\": sigmas[i], \"sigma_hat\": sigma_hat, \"pred_xstart\": denoised}\n        dt = sigmas[i + 1] - sigma_hat\n        if sigmas[i + 1] == 0:\n            # Euler method\n            x = x + d * dt\n        else:\n            # Heun's method\n            x_2 = x + d * dt\n            denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in)\n            d_2 = to_d(x_2, sigmas[i + 1], denoised_2)\n            d_prime = (d + d_2) / 2\n            x = x + d_prime * dt\n    yield {\"x\": x, \"pred_xstart\": denoised}\n\n\n@th.no_grad()\ndef sample_dpm(\n    denoiser,\n    x,\n    sigmas,\n    progress=False,\n    s_churn=0.0,\n    s_tmin=0.0,\n    s_tmax=float(\"inf\"),\n    s_noise=1.0,\n):\n    \"\"\"A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).\"\"\"\n    s_in = x.new_ones([x.shape[0]])\n    indices = range(len(sigmas) - 1)\n    if progress:\n        from tqdm.auto import tqdm\n\n        indices = tqdm(indices)\n\n    for i in indices:\n        gamma = (\n            min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0\n        )\n        eps = th.randn_like(x) * s_noise\n        sigma_hat = sigmas[i] * (gamma + 1)\n        if gamma > 0:\n            x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5\n        denoised = denoiser(x, sigma_hat * s_in)\n        d = to_d(x, sigma_hat, denoised)\n        yield {\"x\": x, \"i\": i, \"sigma\": sigmas[i], \"sigma_hat\": sigma_hat, \"denoised\": denoised}\n        # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule\n        sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3\n        dt_1 = sigma_mid - sigma_hat\n        dt_2 = sigmas[i + 1] - sigma_hat\n        x_2 = x + d * dt_1\n        denoised_2 = denoiser(x_2, sigma_mid * s_in)\n        d_2 = to_d(x_2, sigma_mid, denoised_2)\n        x = x + d_2 * dt_2\n    yield {\"x\": x, \"pred_xstart\": denoised}\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\")\n    return x[(...,) + (None,) * dims_to_append]\n\n\ndef append_zero(x):\n    return th.cat([x, x.new_zeros([1])])\n"
  },
  {
    "path": "shap_e/diffusion/sample.py",
    "content": "from typing import Any, Callable, Dict, Optional\n\nimport torch\nimport torch.nn as nn\n\nfrom .gaussian_diffusion import GaussianDiffusion\nfrom .k_diffusion import karras_sample\n\nDEFAULT_KARRAS_STEPS = 64\nDEFAULT_KARRAS_SIGMA_MIN = 1e-3\nDEFAULT_KARRAS_SIGMA_MAX = 160\nDEFAULT_KARRAS_S_CHURN = 0.0\n\n\ndef uncond_guide_model(\n    model: Callable[..., torch.Tensor], scale: float\n) -> Callable[..., torch.Tensor]:\n    def model_fn(x_t, ts, **kwargs):\n        half = x_t[: len(x_t) // 2]\n        combined = torch.cat([half, half], dim=0)\n        model_out = model(combined, ts, **kwargs)\n        eps, rest = model_out[:, :3], model_out[:, 3:]\n        cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)\n        half_eps = uncond_eps + scale * (cond_eps - uncond_eps)\n        eps = torch.cat([half_eps, half_eps], dim=0)\n        return torch.cat([eps, rest], dim=1)\n\n    return model_fn\n\n\ndef sample_latents(\n    *,\n    batch_size: int,\n    model: nn.Module,\n    diffusion: GaussianDiffusion,\n    model_kwargs: Dict[str, Any],\n    guidance_scale: float,\n    clip_denoised: bool,\n    use_fp16: bool,\n    use_karras: bool,\n    karras_steps: int,\n    sigma_min: float,\n    sigma_max: float,\n    s_churn: float,\n    device: Optional[torch.device] = None,\n    progress: bool = False,\n) -> torch.Tensor:\n    sample_shape = (batch_size, model.d_latent)\n\n    if device is None:\n        device = next(model.parameters()).device\n\n    if hasattr(model, \"cached_model_kwargs\"):\n        model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)\n    if guidance_scale != 1.0 and guidance_scale != 0.0:\n        for k, v in model_kwargs.copy().items():\n            model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)\n\n    sample_shape = (batch_size, model.d_latent)\n    with torch.autocast(device_type=device.type, enabled=use_fp16):\n        if use_karras:\n            samples = karras_sample(\n                diffusion=diffusion,\n                model=model,\n                shape=sample_shape,\n                steps=karras_steps,\n                clip_denoised=clip_denoised,\n                model_kwargs=model_kwargs,\n                device=device,\n                sigma_min=sigma_min,\n                sigma_max=sigma_max,\n                s_churn=s_churn,\n                guidance_scale=guidance_scale,\n                progress=progress,\n            )\n        else:\n            internal_batch_size = batch_size\n            if guidance_scale != 1.0:\n                model = uncond_guide_model(model, guidance_scale)\n                internal_batch_size *= 2\n            samples = diffusion.p_sample_loop(\n                model,\n                shape=(internal_batch_size, *sample_shape[1:]),\n                model_kwargs=model_kwargs,\n                device=device,\n                clip_denoised=clip_denoised,\n                progress=progress,\n            )\n\n    return samples\n"
  },
  {
    "path": "shap_e/examples/encode_model.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"\\n\",\n    \"from shap_e.models.download import load_model\\n\",\n    \"from shap_e.util.data_util import load_or_create_multimodal_batch\\n\",\n    \"from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"xm = load_model('transmitter', device=device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_path = \\\"example_data/cactus/object.obj\\\"\\n\",\n    \"\\n\",\n    \"# This may take a few minutes, since it requires rendering the model twice\\n\",\n    \"# in two different modes.\\n\",\n    \"batch = load_or_create_multimodal_batch(\\n\",\n    \"    device,\\n\",\n    \"    model_path=model_path,\\n\",\n    \"    mv_light_mode=\\\"basic\\\",\\n\",\n    \"    mv_image_size=256,\\n\",\n    \"    cache_dir=\\\"example_data/cactus/cached\\\",\\n\",\n    \"    verbose=True, # this will show Blender output during renders\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"with torch.no_grad():\\n\",\n    \"    latent = xm.encoder.encode_to_bottleneck(batch)\\n\",\n    \"\\n\",\n    \"    render_mode = 'stf' # you can change this to 'nerf'\\n\",\n    \"    size = 128 # recommended that you lower resolution when using nerf\\n\",\n    \"\\n\",\n    \"    cameras = create_pan_cameras(size, device)\\n\",\n    \"    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\\n\",\n    \"    display(gif_widget(images))\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "shap_e/examples/example_data/cactus/material.mtl",
    "content": "newmtl mat0\nKa 0.0000 0.7000 0.0000\nKd 0.0000 0.7000 0.0000\nKs 0.0000 0.0000 0.0000\nnewmtl mat1\nKa 0.6600 0.4400 0.2000\nKd 0.6600 0.4400 0.2000\nKs 0.0000 0.0000 0.0000\nnewmtl mat2\nKa 0.3000 0.3000 0.3000\nKd 0.3000 0.3000 0.3000\nKs 0.0000 0.0000 0.0000\nnewmtl mat3\nKa 0.0000 0.5000 0.0000\nKd 0.0000 0.5000 0.0000\nKs 0.0000 0.0000 0.0000\nnewmtl mat4\nKa 0.0000 0.5667 0.0000\nKd 0.0000 0.5667 0.0000\nKs 0.0000 0.0000 0.0000\nnewmtl mat5\nKa 0.5400 0.3933 0.2333\nKd 0.5400 0.3933 0.2333\nKs 0.0000 0.0000 0.0000\nnewmtl mat6\nKa 0.0000 0.6333 0.0000\nKd 0.0000 0.6333 0.0000\nKs 0.0000 0.0000 0.0000\nnewmtl mat7\nKa 0.2000 0.3667 0.2000\nKd 0.2000 0.3667 0.2000\nKs 0.0000 0.0000 0.0000\nnewmtl mat8\nKa 0.4200 0.3467 0.2667\nKd 0.4200 0.3467 0.2667\nKs 0.0000 0.0000 0.0000\nnewmtl mat9\nKa 0.1000 0.4333 0.1000\nKd 0.1000 0.4333 0.1000\nKs 0.0000 0.0000 0.0000\nnewmtl mat10\nKa 0.1000 0.5667 0.1000\nKd 0.1000 0.5667 0.1000\nKs 0.0000 0.0000 0.0000\nnewmtl mat11\nKa 0.2000 0.4333 0.2000\nKd 0.2000 0.4333 0.2000\nKs 0.0000 0.0000 0.0000\nnewmtl mat12\nKa 0.1000 0.5000 0.1000\nKd 0.1000 0.5000 0.1000\nKs 0.0000 0.0000 0.0000\n"
  },
  {
    "path": "shap_e/examples/sample_image_to_3d.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"964ccced\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"\\n\",\n    \"from shap_e.diffusion.sample import sample_latents\\n\",\n    \"from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\\n\",\n    \"from shap_e.models.download import load_model, load_config\\n\",\n    \"from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget\\n\",\n    \"from shap_e.util.image_util import load_image\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"8eed3a76\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"2d922637\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"xm = load_model('transmitter', device=device)\\n\",\n    \"model = load_model('image300M', device=device)\\n\",\n    \"diffusion = diffusion_from_config(load_config('diffusion'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"53d329d0\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"batch_size = 4\\n\",\n    \"guidance_scale = 3.0\\n\",\n    \"\\n\",\n    \"# To get the best result, you should remove the background and show only the object of interest to the model.\\n\",\n    \"image = load_image(\\\"example_data/corgi.png\\\")\\n\",\n    \"\\n\",\n    \"latents = sample_latents(\\n\",\n    \"    batch_size=batch_size,\\n\",\n    \"    model=model,\\n\",\n    \"    diffusion=diffusion,\\n\",\n    \"    guidance_scale=guidance_scale,\\n\",\n    \"    model_kwargs=dict(images=[image] * batch_size),\\n\",\n    \"    progress=True,\\n\",\n    \"    clip_denoised=True,\\n\",\n    \"    use_fp16=True,\\n\",\n    \"    use_karras=True,\\n\",\n    \"    karras_steps=64,\\n\",\n    \"    sigma_min=1e-3,\\n\",\n    \"    sigma_max=160,\\n\",\n    \"    s_churn=0,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"633da2ec\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"render_mode = 'nerf' # you can change this to 'stf' for mesh rendering\\n\",\n    \"size = 64 # this is the size of the renders; higher values take longer to render.\\n\",\n    \"\\n\",\n    \"cameras = create_pan_cameras(size, device)\\n\",\n    \"for i, latent in enumerate(latents):\\n\",\n    \"    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\\n\",\n    \"    display(gif_widget(images))\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "shap_e/examples/sample_text_to_3d.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"964ccced\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"\\n\",\n    \"from shap_e.diffusion.sample import sample_latents\\n\",\n    \"from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\\n\",\n    \"from shap_e.models.download import load_model, load_config\\n\",\n    \"from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"8eed3a76\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"2d922637\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"xm = load_model('transmitter', device=device)\\n\",\n    \"model = load_model('text300M', device=device)\\n\",\n    \"diffusion = diffusion_from_config(load_config('diffusion'))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"53d329d0\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"batch_size = 4\\n\",\n    \"guidance_scale = 15.0\\n\",\n    \"prompt = \\\"a shark\\\"\\n\",\n    \"\\n\",\n    \"latents = sample_latents(\\n\",\n    \"    batch_size=batch_size,\\n\",\n    \"    model=model,\\n\",\n    \"    diffusion=diffusion,\\n\",\n    \"    guidance_scale=guidance_scale,\\n\",\n    \"    model_kwargs=dict(texts=[prompt] * batch_size),\\n\",\n    \"    progress=True,\\n\",\n    \"    clip_denoised=True,\\n\",\n    \"    use_fp16=True,\\n\",\n    \"    use_karras=True,\\n\",\n    \"    karras_steps=64,\\n\",\n    \"    sigma_min=1e-3,\\n\",\n    \"    sigma_max=160,\\n\",\n    \"    s_churn=0,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"633da2ec\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"render_mode = 'nerf' # you can change this to 'stf'\\n\",\n    \"size = 64 # this is the size of the renders; higher values take longer to render.\\n\",\n    \"\\n\",\n    \"cameras = create_pan_cameras(size, device)\\n\",\n    \"for i, latent in enumerate(latents):\\n\",\n    \"    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\\n\",\n    \"    display(gif_widget(images))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"85a4dce4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Example of saving the latents as meshes.\\n\",\n    \"from shap_e.util.notebooks import decode_latent_mesh\\n\",\n    \"\\n\",\n    \"for i, latent in enumerate(latents):\\n\",\n    \"    t = decode_latent_mesh(xm, latent).tri_mesh()\\n\",\n    \"    with open(f'example_mesh_{i}.ply', 'wb') as f:\\n\",\n    \"        t.write_ply(f)\\n\",\n    \"    with open(f'example_mesh_{i}.obj', 'w') as f:\\n\",\n    \"        t.write_obj(f)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "shap_e/models/__init__.py",
    "content": ""
  },
  {
    "path": "shap_e/models/configs.py",
    "content": "from typing import Any, Dict, Union\n\nimport blobfile as bf\nimport torch\nimport torch.nn as nn\nimport yaml\n\nfrom shap_e.models.generation.latent_diffusion import SplitVectorDiffusion\nfrom shap_e.models.generation.perceiver import PointDiffusionPerceiver\nfrom shap_e.models.generation.pooled_mlp import PooledMLP\nfrom shap_e.models.generation.transformer import (\n    CLIPImageGridPointDiffusionTransformer,\n    CLIPImageGridUpsamplePointDiffusionTransformer,\n    CLIPImagePointDiffusionTransformer,\n    PointDiffusionTransformer,\n    UpsamplePointDiffusionTransformer,\n)\nfrom shap_e.models.nerf.model import MLPNeRFModel, VoidNeRFModel\nfrom shap_e.models.nerf.renderer import OneStepNeRFRenderer, TwoStepNeRFRenderer\nfrom shap_e.models.nerstf.mlp import MLPDensitySDFModel, MLPNeRSTFModel\nfrom shap_e.models.nerstf.renderer import NeRSTFRenderer\nfrom shap_e.models.nn.meta import batch_meta_state_dict\nfrom shap_e.models.stf.mlp import MLPSDFModel, MLPTextureFieldModel\nfrom shap_e.models.stf.renderer import STFRenderer\nfrom shap_e.models.transmitter.base import ChannelsDecoder, Transmitter, VectorDecoder\nfrom shap_e.models.transmitter.channels_encoder import (\n    PointCloudPerceiverChannelsEncoder,\n    PointCloudTransformerChannelsEncoder,\n)\nfrom shap_e.models.transmitter.multiview_encoder import MultiviewTransformerEncoder\nfrom shap_e.models.transmitter.pc_encoder import (\n    PointCloudPerceiverEncoder,\n    PointCloudTransformerEncoder,\n)\nfrom shap_e.models.volume import BoundingBoxVolume, SphericalVolume, UnboundedVolume\n\n\ndef model_from_config(config: Union[str, Dict[str, Any]], device: torch.device) -> nn.Module:\n    if isinstance(config, str):\n        with bf.BlobFile(config, \"rb\") as f:\n            obj = yaml.load(f, Loader=yaml.SafeLoader)\n        return model_from_config(obj, device=device)\n\n    config = config.copy()\n    name = config.pop(\"name\")\n\n    if name == \"PointCloudTransformerEncoder\":\n        return PointCloudTransformerEncoder(device=device, dtype=torch.float32, **config)\n    elif name == \"PointCloudPerceiverEncoder\":\n        return PointCloudPerceiverEncoder(device=device, dtype=torch.float32, **config)\n    elif name == \"PointCloudTransformerChannelsEncoder\":\n        return PointCloudTransformerChannelsEncoder(device=device, dtype=torch.float32, **config)\n    elif name == \"PointCloudPerceiverChannelsEncoder\":\n        return PointCloudPerceiverChannelsEncoder(device=device, dtype=torch.float32, **config)\n    elif name == \"MultiviewTransformerEncoder\":\n        return MultiviewTransformerEncoder(device=device, dtype=torch.float32, **config)\n    elif name == \"Transmitter\":\n        renderer = model_from_config(config.pop(\"renderer\"), device=device)\n        param_shapes = {\n            k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()\n        }\n        encoder_config = config.pop(\"encoder\").copy()\n        encoder_config[\"param_shapes\"] = param_shapes\n        encoder = model_from_config(encoder_config, device=device)\n        return Transmitter(encoder=encoder, renderer=renderer, **config)\n    elif name == \"VectorDecoder\":\n        renderer = model_from_config(config.pop(\"renderer\"), device=device)\n        param_shapes = {\n            k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()\n        }\n        return VectorDecoder(param_shapes=param_shapes, renderer=renderer, device=device, **config)\n    elif name == \"ChannelsDecoder\":\n        renderer = model_from_config(config.pop(\"renderer\"), device=device)\n        param_shapes = {\n            k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()\n        }\n        return ChannelsDecoder(\n            param_shapes=param_shapes, renderer=renderer, device=device, **config\n        )\n    elif name == \"OneStepNeRFRenderer\":\n        config = config.copy()\n        for field in [\n            # Required\n            \"void_model\",\n            \"foreground_model\",\n            \"volume\",\n            # Optional to use NeRF++\n            \"background_model\",\n            \"outer_volume\",\n        ]:\n            if field in config:\n                config[field] = model_from_config(config.pop(field).copy(), device)\n        return OneStepNeRFRenderer(device=device, **config)\n    elif name == \"TwoStepNeRFRenderer\":\n        config = config.copy()\n        for field in [\n            # Required\n            \"void_model\",\n            \"coarse_model\",\n            \"fine_model\",\n            \"volume\",\n            # Optional to use NeRF++\n            \"coarse_background_model\",\n            \"fine_background_model\",\n            \"outer_volume\",\n        ]:\n            if field in config:\n                config[field] = model_from_config(config.pop(field).copy(), device)\n        return TwoStepNeRFRenderer(device=device, **config)\n    elif name == \"PooledMLP\":\n        return PooledMLP(device, **config)\n    elif name == \"PointDiffusionTransformer\":\n        return PointDiffusionTransformer(device=device, dtype=torch.float32, **config)\n    elif name == \"PointDiffusionPerceiver\":\n        return PointDiffusionPerceiver(device=device, dtype=torch.float32, **config)\n    elif name == \"CLIPImagePointDiffusionTransformer\":\n        return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config)\n    elif name == \"CLIPImageGridPointDiffusionTransformer\":\n        return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config)\n    elif name == \"UpsamplePointDiffusionTransformer\":\n        return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config)\n    elif name == \"CLIPImageGridUpsamplePointDiffusionTransformer\":\n        return CLIPImageGridUpsamplePointDiffusionTransformer(\n            device=device, dtype=torch.float32, **config\n        )\n    elif name == \"SplitVectorDiffusion\":\n        inner_config = config.pop(\"inner\")\n        d_latent = config.pop(\"d_latent\")\n        latent_ctx = config.pop(\"latent_ctx\", 1)\n        inner_config[\"input_channels\"] = d_latent // latent_ctx\n        inner_config[\"n_ctx\"] = latent_ctx\n        inner_config[\"output_channels\"] = d_latent // latent_ctx * 2\n        inner_model = model_from_config(inner_config, device)\n        return SplitVectorDiffusion(\n            device=device, wrapped=inner_model, n_ctx=latent_ctx, d_latent=d_latent\n        )\n    elif name == \"STFRenderer\":\n        config = config.copy()\n        for field in [\"sdf\", \"tf\", \"volume\"]:\n            config[field] = model_from_config(config.pop(field), device)\n        return STFRenderer(device=device, **config)\n    elif name == \"NeRSTFRenderer\":\n        config = config.copy()\n        for field in [\"sdf\", \"tf\", \"nerstf\", \"void\", \"volume\"]:\n            if field not in config:\n                continue\n            config[field] = model_from_config(config.pop(field), device)\n        config.setdefault(\"sdf\", None)\n        config.setdefault(\"tf\", None)\n        config.setdefault(\"nerstf\", None)\n        return NeRSTFRenderer(device=device, **config)\n\n    model_cls = {\n        \"MLPSDFModel\": MLPSDFModel,\n        \"MLPTextureFieldModel\": MLPTextureFieldModel,\n        \"MLPNeRFModel\": MLPNeRFModel,\n        \"MLPDensitySDFModel\": MLPDensitySDFModel,\n        \"MLPNeRSTFModel\": MLPNeRSTFModel,\n        \"VoidNeRFModel\": VoidNeRFModel,\n        \"BoundingBoxVolume\": BoundingBoxVolume,\n        \"SphericalVolume\": SphericalVolume,\n        \"UnboundedVolume\": UnboundedVolume,\n    }[name]\n    return model_cls(device=device, **config)\n"
  },
  {
    "path": "shap_e/models/download.py",
    "content": "\"\"\"\nAdapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py\n\"\"\"\n\nimport hashlib\nimport os\nfrom functools import lru_cache\nfrom typing import Dict, Optional\n\nimport requests\nimport torch\nimport yaml\nfrom filelock import FileLock\nfrom tqdm.auto import tqdm\n\nMODEL_PATHS = {\n    \"transmitter\": \"https://openaipublic.azureedge.net/main/shap-e/transmitter.pt\",\n    \"decoder\": \"https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt\",\n    \"text300M\": \"https://openaipublic.azureedge.net/main/shap-e/text_cond.pt\",\n    \"image300M\": \"https://openaipublic.azureedge.net/main/shap-e/image_cond.pt\",\n}\n\nCONFIG_PATHS = {\n    \"transmitter\": \"https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml\",\n    \"decoder\": \"https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml\",\n    \"text300M\": \"https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml\",\n    \"image300M\": \"https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml\",\n    \"diffusion\": \"https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml\",\n}\n\nURL_HASHES = {\n    \"https://openaipublic.azureedge.net/main/shap-e/transmitter.pt\": \"af02a0b85a8abdfb3919584b63c540ba175f6ad4790f574a7fef4617e5acdc3b\",\n    \"https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt\": \"d7e7ebbfe3780499ae89b2da5e7c1354012dba5a6abfe295bed42f25c3be1b98\",\n    \"https://openaipublic.azureedge.net/main/shap-e/text_cond.pt\": \"e6b4fa599a7b3c3b16c222d5f5fe56f9db9289ff0b6575fbe5c11bc97106aad4\",\n    \"https://openaipublic.azureedge.net/main/shap-e/image_cond.pt\": \"cb8072c64bbbcf6910488814d212227de5db291780d4ea99c6152f9346cf12aa\",\n    \"https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml\": \"ffe1bcb405104a37d9408391182ab118a4ef313c391e07689684f1f62071605e\",\n    \"https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml\": \"e6d373649f8e24d85925f4674b9ac41c57aba5f60e42cde6d10f87381326365c\",\n    \"https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml\": \"f290beeea3d3e9ff15db01bde5382b6e549e463060c0744f89c049505be246c1\",\n    \"https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml\": \"4e0745605a533c543c72add803a78d233e2a6401e0abfa0cad58afb4d74ad0b0\",\n    \"https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml\": \"efcb2cd7ee545b2d27223979d41857802448143990572a42645cd09c2942ed57\",\n}\n\n\n@lru_cache()\ndef default_cache_dir() -> str:\n    return os.path.join(os.path.abspath(os.getcwd()), \"shap_e_model_cache\")\n\n\ndef fetch_file_cached(\n    url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096\n) -> str:\n    \"\"\"\n    Download the file at the given URL into a local file and return the path.\n    If cache_dir is specified, it will be used to download the files.\n    Otherwise, default_cache_dir() is used.\n    \"\"\"\n    expected_hash = URL_HASHES[url]\n\n    if cache_dir is None:\n        cache_dir = default_cache_dir()\n    os.makedirs(cache_dir, exist_ok=True)\n    local_path = os.path.join(cache_dir, url.split(\"/\")[-1])\n    if os.path.exists(local_path):\n        check_hash(local_path, expected_hash)\n        return local_path\n\n    response = requests.get(url, stream=True)\n    size = int(response.headers.get(\"content-length\", \"0\"))\n    with FileLock(local_path + \".lock\"):\n        if progress:\n            pbar = tqdm(total=size, unit=\"iB\", unit_scale=True)\n        tmp_path = local_path + \".tmp\"\n        with open(tmp_path, \"wb\") as f:\n            for chunk in response.iter_content(chunk_size):\n                if progress:\n                    pbar.update(len(chunk))\n                f.write(chunk)\n        os.rename(tmp_path, local_path)\n        if progress:\n            pbar.close()\n        check_hash(local_path, expected_hash)\n        return local_path\n\n\ndef check_hash(path: str, expected_hash: str):\n    actual_hash = hash_file(path)\n    if actual_hash != expected_hash:\n        raise RuntimeError(\n            f\"The file {path} should have hash {expected_hash} but has {actual_hash}. \"\n            \"Try deleting it and running this call again.\"\n        )\n\n\ndef hash_file(path: str) -> str:\n    sha256_hash = hashlib.sha256()\n    with open(path, \"rb\") as file:\n        while True:\n            data = file.read(4096)\n            if not len(data):\n                break\n            sha256_hash.update(data)\n    return sha256_hash.hexdigest()\n\n\ndef load_config(\n    config_name: str,\n    progress: bool = False,\n    cache_dir: Optional[str] = None,\n    chunk_size: int = 4096,\n):\n    if config_name not in CONFIG_PATHS:\n        raise ValueError(\n            f\"Unknown config name {config_name}. Known names are: {CONFIG_PATHS.keys()}.\"\n        )\n    path = fetch_file_cached(\n        CONFIG_PATHS[config_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size\n    )\n    with open(path, \"r\") as f:\n        return yaml.safe_load(f)\n\n\ndef load_checkpoint(\n    checkpoint_name: str,\n    device: torch.device,\n    progress: bool = True,\n    cache_dir: Optional[str] = None,\n    chunk_size: int = 4096,\n) -> Dict[str, torch.Tensor]:\n    if checkpoint_name not in MODEL_PATHS:\n        raise ValueError(\n            f\"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}.\"\n        )\n    path = fetch_file_cached(\n        MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size\n    )\n    return torch.load(path, map_location=device)\n\n\ndef load_model(\n    model_name: str,\n    device: torch.device,\n    **kwargs,\n) -> Dict[str, torch.Tensor]:\n    from .configs import model_from_config\n\n    model = model_from_config(load_config(model_name, **kwargs), device=device)\n    model.load_state_dict(load_checkpoint(model_name, device=device, **kwargs))\n    model.eval()\n    return model\n"
  },
  {
    "path": "shap_e/models/generation/__init__.py",
    "content": ""
  },
  {
    "path": "shap_e/models/generation/latent_diffusion.py",
    "content": "from typing import Any, Dict\n\nimport torch\nimport torch.nn as nn\n\n\nclass SplitVectorDiffusion(nn.Module):\n    def __init__(self, *, device: torch.device, wrapped: nn.Module, n_ctx: int, d_latent: int):\n        super().__init__()\n        self.device = device\n        self.n_ctx = n_ctx\n        self.d_latent = d_latent\n        self.wrapped = wrapped\n\n        if hasattr(self.wrapped, \"cached_model_kwargs\"):\n            self.cached_model_kwargs = self.wrapped.cached_model_kwargs\n\n    def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs):\n        h = x.reshape(x.shape[0], self.n_ctx, -1).permute(0, 2, 1)\n        pre_channels = h.shape[1]\n        h = self.wrapped(h, t, **kwargs)\n        assert (\n            h.shape[1] == pre_channels * 2\n        ), \"expected twice as many outputs for variance prediction\"\n        eps, var = torch.chunk(h, 2, dim=1)\n        return torch.cat(\n            [\n                eps.permute(0, 2, 1).flatten(1),\n                var.permute(0, 2, 1).flatten(1),\n            ],\n            dim=1,\n        )\n"
  },
  {
    "path": "shap_e/models/generation/perceiver.py",
    "content": "import math\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\n\nfrom shap_e.models.nn.checkpoint import checkpoint\n\nfrom .transformer import MLP, Transformer, init_linear\nfrom .util import timestep_embedding\n\n\nclass MultiheadCrossAttention(nn.Module):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        n_ctx: int,\n        n_data: int,\n        width: int,\n        heads: int,\n        init_scale: float,\n        data_width: Optional[int] = None,\n    ):\n        super().__init__()\n        self.n_ctx = n_ctx\n        self.n_data = n_data\n        self.width = width\n        self.heads = heads\n        self.data_width = width if data_width is None else data_width\n        self.c_q = nn.Linear(width, width, device=device, dtype=dtype)\n        self.c_kv = nn.Linear(self.data_width, width * 2, device=device, dtype=dtype)\n        self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)\n        self.attention = QKVMultiheadCrossAttention(\n            device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, n_data=n_data\n        )\n        init_linear(self.c_q, init_scale)\n        init_linear(self.c_kv, init_scale)\n        init_linear(self.c_proj, init_scale)\n\n    def forward(self, x, data):\n        x = self.c_q(x)\n        data = self.c_kv(data)\n        x = checkpoint(self.attention, (x, data), (), True)\n        x = self.c_proj(x)\n        return x\n\n\nclass QKVMultiheadCrossAttention(nn.Module):\n    def __init__(\n        self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, n_data: int\n    ):\n        super().__init__()\n        self.device = device\n        self.dtype = dtype\n        self.heads = heads\n        self.n_ctx = n_ctx\n        self.n_data = n_data\n\n    def forward(self, q, kv):\n        _, n_ctx, _ = q.shape\n        bs, n_data, width = kv.shape\n        attn_ch = width // self.heads // 2\n        scale = 1 / math.sqrt(math.sqrt(attn_ch))\n        q = q.view(bs, n_ctx, self.heads, -1)\n        kv = kv.view(bs, n_data, self.heads, -1)\n        k, v = torch.split(kv, attn_ch, dim=-1)\n        weight = torch.einsum(\n            \"bthc,bshc->bhts\", q * scale, k * scale\n        )  # More stable with f16 than dividing afterwards\n        wdtype = weight.dtype\n        weight = torch.softmax(weight.float(), dim=-1).type(wdtype)\n        return torch.einsum(\"bhts,bshc->bthc\", weight, v).reshape(bs, n_ctx, -1)\n\n\nclass ResidualCrossAttentionBlock(nn.Module):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        n_ctx: int,\n        n_data: int,\n        width: int,\n        heads: int,\n        data_width: Optional[int] = None,\n        init_scale: float = 1.0,\n    ):\n        super().__init__()\n\n        if data_width is None:\n            data_width = width\n\n        self.attn = MultiheadCrossAttention(\n            device=device,\n            dtype=dtype,\n            n_ctx=n_ctx,\n            n_data=n_data,\n            width=width,\n            heads=heads,\n            data_width=data_width,\n            init_scale=init_scale,\n        )\n        self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)\n        self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)\n        self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)\n\n    def forward(self, x: torch.Tensor, data: torch.Tensor):\n        x = x + self.attn(self.ln_1(x), self.ln_2(data))\n        x = x + self.mlp(self.ln_3(x))\n        return x\n\n\nclass SimplePerceiver(nn.Module):\n    \"\"\"\n    Only does cross attention\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        n_ctx: int,\n        n_data: int,\n        width: int,\n        layers: int,\n        heads: int,\n        init_scale: float = 0.25,\n        data_width: Optional[int] = None,\n    ):\n        super().__init__()\n        self.n_ctx = n_ctx\n        self.width = width\n        self.layers = layers\n        init_scale = init_scale * math.sqrt(1.0 / width)\n        self.resblocks = nn.ModuleList(\n            [\n                ResidualCrossAttentionBlock(\n                    device=device,\n                    dtype=dtype,\n                    n_ctx=n_ctx,\n                    n_data=n_data,\n                    width=width,\n                    heads=heads,\n                    init_scale=init_scale,\n                    data_width=data_width,\n                )\n                for _ in range(layers)\n            ]\n        )\n\n    def forward(self, x: torch.Tensor, data: torch.Tensor):\n        for block in self.resblocks:\n            x = block(x, data)\n        return x\n\n\nclass PointDiffusionPerceiver(nn.Module):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        input_channels: int = 3,\n        output_channels: int = 3,\n        n_ctx: int = 1024,\n        n_latent: int = 128,\n        width: int = 512,\n        encoder_layers: int = 12,\n        latent_layers: int = 12,\n        decoder_layers: int = 12,\n        heads: int = 8,\n        init_scale: float = 0.25,\n    ):\n        super().__init__()\n        self.time_embed = MLP(\n            device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)\n        )\n        self.latent_embed = MLP(\n            device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)\n        )\n        self.n_latent = n_latent\n\n        self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.encoder = SimplePerceiver(\n            device=device,\n            dtype=dtype,\n            n_ctx=n_latent,\n            n_data=n_ctx,\n            width=width,\n            layers=encoder_layers,\n            heads=heads,\n            init_scale=init_scale,\n        )\n        self.processor = Transformer(\n            device=device,\n            dtype=dtype,\n            n_ctx=n_latent,\n            width=width,\n            layers=latent_layers,\n            heads=heads,\n            init_scale=init_scale,\n        )\n        self.decoder = SimplePerceiver(\n            device=device,\n            dtype=dtype,\n            n_ctx=n_ctx,\n            n_data=n_latent,\n            width=width,\n            layers=decoder_layers,\n            heads=heads,\n            init_scale=init_scale,\n        )\n        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)\n        self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)\n        with torch.no_grad():\n            self.output_proj.weight.zero_()\n            self.output_proj.bias.zero_()\n\n    def forward(self, x: torch.Tensor, t: torch.Tensor):\n        \"\"\"\n        :param x: an [N x C x T] tensor.\n        :param t: an [N] tensor.\n        :return: an [N x C' x T] tensor.\n        \"\"\"\n        assert x.shape[-1] == self.decoder.n_ctx\n        t_embed = self.time_embed(timestep_embedding(t, self.encoder.width))\n        data = self.input_proj(x.permute(0, 2, 1)) + t_embed[:, None]\n        data = self.ln_pre(data)\n\n        l = torch.arange(self.n_latent).to(x.device)\n        h = self.latent_embed(timestep_embedding(l, self.decoder.width))\n        h = h.unsqueeze(0).repeat(x.shape[0], 1, 1)\n\n        h = self.encoder(h, data)\n        h = self.processor(h)\n        h = self.decoder(data, h)\n        h = self.ln_post(h)\n        h = self.output_proj(h)\n        return h.permute(0, 2, 1)\n"
  },
  {
    "path": "shap_e/models/generation/pooled_mlp.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom .util import timestep_embedding\n\n\nclass PooledMLP(nn.Module):\n    def __init__(\n        self,\n        device: torch.device,\n        *,\n        input_channels: int = 3,\n        output_channels: int = 6,\n        hidden_size: int = 256,\n        resblocks: int = 4,\n        pool_op: str = \"max\",\n    ):\n        super().__init__()\n        self.input_embed = nn.Conv1d(input_channels, hidden_size, kernel_size=1, device=device)\n        self.time_embed = nn.Linear(hidden_size, hidden_size, device=device)\n\n        blocks = []\n        for _ in range(resblocks):\n            blocks.append(ResBlock(hidden_size, pool_op, device=device))\n        self.sequence = nn.Sequential(*blocks)\n\n        self.out = nn.Conv1d(hidden_size, output_channels, kernel_size=1, device=device)\n        with torch.no_grad():\n            self.out.bias.zero_()\n            self.out.weight.zero_()\n\n    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n        in_embed = self.input_embed(x)\n        t_embed = self.time_embed(timestep_embedding(t, in_embed.shape[1]))\n        h = in_embed + t_embed[..., None]\n        h = self.sequence(h)\n        h = self.out(h)\n        return h\n\n\nclass ResBlock(nn.Module):\n    def __init__(self, hidden_size: int, pool_op: str, device: torch.device):\n        super().__init__()\n        assert pool_op in [\"mean\", \"max\"]\n        self.pool_op = pool_op\n        self.body = nn.Sequential(\n            nn.SiLU(),\n            nn.LayerNorm((hidden_size,), device=device),\n            nn.Linear(hidden_size, hidden_size, device=device),\n            nn.SiLU(),\n            nn.LayerNorm((hidden_size,), device=device),\n            nn.Linear(hidden_size, hidden_size, device=device),\n        )\n        self.gate = nn.Sequential(\n            nn.Linear(hidden_size, hidden_size, device=device),\n            nn.Tanh(),\n        )\n\n    def forward(self, x: torch.Tensor):\n        N, C, T = x.shape\n        out = self.body(x.permute(0, 2, 1).reshape(N * T, C)).reshape([N, T, C]).permute(0, 2, 1)\n        pooled = pool(self.pool_op, x)\n        gate = self.gate(pooled)\n        return x + out * gate[..., None]\n\n\ndef pool(op_name: str, x: torch.Tensor) -> torch.Tensor:\n    if op_name == \"max\":\n        pooled, _ = torch.max(x, dim=-1)\n    elif op_name == \"mean\":\n        pooled, _ = torch.mean(x, dim=-1)\n    else:\n        raise ValueError(f\"unknown pool op: {op_name}\")\n    return pooled\n"
  },
  {
    "path": "shap_e/models/generation/pretrained_clip.py",
    "content": "from typing import Iterable, List, Optional, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom PIL import Image\n\nfrom shap_e.models.download import default_cache_dir\n\nImageType = Union[np.ndarray, torch.Tensor, Image.Image]\n\n\nclass ImageCLIP(nn.Module):\n    \"\"\"\n    A wrapper around a pre-trained CLIP model that automatically handles\n    batches of texts, images, and embeddings.\n    \"\"\"\n\n    def __init__(\n        self,\n        device: torch.device,\n        dtype: Optional[torch.dtype] = torch.float32,\n        ensure_used_params: bool = True,\n        clip_name: str = \"ViT-L/14\",\n        cache_dir: Optional[str] = None,\n    ):\n        super().__init__()\n\n        assert clip_name in [\"ViT-L/14\", \"ViT-B/32\"]\n\n        self.device = device\n        self.ensure_used_params = ensure_used_params\n\n        # Lazy import because of torchvision.\n        import clip\n\n        self.clip_model, self.preprocess = clip.load(\n            clip_name, device=device, download_root=cache_dir or default_cache_dir()\n        )\n        self.clip_name = clip_name\n\n        if dtype is not None:\n            self.clip_model.to(dtype)\n        self._tokenize = clip.tokenize\n\n    @property\n    def feature_dim(self) -> int:\n        if self.clip_name == \"ViT-L/14\":\n            return 768\n        else:\n            return 512\n\n    @property\n    def grid_size(self) -> int:\n        if self.clip_name == \"ViT-L/14\":\n            return 16\n        else:\n            return 7\n\n    @property\n    def grid_feature_dim(self) -> int:\n        if self.clip_name == \"ViT-L/14\":\n            return 1024\n        else:\n            return 768\n\n    def forward(\n        self,\n        batch_size: int,\n        images: Optional[Iterable[Optional[ImageType]]] = None,\n        texts: Optional[Iterable[Optional[str]]] = None,\n        embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Generate a batch of embeddings from a mixture of images, texts,\n        precomputed embeddings, and possibly empty values.\n\n        For each batch element, at most one of images, texts, and embeddings\n        should have a non-None value. Embeddings from multiple modalities\n        cannot be mixed for a single batch element. If no modality is provided,\n        a zero embedding will be used for the batch element.\n        \"\"\"\n        image_seq = [None] * batch_size if images is None else list(images)\n        text_seq = [None] * batch_size if texts is None else list(texts)\n        embedding_seq = [None] * batch_size if embeddings is None else list(embeddings)\n        assert len(image_seq) == batch_size, \"number of images should match batch size\"\n        assert len(text_seq) == batch_size, \"number of texts should match batch size\"\n        assert len(embedding_seq) == batch_size, \"number of embeddings should match batch size\"\n\n        if self.ensure_used_params:\n            return self._static_multimodal_embed(\n                images=image_seq, texts=text_seq, embeddings=embedding_seq\n            )\n\n        result = torch.zeros((batch_size, self.feature_dim), device=self.device)\n        index_images = []\n        index_texts = []\n        for i, (image, text, emb) in enumerate(zip(image_seq, text_seq, embedding_seq)):\n            assert (\n                sum([int(image is not None), int(text is not None), int(emb is not None)]) < 2\n            ), \"only one modality may be non-None per batch element\"\n            if image is not None:\n                index_images.append((i, image))\n            elif text is not None:\n                index_texts.append((i, text))\n            elif emb is not None:\n                result[i] = emb.to(result)\n\n        if len(index_images):\n            embs = self.embed_images((img for _, img in index_images))\n            for (i, _), emb in zip(index_images, embs):\n                result[i] = emb.to(result)\n        if len(index_texts):\n            embs = self.embed_text((text for _, text in index_texts))\n            for (i, _), emb in zip(index_texts, embs):\n                result[i] = emb.to(result)\n\n        return result\n\n    def _static_multimodal_embed(\n        self,\n        images: List[Optional[ImageType]] = None,\n        texts: List[Optional[str]] = None,\n        embeddings: List[Optional[torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Like forward(), but always runs all encoders to ensure that\n        the forward graph looks the same on every rank.\n        \"\"\"\n        image_emb = self.embed_images(images)\n        text_emb = self.embed_text(t if t else \"\" for t in texts)\n        joined_embs = torch.stack(\n            [\n                emb.to(device=self.device, dtype=torch.float32)\n                if emb is not None\n                else torch.zeros(self.feature_dim, device=self.device)\n                for emb in embeddings\n            ],\n            dim=0,\n        )\n\n        image_flag = torch.tensor([x is not None for x in images], device=self.device)[\n            :, None\n        ].expand_as(image_emb)\n        text_flag = torch.tensor([x is not None for x in texts], device=self.device)[\n            :, None\n        ].expand_as(image_emb)\n        emb_flag = torch.tensor([x is not None for x in embeddings], device=self.device)[\n            :, None\n        ].expand_as(image_emb)\n\n        return (\n            image_flag.float() * image_emb\n            + text_flag.float() * text_emb\n            + emb_flag.float() * joined_embs\n            + self.clip_model.logit_scale * 0  # avoid unused parameters\n        )\n\n    def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:\n        \"\"\"\n        :param xs: N images, stored as numpy arrays, tensors, or PIL images.\n        :return: an [N x D] tensor of features.\n        \"\"\"\n        clip_inputs = self.images_to_tensor(xs)\n        results = self.clip_model.encode_image(clip_inputs).float()\n        return results / torch.linalg.norm(results, dim=-1, keepdim=True)\n\n    def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:\n        \"\"\"\n        Embed text prompts as an [N x D] tensor.\n        \"\"\"\n        enc = self.clip_model.encode_text(\n            self._tokenize(list(prompts), truncate=True).to(self.device)\n        ).float()\n        return enc / torch.linalg.norm(enc, dim=-1, keepdim=True)\n\n    def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:\n        \"\"\"\n        Embed images into latent grids.\n\n        :param xs: an iterable of images to embed.\n        :return: a tensor of shape [N x C x L], where L = self.grid_size**2.\n        \"\"\"\n        if self.ensure_used_params:\n            extra_value = 0.0\n            for p in self.parameters():\n                extra_value = extra_value + p.mean() * 0.0\n        else:\n            extra_value = 0.0\n\n        x = self.images_to_tensor(xs).to(self.clip_model.dtype)\n\n        # https://github.com/openai/CLIP/blob/4d120f3ec35b30bd0f992f5d8af2d793aad98d2a/clip/model.py#L225\n        vt = self.clip_model.visual\n        x = vt.conv1(x)  # shape = [*, width, grid, grid]\n        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]\n        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]\n        x = torch.cat(\n            [\n                vt.class_embedding.to(x.dtype)\n                + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),\n                x,\n            ],\n            dim=1,\n        )  # shape = [*, grid ** 2 + 1, width]\n        x = x + vt.positional_embedding.to(x.dtype)\n        x = vt.ln_pre(x)\n\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = vt.transformer(x)\n        x = x.permute(1, 2, 0)  # LND -> NDL\n\n        return x[..., 1:].contiguous().float() + extra_value\n\n    def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:\n        return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0).to(self.device)\n\n\nclass FrozenImageCLIP:\n    def __init__(self, device: torch.device, **kwargs):\n        self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **kwargs)\n        for parameter in self.model.parameters():\n            parameter.requires_grad_(False)\n\n    @property\n    def feature_dim(self) -> int:\n        return self.model.feature_dim\n\n    @property\n    def grid_size(self) -> int:\n        return self.model.grid_size\n\n    @property\n    def grid_feature_dim(self) -> int:\n        return self.model.grid_feature_dim\n\n    def __call__(\n        self,\n        batch_size: int,\n        images: Optional[Iterable[Optional[ImageType]]] = None,\n        texts: Optional[Iterable[Optional[str]]] = None,\n        embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,\n    ) -> torch.Tensor:\n        # We don't do a no_grad() here so that gradients could still\n        # flow to the input embeddings argument.\n        # This behavior is currently not used, but it could be.\n        return self.model(batch_size=batch_size, images=images, texts=texts, embeddings=embeddings)\n\n    def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:\n        with torch.no_grad():\n            return self.model.embed_images(xs)\n\n    def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:\n        with torch.no_grad():\n            return self.model.embed_text(prompts)\n\n    def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:\n        with torch.no_grad():\n            return self.model.embed_images_grid(xs)\n\n\ndef _image_to_pil(obj: Optional[ImageType]) -> Image.Image:\n    if obj is None:\n        return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8))\n    if isinstance(obj, np.ndarray):\n        return Image.fromarray(obj.astype(np.uint8))\n    elif isinstance(obj, torch.Tensor):\n        return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8))\n    else:\n        return obj\n"
  },
  {
    "path": "shap_e/models/generation/transformer.py",
    "content": "import math\nfrom typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple\n\nimport torch\nimport torch.nn as nn\n\nfrom shap_e.models.nn.checkpoint import checkpoint\n\nfrom .pretrained_clip import FrozenImageCLIP, ImageCLIP, ImageType\nfrom .util import timestep_embedding\n\n\ndef init_linear(l, stddev):\n    nn.init.normal_(l.weight, std=stddev)\n    if l.bias is not None:\n        nn.init.constant_(l.bias, 0.0)\n\n\nclass MultiheadAttention(nn.Module):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        n_ctx: int,\n        width: int,\n        heads: int,\n        init_scale: float,\n    ):\n        super().__init__()\n        self.n_ctx = n_ctx\n        self.width = width\n        self.heads = heads\n        self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype)\n        self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)\n        self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)\n        init_linear(self.c_qkv, init_scale)\n        init_linear(self.c_proj, init_scale)\n\n    def forward(self, x):\n        x = self.c_qkv(x)\n        x = checkpoint(self.attention, (x,), (), True)\n        x = self.c_proj(x)\n        return x\n\n\nclass MLP(nn.Module):\n    def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float):\n        super().__init__()\n        self.width = width\n        self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)\n        self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)\n        self.gelu = nn.GELU()\n        init_linear(self.c_fc, init_scale)\n        init_linear(self.c_proj, init_scale)\n\n    def forward(self, x):\n        return self.c_proj(self.gelu(self.c_fc(x)))\n\n\nclass QKVMultiheadAttention(nn.Module):\n    def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):\n        super().__init__()\n        self.device = device\n        self.dtype = dtype\n        self.heads = heads\n        self.n_ctx = n_ctx\n\n    def forward(self, qkv):\n        bs, n_ctx, width = qkv.shape\n        attn_ch = width // self.heads // 3\n        scale = 1 / math.sqrt(math.sqrt(attn_ch))\n        qkv = qkv.view(bs, n_ctx, self.heads, -1)\n        q, k, v = torch.split(qkv, attn_ch, dim=-1)\n        weight = torch.einsum(\n            \"bthc,bshc->bhts\", q * scale, k * scale\n        )  # More stable with f16 than dividing afterwards\n        wdtype = weight.dtype\n        weight = torch.softmax(weight.float(), dim=-1).type(wdtype)\n        return torch.einsum(\"bhts,bshc->bthc\", weight, v).reshape(bs, n_ctx, -1)\n\n\nclass ResidualAttentionBlock(nn.Module):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        n_ctx: int,\n        width: int,\n        heads: int,\n        init_scale: float = 1.0,\n    ):\n        super().__init__()\n\n        self.attn = MultiheadAttention(\n            device=device,\n            dtype=dtype,\n            n_ctx=n_ctx,\n            width=width,\n            heads=heads,\n            init_scale=init_scale,\n        )\n        self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)\n        self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)\n\n    def forward(self, x: torch.Tensor):\n        x = x + self.attn(self.ln_1(x))\n        x = x + self.mlp(self.ln_2(x))\n        return x\n\n\nclass Transformer(nn.Module):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        n_ctx: int,\n        width: int,\n        layers: int,\n        heads: int,\n        init_scale: float = 0.25,\n    ):\n        super().__init__()\n        self.n_ctx = n_ctx\n        self.width = width\n        self.layers = layers\n        init_scale = init_scale * math.sqrt(1.0 / width)\n        self.resblocks = nn.ModuleList(\n            [\n                ResidualAttentionBlock(\n                    device=device,\n                    dtype=dtype,\n                    n_ctx=n_ctx,\n                    width=width,\n                    heads=heads,\n                    init_scale=init_scale,\n                )\n                for _ in range(layers)\n            ]\n        )\n\n    def forward(self, x: torch.Tensor):\n        for block in self.resblocks:\n            x = block(x)\n        return x\n\n\nclass PointDiffusionTransformer(nn.Module):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        input_channels: int = 3,\n        output_channels: int = 3,\n        n_ctx: int = 1024,\n        width: int = 512,\n        layers: int = 12,\n        heads: int = 8,\n        init_scale: float = 0.25,\n        time_token_cond: bool = False,\n        use_pos_emb: bool = False,\n        pos_emb_init_scale: float = 1.0,\n        pos_emb_n_ctx: Optional[int] = None,\n    ):\n        super().__init__()\n        self.input_channels = input_channels\n        self.output_channels = output_channels\n        self.n_ctx = n_ctx\n        self.time_token_cond = time_token_cond\n        self.use_pos_emb = use_pos_emb\n        self.time_embed = MLP(\n            device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)\n        )\n        self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.backbone = Transformer(\n            device=device,\n            dtype=dtype,\n            n_ctx=n_ctx + int(time_token_cond),\n            width=width,\n            layers=layers,\n            heads=heads,\n            init_scale=init_scale,\n        )\n        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)\n        self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)\n        with torch.no_grad():\n            self.output_proj.weight.zero_()\n            self.output_proj.bias.zero_()\n        if self.use_pos_emb:\n            self.register_parameter(\n                \"pos_emb\",\n                nn.Parameter(\n                    pos_emb_init_scale\n                    * torch.randn(pos_emb_n_ctx or self.n_ctx, width, device=device, dtype=dtype)\n                ),\n            )\n\n    def forward(self, x: torch.Tensor, t: torch.Tensor):\n        \"\"\"\n        :param x: an [N x C x T] tensor.\n        :param t: an [N] tensor.\n        :return: an [N x C' x T] tensor.\n        \"\"\"\n        assert x.shape[-1] == self.n_ctx\n        t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))\n        return self._forward_with_cond(x, [(t_embed, self.time_token_cond)])\n\n    def _forward_with_cond(\n        self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]]\n    ) -> torch.Tensor:\n        h = self.input_proj(x.permute(0, 2, 1))  # NCL -> NLC\n        for emb, as_token in cond_as_token:\n            if not as_token:\n                h = h + emb[:, None]\n        if self.use_pos_emb:\n            h = h + self.pos_emb\n        extra_tokens = [\n            (emb[:, None] if len(emb.shape) == 2 else emb)\n            for emb, as_token in cond_as_token\n            if as_token\n        ]\n        if len(extra_tokens):\n            h = torch.cat(extra_tokens + [h], dim=1)\n\n        h = self.ln_pre(h)\n        h = self.backbone(h)\n        h = self.ln_post(h)\n        if len(extra_tokens):\n            h = h[:, sum(h.shape[1] for h in extra_tokens) :]\n        h = self.output_proj(h)\n        return h.permute(0, 2, 1)\n\n\nclass CLIPImagePointDiffusionTransformer(PointDiffusionTransformer):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        n_ctx: int = 1024,\n        token_cond: bool = False,\n        cond_drop_prob: float = 0.0,\n        frozen_clip: bool = True,\n        **kwargs,\n    ):\n        super().__init__(\n            device=device, dtype=dtype, n_ctx=n_ctx + int(token_cond), pos_emb_n_ctx=n_ctx, **kwargs\n        )\n        self.n_ctx = n_ctx\n        self.token_cond = token_cond\n        self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)\n        self.clip_embed = nn.Linear(\n            self.clip.feature_dim, self.backbone.width, device=device, dtype=dtype\n        )\n        self.cond_drop_prob = cond_drop_prob\n\n    def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:\n        with torch.no_grad():\n            return dict(embeddings=self.clip(batch_size, **model_kwargs))\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        t: torch.Tensor,\n        images: Optional[Iterable[Optional[ImageType]]] = None,\n        texts: Optional[Iterable[Optional[str]]] = None,\n        embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,\n    ):\n        \"\"\"\n        :param x: an [N x C x T] tensor.\n        :param t: an [N] tensor.\n        :param images: a batch of images to condition on.\n        :param texts: a batch of texts to condition on.\n        :param embeddings: a batch of CLIP embeddings to condition on.\n        :return: an [N x C' x T] tensor.\n        \"\"\"\n        assert x.shape[-1] == self.n_ctx\n\n        t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))\n        clip_out = self.clip(batch_size=len(x), images=images, texts=texts, embeddings=embeddings)\n        assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0]\n\n        if self.training:\n            mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob\n            clip_out = clip_out * mask[:, None].to(clip_out)\n\n        # Rescale the features to have unit variance\n        clip_out = math.sqrt(clip_out.shape[1]) * clip_out\n\n        clip_embed = self.clip_embed(clip_out)\n\n        cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)]\n        return self._forward_with_cond(x, cond)\n\n\nclass CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        n_ctx: int = 1024,\n        cond_drop_prob: float = 0.0,\n        frozen_clip: bool = True,\n        **kwargs,\n    ):\n        clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)\n        super().__init__(\n            device=device,\n            dtype=dtype,\n            n_ctx=n_ctx + clip.grid_size**2,\n            pos_emb_n_ctx=n_ctx,\n            **kwargs,\n        )\n        self.n_ctx = n_ctx\n        self.clip = clip\n        self.clip_embed = nn.Sequential(\n            nn.LayerNorm(\n                normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype\n            ),\n            nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),\n        )\n        self.cond_drop_prob = cond_drop_prob\n\n    def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:\n        _ = batch_size\n        with torch.no_grad():\n            return dict(embeddings=self.clip.embed_images_grid(model_kwargs[\"images\"]))\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        t: torch.Tensor,\n        images: Optional[Iterable[ImageType]] = None,\n        embeddings: Optional[Iterable[torch.Tensor]] = None,\n    ):\n        \"\"\"\n        :param x: an [N x C x T] tensor.\n        :param t: an [N] tensor.\n        :param images: a batch of images to condition on.\n        :param embeddings: a batch of CLIP latent grids to condition on.\n        :return: an [N x C' x T] tensor.\n        \"\"\"\n        assert images is not None or embeddings is not None, \"must specify images or embeddings\"\n        assert images is None or embeddings is None, \"cannot specify both images and embeddings\"\n        assert x.shape[-1] == self.n_ctx\n\n        t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))\n\n        if images is not None:\n            clip_out = self.clip.embed_images_grid(images)\n        else:\n            clip_out = embeddings\n\n        if self.training:\n            mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob\n            clip_out = clip_out * mask[:, None, None].to(clip_out)\n\n        clip_out = clip_out.permute(0, 2, 1)  # NCL -> NLC\n        clip_embed = self.clip_embed(clip_out)\n\n        cond = [(t_embed, self.time_token_cond), (clip_embed, True)]\n        return self._forward_with_cond(x, cond)\n\n\nclass UpsamplePointDiffusionTransformer(PointDiffusionTransformer):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        cond_input_channels: Optional[int] = None,\n        cond_ctx: int = 1024,\n        n_ctx: int = 4096 - 1024,\n        channel_scales: Optional[Sequence[float]] = None,\n        channel_biases: Optional[Sequence[float]] = None,\n        **kwargs,\n    ):\n        super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **kwargs)\n        self.n_ctx = n_ctx\n        self.cond_input_channels = cond_input_channels or self.input_channels\n        self.cond_point_proj = nn.Linear(\n            self.cond_input_channels, self.backbone.width, device=device, dtype=dtype\n        )\n\n        self.register_buffer(\n            \"channel_scales\",\n            torch.tensor(channel_scales, dtype=dtype, device=device)\n            if channel_scales is not None\n            else None,\n        )\n        self.register_buffer(\n            \"channel_biases\",\n            torch.tensor(channel_biases, dtype=dtype, device=device)\n            if channel_biases is not None\n            else None,\n        )\n\n    def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor):\n        \"\"\"\n        :param x: an [N x C1 x T] tensor.\n        :param t: an [N] tensor.\n        :param low_res: an [N x C2 x T'] tensor of conditioning points.\n        :return: an [N x C3 x T] tensor.\n        \"\"\"\n        assert x.shape[-1] == self.n_ctx\n        t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))\n        low_res_embed = self._embed_low_res(low_res)\n        cond = [(t_embed, self.time_token_cond), (low_res_embed, True)]\n        return self._forward_with_cond(x, cond)\n\n    def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor:\n        if self.channel_scales is not None:\n            x = x * self.channel_scales[None, :, None]\n        if self.channel_biases is not None:\n            x = x + self.channel_biases[None, :, None]\n        return self.cond_point_proj(x.permute(0, 2, 1))\n\n\nclass CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffusionTransformer):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        n_ctx: int = 4096 - 1024,\n        cond_drop_prob: float = 0.0,\n        frozen_clip: bool = True,\n        **kwargs,\n    ):\n        clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)\n        super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs)\n        self.n_ctx = n_ctx\n\n        self.clip = clip\n        self.clip_embed = nn.Sequential(\n            nn.LayerNorm(\n                normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype\n            ),\n            nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),\n        )\n        self.cond_drop_prob = cond_drop_prob\n\n    def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:\n        _ = batch_size\n        with torch.no_grad():\n            return dict(\n                embeddings=self.clip.embed_images_grid(model_kwargs[\"images\"]),\n                low_res=model_kwargs[\"low_res\"],\n            )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        t: torch.Tensor,\n        *,\n        low_res: torch.Tensor,\n        images: Optional[Iterable[ImageType]] = None,\n        embeddings: Optional[Iterable[torch.Tensor]] = None,\n    ):\n        \"\"\"\n        :param x: an [N x C1 x T] tensor.\n        :param t: an [N] tensor.\n        :param low_res: an [N x C2 x T'] tensor of conditioning points.\n        :param images: a batch of images to condition on.\n        :param embeddings: a batch of CLIP latent grids to condition on.\n        :return: an [N x C3 x T] tensor.\n        \"\"\"\n        assert x.shape[-1] == self.n_ctx\n        t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))\n        low_res_embed = self._embed_low_res(low_res)\n\n        if images is not None:\n            clip_out = self.clip.embed_images_grid(images)\n        else:\n            clip_out = embeddings\n\n        if self.training:\n            mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob\n            clip_out = clip_out * mask[:, None, None].to(clip_out)\n\n        clip_out = clip_out.permute(0, 2, 1)  # NCL -> NLC\n        clip_embed = self.clip_embed(clip_out)\n\n        cond = [(t_embed, self.time_token_cond), (clip_embed, True), (low_res_embed, True)]\n        return self._forward_with_cond(x, cond)\n"
  },
  {
    "path": "shap_e/models/generation/util.py",
    "content": "import math\n\nimport torch\n\n\ndef timestep_embedding(timesteps, dim, max_period=10000):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    half = dim // 2\n    freqs = torch.exp(\n        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n    ).to(device=timesteps.device)\n    args = timesteps[:, None].to(timesteps.dtype) * freqs[None]\n    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n    if dim % 2:\n        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n    return embedding\n"
  },
  {
    "path": "shap_e/models/nerf/__init__.py",
    "content": ""
  },
  {
    "path": "shap_e/models/nerf/model.py",
    "content": "from abc import ABC, abstractmethod\nfrom functools import partial\nfrom typing import Any, Dict, Optional, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom shap_e.models.nn.checkpoint import checkpoint\nfrom shap_e.models.nn.encoding import encode_position, spherical_harmonics_basis\nfrom shap_e.models.nn.meta import MetaModule, subdict\nfrom shap_e.models.nn.ops import MLP, MetaMLP, get_act, mlp_init, zero_init\nfrom shap_e.models.nn.utils import ArrayType\nfrom shap_e.models.query import Query\nfrom shap_e.util.collections import AttrDict\n\n\nclass NeRFModel(ABC):\n    \"\"\"\n    Parametric scene representation whose outputs are integrated by NeRFRenderer\n    \"\"\"\n\n    @abstractmethod\n    def forward(\n        self,\n        query: Query,\n        params: Optional[Dict[str, torch.Tensor]] = None,\n        options: Optional[Dict[str, Any]] = None,\n    ) -> AttrDict:\n        \"\"\"\n        :param query: the points in the field to query.\n        :param params: Meta parameters\n        :param options: Optional hyperparameters\n        :return: An AttrDict containing at least\n            - density: [batch_size x ... x 1]\n            - channels: [batch_size x ... x n_channels]\n            - aux_losses: [batch_size x ... x 1]\n        \"\"\"\n\n\nclass VoidNeRFModel(MetaModule, NeRFModel):\n    \"\"\"\n    Implements the default empty space model where all queries are rendered as\n    background.\n    \"\"\"\n\n    def __init__(\n        self,\n        background: ArrayType,\n        trainable: bool = False,\n        channel_scale: float = 255.0,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        super().__init__()\n        background = nn.Parameter(\n            torch.from_numpy(np.array(background)).to(dtype=torch.float32, device=device)\n            / channel_scale\n        )\n        if trainable:\n            self.register_parameter(\"background\", background)\n        else:\n            self.register_buffer(\"background\", background)\n\n    def forward(\n        self,\n        query: Query,\n        params: Optional[Dict[str, torch.Tensor]] = None,\n        options: Optional[Dict[str, Any]] = None,\n    ) -> AttrDict:\n        _ = params\n        default_bg = self.background[None]\n        background = options.get(\"background\", default_bg) if options is not None else default_bg\n\n        shape = query.position.shape[:-1]\n        ones = [1] * (len(shape) - 1)\n        n_channels = background.shape[-1]\n        background = torch.broadcast_to(\n            background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]\n        )\n        return background\n\n\nclass MLPNeRFModel(MetaModule, NeRFModel):\n    def __init__(\n        self,\n        # Positional encoding parameters\n        n_levels: int = 10,\n        # MLP parameters\n        d_hidden: int = 256,\n        n_density_layers: int = 4,\n        n_channel_layers: int = 1,\n        n_channels: int = 3,\n        sh_degree: int = 4,\n        activation: str = \"relu\",\n        density_activation: str = \"exp\",\n        init: Optional[str] = None,\n        init_scale: float = 1.0,\n        output_activation: str = \"sigmoid\",\n        meta_parameters: bool = False,\n        trainable_meta: bool = False,\n        zero_out: bool = True,\n        register_freqs: bool = True,\n        posenc_version: str = \"v1\",\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        super().__init__()\n\n        # Positional encoding\n        if register_freqs:\n            # not used anymore\n            self.register_buffer(\n                \"freqs\",\n                2.0 ** torch.arange(n_levels, device=device, dtype=torch.float).view(1, n_levels),\n            )\n\n        self.posenc_version = posenc_version\n        dummy = torch.eye(1, 3)\n        d_input = encode_position(posenc_version, position=dummy).shape[-1]\n\n        self.n_levels = n_levels\n\n        self.sh_degree = sh_degree\n        d_sh_coeffs = sh_degree**2\n\n        self.meta_parameters = meta_parameters\n\n        mlp_cls = (\n            partial(\n                MetaMLP,\n                meta_scale=False,\n                meta_shift=False,\n                meta_proj=True,\n                meta_bias=True,\n                trainable_meta=trainable_meta,\n            )\n            if meta_parameters\n            else MLP\n        )\n\n        self.density_mlp = mlp_cls(\n            d_input=d_input,\n            d_hidden=[d_hidden] * (n_density_layers - 1),\n            d_output=d_hidden,\n            act_name=activation,\n            init_scale=init_scale,\n        )\n\n        self.channel_mlp = mlp_cls(\n            d_input=d_hidden + d_sh_coeffs,\n            d_hidden=[d_hidden] * n_channel_layers,\n            d_output=n_channels,\n            act_name=activation,\n            init_scale=init_scale,\n        )\n\n        self.act = get_act(output_activation)\n        self.density_act = get_act(density_activation)\n\n        mlp_init(\n            list(self.density_mlp.affines) + list(self.channel_mlp.affines),\n            init=init,\n            init_scale=init_scale,\n        )\n\n        if zero_out:\n            zero_init(self.channel_mlp.affines[-1])\n\n        self.to(device)\n\n    def encode_position(self, query: Query):\n        h = encode_position(self.posenc_version, position=query.position)\n        return h\n\n    def forward(\n        self,\n        query: Query,\n        params: Optional[Dict[str, torch.Tensor]] = None,\n        options: Optional[Dict[str, Any]] = None,\n    ) -> AttrDict:\n        params = self.update(params)\n\n        options = AttrDict() if options is None else AttrDict(options)\n\n        query = query.copy()\n\n        h_position = self.encode_position(query)\n\n        if self.meta_parameters:\n            density_params = subdict(params, \"density_mlp\")\n            density_mlp = partial(\n                self.density_mlp, params=density_params, options=options, log_prefix=\"density_\"\n            )\n            density_mlp_parameters = list(density_params.values())\n        else:\n            density_mlp = partial(self.density_mlp, options=options, log_prefix=\"density_\")\n            density_mlp_parameters = self.density_mlp.parameters()\n        h_density = checkpoint(\n            density_mlp,\n            (h_position,),\n            density_mlp_parameters,\n            options.checkpoint_nerf_mlp,\n        )\n        h_direction = maybe_get_spherical_harmonics_basis(\n            sh_degree=self.sh_degree,\n            coords_shape=query.position.shape,\n            coords=query.direction,\n            device=query.position.device,\n        )\n\n        if self.meta_parameters:\n            channel_params = subdict(params, \"channel_mlp\")\n            channel_mlp = partial(\n                self.channel_mlp, params=channel_params, options=options, log_prefix=\"channel_\"\n            )\n            channel_mlp_parameters = list(channel_params.values())\n        else:\n            channel_mlp = partial(self.channel_mlp, options=options, log_prefix=\"channel_\")\n            channel_mlp_parameters = self.channel_mlp.parameters()\n        h_channel = checkpoint(\n            channel_mlp,\n            (torch.cat([h_density, h_direction], dim=-1),),\n            channel_mlp_parameters,\n            options.checkpoint_nerf_mlp,\n        )\n\n        density_logit = h_density[..., :1]\n\n        res = AttrDict(\n            density_logit=density_logit,\n            density=self.density_act(density_logit),\n            channels=self.act(h_channel),\n            aux_losses=AttrDict(),\n            no_weight_grad_aux_losses=AttrDict(),\n        )\n        if options.return_h_density:\n            res.h_density = h_density\n\n        return res\n\n\ndef maybe_get_spherical_harmonics_basis(\n    sh_degree: int,\n    coords_shape: Tuple[int],\n    coords: Optional[torch.Tensor] = None,\n    device: torch.device = torch.device(\"cuda\"),\n) -> torch.Tensor:\n    \"\"\"\n    :param sh_degree: Spherical harmonics degree\n    :param coords_shape: [*shape, 3]\n    :param coords: optional coordinate tensor of coords_shape\n    \"\"\"\n    if coords is None:\n        return torch.zeros(*coords_shape[:-1], sh_degree**2).to(device)\n\n    return spherical_harmonics_basis(coords, sh_degree)\n"
  },
  {
    "path": "shap_e/models/nerf/ray.py",
    "content": "from abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport torch\n\nfrom shap_e.models.nn.utils import sample_pmf\nfrom shap_e.models.volume import Volume, VolumeRange\nfrom shap_e.util.collections import AttrDict\n\nfrom .model import NeRFModel, Query\n\n\ndef render_rays(\n    rays: torch.Tensor,\n    parts: List[\"RayVolumeIntegral\"],\n    void_model: NeRFModel,\n    shared: bool = False,\n    prev_raw_outputs: Optional[List[AttrDict]] = None,\n    render_with_direction: bool = True,\n    importance_sampling_options: Optional[Dict[str, Any]] = None,\n) -> Tuple[\"RayVolumeIntegralResults\", List[\"RaySampler\"], List[AttrDict]]:\n    \"\"\"\n    Perform volumetric rendering over a partition of possible t's in the union\n    of rendering volumes (written below with some abuse of notations)\n\n        C(r) := sum(\n            transmittance(t[i]) *\n            integrate(\n                lambda t: density(t) * channels(t) * transmittance(t),\n                [t[i], t[i + 1]],\n            )\n            for i in range(len(parts))\n        ) + transmittance(t[-1]) * void_model(t[-1]).channels\n\n    where\n\n    1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the\n       probability of light passing through the volume specified by [t[0], s].\n       (transmittance of 1 means light can pass freely)\n    2) density and channels are obtained by evaluating the appropriate\n       part.model at time t.\n    3) [t[i], t[i + 1]] is defined as the range of t where the ray intersects\n       (parts[i].volume \\\\ union(part.volume for part in parts[:i])) at the surface\n       of the shell (if bounded). If the ray does not intersect, the integral over\n       this segment is evaluated as 0 and transmittance(t[i + 1]) :=\n       transmittance(t[i]).\n    4) The last term is integration to infinity (e.g. [t[-1], math.inf]) that\n       is evaluated by the void_model (i.e. we consider this space to be empty).\n\n    :param rays: [batch_size x ... x 2 x 3] origin and direction.\n    :param parts: disjoint volume integrals.\n    :param void_model: use this model to integrate over the empty space\n    :param shared: All RayVolumeIntegrals are calculated with the same model.\n    :param prev_raw_outputs: Raw outputs from the previous rendering step\n\n    :return: A tuple of\n        - AttrDict containing the rendered `channels`, `distances`, and the `aux_losses`\n        - A list of importance samplers for additional fine-grained rendering\n        - A list of raw output for each interval\n    \"\"\"\n    if importance_sampling_options is None:\n        importance_sampling_options = {}\n\n    origin, direc = rays[..., 0, :], rays[..., 1, :]\n\n    if prev_raw_outputs is None:\n        prev_raw_outputs = [None] * len(parts)\n\n    samplers = []\n    raw_outputs = []\n    t0 = None\n    results = None\n\n    for part_i, prev_raw_i in zip(parts, prev_raw_outputs):\n\n        # Integrate over [t[i], t[i + 1]]\n        results_i = part_i.render_rays(\n            origin,\n            direc,\n            t0=t0,\n            prev_raw=prev_raw_i,\n            shared=shared,\n            render_with_direction=render_with_direction,\n        )\n\n        # Create an importance sampler for (optional) fine rendering\n        samplers.append(\n            ImportanceRaySampler(\n                results_i.volume_range, results_i.raw, **importance_sampling_options\n            )\n        )\n        raw_outputs.append(results_i.raw)\n\n        # Pass t[i + 1] as the start of integration for the next interval.\n        t0 = results_i.volume_range.next_t0()\n\n        # Combine the results from [t[0], t[i]] and [t[i], t[i+1]]\n        results = results_i if results is None else results.combine(results_i)\n\n    # While integrating out [t[-1], math.inf] is the correct thing to do, this\n    # erases a lot of useful information. Also, void_model is meant to predict\n    # the channels at t=math.inf.\n\n    # # Add the void background over [t[-1], math.inf] to complete integration.\n    # results = results.combine(\n    #     RayVolumeIntegralResults(\n    #         output=AttrDict(\n    #             channels=void_model(origin, direc),\n    #             distances=torch.zeros_like(t0),\n    #             aux_losses=AttrDict(),\n    #         ),\n    #         volume_range=VolumeRange(\n    #             t0=t0,\n    #             t1=torch.full_like(t0, math.inf),\n    #             intersected=torch.full_like(results.volume_range.intersected, True),\n    #         ),\n    #         # Void space extends to infinity. It is assumed that no light\n    #         # passes beyond the void.\n    #         transmittance=torch.zeros_like(results_i.transmittance),\n    #     )\n    # )\n\n    results.output.channels = results.output.channels + results.transmittance * void_model(\n        Query(origin, direc)\n    )\n\n    return results, samplers, raw_outputs\n\n\n@dataclass\nclass RayVolumeIntegralResults:\n    \"\"\"\n    Stores the relevant state and results of\n\n        integrate(\n            lambda t: density(t) * channels(t) * transmittance(t),\n            [t0, t1],\n        )\n    \"\"\"\n\n    # Rendered output and auxiliary losses\n    # output.channels has shape [batch_size, *inner_shape, n_channels]\n    output: AttrDict\n\n    \"\"\"\n    Optional values\n    \"\"\"\n\n    # Raw values contain the sampled `ts`, `density`, `channels`, etc.\n    raw: Optional[AttrDict] = None\n\n    # Integration\n    volume_range: Optional[VolumeRange] = None\n\n    # If a ray intersects, the transmittance from t0 to t1 (e.g. the\n    # probability that the ray passes through this volume).\n    # has shape [batch_size, *inner_shape, 1]\n    transmittance: Optional[torch.Tensor] = None\n\n    def combine(self, cur: \"RayVolumeIntegralResults\") -> \"RayVolumeIntegralResults\":\n        \"\"\"\n        Combines the integration results of `self` over [t0, t1] and\n        `cur` over [t1, t2] to produce a new set of results over [t0, t2] by\n        using a similar equation to (4) in NeRF++:\n\n            integrate(\n                lambda t: density(t) * channels(t) * transmittance(t),\n                [t0, t2]\n            )\n\n          = integrate(\n                lambda t: density(t) * channels(t) * transmittance(t),\n                [t0, t1]\n            ) + transmittance(t1) * integrate(\n                lambda t: density(t) * channels(t) * transmittance(t),\n                [t1, t2]\n            )\n        \"\"\"\n        assert torch.allclose(self.volume_range.next_t0(), cur.volume_range.t0)\n\n        def _combine_fn(\n            prev_val: Optional[torch.Tensor],\n            cur_val: Optional[torch.Tensor],\n            *,\n            prev_transmittance: torch.Tensor,\n        ):\n            assert prev_val is not None\n            if cur_val is None:\n                # cur_output.aux_losses are empty for the void_model.\n                return prev_val\n            return prev_val + prev_transmittance * cur_val\n\n        output = self.output.combine(\n            cur.output, combine_fn=partial(_combine_fn, prev_transmittance=self.transmittance)\n        )\n\n        combined = RayVolumeIntegralResults(\n            output=output,\n            volume_range=self.volume_range.extend(cur.volume_range),\n            transmittance=self.transmittance * cur.transmittance,\n        )\n        return combined\n\n\n@dataclass\nclass RayVolumeIntegral:\n    model: NeRFModel\n    volume: Volume\n    sampler: \"RaySampler\"\n    n_samples: int\n\n    def render_rays(\n        self,\n        origin: torch.Tensor,\n        direction: torch.Tensor,\n        t0: Optional[torch.Tensor] = None,\n        prev_raw: Optional[AttrDict] = None,\n        shared: bool = False,\n        render_with_direction: bool = True,\n    ) -> \"RayVolumeIntegralResults\":\n        \"\"\"\n        Perform volumetric rendering over the given volume.\n\n        :param position: [batch_size, *shape, 3]\n        :param direction: [batch_size, *shape, 3]\n        :param t0: Optional [batch_size, *shape, 1]\n        :param prev_raw: the raw outputs when using multiple levels with this model.\n        :param shared: means the same model is used for all RayVolumeIntegral's\n        :param render_with_direction: use the incoming ray direction when querying the model.\n\n        :return: RayVolumeIntegralResults\n        \"\"\"\n        # 1. Intersect the rays with the current volume and sample ts to\n        # integrate along.\n        vrange = self.volume.intersect(origin, direction, t0_lower=t0)\n        ts = self.sampler.sample(vrange.t0, vrange.t1, self.n_samples)\n\n        if prev_raw is not None and not shared:\n            # Append the previous ts now before fprop because previous\n            # rendering used a different model and we can't reuse the output.\n            ts = torch.sort(torch.cat([ts, prev_raw.ts], dim=-2), dim=-2).values\n\n        # Shape sanity checks\n        batch_size, *_shape, _t0_dim = vrange.t0.shape\n        _, *ts_shape, _ts_dim = ts.shape\n\n        # 2. Get the points along the ray and query the model\n        directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3])\n        positions = origin.unsqueeze(-2) + ts * directions\n\n        optional_directions = directions if render_with_direction else None\n        mids = (ts[..., 1:, :] + ts[..., :-1, :]) / 2\n        raw = self.model(\n            Query(\n                position=positions,\n                direction=optional_directions,\n                t_min=torch.cat([vrange.t0[..., None, :], mids], dim=-2),\n                t_max=torch.cat([mids, vrange.t1[..., None, :]], dim=-2),\n            )\n        )\n        raw.ts = ts\n\n        if prev_raw is not None and shared:\n            # We can append the additional queries to previous raw outputs\n            # before integration\n            copy = prev_raw.copy()\n            result = torch.sort(torch.cat([raw.pop(\"ts\"), copy.pop(\"ts\")], dim=-2), dim=-2)\n            merge_results = partial(self._merge_results, dim=-2, indices=result.indices)\n            raw = raw.combine(copy, merge_results)\n            raw.ts = result.values\n\n        # 3. Integrate the raw results\n        output, transmittance = self.integrate_samples(vrange, raw)\n\n        # 4. Clean up results that do not intersect with the volume.\n        transmittance = torch.where(\n            vrange.intersected, transmittance, torch.ones_like(transmittance)\n        )\n\n        def _mask_fn(_key: str, tensor: torch.Tensor):\n            return torch.where(vrange.intersected, tensor, torch.zeros_like(tensor))\n\n        def _is_tensor(_key: str, value: Any):\n            return isinstance(value, torch.Tensor)\n\n        output = output.map(map_fn=_mask_fn, should_map=_is_tensor)\n\n        return RayVolumeIntegralResults(\n            output=output,\n            raw=raw,\n            volume_range=vrange,\n            transmittance=transmittance,\n        )\n\n    def integrate_samples(\n        self,\n        volume_range: VolumeRange,\n        raw: AttrDict,\n    ) -> Tuple[AttrDict, torch.Tensor]:\n        \"\"\"\n        Integrate the raw.channels along with other aux_losses and values to\n        produce the final output dictionary containing rendered `channels`,\n        estimated `distances` and `aux_losses`.\n\n        :param volume_range: Specifies the integral range [t0, t1]\n        :param raw: Contains a dict of function evaluations at ts. Should have\n\n            density: torch.Tensor [batch_size, *shape, n_samples, 1]\n            channels: torch.Tensor [batch_size, *shape, n_samples, n_channels]\n            aux_losses: {key: torch.Tensor [batch_size, *shape, n_samples, 1] for each key}\n            no_weight_grad_aux_losses: an optional set of losses for which the weights\n                                       should be detached before integration.\n\n            after the call, integrate_samples populates some intermediate calculations\n            for later use like\n\n            weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density *\n                transmittance)[i] weight for each rgb output at [..., i, :].\n        :returns: a tuple of (\n            a dictionary of rendered outputs and aux_losses,\n            transmittance of this volume,\n        )\n        \"\"\"\n\n        # 1. Calculate the weights\n        _, _, dt = volume_range.partition(raw.ts)\n        ddensity = raw.density * dt\n\n        mass = torch.cumsum(ddensity, dim=-2)\n        transmittance = torch.exp(-mass[..., -1, :])\n\n        alphas = 1.0 - torch.exp(-ddensity)\n        Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2))\n        # This is the probability of light hitting and reflecting off of\n        # something at depth [..., i, :].\n        weights = alphas * Ts\n\n        # 2. Integrate all results\n        def _integrate(key: str, samples: torch.Tensor, weights: torch.Tensor):\n            if key == \"density\":\n                # Omit integrating the density, because we don't need it\n                return None\n            return torch.sum(samples * weights, dim=-2)\n\n        def _is_tensor(_key: str, value: Any):\n            return isinstance(value, torch.Tensor)\n\n        if raw.no_weight_grad_aux_losses:\n            extra_aux_losses = raw.no_weight_grad_aux_losses.map(\n                partial(_integrate, weights=weights.detach()), should_map=_is_tensor\n            )\n        else:\n            extra_aux_losses = {}\n        output = raw.map(partial(_integrate, weights=weights), should_map=_is_tensor)\n        if \"no_weight_grad_aux_losses\" in output:\n            del output[\"no_weight_grad_aux_losses\"]\n        output.aux_losses.update(extra_aux_losses)\n\n        # Integrating the ts yields the distance away from the origin; rename the variable.\n        output.distances = output.ts\n        del output[\"ts\"]\n        del output[\"density\"]\n\n        assert output.distances.shape == (*output.channels.shape[:-1], 1)\n        assert output.channels.shape[:-1] == raw.channels.shape[:-2]\n        assert output.channels.shape[-1] == raw.channels.shape[-1]\n\n        # 3. Reduce loss\n        def _reduce_loss(_key: str, loss: torch.Tensor):\n            return loss.view(loss.shape[0], -1).sum(dim=-1)\n\n        # 4. Store other useful calculations\n        raw.weights = weights\n\n        output.aux_losses = output.aux_losses.map(_reduce_loss)\n\n        return output, transmittance\n\n    def _merge_results(\n        self, a: Optional[torch.Tensor], b: torch.Tensor, dim: int, indices: torch.Tensor\n    ):\n        \"\"\"\n        :param a: [..., n_a, ...]. The other dictionary containing the b's may\n            contain extra tensors from earlier calculations, so a can be None.\n        :param b: [..., n_b, ...]\n        :param dim: dimension to merge\n        :param indices: how the merged results should be sorted at the end\n        :return: a concatted and sorted tensor of size [..., n_a + n_b, ...]\n        \"\"\"\n        if a is None:\n            return None\n\n        merged = torch.cat([a, b], dim=dim)\n        return torch.gather(merged, dim=dim, index=torch.broadcast_to(indices, merged.shape))\n\n\nclass RaySampler(ABC):\n    @abstractmethod\n    def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:\n        \"\"\"\n        :param t0: start time has shape [batch_size, *shape, 1]\n        :param t1: finish time has shape [batch_size, *shape, 1]\n        :param n_samples: number of ts to sample\n        :return: sampled ts of shape [batch_size, *shape, n_samples, 1]\n        \"\"\"\n\n\nclass StratifiedRaySampler(RaySampler):\n    \"\"\"\n    Instead of fixed intervals, a sample is drawn uniformly at random from each\n    interval.\n    \"\"\"\n\n    def __init__(self, depth_mode: str = \"linear\"):\n        \"\"\"\n        :param depth_mode: linear samples ts linearly in depth. harmonic ensures\n            closer points are sampled more densely.\n        \"\"\"\n        self.depth_mode = depth_mode\n        assert self.depth_mode in (\"linear\", \"geometric\", \"harmonic\")\n\n    def sample(\n        self,\n        t0: torch.Tensor,\n        t1: torch.Tensor,\n        n_samples: int,\n        epsilon: float = 1e-3,\n    ) -> torch.Tensor:\n        \"\"\"\n        :param t0: start time has shape [batch_size, *shape, 1]\n        :param t1: finish time has shape [batch_size, *shape, 1]\n        :param n_samples: number of ts to sample\n        :return: sampled ts of shape [batch_size, *shape, n_samples, 1]\n        \"\"\"\n        ones = [1] * (len(t0.shape) - 1)\n        ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device)\n\n        if self.depth_mode == \"linear\":\n            ts = t0 * (1.0 - ts) + t1 * ts\n        elif self.depth_mode == \"geometric\":\n            ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp()\n        elif self.depth_mode == \"harmonic\":\n            # The original NeRF recommends this interpolation scheme for\n            # spherical scenes, but there could be some weird edge cases when\n            # the observer crosses from the inner to outer volume.\n            ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts)\n\n        mids = 0.5 * (ts[..., 1:] + ts[..., :-1])\n        upper = torch.cat([mids, t1], dim=-1)\n        lower = torch.cat([t0, mids], dim=-1)\n        t_rand = torch.rand_like(ts)\n\n        ts = lower + (upper - lower) * t_rand\n        return ts.unsqueeze(-1)\n\n\nclass ImportanceRaySampler(RaySampler):\n    \"\"\"\n    Given the initial estimate of densities, this samples more from\n    regions/bins expected to have objects.\n    \"\"\"\n\n    def __init__(\n        self, volume_range: VolumeRange, raw: AttrDict, blur_pool: bool = False, alpha: float = 1e-5\n    ):\n        \"\"\"\n        :param volume_range: the range in which a ray intersects the given volume.\n        :param raw: dictionary of raw outputs from the NeRF models of shape\n            [batch_size, *shape, n_coarse_samples, 1]. Should at least contain\n\n            :param ts: earlier samples from the coarse rendering step\n            :param weights: discretized version of density * transmittance\n        :param blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF.\n        :param alpha: small value to add to weights.\n        \"\"\"\n        self.volume_range = volume_range\n        self.ts = raw.ts.clone().detach()\n        self.weights = raw.weights.clone().detach()\n        self.blur_pool = blur_pool\n        self.alpha = alpha\n\n    @torch.no_grad()\n    def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:\n        \"\"\"\n        :param t0: start time has shape [batch_size, *shape, 1]\n        :param t1: finish time has shape [batch_size, *shape, 1]\n        :param n_samples: number of ts to sample\n        :return: sampled ts of shape [batch_size, *shape, n_samples, 1]\n        \"\"\"\n        lower, upper, _ = self.volume_range.partition(self.ts)\n\n        batch_size, *shape, n_coarse_samples, _ = self.ts.shape\n\n        weights = self.weights\n        if self.blur_pool:\n            padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2)\n            maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :])\n            weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :])\n        weights = weights + self.alpha\n        pmf = weights / weights.sum(dim=-2, keepdim=True)\n        inds = sample_pmf(pmf, n_samples)\n        assert inds.shape == (batch_size, *shape, n_samples, 1)\n        assert (inds >= 0).all() and (inds < n_coarse_samples).all()\n\n        t_rand = torch.rand(inds.shape, device=inds.device)\n        lower_ = torch.gather(lower, -2, inds)\n        upper_ = torch.gather(upper, -2, inds)\n\n        ts = lower_ + (upper_ - lower_) * t_rand\n        ts = torch.sort(ts, dim=-2).values\n        return ts\n"
  },
  {
    "path": "shap_e/models/nerf/renderer.py",
    "content": "from functools import partial\nfrom typing import Any, Dict, Optional\n\nimport torch\n\nfrom shap_e.models.nn.meta import subdict\nfrom shap_e.models.renderer import RayRenderer\nfrom shap_e.models.volume import Volume\nfrom shap_e.util.collections import AttrDict\n\nfrom .model import NeRFModel\nfrom .ray import RayVolumeIntegral, StratifiedRaySampler, render_rays\n\n\nclass TwoStepNeRFRenderer(RayRenderer):\n    \"\"\"\n    Coarse and fine-grained rendering as proposed by NeRF. This class\n    additionally supports background rendering like NeRF++.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_coarse_samples: int,\n        n_fine_samples: int,\n        void_model: NeRFModel,\n        fine_model: NeRFModel,\n        volume: Volume,\n        coarse_model: Optional[NeRFModel] = None,\n        coarse_background_model: Optional[NeRFModel] = None,\n        fine_background_model: Optional[NeRFModel] = None,\n        outer_volume: Optional[Volume] = None,\n        foreground_stratified_depth_sampling_mode: str = \"linear\",\n        background_stratified_depth_sampling_mode: str = \"linear\",\n        importance_sampling_options: Optional[Dict[str, Any]] = None,\n        channel_scale: float = 255,\n        device: torch.device = torch.device(\"cuda\"),\n        **kwargs,\n    ):\n        \"\"\"\n        :param outer_volume: is where distant objects are encoded.\n        \"\"\"\n        super().__init__(**kwargs)\n\n        if coarse_model is None:\n            assert (\n                fine_background_model is None or coarse_background_model is None\n            ), \"models should be shared for both fg and bg\"\n\n        self.n_coarse_samples = n_coarse_samples\n        self.n_fine_samples = n_fine_samples\n        self.void_model = void_model\n        self.coarse_model = coarse_model\n        self.fine_model = fine_model\n        self.volume = volume\n        self.coarse_background_model = coarse_background_model\n        self.fine_background_model = fine_background_model\n        self.outer_volume = outer_volume\n        self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode\n        self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode\n        self.importance_sampling_options = AttrDict(importance_sampling_options or {})\n        self.channel_scale = channel_scale\n        self.device = device\n        self.to(device)\n\n        if self.coarse_background_model is not None:\n            assert self.fine_background_model is not None\n            assert self.outer_volume is not None\n\n    def render_rays(\n        self,\n        batch: Dict,\n        params: Optional[Dict] = None,\n        options: Optional[Dict] = None,\n    ) -> AttrDict:\n        params = self.update(params)\n\n        batch = AttrDict(batch)\n        if options is None:\n            options = AttrDict()\n        options.setdefault(\"render_background\", True)\n        options.setdefault(\"render_with_direction\", True)\n        options.setdefault(\"n_coarse_samples\", self.n_coarse_samples)\n        options.setdefault(\"n_fine_samples\", self.n_fine_samples)\n        options.setdefault(\n            \"foreground_stratified_depth_sampling_mode\",\n            self.foreground_stratified_depth_sampling_mode,\n        )\n        options.setdefault(\n            \"background_stratified_depth_sampling_mode\",\n            self.background_stratified_depth_sampling_mode,\n        )\n\n        shared = self.coarse_model is None\n\n        # First, render rays using the coarse models with stratified ray samples.\n        coarse_model, coarse_key = (\n            (self.fine_model, \"fine_model\") if shared else (self.coarse_model, \"coarse_model\")\n        )\n        coarse_model = partial(\n            coarse_model,\n            params=subdict(params, coarse_key),\n            options=options,\n        )\n        parts = [\n            RayVolumeIntegral(\n                model=coarse_model,\n                volume=self.volume,\n                sampler=StratifiedRaySampler(\n                    depth_mode=options.foreground_stratified_depth_sampling_mode,\n                ),\n                n_samples=options.n_coarse_samples,\n            ),\n        ]\n        if options.render_background and self.outer_volume is not None:\n            coarse_background_model, coarse_background_key = (\n                (self.fine_background_model, \"fine_background_model\")\n                if shared\n                else (self.coarse_background_model, \"coarse_background_model\")\n            )\n            coarse_background_model = partial(\n                coarse_background_model,\n                params=subdict(params, coarse_background_key),\n                options=options,\n            )\n            parts.append(\n                RayVolumeIntegral(\n                    model=coarse_background_model,\n                    volume=self.outer_volume,\n                    sampler=StratifiedRaySampler(\n                        depth_mode=options.background_stratified_depth_sampling_mode,\n                    ),\n                    n_samples=options.n_coarse_samples,\n                )\n            )\n        coarse_results, samplers, coarse_raw_outputs = render_rays(\n            batch.rays,\n            parts,\n            partial(self.void_model, options=options),\n            shared=shared,\n            render_with_direction=options.render_with_direction,\n            importance_sampling_options=AttrDict(self.importance_sampling_options),\n        )\n\n        # Then, render rays using the fine models with importance-weighted ray samples.\n        fine_model = partial(\n            self.fine_model,\n            params=subdict(params, \"fine_model\"),\n            options=options,\n        )\n        parts = [\n            RayVolumeIntegral(\n                model=fine_model,\n                volume=self.volume,\n                sampler=samplers[0],\n                n_samples=options.n_fine_samples,\n            ),\n        ]\n        if options.render_background and self.outer_volume is not None:\n            fine_background_model = partial(\n                self.fine_background_model,\n                params=subdict(params, \"fine_background_model\"),\n                options=options,\n            )\n            parts.append(\n                RayVolumeIntegral(\n                    model=fine_background_model,\n                    volume=self.outer_volume,\n                    sampler=samplers[1],\n                    n_samples=options.n_fine_samples,\n                )\n            )\n        fine_results, *_ = render_rays(\n            batch.rays,\n            parts,\n            partial(self.void_model, options=options),\n            shared=shared,\n            prev_raw_outputs=coarse_raw_outputs,\n            render_with_direction=options.render_with_direction,\n        )\n\n        # Combine results\n        aux_losses = fine_results.output.aux_losses.copy()\n        for key, val in coarse_results.output.aux_losses.items():\n            aux_losses[key + \"_coarse\"] = val\n\n        return AttrDict(\n            channels=fine_results.output.channels * self.channel_scale,\n            channels_coarse=coarse_results.output.channels * self.channel_scale,\n            distances=fine_results.output.distances,\n            transmittance=fine_results.transmittance,\n            transmittance_coarse=coarse_results.transmittance,\n            t0=fine_results.volume_range.t0,\n            t1=fine_results.volume_range.t1,\n            intersected=fine_results.volume_range.intersected,\n            aux_losses=aux_losses,\n        )\n\n\nclass OneStepNeRFRenderer(RayRenderer):\n    \"\"\"\n    Renders rays using stratified sampling only unlike vanilla NeRF.\n    The same setup as NeRF++.\n    \"\"\"\n\n    def __init__(\n        self,\n        n_samples: int,\n        void_model: NeRFModel,\n        foreground_model: NeRFModel,\n        volume: Volume,\n        background_model: Optional[NeRFModel] = None,\n        outer_volume: Optional[Volume] = None,\n        foreground_stratified_depth_sampling_mode: str = \"linear\",\n        background_stratified_depth_sampling_mode: str = \"linear\",\n        channel_scale: float = 255,\n        device: torch.device = torch.device(\"cuda\"),\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.n_samples = n_samples\n        self.void_model = void_model\n        self.foreground_model = foreground_model\n        self.volume = volume\n        self.background_model = background_model\n        self.outer_volume = outer_volume\n        self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode\n        self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode\n        self.channel_scale = channel_scale\n        self.device = device\n        self.to(device)\n\n    def render_rays(\n        self,\n        batch: Dict,\n        params: Optional[Dict] = None,\n        options: Optional[Dict] = None,\n    ) -> AttrDict:\n        params = self.update(params)\n\n        batch = AttrDict(batch)\n        if options is None:\n            options = AttrDict()\n        options.setdefault(\"render_background\", True)\n        options.setdefault(\"render_with_direction\", True)\n        options.setdefault(\"n_samples\", self.n_samples)\n        options.setdefault(\n            \"foreground_stratified_depth_sampling_mode\",\n            self.foreground_stratified_depth_sampling_mode,\n        )\n        options.setdefault(\n            \"background_stratified_depth_sampling_mode\",\n            self.background_stratified_depth_sampling_mode,\n        )\n\n        foreground_model = partial(\n            self.foreground_model,\n            params=subdict(params, \"foreground_model\"),\n            options=options,\n        )\n        parts = [\n            RayVolumeIntegral(\n                model=foreground_model,\n                volume=self.volume,\n                sampler=StratifiedRaySampler(\n                    depth_mode=options.foreground_stratified_depth_sampling_mode\n                ),\n                n_samples=options.n_samples,\n            ),\n        ]\n        if options.render_background and self.outer_volume is not None:\n            background_model = partial(\n                self.background_model,\n                params=subdict(params, \"background_model\"),\n                options=options,\n            )\n            parts.append(\n                RayVolumeIntegral(\n                    model=background_model,\n                    volume=self.outer_volume,\n                    sampler=StratifiedRaySampler(\n                        depth_mode=options.background_stratified_depth_sampling_mode\n                    ),\n                    n_samples=options.n_samples,\n                )\n            )\n        results, *_ = render_rays(\n            batch.rays,\n            parts,\n            self.void_model,\n            render_with_direction=options.render_with_direction,\n        )\n\n        return AttrDict(\n            channels=results.output.channels * self.channel_scale,\n            distances=results.output.distances,\n            transmittance=results.transmittance,\n            t0=results.volume_range.t0,\n            t1=results.volume_range.t1,\n            intersected=results.volume_range.intersected,\n            aux_losses=results.output.aux_losses,\n        )\n"
  },
  {
    "path": "shap_e/models/nerstf/mlp.py",
    "content": "from typing import Any, Dict, Optional, Tuple\n\nimport torch\n\nfrom shap_e.models.nn.ops import get_act\nfrom shap_e.models.query import Query\nfrom shap_e.models.stf.mlp import MLPModel\nfrom shap_e.util.collections import AttrDict\n\n\nclass MLPDensitySDFModel(MLPModel):\n    def __init__(\n        self,\n        initial_bias: float = -0.1,\n        sdf_activation=\"tanh\",\n        density_activation=\"exp\",\n        **kwargs,\n    ):\n        super().__init__(\n            n_output=2,\n            output_activation=\"identity\",\n            **kwargs,\n        )\n        self.mlp[-1].bias[0].data.fill_(initial_bias)\n        self.sdf_activation = get_act(sdf_activation)\n        self.density_activation = get_act(density_activation)\n\n    def forward(\n        self,\n        query: Query,\n        params: Optional[Dict[str, torch.Tensor]] = None,\n        options: Optional[Dict[str, Any]] = None,\n    ) -> AttrDict[str, Any]:\n        # query.direction is None typically for SDF models and training\n        h, _h_directionless = self._mlp(\n            query.position, query.direction, params=params, options=options\n        )\n        h_sdf, h_density = h.split(1, dim=-1)\n        return AttrDict(\n            density=self.density_activation(h_density),\n            signed_distance=self.sdf_activation(h_sdf),\n        )\n\n\nclass MLPNeRSTFModel(MLPModel):\n    def __init__(\n        self,\n        sdf_activation=\"tanh\",\n        density_activation=\"exp\",\n        channel_activation=\"sigmoid\",\n        direction_dependent_shape: bool = True,  # To be able to load old models. Set this to be False in future models.\n        separate_nerf_channels: bool = False,\n        separate_coarse_channels: bool = False,\n        initial_density_bias: float = 0.0,\n        initial_sdf_bias: float = -0.1,\n        **kwargs,\n    ):\n        h_map, h_directionless_map = indices_for_output_mode(\n            direction_dependent_shape=direction_dependent_shape,\n            separate_nerf_channels=separate_nerf_channels,\n            separate_coarse_channels=separate_coarse_channels,\n        )\n        n_output = index_mapping_max(h_map)\n        super().__init__(\n            n_output=n_output,\n            output_activation=\"identity\",\n            **kwargs,\n        )\n        self.direction_dependent_shape = direction_dependent_shape\n        self.separate_nerf_channels = separate_nerf_channels\n        self.separate_coarse_channels = separate_coarse_channels\n        self.sdf_activation = get_act(sdf_activation)\n        self.density_activation = get_act(density_activation)\n        self.channel_activation = get_act(channel_activation)\n        self.h_map = h_map\n        self.h_directionless_map = h_directionless_map\n        self.mlp[-1].bias.data.zero_()\n        layer = -1 if self.direction_dependent_shape else self.insert_direction_at\n        self.mlp[layer].bias[0].data.fill_(initial_sdf_bias)\n        self.mlp[layer].bias[1].data.fill_(initial_density_bias)\n\n    def forward(\n        self,\n        query: Query,\n        params: Optional[Dict[str, torch.Tensor]] = None,\n        options: Optional[Dict[str, Any]] = None,\n    ) -> AttrDict[str, Any]:\n        options = AttrDict() if options is None else AttrDict(options)\n        h, h_directionless = self._mlp(\n            query.position, query.direction, params=params, options=options\n        )\n        activations = map_indices_to_keys(self.h_map, h)\n        activations.update(map_indices_to_keys(self.h_directionless_map, h_directionless))\n\n        if options.nerf_level == \"coarse\":\n            h_density = activations.density_coarse\n        else:\n            h_density = activations.density_fine\n\n        if options.get(\"rendering_mode\", \"stf\") == \"nerf\":\n            if options.nerf_level == \"coarse\":\n                h_channels = activations.nerf_coarse\n            else:\n                h_channels = activations.nerf_fine\n        else:\n            h_channels = activations.stf\n        return AttrDict(\n            density=self.density_activation(h_density),\n            signed_distance=self.sdf_activation(activations.sdf),\n            channels=self.channel_activation(h_channels),\n        )\n\n\nIndexMapping = AttrDict[str, Tuple[int, int]]\n\n\ndef indices_for_output_mode(\n    direction_dependent_shape: bool,\n    separate_nerf_channels: bool,\n    separate_coarse_channels: bool,\n) -> Tuple[IndexMapping, IndexMapping]:\n    \"\"\"\n    Get output mappings for (h, h_directionless).\n    \"\"\"\n    h_map = AttrDict()\n    h_directionless_map = AttrDict()\n    if direction_dependent_shape:\n        h_map.sdf = (0, 1)\n        if separate_coarse_channels:\n            assert separate_nerf_channels\n            h_map.density_coarse = (1, 2)\n            h_map.density_fine = (2, 3)\n            h_map.stf = (3, 6)\n            h_map.nerf_coarse = (6, 9)\n            h_map.nerf_fine = (9, 12)\n        else:\n            h_map.density_coarse = (1, 2)\n            h_map.density_fine = (1, 2)\n            if separate_nerf_channels:\n                h_map.stf = (2, 5)\n                h_map.nerf_coarse = (5, 8)\n                h_map.nerf_fine = (5, 8)\n            else:\n                h_map.stf = (2, 5)\n                h_map.nerf_coarse = (2, 5)\n                h_map.nerf_fine = (2, 5)\n    else:\n        h_directionless_map.sdf = (0, 1)\n        h_directionless_map.density_coarse = (1, 2)\n        if separate_coarse_channels:\n            h_directionless_map.density_fine = (2, 3)\n        else:\n            h_directionless_map.density_fine = h_directionless_map.density_coarse\n        h_map.stf = (0, 3)\n        if separate_coarse_channels:\n            assert separate_nerf_channels\n            h_map.nerf_coarse = (3, 6)\n            h_map.nerf_fine = (6, 9)\n        else:\n            if separate_nerf_channels:\n                h_map.nerf_coarse = (3, 6)\n            else:\n                h_map.nerf_coarse = (0, 3)\n            h_map.nerf_fine = h_map.nerf_coarse\n    return h_map, h_directionless_map\n\n\ndef map_indices_to_keys(mapping: IndexMapping, data: torch.Tensor) -> AttrDict[str, torch.Tensor]:\n    return AttrDict({k: data[..., start:end] for k, (start, end) in mapping.items()})\n\n\ndef index_mapping_max(mapping: IndexMapping) -> int:\n    return max(end for _, (_, end) in mapping.items())\n"
  },
  {
    "path": "shap_e/models/nerstf/renderer.py",
    "content": "from functools import partial\nfrom typing import Any, Dict, Optional, Sequence, Tuple, Union\n\nimport torch\n\nfrom shap_e.models.nerf.model import NeRFModel\nfrom shap_e.models.nerf.ray import RayVolumeIntegral, StratifiedRaySampler, render_rays\nfrom shap_e.models.nn.meta import subdict\nfrom shap_e.models.nn.utils import to_torch\nfrom shap_e.models.query import Query\nfrom shap_e.models.renderer import RayRenderer, render_views_from_rays\nfrom shap_e.models.stf.base import Model\nfrom shap_e.models.stf.renderer import STFRendererBase, render_views_from_stf\nfrom shap_e.models.volume import BoundingBoxVolume, Volume\nfrom shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR\nfrom shap_e.util.collections import AttrDict\n\n\nclass NeRSTFRenderer(RayRenderer, STFRendererBase):\n    def __init__(\n        self,\n        sdf: Optional[Model],\n        tf: Optional[Model],\n        nerstf: Optional[Model],\n        void: NeRFModel,\n        volume: Volume,\n        grid_size: int,\n        n_coarse_samples: int,\n        n_fine_samples: int,\n        importance_sampling_options: Optional[Dict[str, Any]] = None,\n        separate_shared_samples: bool = False,\n        texture_channels: Sequence[str] = (\"R\", \"G\", \"B\"),\n        channel_scale: Sequence[float] = (255.0, 255.0, 255.0),\n        ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR,\n        diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR,\n        specular_color: Union[float, Tuple[float]] = 0.0,\n        output_srgb: bool = True,\n        device: torch.device = torch.device(\"cuda\"),\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        assert isinstance(volume, BoundingBoxVolume), \"cannot sample points in unknown volume\"\n        assert (nerstf is not None) ^ (sdf is not None and tf is not None)\n        self.sdf = sdf\n        self.tf = tf\n        self.nerstf = nerstf\n        self.void = void\n        self.volume = volume\n        self.grid_size = grid_size\n        self.n_coarse_samples = n_coarse_samples\n        self.n_fine_samples = n_fine_samples\n        self.importance_sampling_options = AttrDict(importance_sampling_options or {})\n        self.separate_shared_samples = separate_shared_samples\n        self.texture_channels = texture_channels\n        self.channel_scale = to_torch(channel_scale).to(device)\n        self.ambient_color = ambient_color\n        self.diffuse_color = diffuse_color\n        self.specular_color = specular_color\n        self.output_srgb = output_srgb\n        self.device = device\n        self.to(device)\n\n    def _query(\n        self,\n        query: Query,\n        params: AttrDict[str, torch.Tensor],\n        options: AttrDict[str, Any],\n    ) -> AttrDict:\n        no_dir_query = query.copy()\n        no_dir_query.direction = None\n\n        if options.get(\"rendering_mode\", \"stf\") == \"stf\":\n            assert query.direction is None\n\n        if self.nerstf is not None:\n            sdf = tf = self.nerstf(\n                query,\n                params=subdict(params, \"nerstf\"),\n                options=options,\n            )\n        else:\n            sdf = self.sdf(no_dir_query, params=subdict(params, \"sdf\"), options=options)\n            tf = self.tf(query, params=subdict(params, \"tf\"), options=options)\n\n        return AttrDict(\n            density=sdf.density,\n            signed_distance=sdf.signed_distance,\n            channels=tf.channels,\n            aux_losses=dict(),\n        )\n\n    def render_rays(\n        self,\n        batch: AttrDict,\n        params: Optional[Dict] = None,\n        options: Optional[AttrDict] = None,\n    ) -> AttrDict:\n        \"\"\"\n        :param batch: has\n\n            - rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray.\n        :param options: Optional[Dict]\n        \"\"\"\n        params = self.update(params)\n        options = AttrDict() if options is None else AttrDict(options)\n\n        # Necessary to tell the TF to use specific NeRF channels.\n        options.rendering_mode = \"nerf\"\n\n        model = partial(self._query, params=params, options=options)\n\n        # First, render rays with coarse, stratified samples.\n        options.nerf_level = \"coarse\"\n        parts = [\n            RayVolumeIntegral(\n                model=model,\n                volume=self.volume,\n                sampler=StratifiedRaySampler(),\n                n_samples=self.n_coarse_samples,\n            ),\n        ]\n        coarse_results, samplers, coarse_raw_outputs = render_rays(\n            batch.rays,\n            parts,\n            self.void,\n            shared=not self.separate_shared_samples,\n            render_with_direction=options.render_with_direction,\n            importance_sampling_options=self.importance_sampling_options,\n        )\n\n        # Then, render with additional importance-weighted ray samples.\n        options.nerf_level = \"fine\"\n        parts = [\n            RayVolumeIntegral(\n                model=model,\n                volume=self.volume,\n                sampler=samplers[0],\n                n_samples=self.n_fine_samples,\n            ),\n        ]\n        fine_results, _, raw_outputs = render_rays(\n            batch.rays,\n            parts,\n            self.void,\n            shared=not self.separate_shared_samples,\n            prev_raw_outputs=coarse_raw_outputs,\n            render_with_direction=options.render_with_direction,\n        )\n        raw = raw_outputs[0]\n\n        aux_losses = fine_results.output.aux_losses.copy()\n        if self.separate_shared_samples:\n            for key, val in coarse_results.output.aux_losses.items():\n                aux_losses[key + \"_coarse\"] = val\n\n        channels = fine_results.output.channels\n        shape = [1] * (channels.ndim - 1) + [len(self.texture_channels)]\n        channels = channels * self.channel_scale.view(*shape)\n\n        res = AttrDict(\n            channels=channels,\n            transmittance=fine_results.transmittance,\n            raw_signed_distance=raw.signed_distance,\n            raw_density=raw.density,\n            distances=fine_results.output.distances,\n            t0=fine_results.volume_range.t0,\n            t1=fine_results.volume_range.t1,\n            intersected=fine_results.volume_range.intersected,\n            aux_losses=aux_losses,\n        )\n\n        if self.separate_shared_samples:\n            res.update(\n                dict(\n                    channels_coarse=(\n                        coarse_results.output.channels * self.channel_scale.view(*shape)\n                    ),\n                    distances_coarse=coarse_results.output.distances,\n                    transmittance_coarse=coarse_results.transmittance,\n                )\n            )\n\n        return res\n\n    def render_views(\n        self,\n        batch: AttrDict,\n        params: Optional[Dict] = None,\n        options: Optional[AttrDict] = None,\n    ) -> AttrDict:\n        \"\"\"\n        Returns a backproppable rendering of a view\n\n        :param batch: contains either [\"poses\", \"camera\"], or [\"cameras\"]. Can\n            optionally contain any of [\"height\", \"width\", \"query_batch_size\"]\n\n        :param params: Meta parameters\n            contains rendering_mode in [\"stf\", \"nerf\"]\n        :param options: controls checkpointing, caching, and rendering.\n            Can provide a `rendering_mode` in [\"stf\", \"nerf\"]\n        \"\"\"\n        params = self.update(params)\n        options = AttrDict() if options is None else AttrDict(options)\n\n        if options.cache is None:\n            created_cache = True\n            options.cache = AttrDict()\n        else:\n            created_cache = False\n\n        rendering_mode = options.get(\"rendering_mode\", \"stf\")\n\n        if rendering_mode == \"nerf\":\n\n            output = render_views_from_rays(\n                self.render_rays,\n                batch,\n                params=params,\n                options=options,\n                device=self.device,\n            )\n\n        elif rendering_mode == \"stf\":\n\n            sdf_fn = tf_fn = nerstf_fn = None\n            if self.nerstf is not None:\n                nerstf_fn = partial(\n                    self.nerstf.forward_batched,\n                    params=subdict(params, \"nerstf\"),\n                    options=options,\n                )\n            else:\n                sdf_fn = partial(\n                    self.sdf.forward_batched,\n                    params=subdict(params, \"sdf\"),\n                    options=options,\n                )\n                tf_fn = partial(\n                    self.tf.forward_batched,\n                    params=subdict(params, \"tf\"),\n                    options=options,\n                )\n            output = render_views_from_stf(\n                batch,\n                options,\n                sdf_fn=sdf_fn,\n                tf_fn=tf_fn,\n                nerstf_fn=nerstf_fn,\n                volume=self.volume,\n                grid_size=self.grid_size,\n                channel_scale=self.channel_scale,\n                texture_channels=self.texture_channels,\n                ambient_color=self.ambient_color,\n                diffuse_color=self.diffuse_color,\n                specular_color=self.specular_color,\n                output_srgb=self.output_srgb,\n                device=self.device,\n            )\n\n        else:\n\n            raise NotImplementedError\n\n        if created_cache:\n            del options[\"cache\"]\n\n        return output\n\n    def get_signed_distance(\n        self,\n        query: Query,\n        params: Dict[str, torch.Tensor],\n        options: AttrDict[str, Any],\n    ) -> torch.Tensor:\n        if self.sdf is not None:\n            return self.sdf(query, params=subdict(params, \"sdf\"), options=options).signed_distance\n        assert self.nerstf is not None\n        return self.nerstf(query, params=subdict(params, \"nerstf\"), options=options).signed_distance\n\n    def get_texture(\n        self,\n        query: Query,\n        params: Dict[str, torch.Tensor],\n        options: AttrDict[str, Any],\n    ) -> torch.Tensor:\n        if self.tf is not None:\n            return self.tf(query, params=subdict(params, \"tf\"), options=options).channels\n        assert self.nerstf is not None\n        return self.nerstf(query, params=subdict(params, \"nerstf\"), options=options).channels\n"
  },
  {
    "path": "shap_e/models/nn/__init__.py",
    "content": "from .meta import *\nfrom .ops import *\n"
  },
  {
    "path": "shap_e/models/nn/camera.py",
    "content": "from abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple, Union\n\nimport numpy as np\nimport torch\n\nfrom shap_e.rendering.view_data import ProjectiveCamera\n\n\n@dataclass\nclass DifferentiableCamera(ABC):\n    \"\"\"\n    An object describing how a camera corresponds to pixels in an image.\n    \"\"\"\n\n    @abstractmethod\n    def camera_rays(self, coords: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        For every (x, y) coordinate in a rendered image, compute the ray of the\n        corresponding pixel.\n\n        :param coords: an [N x ... x 2] integer array of 2D image coordinates.\n        :return: an [N x ... x 2 x 3] array of [2 x 3] (origin, direction) tuples.\n                 The direction should always be unit length.\n        \"\"\"\n\n    @abstractmethod\n    def resize_image(self, width: int, height: int) -> \"DifferentiableCamera\":\n        \"\"\"\n        Creates a new camera with the same intrinsics and direction as this one,\n        but with resized image dimensions.\n        \"\"\"\n\n\n@dataclass\nclass DifferentiableProjectiveCamera(DifferentiableCamera):\n    \"\"\"\n    Implements a batch, differentiable, standard pinhole camera\n    \"\"\"\n\n    origin: torch.Tensor  # [batch_size x 3]\n    x: torch.Tensor  # [batch_size x 3]\n    y: torch.Tensor  # [batch_size x 3]\n    z: torch.Tensor  # [batch_size x 3]\n    width: int\n    height: int\n    x_fov: float\n    y_fov: float\n\n    def __post_init__(self):\n        assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0]\n        assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3\n        assert (\n            len(self.x.shape)\n            == len(self.y.shape)\n            == len(self.z.shape)\n            == len(self.origin.shape)\n            == 2\n        )\n\n    def resolution(self):\n        return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32))\n\n    def fov(self):\n        return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32))\n\n    def image_coords(self) -> torch.Tensor:\n        \"\"\"\n        :return: coords of shape (width * height, 2)\n        \"\"\"\n        pixel_indices = torch.arange(self.height * self.width)\n        coords = torch.stack(\n            [\n                pixel_indices % self.width,\n                torch.div(pixel_indices, self.width, rounding_mode=\"trunc\"),\n            ],\n            axis=1,\n        )\n        return coords\n\n    def camera_rays(self, coords: torch.Tensor) -> torch.Tensor:\n        batch_size, *shape, n_coords = coords.shape\n        assert n_coords == 2\n        assert batch_size == self.origin.shape[0]\n        flat = coords.view(batch_size, -1, 2)\n\n        res = self.resolution().to(flat.device)\n        fov = self.fov().to(flat.device)\n\n        fracs = (flat.float() / (res - 1)) * 2 - 1\n        fracs = fracs * torch.tan(fov / 2)\n\n        fracs = fracs.view(batch_size, -1, 2)\n        directions = (\n            self.z.view(batch_size, 1, 3)\n            + self.x.view(batch_size, 1, 3) * fracs[:, :, :1]\n            + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:]\n        )\n        directions = directions / directions.norm(dim=-1, keepdim=True)\n        rays = torch.stack(\n            [\n                torch.broadcast_to(\n                    self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]\n                ),\n                directions,\n            ],\n            dim=2,\n        )\n        return rays.view(batch_size, *shape, 2, 3)\n\n    def resize_image(self, width: int, height: int) -> \"DifferentiableProjectiveCamera\":\n        \"\"\"\n        Creates a new camera for the resized view assuming the aspect ratio does not change.\n        \"\"\"\n        assert width * self.height == height * self.width, \"The aspect ratio should not change.\"\n        return DifferentiableProjectiveCamera(\n            origin=self.origin,\n            x=self.x,\n            y=self.y,\n            z=self.z,\n            width=width,\n            height=height,\n            x_fov=self.x_fov,\n            y_fov=self.y_fov,\n        )\n\n\n@dataclass\nclass DifferentiableCameraBatch(ABC):\n    \"\"\"\n    Annotate a differentiable camera with a multi-dimensional batch shape.\n    \"\"\"\n\n    shape: Tuple[int]\n    flat_camera: DifferentiableCamera\n\n\ndef normalize(vec: torch.Tensor) -> torch.Tensor:\n    return vec / vec.norm(dim=-1, keepdim=True)\n\n\ndef project_out(vec1: torch.Tensor, vec2: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Removes the vec2 component from vec1\n    \"\"\"\n    vec2 = normalize(vec2)\n    proj = (vec1 * vec2).sum(dim=-1, keepdim=True)\n    return vec1 - proj * vec2\n\n\ndef camera_orientation(toward: torch.Tensor, up: Optional[torch.Tensor] = None) -> torch.Tensor:\n    \"\"\"\n    :param toward: [batch_size x 3] unit vector from camera position to the object\n    :param up: Optional [batch_size x 3] specifying the physical up direction in the world frame.\n    :return: [batch_size x 3 x 3]\n    \"\"\"\n\n    if up is None:\n        up = torch.zeros_like(toward)\n        up[:, 2] = 1\n\n    assert len(toward.shape) == 2\n    assert toward.shape[1] == 3\n\n    assert len(up.shape) == 2\n    assert up.shape[1] == 3\n\n    z = toward / toward.norm(dim=-1, keepdim=True)\n    y = -normalize(project_out(up, toward))\n    x = torch.cross(y, z, dim=1)\n    return torch.stack([x, y, z], dim=1)\n\n\ndef projective_camera_frame(\n    origin: torch.Tensor,\n    toward: torch.Tensor,\n    camera_params: Union[ProjectiveCamera, DifferentiableProjectiveCamera],\n) -> DifferentiableProjectiveCamera:\n    \"\"\"\n    Given the origin and the direction of a view, return a differentiable\n    projective camera with the given parameters.\n\n    TODO: We need to support the rotation of the camera frame about the\n    `toward` vector to fully implement 6 degrees of freedom.\n    \"\"\"\n    rot = camera_orientation(toward)\n    camera = DifferentiableProjectiveCamera(\n        origin=origin,\n        x=rot[:, 0],\n        y=rot[:, 1],\n        z=rot[:, 2],\n        width=camera_params.width,\n        height=camera_params.height,\n        x_fov=camera_params.x_fov,\n        y_fov=camera_params.y_fov,\n    )\n    return camera\n\n\n@torch.no_grad()\ndef get_image_coords(width, height) -> torch.Tensor:\n    pixel_indices = torch.arange(height * width)\n    # torch throws warnings for pixel_indices // width\n    pixel_indices_div = torch.div(pixel_indices, width, rounding_mode=\"trunc\")\n    coords = torch.stack([pixel_indices % width, pixel_indices_div], dim=1)\n    return coords\n"
  },
  {
    "path": "shap_e/models/nn/checkpoint.py",
    "content": "from typing import Callable, Iterable, Sequence, Union\n\nimport torch\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n\ndef checkpoint(\n    func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],\n    inputs: Sequence[torch.Tensor],\n    params: Iterable[torch.Tensor],\n    flag: bool,\n):\n    \"\"\"\n    Evaluate a function without caching intermediate activations, allowing for\n    reduced memory at the expense of extra compute in the backward pass.\n    :param func: the function to evaluate.\n    :param inputs: the argument sequence to pass to `func`.\n    :param params: a sequence of parameters `func` depends on but does not\n                   explicitly take as arguments.\n    :param flag: if False, disable gradient checkpointing.\n    \"\"\"\n    if flag:\n        args = tuple(inputs) + tuple(params)\n        return CheckpointFunction.apply(func, len(inputs), *args)\n    else:\n        return func(*inputs)\n\n\nclass CheckpointFunction(torch.autograd.Function):\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, run_function, length, *args):\n        ctx.run_function = run_function\n        ctx.length = length\n        input_tensors = list(args[:length])\n        input_params = list(args[length:])\n        ctx.save_for_backward(*input_tensors, *input_params)\n        with torch.no_grad():\n            output_tensors = ctx.run_function(*input_tensors)\n        return output_tensors\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, *output_grads):\n        inputs = ctx.saved_tensors\n        input_tensors = inputs[: ctx.length]\n        input_params = inputs[ctx.length :]\n        res = CheckpointFunctionGradFunction.apply(\n            ctx.run_function,\n            len(input_tensors),\n            len(input_params),\n            *input_tensors,\n            *input_params,\n            *output_grads\n        )\n        return (None, None) + res\n\n\nclass CheckpointFunctionGradFunction(torch.autograd.Function):\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, run_function, length_1, length_2, *args):\n        ctx.run_function = run_function\n        ctx.length_1 = length_1\n        ctx.length_2 = length_2\n        input_tensors = [x.detach().requires_grad_(True) for x in args[:length_1]]\n        input_params = list(args[length_1 : length_1 + length_2])\n        output_grads = list(args[length_1 + length_2 :])\n        ctx.save_for_backward(*input_tensors, *input_params, *output_grads)\n\n        with torch.enable_grad():\n            # Fixes a bug where the first op in run_function modifies the\n            # Tensor storage in place, which is not allowed for detach()'d\n            # Tensors.\n            shallow_copies = [x.view_as(x) for x in input_tensors]\n            output_tensors = ctx.run_function(*shallow_copies)\n        input_grads = torch.autograd.grad(\n            output_tensors,\n            input_tensors + input_params,\n            output_grads,\n            allow_unused=True,\n        )\n        return input_grads\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, *all_output_grads):\n        args = ctx.saved_tensors\n        input_tensors = [x.detach().requires_grad_(True) for x in args[: ctx.length_1]]\n        input_params = list(args[ctx.length_1 : ctx.length_1 + ctx.length_2])\n        output_grads = [\n            x.detach().requires_grad_(True) for x in args[ctx.length_1 + ctx.length_2 :]\n        ]\n\n        with torch.enable_grad():\n            # Fixes a bug where the first op in run_function modifies the\n            # Tensor storage in place, which is not allowed for detach()'d\n            # Tensors.\n            shallow_copies = [x.view_as(x) for x in input_tensors]\n            output_tensors = ctx.run_function(*shallow_copies)\n            input_grads = torch.autograd.grad(\n                output_tensors,\n                input_tensors + input_params,\n                output_grads,\n                allow_unused=True,\n                create_graph=True,\n                retain_graph=True,\n            )\n        input_grads_grads = torch.autograd.grad(\n            input_grads,\n            input_tensors + input_params + output_grads,\n            all_output_grads,\n            allow_unused=True,\n        )\n        del input_grads\n        return (None, None, None) + input_grads_grads\n"
  },
  {
    "path": "shap_e/models/nn/encoding.py",
    "content": "import math\nfrom functools import lru_cache\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\n\n\ndef encode_position(version: str, *, position: torch.Tensor):\n    if version == \"v1\":\n        freqs = get_scales(0, 10, position.dtype, position.device).view(1, -1)\n        freqs = position.reshape(-1, 1) * freqs\n        return torch.cat([freqs.cos(), freqs.sin()], dim=1).reshape(*position.shape[:-1], -1)\n    elif version == \"nerf\":\n        return posenc_nerf(position, min_deg=0, max_deg=15)\n    else:\n        raise ValueError(version)\n\n\ndef encode_channels(version: str, *, channels: torch.Tensor):\n    if version == \"v1\":\n        freqs = get_scales(0, 10, channels.dtype, channels.device).view(1, -1)\n        freqs = channels.reshape(-1, 1) * freqs\n        return torch.cat([freqs.cos(), freqs.sin()], dim=1).reshape(*channels.shape[:-1], -1)\n    elif version == \"nerf\":\n        return posenc_nerf(channels, min_deg=0, max_deg=15)\n    else:\n        raise ValueError(version)\n\n\ndef position_encoding_channels(version: Optional[str] = None) -> int:\n    if version is None:\n        return 1\n    return encode_position(version, position=torch.zeros(1, 1)).shape[-1]\n\n\ndef channel_encoding_channels(version: Optional[str] = None) -> int:\n    if version is None:\n        return 1\n    return encode_channels(version, channels=torch.zeros(1, 1)).shape[-1]\n\n\nclass PosEmbLinear(nn.Linear):\n    def __init__(\n        self, posemb_version: Optional[str], in_features: int, out_features: int, **kwargs\n    ):\n        super().__init__(\n            in_features * position_encoding_channels(posemb_version),\n            out_features,\n            **kwargs,\n        )\n        self.posemb_version = posemb_version\n\n    def forward(self, x: torch.Tensor):\n        if self.posemb_version is not None:\n            x = encode_position(self.posemb_version, position=x)\n        return super().forward(x)\n\n\nclass MultiviewPoseEmbedding(nn.Conv2d):\n    def __init__(\n        self,\n        posemb_version: Optional[str],\n        n_channels: int,\n        out_features: int,\n        stride: int = 1,\n        **kwargs,\n    ):\n        in_features = (\n            n_channels * channel_encoding_channels(version=posemb_version)\n            + 3 * position_encoding_channels(version=posemb_version)\n            + 3 * position_encoding_channels(version=posemb_version)\n        )\n        super().__init__(\n            in_features,\n            out_features,\n            kernel_size=3,\n            stride=stride,\n            padding=1,\n            **kwargs,\n        )\n        self.posemb_version = posemb_version\n\n    def forward(\n        self, channels: torch.Tensor, position: torch.Tensor, direction: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        :param channels: [batch_shape, inner_batch_shape, n_channels, height, width]\n        :param position: [batch_shape, inner_batch_shape, 3, height, width]\n        :param direction: [batch_shape, inner_batch_shape, 3, height, width]\n        :return: [*batch_shape, out_features, height, width]\n        \"\"\"\n\n        if self.posemb_version is not None:\n            channels = channels.permute(0, 1, 3, 4, 2)\n            position = position.permute(0, 1, 3, 4, 2)\n            direction = direction.permute(0, 1, 3, 4, 2)\n            channels = encode_channels(self.posemb_version, channels=channels).permute(\n                0, 1, 4, 2, 3\n            )\n            direction = maybe_encode_direction(\n                self.posemb_version, position=position, direction=direction\n            ).permute(0, 1, 4, 2, 3)\n            position = encode_position(self.posemb_version, position=position).permute(\n                0, 1, 4, 2, 3\n            )\n        x = torch.cat([channels, position, direction], dim=-3)\n        *batch_shape, in_features, height, width = x.shape\n        return (\n            super()\n            .forward(x.view(-1, in_features, height, width))\n            .view(*batch_shape, -1, height, width)\n        )\n\n\nclass MultiviewPointCloudEmbedding(nn.Conv2d):\n    def __init__(\n        self,\n        posemb_version: Optional[str],\n        n_channels: int,\n        out_features: int,\n        stride: int = 1,\n        **kwargs,\n    ):\n        in_features = (\n            n_channels * channel_encoding_channels(version=posemb_version)\n            + 3 * position_encoding_channels(version=posemb_version)\n            + 3 * position_encoding_channels(version=posemb_version)\n        )\n        super().__init__(\n            in_features,\n            out_features,\n            kernel_size=3,\n            stride=stride,\n            padding=1,\n            **kwargs,\n        )\n        self.posemb_version = posemb_version\n        self.register_parameter(\n            \"unk_token\", nn.Parameter(torch.randn(in_features, **kwargs) * 0.01)\n        )\n        self.unk_token: torch.Tensor\n\n    def forward(\n        self,\n        channels: torch.Tensor,\n        origin: torch.Tensor,\n        position: torch.Tensor,\n        mask: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"\n        :param channels: [batch_shape, inner_batch_shape, n_channels, height, width]\n        :param origin: [batch_shape, inner_batch_shape, 3, height, width]\n        :param position: [batch_shape, inner_batch_shape, 3, height, width]\n        :return: [*batch_shape, out_features, height, width]\n        \"\"\"\n\n        if self.posemb_version is not None:\n            channels = channels.permute(0, 1, 3, 4, 2)\n            origin = origin.permute(0, 1, 3, 4, 2)\n            position = position.permute(0, 1, 3, 4, 2)\n            channels = encode_channels(self.posemb_version, channels=channels).permute(\n                0, 1, 4, 2, 3\n            )\n            origin = encode_position(self.posemb_version, position=origin).permute(0, 1, 4, 2, 3)\n            position = encode_position(self.posemb_version, position=position).permute(\n                0, 1, 4, 2, 3\n            )\n        x = torch.cat([channels, origin, position], dim=-3)\n        unk_token = torch.broadcast_to(self.unk_token.view(1, 1, -1, 1, 1), x.shape)\n        x = torch.where(mask, x, unk_token)\n        *batch_shape, in_features, height, width = x.shape\n        return (\n            super()\n            .forward(x.view(-1, in_features, height, width))\n            .view(*batch_shape, -1, height, width)\n        )\n\n\ndef maybe_encode_direction(\n    version: str,\n    *,\n    position: torch.Tensor,\n    direction: Optional[torch.Tensor] = None,\n):\n\n    if version == \"v1\":\n        sh_degree = 4\n        if direction is None:\n            return torch.zeros(*position.shape[:-1], sh_degree**2).to(position)\n        return spherical_harmonics_basis(direction, sh_degree=sh_degree)\n    elif version == \"nerf\":\n        if direction is None:\n            return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8))\n        return posenc_nerf(direction, min_deg=0, max_deg=8)\n    else:\n        raise ValueError(version)\n\n\ndef posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor:\n    \"\"\"\n    Concatenate x and its positional encodings, following NeRF.\n\n    Reference: https://arxiv.org/pdf/2210.04628.pdf\n    \"\"\"\n    if min_deg == max_deg:\n        return x\n    scales = get_scales(min_deg, max_deg, x.dtype, x.device)\n    *shape, dim = x.shape\n    xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1)\n    assert xb.shape[-1] == dim * (max_deg - min_deg)\n    emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin()\n    return torch.cat([x, emb], dim=-1)\n\n\n@lru_cache\ndef get_scales(\n    min_deg: int,\n    max_deg: int,\n    dtype: torch.dtype,\n    device: torch.device,\n) -> torch.Tensor:\n    return 2.0 ** torch.arange(min_deg, max_deg, device=device, dtype=dtype)\n\n\ndef spherical_harmonics_basis(\n    coords: torch.Tensor,\n    sh_degree: int,\n) -> torch.Tensor:\n    \"\"\"\n    Calculate the spherical harmonics basis\n\n    :param coords: [batch_size, *shape, 3] of unit norm\n    :param sh_degree: Spherical harmonics degree\n    :return: [batch_size, *shape, sh_degree**2]\n    \"\"\"\n    if sh_degree > 8:\n        raise NotImplementedError\n\n    batch_size, *shape, _ = coords.shape\n    x, y, z = coords.reshape(-1, 3).split(1, dim=-1)\n    x = x.squeeze(dim=-1)\n    y = y.squeeze(dim=-1)\n    z = z.squeeze(dim=-1)\n\n    xy, xz, yz = x * y, x * z, y * z\n    x2, y2, z2 = x * x, y * y, z * z\n    x4, y4, z4 = x2 * x2, y2 * y2, z2 * z2\n    x6, y6, z6 = x4 * x2, y4 * y2, z4 * z2\n    xyz = xy * z\n\n    # https://github.com/NVlabs/tiny-cuda-nn/blob/8575542682cb67cddfc748cc3d3cfc12593799aa/include/tiny-cuda-nn/encodings/spherical_harmonics.h#L76\n\n    out = torch.zeros(x.shape[0], sh_degree**2, dtype=x.dtype, device=x.device)\n\n    def _sh():\n        out[:, 0] = 0.28209479177387814  # 1/(2*sqrt(pi))\n        if sh_degree <= 1:\n            return\n        out[:, 1] = -0.48860251190291987 * y  # -sqrt(3)*y/(2*sqrt(pi))\n        out[:, 2] = 0.48860251190291987 * z  # sqrt(3)*z/(2*sqrt(pi))\n        out[:, 3] = -0.48860251190291987 * x  # -sqrt(3)*x/(2*sqrt(pi))\n        if sh_degree <= 2:\n            return\n        out[:, 4] = 1.0925484305920792 * xy  # sqrt(15)*xy/(2*sqrt(pi))\n        out[:, 5] = -1.0925484305920792 * yz  # -sqrt(15)*yz/(2*sqrt(pi))\n        out[:, 6] = (\n            0.94617469575755997 * z2 - 0.31539156525251999\n        )  # sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))\n        out[:, 7] = -1.0925484305920792 * xz  # -sqrt(15)*xz/(2*sqrt(pi))\n        out[:, 8] = (\n            0.54627421529603959 * x2 - 0.54627421529603959 * y2\n        )  # sqrt(15)*(x2 - y2)/(4*sqrt(pi))\n        if sh_degree <= 3:\n            return\n        out[:, 9] = (\n            0.59004358992664352 * y * (-3.0 * x2 + y2)\n        )  # sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))\n        out[:, 10] = 2.8906114426405538 * xy * z  # sqrt(105)*xy*z/(2*sqrt(pi))\n        out[:, 11] = (\n            0.45704579946446572 * y * (1.0 - 5.0 * z2)\n        )  # sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))\n        out[:, 12] = 0.3731763325901154 * z * (5.0 * z2 - 3.0)  # sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))\n        out[:, 13] = (\n            0.45704579946446572 * x * (1.0 - 5.0 * z2)\n        )  # sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))\n        out[:, 14] = 1.4453057213202769 * z * (x2 - y2)  # sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))\n        out[:, 15] = (\n            0.59004358992664352 * x * (-x2 + 3.0 * y2)\n        )  # sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))\n        if sh_degree <= 4:\n            return\n        out[:, 16] = 2.5033429417967046 * xy * (x2 - y2)  # 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi))\n        out[:, 17] = (\n            1.7701307697799304 * yz * (-3.0 * x2 + y2)\n        )  # 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi))\n        out[:, 18] = (\n            0.94617469575756008 * xy * (7.0 * z2 - 1.0)\n        )  # 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi))\n        out[:, 19] = (\n            0.66904654355728921 * yz * (3.0 - 7.0 * z2)\n        )  # 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi))\n        out[:, 20] = (\n            -3.1735664074561294 * z2 + 3.7024941420321507 * z4 + 0.31735664074561293\n        )  # 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi))\n        out[:, 21] = (\n            0.66904654355728921 * xz * (3.0 - 7.0 * z2)\n        )  # 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi))\n        out[:, 22] = (\n            0.47308734787878004 * (x2 - y2) * (7.0 * z2 - 1.0)\n        )  # 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi))\n        out[:, 23] = (\n            1.7701307697799304 * xz * (-x2 + 3.0 * y2)\n        )  # 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi))\n        out[:, 24] = (\n            -3.7550144126950569 * x2 * y2 + 0.62583573544917614 * x4 + 0.62583573544917614 * y4\n        )  # 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n        if sh_degree <= 5:\n            return\n        out[:, 25] = (\n            0.65638205684017015 * y * (10.0 * x2 * y2 - 5.0 * x4 - y4)\n        )  # 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n        out[:, 26] = (\n            8.3026492595241645 * xy * z * (x2 - y2)\n        )  # 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi))\n        out[:, 27] = (\n            -0.48923829943525038 * y * (3.0 * x2 - y2) * (9.0 * z2 - 1.0)\n        )  # -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))\n        out[:, 28] = (\n            4.7935367849733241 * xy * z * (3.0 * z2 - 1.0)\n        )  # sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi))\n        out[:, 29] = (\n            0.45294665119569694 * y * (14.0 * z2 - 21.0 * z4 - 1.0)\n        )  # sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))\n        out[:, 30] = (\n            0.1169503224534236 * z * (-70.0 * z2 + 63.0 * z4 + 15.0)\n        )  # sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi))\n        out[:, 31] = (\n            0.45294665119569694 * x * (14.0 * z2 - 21.0 * z4 - 1.0)\n        )  # sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))\n        out[:, 32] = (\n            2.3967683924866621 * z * (x2 - y2) * (3.0 * z2 - 1.0)\n        )  # sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi))\n        out[:, 33] = (\n            -0.48923829943525038 * x * (x2 - 3.0 * y2) * (9.0 * z2 - 1.0)\n        )  # -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi))\n        out[:, 34] = (\n            2.0756623148810411 * z * (-6.0 * x2 * y2 + x4 + y4)\n        )  # 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n        out[:, 35] = (\n            0.65638205684017015 * x * (10.0 * x2 * y2 - x4 - 5.0 * y4)\n        )  # 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))\n        if sh_degree <= 6:\n            return\n        out[:, 36] = (\n            1.3663682103838286 * xy * (-10.0 * x2 * y2 + 3.0 * x4 + 3.0 * y4)\n        )  # sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))\n        out[:, 37] = (\n            2.3666191622317521 * yz * (10.0 * x2 * y2 - 5.0 * x4 - y4)\n        )  # 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n        out[:, 38] = (\n            2.0182596029148963 * xy * (x2 - y2) * (11.0 * z2 - 1.0)\n        )  # 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))\n        out[:, 39] = (\n            -0.92120525951492349 * yz * (3.0 * x2 - y2) * (11.0 * z2 - 3.0)\n        )  # -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))\n        out[:, 40] = (\n            0.92120525951492349 * xy * (-18.0 * z2 + 33.0 * z4 + 1.0)\n        )  # sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n        out[:, 41] = (\n            0.58262136251873131 * yz * (30.0 * z2 - 33.0 * z4 - 5.0)\n        )  # sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n        out[:, 42] = (\n            6.6747662381009842 * z2\n            - 20.024298714302954 * z4\n            + 14.684485723822165 * z6\n            - 0.31784601133814211\n        )  # sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi))\n        out[:, 43] = (\n            0.58262136251873131 * xz * (30.0 * z2 - 33.0 * z4 - 5.0)\n        )  # sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n        out[:, 44] = (\n            0.46060262975746175 * (x2 - y2) * (11.0 * z2 * (3.0 * z2 - 1.0) - 7.0 * z2 + 1.0)\n        )  # sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi))\n        out[:, 45] = (\n            -0.92120525951492349 * xz * (x2 - 3.0 * y2) * (11.0 * z2 - 3.0)\n        )  # -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi))\n        out[:, 46] = (\n            0.50456490072872406 * (11.0 * z2 - 1.0) * (-6.0 * x2 * y2 + x4 + y4)\n        )  # 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))\n        out[:, 47] = (\n            2.3666191622317521 * xz * (10.0 * x2 * y2 - x4 - 5.0 * y4)\n        )  # 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))\n        out[:, 48] = (\n            10.247761577878714 * x2 * y4\n            - 10.247761577878714 * x4 * y2\n            + 0.6831841051919143 * x6\n            - 0.6831841051919143 * y6\n        )  # sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))\n        if sh_degree <= 7:\n            return\n        out[:, 49] = (\n            0.70716273252459627 * y * (-21.0 * x2 * y4 + 35.0 * x4 * y2 - 7.0 * x6 + y6)\n        )  # 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi))\n        out[:, 50] = (\n            5.2919213236038001 * xy * z * (-10.0 * x2 * y2 + 3.0 * x4 + 3.0 * y4)\n        )  # 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))\n        out[:, 51] = (\n            -0.51891557872026028 * y * (13.0 * z2 - 1.0) * (-10.0 * x2 * y2 + 5.0 * x4 + y4)\n        )  # -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi))\n        out[:, 52] = (\n            4.1513246297620823 * xy * z * (x2 - y2) * (13.0 * z2 - 3.0)\n        )  # 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))\n        out[:, 53] = (\n            -0.15645893386229404\n            * y\n            * (3.0 * x2 - y2)\n            * (13.0 * z2 * (11.0 * z2 - 3.0) - 27.0 * z2 + 3.0)\n        )  # -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))\n        out[:, 54] = (\n            0.44253269244498261 * xy * z * (-110.0 * z2 + 143.0 * z4 + 15.0)\n        )  # 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))\n        out[:, 55] = (\n            0.090331607582517306 * y * (-135.0 * z2 + 495.0 * z4 - 429.0 * z6 + 5.0)\n        )  # sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))\n        out[:, 56] = (\n            0.068284276912004949 * z * (315.0 * z2 - 693.0 * z4 + 429.0 * z6 - 35.0)\n        )  # sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi))\n        out[:, 57] = (\n            0.090331607582517306 * x * (-135.0 * z2 + 495.0 * z4 - 429.0 * z6 + 5.0)\n        )  # sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))\n        out[:, 58] = (\n            0.07375544874083044\n            * z\n            * (x2 - y2)\n            * (143.0 * z2 * (3.0 * z2 - 1.0) - 187.0 * z2 + 45.0)\n        )  # sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi))\n        out[:, 59] = (\n            -0.15645893386229404\n            * x\n            * (x2 - 3.0 * y2)\n            * (13.0 * z2 * (11.0 * z2 - 3.0) - 27.0 * z2 + 3.0)\n        )  # -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))\n        out[:, 60] = (\n            1.0378311574405206 * z * (13.0 * z2 - 3.0) * (-6.0 * x2 * y2 + x4 + y4)\n        )  # 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))\n        out[:, 61] = (\n            -0.51891557872026028 * x * (13.0 * z2 - 1.0) * (-10.0 * x2 * y2 + x4 + 5.0 * y4)\n        )  # -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi))\n        out[:, 62] = (\n            2.6459606618019 * z * (15.0 * x2 * y4 - 15.0 * x4 * y2 + x6 - y6)\n        )  # 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))\n        out[:, 63] = (\n            0.70716273252459627 * x * (-35.0 * x2 * y4 + 21.0 * x4 * y2 - x6 + 7.0 * y6)\n        )  # 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi))\n\n    _sh()\n    return out.view(batch_size, *shape, sh_degree**2)\n"
  },
  {
    "path": "shap_e/models/nn/meta.py",
    "content": "\"\"\"\nMeta-learning modules based on: https://github.com/tristandeleu/pytorch-meta\n\nMIT License\n\nCopyright (c) 2019-2020 Tristan Deleu\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\nimport itertools\nimport re\nfrom collections import OrderedDict\n\nimport torch.nn as nn\n\nfrom shap_e.util.collections import AttrDict\n\n__all__ = [\n    \"MetaModule\",\n    \"subdict\",\n    \"superdict\",\n    \"leveldict\",\n    \"leveliter\",\n    \"batch_meta_parameters\",\n    \"batch_meta_state_dict\",\n]\n\n\ndef subdict(dictionary, key=None):\n    if dictionary is None:\n        return None\n    if (key is None) or (key == \"\"):\n        return dictionary\n    key_re = re.compile(r\"^{0}\\.(.+)\".format(re.escape(key)))\n    return AttrDict(\n        OrderedDict(\n            (key_re.sub(r\"\\1\", k), value)\n            for (k, value) in dictionary.items()\n            if key_re.match(k) is not None\n        )\n    )\n\n\ndef superdict(dictionary, key=None):\n    if dictionary is None:\n        return None\n    if (key is None) or (key == \"\"):\n        return dictionary\n    return AttrDict(OrderedDict((key + \".\" + k, value) for (k, value) in dictionary.items()))\n\n\ndef leveldict(dictionary, depth=0):\n    return AttrDict(leveliter(dictionary, depth=depth))\n\n\ndef leveliter(dictionary, depth=0):\n    \"\"\"\n    depth == 0 is root\n    \"\"\"\n    for key, value in dictionary.items():\n        if key.count(\".\") == depth:\n            yield key, value\n\n\nclass MetaModule(nn.Module):\n    \"\"\"\n    Base class for PyTorch meta-learning modules. These modules accept an\n    additional argument `params` in their `forward` method.\n\n    Notes\n    -----\n    Objects inherited from `MetaModule` are fully compatible with PyTorch\n    modules from `torch.nn.Module`. The argument `params` is a dictionary of\n    tensors, with full support of the computation graph (for differentiation).\n\n    Based on SIREN's torchmeta with some additional features/changes.\n\n    All meta weights must not have the batch dimension, as they are later tiled\n    to the given batch size after unsqueezing the first dimension (e.g. a\n    weight of dimension [d_out x d_in] is tiled to have the dimension [batch x\n    d_out x d_in]).  Requiring all meta weights to have a batch dimension of 1\n    (e.g. [1 x d_out x d_in] from the earlier example) could be a more natural\n    choice, but this results in silent failures.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._meta_state_dict = set()\n        self._meta_params = set()\n\n    def register_meta_buffer(self, name: str, param: nn.Parameter):\n        \"\"\"\n        Registers a trainable or nontrainable parameter as a meta buffer. This\n        can be later retrieved by meta_state_dict\n        \"\"\"\n        self.register_buffer(name, param)\n        self._meta_state_dict.add(name)\n\n    def register_meta_parameter(self, name: str, parameter: nn.Parameter):\n        \"\"\"\n        Registers a meta parameter so it is included in named_meta_parameters\n        and meta_state_dict.\n        \"\"\"\n        self.register_parameter(name, parameter)\n        self._meta_params.add(name)\n        self._meta_state_dict.add(name)\n\n    def register_meta(self, name: str, parameter: nn.Parameter, trainable: bool = True):\n        if trainable:\n            self.register_meta_parameter(name, parameter)\n        else:\n            self.register_meta_buffer(name, parameter)\n\n    def register(self, name: str, parameter: nn.Parameter, meta: bool, trainable: bool = True):\n        if meta:\n            if trainable:\n                self.register_meta_parameter(name, parameter)\n            else:\n                self.register_meta_buffer(name, parameter)\n        else:\n            if trainable:\n                self.register_parameter(name, parameter)\n            else:\n                self.register_buffer(name, parameter)\n\n    def named_meta_parameters(self, prefix=\"\", recurse=True):\n        \"\"\"\n        Returns an iterator over all the names and meta parameters\n        \"\"\"\n\n        def meta_iterator(module):\n            meta = module._meta_params if isinstance(module, MetaModule) else set()\n            for name, param in module._parameters.items():\n                if name in meta:\n                    yield name, param\n\n        gen = self._named_members(\n            meta_iterator,\n            prefix=prefix,\n            recurse=recurse,\n        )\n        for name, param in gen:\n            yield name, param\n\n    def named_nonmeta_parameters(self, prefix=\"\", recurse=True):\n        def _iterator(module):\n            meta = module._meta_params if isinstance(module, MetaModule) else set()\n            for name, param in module._parameters.items():\n                if name not in meta:\n                    yield name, param\n\n        gen = self._named_members(\n            _iterator,\n            prefix=prefix,\n            recurse=recurse,\n        )\n        for name, param in gen:\n            yield name, param\n\n    def nonmeta_parameters(self, prefix=\"\", recurse=True):\n        for _, param in self.named_nonmeta_parameters(prefix=prefix, recurse=recurse):\n            yield param\n\n    def meta_state_dict(self, prefix=\"\", recurse=True):\n        \"\"\"\n        Returns an iterator over all the names and meta parameters/buffers.\n\n        One difference between module.state_dict() is that this preserves\n        requires_grad, because we may want to compute the gradient w.r.t. meta\n        buffers, but don't necessarily update them automatically.\n        \"\"\"\n\n        def meta_iterator(module):\n            meta = module._meta_state_dict if isinstance(module, MetaModule) else set()\n            for name, param in itertools.chain(module._buffers.items(), module._parameters.items()):\n                if name in meta:\n                    yield name, param\n\n        gen = self._named_members(\n            meta_iterator,\n            prefix=prefix,\n            recurse=recurse,\n        )\n        return dict(gen)\n\n    def update(self, params=None):\n        \"\"\"\n        Updates the parameter list before the forward prop so that if `params`\n        is None or doesn't have a certain key, the module uses the default\n        parameter/buffer registered in the module.\n        \"\"\"\n        if params is None:\n            params = AttrDict()\n        params = AttrDict(params)\n        named_params = set([name for name, _ in self.named_parameters()])\n        for name, param in self.named_parameters():\n            params.setdefault(name, param)\n        for name, param in self.state_dict().items():\n            if name not in named_params:\n                params.setdefault(name, param)\n        return params\n\n\ndef batch_meta_parameters(net, batch_size):\n    params = AttrDict()\n    for name, param in net.named_meta_parameters():\n        params[name] = param.clone().unsqueeze(0).repeat(batch_size, *[1] * len(param.shape))\n    return params\n\n\ndef batch_meta_state_dict(net, batch_size):\n    state_dict = AttrDict()\n    meta_parameters = set([name for name, _ in net.named_meta_parameters()])\n    for name, param in net.meta_state_dict().items():\n        state_dict[name] = param.clone().unsqueeze(0).repeat(batch_size, *[1] * len(param.shape))\n    return state_dict\n"
  },
  {
    "path": "shap_e/models/nn/ops.py",
    "content": "import math\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom shap_e.util.collections import AttrDict\n\nfrom .meta import MetaModule, subdict\nfrom .pointnet2_utils import sample_and_group, sample_and_group_all\n\n\ndef gelu(x):\n    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))\n\n\ndef swish(x):\n    return x * torch.sigmoid(x)\n\n\ndef quick_gelu(x):\n    return x * torch.sigmoid(1.702 * x)\n\n\ndef torch_gelu(x):\n    return torch.nn.functional.gelu(x)\n\n\ndef geglu(x):\n    v, gates = x.chunk(2, dim=-1)\n    return v * gelu(gates)\n\n\nclass SirenSin:\n    def __init__(self, w0=30.0):\n        self.w0 = w0\n\n    def __call__(self, x):\n        return torch.sin(self.w0 * x)\n\n\ndef get_act(name):\n    return {\n        \"relu\": torch.nn.functional.relu,\n        \"leaky_relu\": torch.nn.functional.leaky_relu,\n        \"swish\": swish,\n        \"tanh\": torch.tanh,\n        \"gelu\": gelu,\n        \"quick_gelu\": quick_gelu,\n        \"torch_gelu\": torch_gelu,\n        \"gelu2\": quick_gelu,\n        \"geglu\": geglu,\n        \"sigmoid\": torch.sigmoid,\n        \"sin\": torch.sin,\n        \"sin30\": SirenSin(w0=30.0),\n        \"softplus\": F.softplus,\n        \"exp\": torch.exp,\n        \"identity\": lambda x: x,\n    }[name]\n\n\ndef zero_init(affine):\n    nn.init.constant_(affine.weight, 0.0)\n    if affine.bias is not None:\n        nn.init.constant_(affine.bias, 0.0)\n\n\ndef siren_init_first_layer(affine, init_scale: float = 1.0):\n    n_input = affine.weight.shape[1]\n    u = init_scale / n_input\n    nn.init.uniform_(affine.weight, -u, u)\n    if affine.bias is not None:\n        nn.init.constant_(affine.bias, 0.0)\n\n\ndef siren_init(affine, coeff=1.0, init_scale: float = 1.0):\n    n_input = affine.weight.shape[1]\n    u = init_scale * np.sqrt(6.0 / n_input) / coeff\n    nn.init.uniform_(affine.weight, -u, u)\n    if affine.bias is not None:\n        nn.init.constant_(affine.bias, 0.0)\n\n\ndef siren_init_30(affine, init_scale: float = 1.0):\n    siren_init(affine, coeff=30.0, init_scale=init_scale)\n\n\ndef std_init(affine, init_scale: float = 1.0):\n    n_in = affine.weight.shape[1]\n    stddev = init_scale / math.sqrt(n_in)\n    nn.init.normal_(affine.weight, std=stddev)\n    if affine.bias is not None:\n        nn.init.constant_(affine.bias, 0.0)\n\n\ndef mlp_init(affines, init: Optional[str] = None, init_scale: float = 1.0):\n    if init == \"siren30\":\n        for idx, affine in enumerate(affines):\n            init = siren_init_first_layer if idx == 0 else siren_init_30\n            init(affine, init_scale=init_scale)\n    elif init == \"siren\":\n        for idx, affine in enumerate(affines):\n            init = siren_init_first_layer if idx == 0 else siren_init\n            init(affine, init_scale=init_scale)\n    elif init is None:\n        for affine in affines:\n            std_init(affine, init_scale=init_scale)\n    else:\n        raise NotImplementedError(init)\n\n\nclass MetaLinear(MetaModule):\n    def __init__(\n        self,\n        n_in,\n        n_out,\n        bias: bool = True,\n        meta_scale: bool = True,\n        meta_shift: bool = True,\n        meta_proj: bool = False,\n        meta_bias: bool = False,\n        trainable_meta: bool = False,\n        **kwargs,\n    ):\n        super().__init__()\n        # n_in, n_out, bias=bias)\n        register_meta_fn = (\n            self.register_meta_parameter if trainable_meta else self.register_meta_buffer\n        )\n        if meta_scale:\n            register_meta_fn(\"scale\", nn.Parameter(torch.ones(n_out, **kwargs)))\n        if meta_shift:\n            register_meta_fn(\"shift\", nn.Parameter(torch.zeros(n_out, **kwargs)))\n\n        register_proj_fn = self.register_parameter if not meta_proj else register_meta_fn\n        register_proj_fn(\"weight\", nn.Parameter(torch.empty((n_out, n_in), **kwargs)))\n\n        if not bias:\n            self.register_parameter(\"bias\", None)\n        else:\n            register_bias_fn = self.register_parameter if not meta_bias else register_meta_fn\n            register_bias_fn(\"bias\", nn.Parameter(torch.empty(n_out, **kwargs)))\n\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n\n        # from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear\n\n        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with\n        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see\n        # https://github.com/pytorch/pytorch/issues/57109\n        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n        if self.bias is not None:\n            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)\n            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0\n            nn.init.uniform_(self.bias, -bound, bound)\n\n    def _bcast(self, op, left, right):\n        if right.ndim == 2:\n            # Has dimension [batch x d_output]\n            right = right.unsqueeze(1)\n        return op(left, right)\n\n    def forward(self, x, params=None):\n        params = self.update(params)\n\n        batch_size, *shape, d_in = x.shape\n        x = x.view(batch_size, -1, d_in)\n\n        if params.weight.ndim == 2:\n            h = torch.einsum(\"bni,oi->bno\", x, params.weight)\n        elif params.weight.ndim == 3:\n            h = torch.einsum(\"bni,boi->bno\", x, params.weight)\n\n        if params.bias is not None:\n            h = self._bcast(torch.add, h, params.bias)\n\n        if params.scale is not None:\n            h = self._bcast(torch.mul, h, params.scale)\n\n        if params.shift is not None:\n            h = self._bcast(torch.add, h, params.shift)\n\n        h = h.view(batch_size, *shape, -1)\n        return h\n\n\ndef Conv(n_dim, d_in, d_out, kernel, stride=1, padding=0, dilation=1, **kwargs):\n    cls = {\n        1: nn.Conv1d,\n        2: nn.Conv2d,\n        3: nn.Conv3d,\n    }[n_dim]\n    return cls(d_in, d_out, kernel, stride=stride, padding=padding, dilation=dilation, **kwargs)\n\n\ndef flatten(x):\n    batch_size, *shape, n_channels = x.shape\n    n_ctx = np.prod(shape)\n    return x.view(batch_size, n_ctx, n_channels), AttrDict(\n        shape=shape, n_ctx=n_ctx, n_channels=n_channels\n    )\n\n\ndef unflatten(x, info):\n    batch_size = x.shape[0]\n    return x.view(batch_size, *info.shape, info.n_channels)\n\n\ndef torchify(x):\n    extent = list(range(1, x.ndim - 1))\n    return x.permute([0, x.ndim - 1, *extent])\n\n\ndef untorchify(x):\n    extent = list(range(2, x.ndim))\n    return x.permute([0, *extent, 1])\n\n\nclass MLP(nn.Module):\n    def __init__(\n        self,\n        d_input: int,\n        d_hidden: List[int],\n        d_output: int,\n        act_name: str = \"quick_gelu\",\n        bias: bool = True,\n        init: Optional[str] = None,\n        init_scale: float = 1.0,\n        zero_out: bool = False,\n    ):\n        \"\"\"\n        Required: d_input, d_hidden, d_output\n        Optional: act_name, bias\n        \"\"\"\n        super().__init__()\n\n        ds = [d_input] + d_hidden + [d_output]\n        affines = [nn.Linear(d_in, d_out, bias=bias) for d_in, d_out in zip(ds[:-1], ds[1:])]\n        self.d = ds\n        self.affines = nn.ModuleList(affines)\n        self.act = get_act(act_name)\n\n        mlp_init(self.affines, init=init, init_scale=init_scale)\n        if zero_out:\n            zero_init(affines[-1])\n\n    def forward(self, h, options: Optional[AttrDict] = None, log_prefix: str = \"\"):\n        options = AttrDict() if options is None else AttrDict(options)\n        *hid, out = self.affines\n        for i, f in enumerate(hid):\n            h = self.act(f(h))\n        h = out(h)\n        return h\n\n\nclass MetaMLP(MetaModule):\n    def __init__(\n        self,\n        d_input: int,\n        d_hidden: List[int],\n        d_output: int,\n        act_name: str = \"quick_gelu\",\n        bias: bool = True,\n        meta_scale: bool = True,\n        meta_shift: bool = True,\n        meta_proj: bool = False,\n        meta_bias: bool = False,\n        trainable_meta: bool = False,\n        init: Optional[str] = None,\n        init_scale: float = 1.0,\n        zero_out: bool = False,\n    ):\n        super().__init__()\n        ds = [d_input] + d_hidden + [d_output]\n        affines = [\n            MetaLinear(\n                d_in,\n                d_out,\n                bias=bias,\n                meta_scale=meta_scale,\n                meta_shift=meta_shift,\n                meta_proj=meta_proj,\n                meta_bias=meta_bias,\n                trainable_meta=trainable_meta,\n            )\n            for d_in, d_out in zip(ds[:-1], ds[1:])\n        ]\n        self.d = ds\n        self.affines = nn.ModuleList(affines)\n        self.act = get_act(act_name)\n\n        mlp_init(affines, init=init, init_scale=init_scale)\n        if zero_out:\n            zero_init(affines[-1])\n\n    def forward(self, h, params=None, options: Optional[AttrDict] = None, log_prefix: str = \"\"):\n        options = AttrDict() if options is None else AttrDict(options)\n        params = self.update(params)\n        *hid, out = self.affines\n        for i, layer in enumerate(hid):\n            h = self.act(layer(h, params=subdict(params, f\"{log_prefix}affines.{i}\")))\n        last = len(self.affines) - 1\n        h = out(h, params=subdict(params, f\"{log_prefix}affines.{last}\"))\n        return h\n\n\nclass LayerNorm(nn.LayerNorm):\n    def __init__(\n        self, norm_shape: Union[int, Tuple[int]], eps: float = 1e-5, elementwise_affine: bool = True\n    ):\n        super().__init__(norm_shape, eps=eps, elementwise_affine=elementwise_affine)\n        self.width = np.prod(norm_shape)\n        self.max_numel = 65535 * self.width\n\n    def forward(self, input):\n        if input.numel() > self.max_numel:\n            return F.layer_norm(\n                input.float(), self.normalized_shape, self.weight, self.bias, self.eps\n            ).type_as(input)\n        else:\n            return super(LayerNorm, self).forward(input.float()).type_as(input)\n\n\nclass PointSetEmbedding(nn.Module):\n    def __init__(\n        self,\n        *,\n        radius: float,\n        n_point: int,\n        n_sample: int,\n        d_input: int,\n        d_hidden: List[int],\n        patch_size: int = 1,\n        stride: int = 1,\n        activation: str = \"swish\",\n        group_all: bool = False,\n        padding_mode: str = \"zeros\",\n        fps_method: str = \"fps\",\n        **kwargs,\n    ):\n        super().__init__()\n        self.n_point = n_point\n        self.radius = radius\n        self.n_sample = n_sample\n        self.mlp_convs = nn.ModuleList()\n        self.act = get_act(activation)\n        self.patch_size = patch_size\n        self.stride = stride\n        last_channel = d_input + 3\n        for out_channel in d_hidden:\n            self.mlp_convs.append(\n                nn.Conv2d(\n                    last_channel,\n                    out_channel,\n                    kernel_size=(patch_size, 1),\n                    stride=(stride, 1),\n                    padding=(patch_size // 2, 0),\n                    padding_mode=padding_mode,\n                    **kwargs,\n                )\n            )\n            last_channel = out_channel\n        self.group_all = group_all\n        self.fps_method = fps_method\n\n    def forward(self, xyz, points):\n        \"\"\"\n        Input:\n            xyz: input points position data, [B, C, N]\n            points: input points data, [B, D, N]\n        Return:\n            new_points: sample points feature data, [B, d_hidden[-1], n_point]\n        \"\"\"\n        xyz = xyz.permute(0, 2, 1)\n        if points is not None:\n            points = points.permute(0, 2, 1)\n\n        if self.group_all:\n            new_xyz, new_points = sample_and_group_all(xyz, points)\n        else:\n            new_xyz, new_points = sample_and_group(\n                self.n_point,\n                self.radius,\n                self.n_sample,\n                xyz,\n                points,\n                deterministic=not self.training,\n                fps_method=self.fps_method,\n            )\n        # new_xyz: sampled points position data, [B, n_point, C]\n        # new_points: sampled points data, [B, n_point, n_sample, C+D]\n        new_points = new_points.permute(0, 3, 2, 1)  # [B, C+D, n_sample, n_point]\n        for i, conv in enumerate(self.mlp_convs):\n            new_points = self.act(self.apply_conv(new_points, conv))\n\n        new_points = new_points.mean(dim=2)\n        return new_points\n\n    def apply_conv(self, points: torch.Tensor, conv: nn.Module):\n        batch, channels, n_samples, _ = points.shape\n        # Shuffle the representations\n        if self.patch_size > 1:\n            # TODO shuffle deterministically when not self.training\n            _, indices = torch.rand(batch, channels, n_samples, 1, device=points.device).sort(dim=2)\n            points = torch.gather(points, 2, torch.broadcast_to(indices, points.shape))\n        return conv(points)\n"
  },
  {
    "path": "shap_e/models/nn/pointnet2_utils.py",
    "content": "\"\"\"\nBased on https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet2_utils.py\n\nMIT License\n\nCopyright (c) 2019 benny\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\nfrom time import time\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef timeit(tag, t):\n    print(\"{}: {}s\".format(tag, time() - t))\n    return time()\n\n\ndef pc_normalize(pc):\n    l = pc.shape[0]\n    centroid = np.mean(pc, axis=0)\n    pc = pc - centroid\n    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n    pc = pc / m\n    return pc\n\n\ndef square_distance(src, dst):\n    \"\"\"\n    Calculate Euclid distance between each two points.\n\n    src^T * dst = xn * xm + yn * ym + zn * zm;\n    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;\n    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;\n    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2\n         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst\n\n    Input:\n        src: source points, [B, N, C]\n        dst: target points, [B, M, C]\n    Output:\n        dist: per-point square distance, [B, N, M]\n    \"\"\"\n    B, N, _ = src.shape\n    _, M, _ = dst.shape\n    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))\n    dist += torch.sum(src**2, -1).view(B, N, 1)\n    dist += torch.sum(dst**2, -1).view(B, 1, M)\n    return dist\n\n\ndef index_points(points, idx):\n    \"\"\"\n\n    Input:\n        points: input points data, [B, N, C]\n        idx: sample index data, [B, S]\n    Return:\n        new_points:, indexed points data, [B, S, C]\n    \"\"\"\n    device = points.device\n    B = points.shape[0]\n    view_shape = list(idx.shape)\n    view_shape[1:] = [1] * (len(view_shape) - 1)\n    repeat_shape = list(idx.shape)\n    repeat_shape[0] = 1\n    batch_indices = (\n        torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)\n    )\n    new_points = points[batch_indices, idx, :]\n    return new_points\n\n\ndef farthest_point_sample(xyz, npoint, deterministic=False):\n    \"\"\"\n    Input:\n        xyz: pointcloud data, [B, N, 3]\n        npoint: number of samples\n    Return:\n        centroids: sampled pointcloud index, [B, npoint]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)\n    distance = torch.ones(B, N).to(device) * 1e10\n    if deterministic:\n        farthest = torch.arange(0, B, dtype=torch.long).to(device)\n    else:\n        farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)\n    batch_indices = torch.arange(B, dtype=torch.long).to(device)\n    for i in range(npoint):\n        centroids[:, i] = farthest\n        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)\n        dist = torch.sum((xyz - centroid) ** 2, -1)\n        mask = dist < distance\n        distance[mask] = dist[mask]\n        farthest = torch.max(distance, -1)[1]\n    return centroids\n\n\ndef query_ball_point(radius, nsample, xyz, new_xyz):\n    \"\"\"\n    Input:\n        radius: local region radius\n        nsample: max sample number in local region\n        xyz: all points, [B, N, 3]\n        new_xyz: query points, [B, S, 3]\n    Return:\n        group_idx: grouped points index, [B, S, nsample]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    _, S, _ = new_xyz.shape\n    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])\n    sqrdists = square_distance(new_xyz, xyz)\n    group_idx[sqrdists > radius**2] = N\n    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]\n    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])\n    mask = group_idx == N\n    group_idx[mask] = group_first[mask]\n    return group_idx\n\n\ndef sample_and_group(\n    npoint,\n    radius,\n    nsample,\n    xyz,\n    points,\n    returnfps=False,\n    deterministic=False,\n    fps_method: str = \"fps\",\n):\n    \"\"\"\n    Input:\n        npoint:\n        radius:\n        nsample:\n        xyz: input points position data, [B, N, 3]\n        points: input points data, [B, N, D]\n    Return:\n        new_xyz: sampled points position data, [B, npoint, nsample, 3]\n        new_points: sampled points data, [B, npoint, nsample, 3+D]\n    \"\"\"\n    B, N, C = xyz.shape\n    S = npoint\n    if fps_method == \"fps\":\n        fps_idx = farthest_point_sample(xyz, npoint, deterministic=deterministic)  # [B, npoint, C]\n    elif fps_method == \"first\":\n        fps_idx = torch.arange(npoint)[None].repeat(B, 1)\n    else:\n        raise ValueError(f\"Unknown FPS method: {fps_method}\")\n    new_xyz = index_points(xyz, fps_idx)\n    idx = query_ball_point(radius, nsample, xyz, new_xyz)\n    grouped_xyz = index_points(xyz, idx)  # [B, npoint, nsample, C]\n    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)\n\n    if points is not None:\n        grouped_points = index_points(points, idx)\n        new_points = torch.cat(\n            [grouped_xyz_norm, grouped_points], dim=-1\n        )  # [B, npoint, nsample, C+D]\n    else:\n        new_points = grouped_xyz_norm\n    if returnfps:\n        return new_xyz, new_points, grouped_xyz, fps_idx\n    else:\n        return new_xyz, new_points\n\n\ndef sample_and_group_all(xyz, points):\n    \"\"\"\n    Input:\n        xyz: input points position data, [B, N, 3]\n        points: input points data, [B, N, D]\n    Return:\n        new_xyz: sampled points position data, [B, 1, 3]\n        new_points: sampled points data, [B, 1, N, 3+D]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    new_xyz = torch.zeros(B, 1, C).to(device)\n    grouped_xyz = xyz.view(B, 1, N, C)\n    if points is not None:\n        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)\n    else:\n        new_points = grouped_xyz\n    return new_xyz, new_points\n\n\nclass PointNetSetAbstraction(nn.Module):\n    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):\n        super(PointNetSetAbstraction, self).__init__()\n        self.npoint = npoint\n        self.radius = radius\n        self.nsample = nsample\n        self.mlp_convs = nn.ModuleList()\n        self.mlp_bns = nn.ModuleList()\n        last_channel = in_channel\n        for out_channel in mlp:\n            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))\n            self.mlp_bns.append(nn.BatchNorm2d(out_channel))\n            last_channel = out_channel\n        self.group_all = group_all\n\n    def forward(self, xyz, points):\n        \"\"\"\n        Input:\n            xyz: input points position data, [B, C, N]\n            points: input points data, [B, D, N]\n        Return:\n            new_xyz: sampled points position data, [B, C, S]\n            new_points_concat: sample points feature data, [B, D', S]\n        \"\"\"\n        xyz = xyz.permute(0, 2, 1)\n        if points is not None:\n            points = points.permute(0, 2, 1)\n\n        if self.group_all:\n            new_xyz, new_points = sample_and_group_all(xyz, points)\n        else:\n            new_xyz, new_points = sample_and_group(\n                self.npoint, self.radius, self.nsample, xyz, points, deterministic=not self.training\n            )\n        # new_xyz: sampled points position data, [B, npoint, C]\n        # new_points: sampled points data, [B, npoint, nsample, C+D]\n        new_points = new_points.permute(0, 3, 2, 1)  # [B, C+D, nsample,npoint]\n        for i, conv in enumerate(self.mlp_convs):\n            bn = self.mlp_bns[i]\n            new_points = F.relu(bn(conv(new_points)))\n\n        new_points = torch.max(new_points, 2)[0]\n        new_xyz = new_xyz.permute(0, 2, 1)\n        return new_xyz, new_points\n\n\nclass PointNetSetAbstractionMsg(nn.Module):\n    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):\n        super(PointNetSetAbstractionMsg, self).__init__()\n        self.npoint = npoint\n        self.radius_list = radius_list\n        self.nsample_list = nsample_list\n        self.conv_blocks = nn.ModuleList()\n        self.bn_blocks = nn.ModuleList()\n        for i in range(len(mlp_list)):\n            convs = nn.ModuleList()\n            bns = nn.ModuleList()\n            last_channel = in_channel + 3\n            for out_channel in mlp_list[i]:\n                convs.append(nn.Conv2d(last_channel, out_channel, 1))\n                bns.append(nn.BatchNorm2d(out_channel))\n                last_channel = out_channel\n            self.conv_blocks.append(convs)\n            self.bn_blocks.append(bns)\n\n    def forward(self, xyz, points):\n        \"\"\"\n        Input:\n            xyz: input points position data, [B, C, N]\n            points: input points data, [B, D, N]\n        Return:\n            new_xyz: sampled points position data, [B, C, S]\n            new_points_concat: sample points feature data, [B, D', S]\n        \"\"\"\n        xyz = xyz.permute(0, 2, 1)\n        if points is not None:\n            points = points.permute(0, 2, 1)\n\n        B, N, C = xyz.shape\n        S = self.npoint\n        new_xyz = index_points(xyz, farthest_point_sample(xyz, S, deterministic=not self.training))\n        new_points_list = []\n        for i, radius in enumerate(self.radius_list):\n            K = self.nsample_list[i]\n            group_idx = query_ball_point(radius, K, xyz, new_xyz)\n            grouped_xyz = index_points(xyz, group_idx)\n            grouped_xyz -= new_xyz.view(B, S, 1, C)\n            if points is not None:\n                grouped_points = index_points(points, group_idx)\n                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)\n            else:\n                grouped_points = grouped_xyz\n\n            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]\n            for j in range(len(self.conv_blocks[i])):\n                conv = self.conv_blocks[i][j]\n                bn = self.bn_blocks[i][j]\n                grouped_points = F.relu(bn(conv(grouped_points)))\n            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]\n            new_points_list.append(new_points)\n\n        new_xyz = new_xyz.permute(0, 2, 1)\n        new_points_concat = torch.cat(new_points_list, dim=1)\n        return new_xyz, new_points_concat\n\n\nclass PointNetFeaturePropagation(nn.Module):\n    def __init__(self, in_channel, mlp):\n        super(PointNetFeaturePropagation, self).__init__()\n        self.mlp_convs = nn.ModuleList()\n        self.mlp_bns = nn.ModuleList()\n        last_channel = in_channel\n        for out_channel in mlp:\n            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))\n            self.mlp_bns.append(nn.BatchNorm1d(out_channel))\n            last_channel = out_channel\n\n    def forward(self, xyz1, xyz2, points1, points2):\n        \"\"\"\n        Input:\n            xyz1: input points position data, [B, C, N]\n            xyz2: sampled input points position data, [B, C, S]\n            points1: input points data, [B, D, N]\n            points2: input points data, [B, D, S]\n        Return:\n            new_points: upsampled points data, [B, D', N]\n        \"\"\"\n        xyz1 = xyz1.permute(0, 2, 1)\n        xyz2 = xyz2.permute(0, 2, 1)\n\n        points2 = points2.permute(0, 2, 1)\n        B, N, C = xyz1.shape\n        _, S, _ = xyz2.shape\n\n        if S == 1:\n            interpolated_points = points2.repeat(1, N, 1)\n        else:\n            dists = square_distance(xyz1, xyz2)\n            dists, idx = dists.sort(dim=-1)\n            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]\n\n            dist_recip = 1.0 / (dists + 1e-8)\n            norm = torch.sum(dist_recip, dim=2, keepdim=True)\n            weight = dist_recip / norm\n            interpolated_points = torch.sum(\n                index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2\n            )\n\n        if points1 is not None:\n            points1 = points1.permute(0, 2, 1)\n            new_points = torch.cat([points1, interpolated_points], dim=-1)\n        else:\n            new_points = interpolated_points\n\n        new_points = new_points.permute(0, 2, 1)\n        for i, conv in enumerate(self.mlp_convs):\n            bn = self.mlp_bns[i]\n            new_points = F.relu(bn(conv(new_points)))\n        return new_points\n"
  },
  {
    "path": "shap_e/models/nn/utils.py",
    "content": "from typing import Iterable, Union\n\nimport numpy as np\nimport torch\n\nArrayType = Union[np.ndarray, Iterable[int], torch.Tensor]\n\n\ndef to_torch(arr: ArrayType, dtype=torch.float):\n    if isinstance(arr, torch.Tensor):\n        return arr\n    return torch.from_numpy(np.array(arr)).to(dtype)\n\n\ndef sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor:\n    \"\"\"\n    Sample from the given discrete probability distribution with replacement.\n\n    The i-th bin is assumed to have mass pmf[i].\n\n    :param pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all()\n    :param n_samples: number of samples\n\n    :return: indices sampled with replacement\n    \"\"\"\n\n    *shape, support_size, last_dim = pmf.shape\n    assert last_dim == 1\n\n    cdf = torch.cumsum(pmf.view(-1, support_size), dim=1)\n    inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device))\n\n    return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1)\n\n\ndef safe_divide(a, b, epsilon=1e-6):\n    return a / torch.where(b < 0, b - epsilon, b + epsilon)\n"
  },
  {
    "path": "shap_e/models/query.py",
    "content": "from dataclasses import dataclass\nfrom typing import Callable, Optional\n\nimport torch\n\n\n@dataclass\nclass Query:\n    # Both of these are of shape [batch_size x ... x 3]\n    position: torch.Tensor\n    direction: Optional[torch.Tensor] = None\n\n    t_min: Optional[torch.Tensor] = None\n    t_max: Optional[torch.Tensor] = None\n\n    def copy(self) -> \"Query\":\n        return Query(\n            position=self.position,\n            direction=self.direction,\n            t_min=self.t_min,\n            t_max=self.t_max,\n        )\n\n    def map_tensors(self, f: Callable[[torch.Tensor], torch.Tensor]) -> \"Query\":\n        return Query(\n            position=f(self.position),\n            direction=f(self.direction) if self.direction is not None else None,\n            t_min=f(self.t_min) if self.t_min is not None else None,\n            t_max=f(self.t_max) if self.t_max is not None else None,\n        )\n"
  },
  {
    "path": "shap_e/models/renderer.py",
    "content": "from abc import abstractmethod\nfrom typing import Callable, Dict, List, Optional, Tuple\n\nimport numpy as np\nimport torch\n\nfrom shap_e.models.nn.camera import (\n    DifferentiableCamera,\n    DifferentiableProjectiveCamera,\n    get_image_coords,\n    projective_camera_frame,\n)\nfrom shap_e.models.nn.meta import MetaModule\nfrom shap_e.util.collections import AttrDict\n\n\nclass Renderer(MetaModule):\n    \"\"\"\n    A rendering abstraction that can render rays and views by calling the\n    appropriate models. The models are instantiated outside but registered in\n    this module.\n    \"\"\"\n\n    @abstractmethod\n    def render_views(\n        self,\n        batch: AttrDict,\n        params: Optional[Dict] = None,\n        options: Optional[Dict] = None,\n    ) -> AttrDict:\n        \"\"\"\n        Returns a backproppable rendering of a view\n\n        :param batch: contains\n            - height: Optional[int]\n            - width: Optional[int]\n            - inner_batch_size or ray_batch_size: Optional[int] defaults to 4096 rays\n\n            And additionally, to specify poses with a default up direction:\n            - poses: [batch_size x *shape x 2 x 3] where poses[:, ..., 0, :] are the camera\n                positions, and poses[:, ..., 1, :] are the z-axis (toward the object) of\n                the camera frame.\n            - camera: DifferentiableCamera. Assumes the same camera position\n                across batch for simplicity.  Could eventually support\n                batched cameras.\n\n            or to specify a batch of arbitrary poses:\n            - cameras: DifferentiableCameraBatch of shape [batch_size x *shape].\n\n        :param params: Meta parameters\n        :param options: Optional[Dict]\n        \"\"\"\n\n\nclass RayRenderer(Renderer):\n    \"\"\"\n    A rendering abstraction that can render rays and views by calling the\n    appropriate models. The models are instantiated outside but registered in\n    this module.\n    \"\"\"\n\n    @abstractmethod\n    def render_rays(\n        self,\n        batch: AttrDict,\n        params: Optional[Dict] = None,\n        options: Optional[Dict] = None,\n    ) -> AttrDict:\n        \"\"\"\n        :param batch: has\n            - rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray.\n            - radii (optional): [batch_size x ... x 1] the \"thickness\" of each ray.\n        :param options: Optional[Dict]\n        \"\"\"\n\n    def render_views(\n        self,\n        batch: AttrDict,\n        params: Optional[Dict] = None,\n        options: Optional[Dict] = None,\n        device: torch.device = torch.device(\"cuda\"),\n    ) -> AttrDict:\n        output = render_views_from_rays(\n            self.render_rays,\n            batch,\n            params=params,\n            options=options,\n            device=self.device,\n        )\n        return output\n\n    def forward(\n        self,\n        batch: AttrDict,\n        params: Optional[Dict] = None,\n        options: Optional[Dict] = None,\n    ) -> AttrDict:\n        \"\"\"\n        :param batch: must contain either\n\n            - rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray.\n\n            or\n\n            - poses: [batch_size x 2 x 3] where poses[:, 0] are the camera\n                positions, and poses[:, 1] are the z-axis (toward the object) of\n                the camera frame.\n            - camera: an instance of Camera that implements camera_rays\n\n            or\n\n            - cameras: DifferentiableCameraBatch of shape [batch_size x *shape].\n\n            For both of the above two options, these may be specified.\n            - height: Optional[int]\n            - width: Optional[int]\n            - ray_batch_size or inner_batch_size: Optional[int] defaults to 4096 rays\n\n        :param params: a dictionary of optional meta parameters.\n        :param options: A Dict of other hyperparameters that could be\n            related to rendering or debugging\n\n        :return: a dictionary containing\n\n            - channels: [batch_size, *shape, n_channels]\n            - distances: [batch_size, *shape, 1]\n            - transmittance: [batch_size, *shape, 1]\n            - aux_losses: Dict[str, torch.Tensor]\n        \"\"\"\n\n        if \"rays\" in batch:\n            for key in [\"poses\", \"camera\", \"height\", \"width\"]:\n                assert key not in batch\n            return self.render_rays(batch, params=params, options=options)\n        elif \"poses\" in batch or \"cameras\" in batch:\n            assert \"rays\" not in batch\n            if \"poses\" in batch:\n                assert \"camera\" in batch\n            else:\n                assert \"camera\" not in batch\n            return self.render_views(batch, params=params, options=options)\n\n        raise NotImplementedError\n\n\ndef get_camera_from_batch(batch: AttrDict) -> Tuple[DifferentiableCamera, int, Tuple[int]]:\n    if \"poses\" in batch:\n        assert not \"cameras\" in batch\n        batch_size, *inner_shape, n_vecs, spatial_dim = batch.poses.shape\n        assert n_vecs == 2 and spatial_dim == 3\n        inner_batch_size = int(np.prod(inner_shape))\n        poses = batch.poses.view(batch_size * inner_batch_size, 2, 3)\n        position, direction = poses[:, 0], poses[:, 1]\n        camera = projective_camera_frame(position, direction, batch.camera)\n    elif \"cameras\" in batch:\n        assert not \"camera\" in batch\n        batch_size, *inner_shape = batch.cameras.shape\n        camera = batch.cameras.flat_camera\n    else:\n        raise ValueError(f'neither \"poses\" nor \"cameras\" found in keys: {batch.keys()}')\n    if \"height\" in batch and \"width\" in batch:\n        camera = camera.resize_image(batch.width, batch.height)\n    return camera, batch_size, inner_shape\n\n\ndef append_tensor(val_list: Optional[List[torch.Tensor]], output: Optional[torch.Tensor]):\n    if val_list is None:\n        return [output]\n    return val_list + [output]\n\n\ndef render_views_from_rays(\n    render_rays: Callable[[AttrDict, AttrDict, AttrDict], AttrDict],\n    batch: AttrDict,\n    params: Optional[Dict] = None,\n    options: Optional[Dict] = None,\n    device: torch.device = torch.device(\"cuda\"),\n) -> AttrDict:\n    camera, batch_size, inner_shape = get_camera_from_batch(batch)\n    inner_batch_size = int(np.prod(inner_shape))\n\n    coords = get_image_coords(camera.width, camera.height).to(device)\n    coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape])\n    rays = camera.camera_rays(coords)\n\n    # mip-NeRF radii calculation from: https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/datasets.py#L193-L200\n    directions = rays.view(batch_size, inner_batch_size, camera.height, camera.width, 2, 3)[\n        ..., 1, :\n    ]\n    neighbor_dists = torch.linalg.norm(directions[:, :, :, 1:] - directions[:, :, :, :-1], dim=-1)\n    neighbor_dists = torch.cat([neighbor_dists, neighbor_dists[:, :, :, -2:-1]], dim=3)\n    radii = (neighbor_dists * 2 / np.sqrt(12)).view(batch_size, -1, 1)\n\n    rays = rays.view(batch_size, inner_batch_size * camera.height * camera.width, 2, 3)\n\n    if isinstance(camera, DifferentiableProjectiveCamera):\n        # Compute the camera z direction corresponding to every ray's pixel.\n        # Used for depth computations below.\n        z_directions = (\n            (camera.z / torch.linalg.norm(camera.z, dim=-1, keepdim=True))\n            .reshape([batch_size, inner_batch_size, 1, 3])\n            .repeat(1, 1, camera.width * camera.height, 1)\n            .reshape(1, inner_batch_size * camera.height * camera.width, 3)\n        )\n\n    ray_batch_size = batch.get(\"ray_batch_size\", batch.get(\"inner_batch_size\", 4096))\n    assert rays.shape[1] % ray_batch_size == 0\n    n_batches = rays.shape[1] // ray_batch_size\n\n    output_list = AttrDict(aux_losses=dict())\n\n    for idx in range(n_batches):\n        rays_batch = AttrDict(\n            rays=rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size],\n            radii=radii[:, idx * ray_batch_size : (idx + 1) * ray_batch_size],\n        )\n        output = render_rays(rays_batch, params=params, options=options)\n\n        if isinstance(camera, DifferentiableProjectiveCamera):\n            z_batch = z_directions[:, idx * ray_batch_size : (idx + 1) * ray_batch_size]\n            ray_directions = rays_batch.rays[:, :, 1]\n            z_dots = (ray_directions * z_batch).sum(-1, keepdim=True)\n            output.depth = output.distances * z_dots\n\n        output_list = output_list.combine(output, append_tensor)\n\n    def _resize(val_list: List[torch.Tensor]):\n        val = torch.cat(val_list, dim=1)\n        assert val.shape[1] == inner_batch_size * camera.height * camera.width\n        return val.view(batch_size, *inner_shape, camera.height, camera.width, -1)\n\n    def _avg(_key: str, loss_list: List[torch.Tensor]):\n        return sum(loss_list) / n_batches\n\n    output = AttrDict(\n        {name: _resize(val_list) for name, val_list in output_list.items() if name != \"aux_losses\"}\n    )\n    output.aux_losses = output_list.aux_losses.map(_avg)\n\n    return output\n"
  },
  {
    "path": "shap_e/models/stf/__init__.py",
    "content": ""
  },
  {
    "path": "shap_e/models/stf/base.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Any, Dict, Optional\n\nimport torch\n\nfrom shap_e.models.query import Query\nfrom shap_e.models.renderer import append_tensor\nfrom shap_e.util.collections import AttrDict\n\n\nclass Model(ABC):\n    @abstractmethod\n    def forward(\n        self,\n        query: Query,\n        params: Optional[Dict[str, torch.Tensor]] = None,\n        options: Optional[Dict[str, Any]] = None,\n    ) -> AttrDict[str, Any]:\n        \"\"\"\n        Predict an attribute given position\n        \"\"\"\n\n    def forward_batched(\n        self,\n        query: Query,\n        query_batch_size: int = 4096,\n        params: Optional[Dict[str, torch.Tensor]] = None,\n        options: Optional[Dict[str, Any]] = None,\n    ) -> AttrDict[str, Any]:\n        if not query.position.numel():\n            # Avoid torch.cat() of zero tensors.\n            return self(query, params=params, options=options)\n\n        if options.cache is None:\n            created_cache = True\n            options.cache = AttrDict()\n        else:\n            created_cache = False\n\n        results_list = AttrDict()\n        for i in range(0, query.position.shape[1], query_batch_size):\n            out = self(\n                query=query.map_tensors(lambda x, i=i: x[:, i : i + query_batch_size]),\n                params=params,\n                options=options,\n            )\n            results_list = results_list.combine(out, append_tensor)\n\n        if created_cache:\n            del options[\"cache\"]\n\n        return results_list.map(lambda key, tensor_list: torch.cat(tensor_list, dim=1))\n"
  },
  {
    "path": "shap_e/models/stf/mlp.py",
    "content": "from functools import partial\nfrom typing import Any, Dict, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\n\nfrom shap_e.models.nn.checkpoint import checkpoint\nfrom shap_e.models.nn.encoding import encode_position, maybe_encode_direction\nfrom shap_e.models.nn.meta import MetaModule, subdict\nfrom shap_e.models.nn.ops import MetaLinear, get_act, mlp_init\nfrom shap_e.models.query import Query\nfrom shap_e.util.collections import AttrDict\n\nfrom .base import Model\n\n\nclass MLPModel(MetaModule, Model):\n    def __init__(\n        self,\n        n_output: int,\n        output_activation: str,\n        # Positional encoding parameters\n        posenc_version: str = \"v1\",\n        # Direction related channel prediction\n        insert_direction_at: Optional[int] = None,\n        # MLP parameters\n        d_hidden: int = 256,\n        n_hidden_layers: int = 4,\n        activation: str = \"relu\",\n        init: Optional[str] = None,\n        init_scale: float = 1.0,\n        meta_parameters: bool = False,\n        trainable_meta: bool = False,\n        meta_proj: bool = True,\n        meta_bias: bool = True,\n        meta_start: int = 0,\n        meta_stop: Optional[int] = None,\n        n_meta_layers: Optional[int] = None,\n        register_freqs: bool = False,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        super().__init__()\n\n        if register_freqs:\n            self.register_buffer(\"freqs\", 2.0 ** torch.arange(10, device=device).view(1, 10))\n\n        # Positional encoding\n        self.posenc_version = posenc_version\n        dummy = torch.eye(1, 3)\n        d_posenc_pos = encode_position(posenc_version, position=dummy).shape[-1]\n        d_posenc_dir = maybe_encode_direction(posenc_version, position=dummy).shape[-1]\n\n        # Instantiate the MLP\n        mlp_widths = [d_hidden] * n_hidden_layers\n        input_widths = [d_posenc_pos, *mlp_widths]\n        output_widths = mlp_widths + [n_output]\n\n        self.meta_parameters = meta_parameters\n\n        # When this model is used jointly to express NeRF, it may have to\n        # process directions as well in which case we simply concatenate\n        # the direction representation at the specified layer.\n        self.insert_direction_at = insert_direction_at\n        if insert_direction_at is not None:\n            input_widths[self.insert_direction_at] += d_posenc_dir\n\n        linear_cls = lambda meta: (\n            partial(\n                MetaLinear,\n                meta_scale=False,\n                meta_shift=False,\n                meta_proj=meta_proj,\n                meta_bias=meta_bias,\n                trainable_meta=trainable_meta,\n            )\n            if meta\n            else nn.Linear\n        )\n\n        if meta_stop is None:\n            if n_meta_layers is not None:\n                assert n_meta_layers > 0\n                meta_stop = meta_start + n_meta_layers - 1\n            else:\n                meta_stop = n_hidden_layers\n\n        if meta_parameters:\n            metas = [meta_start <= layer <= meta_stop for layer in range(n_hidden_layers + 1)]\n        else:\n            metas = [False] * (n_hidden_layers + 1)\n\n        self.mlp = nn.ModuleList(\n            [\n                linear_cls(meta)(d_in, d_out, device=device)\n                for meta, d_in, d_out in zip(metas, input_widths, output_widths)\n            ]\n        )\n\n        mlp_init(self.mlp, init=init, init_scale=init_scale)\n\n        self.activation = get_act(activation)\n        self.output_activation = get_act(output_activation)\n\n        self.device = device\n        self.to(device)\n\n    def forward(\n        self,\n        query: Query,\n        params: Optional[Dict[str, torch.Tensor]] = None,\n        options: Optional[Dict[str, Any]] = None,\n    ) -> AttrDict:\n        \"\"\"\n        :param position: [batch_size x ... x 3]\n        :param params: Meta parameters\n        :param options: Optional hyperparameters\n        \"\"\"\n\n        # query.direction is None typically for SDF models and training\n        h_final, _h_directionless = self._mlp(\n            query.position, query.direction, params=params, options=options\n        )\n        return self.output_activation(h_final)\n\n    def _run_mlp(\n        self, position: torch.Tensor, direction: torch.Tensor, params: AttrDict[str, torch.Tensor]\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        :return: the final and directionless activations at the given query\n        \"\"\"\n        h_preact = h = encode_position(self.posenc_version, position=position)\n        h_directionless = None\n        for i, layer in enumerate(self.mlp):\n            if i == self.insert_direction_at:\n                h_directionless = h_preact\n                h_direction = maybe_encode_direction(\n                    self.posenc_version, position=position, direction=direction\n                )\n                h = torch.cat([h, h_direction], dim=-1)\n            if isinstance(layer, MetaLinear):\n                h = layer(h, params=subdict(params, f\"mlp.{i}\"))\n            else:\n                h = layer(h)\n            h_preact = h\n            if i < len(self.mlp) - 1:\n                h = self.activation(h)\n        h_final = h\n        if h_directionless is None:\n            h_directionless = h_preact\n        return h_final, h_directionless\n\n    def _mlp(\n        self,\n        position: torch.Tensor,\n        direction: Optional[torch.Tensor] = None,\n        params: Optional[Dict[str, torch.Tensor]] = None,\n        options: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        :param position: [batch_size x ... x 3]\n        :param params: Meta parameters\n        :param options: Optional hyperparameters\n        :return: the final and directionless activations at the given query\n        \"\"\"\n        params = self.update(params)\n        options = AttrDict() if options is None else AttrDict(options)\n\n        mlp = partial(self._run_mlp, direction=direction, params=params)\n        parameters = []\n        for i, layer in enumerate(self.mlp):\n            if isinstance(layer, MetaLinear):\n                parameters.extend(list(subdict(params, f\"mlp.{i}\").values()))\n            else:\n                parameters.extend(layer.parameters())\n\n        h_final, h_directionless = checkpoint(\n            mlp, (position,), parameters, options.checkpoint_stf_model\n        )\n\n        return h_final, h_directionless\n\n\nclass MLPSDFModel(MLPModel):\n    def __init__(self, initial_bias: float = -0.1, **kwargs):\n        super().__init__(n_output=1, output_activation=\"identity\", **kwargs)\n        self.mlp[-1].bias.data.fill_(initial_bias)\n\n    def forward(\n        self,\n        query: Query,\n        params: Optional[Dict[str, torch.Tensor]] = None,\n        options: Optional[Dict[str, Any]] = None,\n    ) -> AttrDict[str, Any]:\n        signed_distance = super().forward(query=query, params=params, options=options)\n        return AttrDict(signed_distance=signed_distance)\n\n\nclass MLPTextureFieldModel(MLPModel):\n    def __init__(\n        self,\n        n_channels: int = 3,\n        **kwargs,\n    ):\n        super().__init__(n_output=n_channels, output_activation=\"sigmoid\", **kwargs)\n\n    def forward(\n        self,\n        query: Query,\n        params: Optional[Dict[str, torch.Tensor]] = None,\n        options: Optional[Dict[str, Any]] = None,\n    ) -> AttrDict[str, Any]:\n        channels = super().forward(query=query, params=params, options=options)\n        return AttrDict(channels=channels)\n"
  },
  {
    "path": "shap_e/models/stf/renderer.py",
    "content": "import warnings\nfrom abc import ABC, abstractmethod\nfrom functools import partial\nfrom typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom shap_e.models.nn.camera import DifferentiableCamera, DifferentiableProjectiveCamera\nfrom shap_e.models.nn.meta import subdict\nfrom shap_e.models.nn.utils import to_torch\nfrom shap_e.models.query import Query\nfrom shap_e.models.renderer import Renderer, get_camera_from_batch\nfrom shap_e.models.volume import BoundingBoxVolume, Volume\nfrom shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR\nfrom shap_e.rendering.mc import marching_cubes\nfrom shap_e.rendering.torch_mesh import TorchMesh\nfrom shap_e.rendering.view_data import ProjectiveCamera\nfrom shap_e.util.collections import AttrDict\n\nfrom .base import Model\n\n\nclass STFRendererBase(ABC):\n    @abstractmethod\n    def get_signed_distance(\n        self,\n        position: torch.Tensor,\n        params: Dict[str, torch.Tensor],\n        options: AttrDict[str, Any],\n    ) -> torch.Tensor:\n        pass\n\n    @abstractmethod\n    def get_texture(\n        self,\n        position: torch.Tensor,\n        params: Dict[str, torch.Tensor],\n        options: AttrDict[str, Any],\n    ) -> torch.Tensor:\n        pass\n\n\nclass STFRenderer(Renderer, STFRendererBase):\n    def __init__(\n        self,\n        sdf: Model,\n        tf: Model,\n        volume: Volume,\n        grid_size: int,\n        texture_channels: Sequence[str] = (\"R\", \"G\", \"B\"),\n        channel_scale: Sequence[float] = (255.0, 255.0, 255.0),\n        ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR,\n        diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR,\n        specular_color: Union[float, Tuple[float]] = 0.0,\n        output_srgb: bool = True,\n        device: torch.device = torch.device(\"cuda\"),\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        assert isinstance(volume, BoundingBoxVolume), \"cannot sample points in unknown volume\"\n        self.sdf = sdf\n        self.tf = tf\n        self.volume = volume\n        self.grid_size = grid_size\n        self.texture_channels = texture_channels\n        self.channel_scale = to_torch(channel_scale).to(device)\n        self.ambient_color = ambient_color\n        self.diffuse_color = diffuse_color\n        self.specular_color = specular_color\n        self.output_srgb = output_srgb\n        self.device = device\n        self.to(device)\n\n    def render_views(\n        self,\n        batch: Dict,\n        params: Optional[Dict] = None,\n        options: Optional[Dict] = None,\n    ) -> AttrDict:\n        params = self.update(params)\n        options = AttrDict() if not options else AttrDict(options)\n\n        sdf_fn = partial(self.sdf.forward_batched, params=subdict(params, \"sdf\"))\n        tf_fn = partial(self.tf.forward_batched, params=subdict(params, \"tf\"))\n        nerstf_fn = None\n\n        return render_views_from_stf(\n            batch,\n            options,\n            sdf_fn=sdf_fn,\n            tf_fn=tf_fn,\n            nerstf_fn=nerstf_fn,\n            volume=self.volume,\n            grid_size=self.grid_size,\n            channel_scale=self.channel_scale,\n            texture_channels=self.texture_channels,\n            ambient_color=self.ambient_color,\n            diffuse_color=self.diffuse_color,\n            specular_color=self.specular_color,\n            output_srgb=self.output_srgb,\n            device=self.device,\n        )\n\n    def get_signed_distance(\n        self,\n        query: Query,\n        params: Dict[str, torch.Tensor],\n        options: AttrDict[str, Any],\n    ) -> torch.Tensor:\n        return self.sdf(\n            query,\n            params=subdict(params, \"sdf\"),\n            options=options,\n        ).signed_distance\n\n    def get_texture(\n        self,\n        query: Query,\n        params: Dict[str, torch.Tensor],\n        options: AttrDict[str, Any],\n    ) -> torch.Tensor:\n        return self.tf(\n            query,\n            params=subdict(params, \"tf\"),\n            options=options,\n        ).channels\n\n\ndef render_views_from_stf(\n    batch: Dict,\n    options: AttrDict[str, Any],\n    *,\n    sdf_fn: Optional[Callable],\n    tf_fn: Optional[Callable],\n    nerstf_fn: Optional[Callable],\n    volume: BoundingBoxVolume,\n    grid_size: int,\n    channel_scale: torch.Tensor,\n    texture_channels: Sequence[str] = (\"R\", \"G\", \"B\"),\n    ambient_color: Union[float, Tuple[float]] = 0.0,\n    diffuse_color: Union[float, Tuple[float]] = 1.0,\n    specular_color: Union[float, Tuple[float]] = 0.2,\n    output_srgb: bool = False,\n    device: torch.device = torch.device(\"cuda\"),\n) -> AttrDict:\n    \"\"\"\n    :param batch: contains either [\"poses\", \"camera\"], or [\"cameras\"]. Can\n        optionally contain any of [\"height\", \"width\", \"query_batch_size\"]\n    :param options: controls checkpointing, caching, and rendering\n    :param sdf_fn: returns [batch_size, query_batch_size, n_output] where\n        n_output >= 1.\n    :param tf_fn: returns [batch_size, query_batch_size, n_channels]\n    :param volume: AABB volume\n    :param grid_size: SDF sampling resolution\n    :param texture_channels: what texture to predict\n    :param channel_scale: how each channel is scaled\n    :return: at least\n        channels: [batch_size, len(cameras), height, width, 3]\n        transmittance: [batch_size, len(cameras), height, width, 1]\n        aux_losses: AttrDict[str, torch.Tensor]\n    \"\"\"\n    camera, batch_size, inner_shape = get_camera_from_batch(batch)\n    inner_batch_size = int(np.prod(inner_shape))\n    assert camera.width == camera.height, \"only square views are supported\"\n    assert camera.x_fov == camera.y_fov, \"only square views are supported\"\n    assert isinstance(camera, DifferentiableProjectiveCamera)\n\n    device = camera.origin.device\n    device_type = device.type\n\n    TO_CACHE = [\"fields\", \"raw_meshes\", \"raw_signed_distance\", \"raw_density\", \"mesh_mask\", \"meshes\"]\n    if options.cache is not None and all(key in options.cache for key in TO_CACHE):\n        fields = options.cache.fields\n        raw_meshes = options.cache.raw_meshes\n        raw_signed_distance = options.cache.raw_signed_distance\n        raw_density = options.cache.raw_density\n        mesh_mask = options.cache.mesh_mask\n    else:\n        query_batch_size = batch.get(\"query_batch_size\", batch.get(\"ray_batch_size\", 4096))\n        query_points = volume_query_points(volume, grid_size)\n        fn = nerstf_fn if sdf_fn is None else sdf_fn\n        sdf_out = fn(\n            query=Query(position=query_points[None].repeat(batch_size, 1, 1)),\n            query_batch_size=query_batch_size,\n            options=options,\n        )\n        raw_signed_distance = sdf_out.signed_distance\n        raw_density = None\n        if \"density\" in sdf_out:\n            raw_density = sdf_out.density\n        with torch.autocast(device_type, enabled=False):\n            fields = sdf_out.signed_distance.float()\n            raw_signed_distance = sdf_out.signed_distance\n            assert (\n                len(fields.shape) == 3 and fields.shape[-1] == 1\n            ), f\"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}\"\n            fields = fields.reshape(batch_size, *([grid_size] * 3))\n\n            # Force a negative border around the SDFs to close off all the models.\n            full_grid = torch.zeros(\n                batch_size,\n                grid_size + 2,\n                grid_size + 2,\n                grid_size + 2,\n                device=fields.device,\n                dtype=fields.dtype,\n            )\n            full_grid.fill_(-1.0)\n            full_grid[:, 1:-1, 1:-1, 1:-1] = fields\n            fields = full_grid\n\n            raw_meshes = []\n            mesh_mask = []\n            for field in fields:\n                raw_mesh = marching_cubes(field, volume.bbox_min, volume.bbox_max - volume.bbox_min)\n                if len(raw_mesh.faces) == 0:\n                    # DDP deadlocks when there are unused parameters on some ranks\n                    # and not others, so we make sure the field is a dependency in\n                    # the graph regardless of empty meshes.\n                    vertex_dependency = field.mean()\n                    raw_mesh = TorchMesh(\n                        verts=torch.zeros(3, 3, device=device) + vertex_dependency,\n                        faces=torch.tensor([[0, 1, 2]], dtype=torch.long, device=device),\n                    )\n                    # Make sure we only feed back zero gradients to the field\n                    # by masking out the final renderings of this mesh.\n                    mesh_mask.append(False)\n                else:\n                    mesh_mask.append(True)\n                raw_meshes.append(raw_mesh)\n            mesh_mask = torch.tensor(mesh_mask, device=device)\n\n        max_vertices = max(len(m.verts) for m in raw_meshes)\n\n        fn = nerstf_fn if tf_fn is None else tf_fn\n        tf_out = fn(\n            query=Query(\n                position=torch.stack(\n                    [m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes],\n                    dim=0,\n                )\n            ),\n            query_batch_size=query_batch_size,\n            options=options,\n        )\n\n        if \"cache\" in options:\n            options.cache.fields = fields\n            options.cache.raw_meshes = raw_meshes\n            options.cache.raw_signed_distance = raw_signed_distance\n            options.cache.raw_density = raw_density\n            options.cache.mesh_mask = mesh_mask\n\n    if output_srgb:\n        tf_out.channels = _convert_srgb_to_linear(tf_out.channels)\n\n    # Make sure the raw meshes have colors.\n    with torch.autocast(device_type, enabled=False):\n        textures = tf_out.channels.float()\n        assert len(textures.shape) == 3 and textures.shape[-1] == len(\n            texture_channels\n        ), f\"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}\"\n        for m, texture in zip(raw_meshes, textures):\n            texture = texture[: len(m.verts)]\n            m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))}\n\n    args = dict(\n        options=options,\n        texture_channels=texture_channels,\n        ambient_color=ambient_color,\n        diffuse_color=diffuse_color,\n        specular_color=specular_color,\n        camera=camera,\n        batch_size=batch_size,\n        inner_batch_size=inner_batch_size,\n        inner_shape=inner_shape,\n        raw_meshes=raw_meshes,\n        tf_out=tf_out,\n    )\n\n    try:\n        out = _render_with_pytorch3d(**args)\n    except ModuleNotFoundError as exc:\n        warnings.warn(f\"exception rendering with PyTorch3D: {exc}\")\n        warnings.warn(\n            \"falling back on native PyTorch renderer, which does not support full gradients\"\n        )\n        out = _render_with_raycast(**args)\n\n    # Apply mask to prevent gradients for empty meshes.\n    reshaped_mask = mesh_mask.view([-1] + [1] * (len(out.channels.shape) - 1))\n    out.channels = torch.where(reshaped_mask, out.channels, torch.zeros_like(out.channels))\n    out.transmittance = torch.where(\n        reshaped_mask, out.transmittance, torch.ones_like(out.transmittance)\n    )\n\n    if output_srgb:\n        out.channels = _convert_linear_to_srgb(out.channels)\n    out.channels = out.channels * (1 - out.transmittance) * channel_scale.view(-1)\n\n    # This might be useful information to have downstream\n    out.raw_meshes = raw_meshes\n    out.fields = fields\n    out.mesh_mask = mesh_mask\n    out.raw_signed_distance = raw_signed_distance\n    out.aux_losses = AttrDict(cross_entropy=cross_entropy_sdf_loss(fields))\n    if raw_density is not None:\n        out.raw_density = raw_density\n\n    return out\n\n\ndef _render_with_pytorch3d(\n    options: AttrDict,\n    texture_channels: Sequence[str],\n    ambient_color: Union[float, Tuple[float]],\n    diffuse_color: Union[float, Tuple[float]],\n    specular_color: Union[float, Tuple[float]],\n    camera: DifferentiableCamera,\n    batch_size: int,\n    inner_shape: Sequence[int],\n    inner_batch_size: int,\n    raw_meshes: List[TorchMesh],\n    tf_out: AttrDict,\n):\n    _ = tf_out\n\n    # Lazy import because pytorch3d is installed lazily.\n    from shap_e.rendering.pytorch3d_util import (\n        blender_uniform_lights,\n        convert_cameras_torch,\n        convert_meshes,\n        render_images,\n    )\n\n    n_channels = len(texture_channels)\n    device = camera.origin.device\n    device_type = device.type\n\n    with torch.autocast(device_type, enabled=False):\n        meshes = convert_meshes(raw_meshes)\n\n        lights = blender_uniform_lights(\n            batch_size,\n            device,\n            ambient_color=ambient_color,\n            diffuse_color=diffuse_color,\n            specular_color=specular_color,\n        )\n\n        # Separate camera intrinsics for each view, so that we can\n        # create a new camera for each batch of views.\n        cam_shape = [batch_size, inner_batch_size, -1]\n        position = camera.origin.reshape(cam_shape)\n        x = camera.x.reshape(cam_shape)\n        y = camera.y.reshape(cam_shape)\n        z = camera.z.reshape(cam_shape)\n\n        results = []\n        for i in range(inner_batch_size):\n            sub_cams = convert_cameras_torch(\n                position[:, i], x[:, i], y[:, i], z[:, i], fov=camera.x_fov\n            )\n            imgs = render_images(\n                camera.width,\n                meshes,\n                sub_cams,\n                lights,\n                use_checkpoint=options.checkpoint_render,\n                **options.get(\"render_options\", {}),\n            )\n            results.append(imgs)\n        views = torch.stack(results, dim=1)\n        views = views.view(batch_size, *inner_shape, camera.height, camera.width, n_channels + 1)\n\n        out = AttrDict(\n            channels=views[..., :-1],  # [batch_size, *inner_shape, height, width, n_channels]\n            transmittance=1 - views[..., -1:],  # [batch_size, *inner_shape, height, width, 1]\n            meshes=meshes,\n        )\n\n    return out\n\n\ndef _render_with_raycast(\n    options: AttrDict,\n    texture_channels: Sequence[str],\n    ambient_color: Union[float, Tuple[float]],\n    diffuse_color: Union[float, Tuple[float]],\n    specular_color: Union[float, Tuple[float]],\n    camera: DifferentiableCamera,\n    batch_size: int,\n    inner_shape: Sequence[int],\n    inner_batch_size: int,\n    raw_meshes: List[TorchMesh],\n    tf_out: AttrDict,\n):\n    assert np.mean(np.array(specular_color)) == 0\n\n    from shap_e.rendering.raycast.render import render_diffuse_mesh\n    from shap_e.rendering.raycast.types import TriMesh as TorchTriMesh\n\n    device = camera.origin.device\n    device_type = device.type\n\n    cam_shape = [batch_size, inner_batch_size, -1]\n    origin = camera.origin.reshape(cam_shape)\n    x = camera.x.reshape(cam_shape)\n    y = camera.y.reshape(cam_shape)\n    z = camera.z.reshape(cam_shape)\n\n    with torch.autocast(device_type, enabled=False):\n        all_meshes = []\n        for i, mesh in enumerate(raw_meshes):\n            all_meshes.append(\n                TorchTriMesh(\n                    faces=mesh.faces.long(),\n                    vertices=mesh.verts.float(),\n                    vertex_colors=tf_out.channels[i, : len(mesh.verts)].float(),\n                )\n            )\n        all_images = []\n        for i, mesh in enumerate(all_meshes):\n            for j in range(inner_batch_size):\n                all_images.append(\n                    render_diffuse_mesh(\n                        camera=ProjectiveCamera(\n                            origin=origin[i, j].detach().cpu().numpy(),\n                            x=x[i, j].detach().cpu().numpy(),\n                            y=y[i, j].detach().cpu().numpy(),\n                            z=z[i, j].detach().cpu().numpy(),\n                            width=camera.width,\n                            height=camera.height,\n                            x_fov=camera.x_fov,\n                            y_fov=camera.y_fov,\n                        ),\n                        mesh=mesh,\n                        diffuse=float(np.array(diffuse_color).mean()),\n                        ambient=float(np.array(ambient_color).mean()),\n                        ray_batch_size=16,  # low memory usage\n                        checkpoint=options.checkpoint_render,\n                    )\n                )\n\n        n_channels = len(texture_channels)\n        views = torch.stack(all_images).view(\n            batch_size, *inner_shape, camera.height, camera.width, n_channels + 1\n        )\n        return AttrDict(\n            channels=views[..., :-1],  # [batch_size, *inner_shape, height, width, n_channels]\n            transmittance=1 - views[..., -1:],  # [batch_size, *inner_shape, height, width, 1]\n            meshes=all_meshes,\n        )\n\n\ndef _convert_srgb_to_linear(u: torch.Tensor) -> torch.Tensor:\n    return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4)\n\n\ndef _convert_linear_to_srgb(u: torch.Tensor) -> torch.Tensor:\n    return torch.where(u <= 0.0031308, 12.92 * u, 1.055 * (u ** (1 / 2.4)) - 0.055)\n\n\ndef cross_entropy_sdf_loss(fields: torch.Tensor):\n    logits = F.logsigmoid(fields)\n    signs = (fields > 0).float()\n\n    losses = []\n    for dim in range(1, 4):\n        n = logits.shape[dim]\n        for (t_start, t_end, p_start, p_end) in [(0, -1, 1, n), (1, n, 0, -1)]:\n            targets = slice_fields(signs, dim, t_start, t_end)\n            preds = slice_fields(logits, dim, p_start, p_end)\n            losses.append(\n                F.binary_cross_entropy_with_logits(preds, targets, reduction=\"none\")\n                .flatten(1)\n                .mean()\n            )\n    return torch.stack(losses, dim=-1).sum()\n\n\ndef slice_fields(fields: torch.Tensor, dim: int, start: int, end: int):\n    if dim == 1:\n        return fields[:, start:end]\n    elif dim == 2:\n        return fields[:, :, start:end]\n    elif dim == 3:\n        return fields[:, :, :, start:end]\n    else:\n        raise ValueError(f\"cannot slice dimension {dim}\")\n\n\ndef volume_query_points(\n    volume: Volume,\n    grid_size: int,\n):\n    assert isinstance(volume, BoundingBoxVolume)\n    indices = torch.arange(grid_size**3, device=volume.bbox_min.device)\n    zs = indices % grid_size\n    ys = torch.div(indices, grid_size, rounding_mode=\"trunc\") % grid_size\n    xs = torch.div(indices, grid_size**2, rounding_mode=\"trunc\") % grid_size\n    combined = torch.stack([xs, ys, zs], dim=1)\n    return (combined.float() / (grid_size - 1)) * (\n        volume.bbox_max - volume.bbox_min\n    ) + volume.bbox_min\n"
  },
  {
    "path": "shap_e/models/transmitter/__init__.py",
    "content": ""
  },
  {
    "path": "shap_e/models/transmitter/base.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Any, Dict, Optional, Tuple\n\nimport torch.nn as nn\nfrom torch import torch\n\nfrom shap_e.models.renderer import Renderer\nfrom shap_e.util.collections import AttrDict\n\nfrom .bottleneck import latent_bottleneck_from_config, latent_warp_from_config\nfrom .params_proj import flatten_param_shapes, params_proj_from_config\n\n\nclass Encoder(nn.Module, ABC):\n    def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]]):\n        \"\"\"\n        Instantiate the encoder with information about the renderer's input\n        parameters. This information can be used to create output layers to\n        generate the necessary latents.\n        \"\"\"\n        super().__init__()\n        self.param_shapes = param_shapes\n        self.device = device\n\n    @abstractmethod\n    def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict:\n        \"\"\"\n        Encode a batch of data into a batch of latent information.\n        \"\"\"\n\n\nclass VectorEncoder(Encoder):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        param_shapes: Dict[str, Tuple[int]],\n        params_proj: Dict[str, Any],\n        d_latent: int,\n        latent_bottleneck: Optional[Dict[str, Any]] = None,\n        latent_warp: Optional[Dict[str, Any]] = None,\n    ):\n        super().__init__(device=device, param_shapes=param_shapes)\n        if latent_bottleneck is None:\n            latent_bottleneck = dict(name=\"identity\")\n        if latent_warp is None:\n            latent_warp = dict(name=\"identity\")\n        self.d_latent = d_latent\n        self.params_proj = params_proj_from_config(\n            params_proj, device=device, param_shapes=param_shapes, d_latent=d_latent\n        )\n        self.latent_bottleneck = latent_bottleneck_from_config(\n            latent_bottleneck, device=device, d_latent=d_latent\n        )\n        self.latent_warp = latent_warp_from_config(latent_warp, device=device)\n\n    def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict:\n        h = self.encode_to_bottleneck(batch, options=options)\n        return self.bottleneck_to_params(h, options=options)\n\n    def encode_to_bottleneck(\n        self, batch: AttrDict, options: Optional[AttrDict] = None\n    ) -> torch.Tensor:\n        return self.latent_warp.warp(\n            self.latent_bottleneck(self.encode_to_vector(batch, options=options), options=options),\n            options=options,\n        )\n\n    @abstractmethod\n    def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:\n        \"\"\"\n        Encode the batch into a single latent vector.\n        \"\"\"\n\n    def bottleneck_to_params(\n        self, vector: torch.Tensor, options: Optional[AttrDict] = None\n    ) -> AttrDict:\n        _ = options\n        return self.params_proj(self.latent_warp.unwarp(vector, options=options), options=options)\n\n\nclass ChannelsEncoder(VectorEncoder):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        param_shapes: Dict[str, Tuple[int]],\n        params_proj: Dict[str, Any],\n        d_latent: int,\n        latent_bottleneck: Optional[Dict[str, Any]] = None,\n        latent_warp: Optional[Dict[str, Any]] = None,\n    ):\n        super().__init__(\n            device=device,\n            param_shapes=param_shapes,\n            params_proj=params_proj,\n            d_latent=d_latent,\n            latent_bottleneck=latent_bottleneck,\n            latent_warp=latent_warp,\n        )\n        self.flat_shapes = flatten_param_shapes(param_shapes)\n        self.latent_ctx = sum(flat[0] for flat in self.flat_shapes.values())\n\n    @abstractmethod\n    def encode_to_channels(\n        self, batch: AttrDict, options: Optional[AttrDict] = None\n    ) -> torch.Tensor:\n        \"\"\"\n        Encode the batch into a per-data-point set of latents.\n        :return: [batch_size, latent_ctx, latent_width]\n        \"\"\"\n\n    def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:\n        return self.encode_to_channels(batch, options=options).flatten(1)\n\n    def bottleneck_to_channels(\n        self, vector: torch.Tensor, options: Optional[AttrDict] = None\n    ) -> torch.Tensor:\n        _ = options\n        return vector.view(vector.shape[0], self.latent_ctx, -1)\n\n    def bottleneck_to_params(\n        self, vector: torch.Tensor, options: Optional[AttrDict] = None\n    ) -> AttrDict:\n        _ = options\n        return self.params_proj(\n            self.bottleneck_to_channels(self.latent_warp.unwarp(vector)), options=options\n        )\n\n\nclass Transmitter(nn.Module):\n    def __init__(self, encoder: Encoder, renderer: Renderer):\n        super().__init__()\n        self.encoder = encoder\n        self.renderer = renderer\n\n    def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict:\n        \"\"\"\n        Transmit the batch through the encoder and then the renderer.\n        \"\"\"\n        params = self.encoder(batch, options=options)\n        return self.renderer(batch, params=params, options=options)\n\n\nclass VectorDecoder(nn.Module):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        param_shapes: Dict[str, Tuple[int]],\n        params_proj: Dict[str, Any],\n        d_latent: int,\n        latent_warp: Optional[Dict[str, Any]] = None,\n        renderer: Renderer,\n    ):\n        super().__init__()\n        self.device = device\n        self.param_shapes = param_shapes\n\n        if latent_warp is None:\n            latent_warp = dict(name=\"identity\")\n        self.d_latent = d_latent\n        self.params_proj = params_proj_from_config(\n            params_proj, device=device, param_shapes=param_shapes, d_latent=d_latent\n        )\n        self.latent_warp = latent_warp_from_config(latent_warp, device=device)\n        self.renderer = renderer\n\n    def bottleneck_to_params(\n        self, vector: torch.Tensor, options: Optional[AttrDict] = None\n    ) -> AttrDict:\n        _ = options\n        return self.params_proj(self.latent_warp.unwarp(vector, options=options), options=options)\n\n\nclass ChannelsDecoder(VectorDecoder):\n    def __init__(\n        self,\n        *,\n        latent_ctx: int,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.latent_ctx = latent_ctx\n\n    def bottleneck_to_channels(\n        self, vector: torch.Tensor, options: Optional[AttrDict] = None\n    ) -> torch.Tensor:\n        _ = options\n        return vector.view(vector.shape[0], self.latent_ctx, -1)\n\n    def bottleneck_to_params(\n        self, vector: torch.Tensor, options: Optional[AttrDict] = None\n    ) -> AttrDict:\n        _ = options\n        return self.params_proj(\n            self.bottleneck_to_channels(self.latent_warp.unwarp(vector)), options=options\n        )\n"
  },
  {
    "path": "shap_e/models/transmitter/bottleneck.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Any, Dict, Optional\n\nimport numpy as np\nimport torch.nn as nn\nfrom torch import torch\n\nfrom shap_e.diffusion.gaussian_diffusion import diffusion_from_config\nfrom shap_e.util.collections import AttrDict\n\n\nclass LatentBottleneck(nn.Module, ABC):\n    def __init__(self, *, device: torch.device, d_latent: int):\n        super().__init__()\n        self.device = device\n        self.d_latent = d_latent\n\n    @abstractmethod\n    def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        pass\n\n\nclass LatentWarp(nn.Module, ABC):\n    def __init__(self, *, device: torch.device):\n        super().__init__()\n        self.device = device\n\n    @abstractmethod\n    def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        pass\n\n    @abstractmethod\n    def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        pass\n\n\nclass IdentityLatentWarp(LatentWarp):\n    def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        _ = options\n        return x\n\n    def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        _ = options\n        return x\n\n\nclass Tan2LatentWarp(LatentWarp):\n    def __init__(self, *, coeff1: float = 1.0, device: torch.device):\n        super().__init__(device=device)\n        self.coeff1 = coeff1\n        self.scale = np.tan(np.tan(1.0) * coeff1)\n\n    def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        _ = options\n        return ((x.float().tan() * self.coeff1).tan() / self.scale).to(x.dtype)\n\n    def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        _ = options\n        return ((x.float() * self.scale).arctan() / self.coeff1).arctan().to(x.dtype)\n\n\nclass IdentityLatentBottleneck(LatentBottleneck):\n    def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        _ = options\n        return x\n\n\nclass ClampNoiseBottleneck(LatentBottleneck):\n    def __init__(self, *, device: torch.device, d_latent: int, noise_scale: float):\n        super().__init__(device=device, d_latent=d_latent)\n        self.noise_scale = noise_scale\n\n    def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        _ = options\n        x = x.tanh()\n        if not self.training:\n            return x\n        return x + torch.randn_like(x) * self.noise_scale\n\n\nclass ClampDiffusionNoiseBottleneck(LatentBottleneck):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        d_latent: int,\n        diffusion: Dict[str, Any],\n        diffusion_prob: float = 1.0,\n    ):\n        super().__init__(device=device, d_latent=d_latent)\n        self.diffusion = diffusion_from_config(diffusion)\n        self.diffusion_prob = diffusion_prob\n\n    def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        _ = options\n        x = x.tanh()\n        if not self.training:\n            return x\n        t = torch.randint(low=0, high=self.diffusion.num_timesteps, size=(len(x),), device=x.device)\n        t = torch.where(\n            torch.rand(len(x), device=x.device) < self.diffusion_prob, t, torch.zeros_like(t)\n        )\n        return self.diffusion.q_sample(x, t)\n\n\ndef latent_bottleneck_from_config(config: Dict[str, Any], device: torch.device, d_latent: int):\n    name = config.pop(\"name\")\n    if name == \"clamp_noise\":\n        return ClampNoiseBottleneck(**config, device=device, d_latent=d_latent)\n    elif name == \"identity\":\n        return IdentityLatentBottleneck(**config, device=device, d_latent=d_latent)\n    elif name == \"clamp_diffusion_noise\":\n        return ClampDiffusionNoiseBottleneck(**config, device=device, d_latent=d_latent)\n    else:\n        raise ValueError(f\"unknown latent bottleneck: {name}\")\n\n\ndef latent_warp_from_config(config: Dict[str, Any], device: torch.device):\n    name = config.pop(\"name\")\n    if name == \"identity\":\n        return IdentityLatentWarp(**config, device=device)\n    elif name == \"tan2\":\n        return Tan2LatentWarp(**config, device=device)\n    else:\n        raise ValueError(f\"unknown latent warping function: {name}\")\n"
  },
  {
    "path": "shap_e/models/transmitter/channels_encoder.py",
    "content": "from abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom PIL import Image\nfrom torch import torch\n\nfrom shap_e.models.generation.perceiver import SimplePerceiver\nfrom shap_e.models.generation.transformer import Transformer\nfrom shap_e.models.nn.camera import DifferentiableProjectiveCamera\nfrom shap_e.models.nn.encoding import (\n    MultiviewPointCloudEmbedding,\n    MultiviewPoseEmbedding,\n    PosEmbLinear,\n)\nfrom shap_e.models.nn.ops import PointSetEmbedding\nfrom shap_e.rendering.point_cloud import PointCloud\nfrom shap_e.rendering.view_data import ProjectiveCamera\nfrom shap_e.util.collections import AttrDict\n\nfrom .base import ChannelsEncoder\n\n\nclass TransformerChannelsEncoder(ChannelsEncoder, ABC):\n    \"\"\"\n    Encode point clouds using a transformer model with an extra output\n    token used to extract a latent vector.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        param_shapes: Dict[str, Tuple[int]],\n        params_proj: Dict[str, Any],\n        d_latent: int = 512,\n        latent_bottleneck: Optional[Dict[str, Any]] = None,\n        latent_warp: Optional[Dict[str, Any]] = None,\n        n_ctx: int = 1024,\n        width: int = 512,\n        layers: int = 12,\n        heads: int = 8,\n        init_scale: float = 0.25,\n        latent_scale: float = 1.0,\n    ):\n        super().__init__(\n            device=device,\n            param_shapes=param_shapes,\n            params_proj=params_proj,\n            d_latent=d_latent,\n            latent_bottleneck=latent_bottleneck,\n            latent_warp=latent_warp,\n        )\n        self.width = width\n        self.device = device\n        self.dtype = dtype\n\n        self.n_ctx = n_ctx\n\n        self.backbone = Transformer(\n            device=device,\n            dtype=dtype,\n            n_ctx=n_ctx + self.latent_ctx,\n            width=width,\n            layers=layers,\n            heads=heads,\n            init_scale=init_scale,\n        )\n        self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.register_parameter(\n            \"output_tokens\",\n            nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)),\n        )\n        self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype)\n        self.latent_scale = latent_scale\n\n    @abstractmethod\n    def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:\n        pass\n\n    def encode_to_channels(\n        self, batch: AttrDict, options: Optional[AttrDict] = None\n    ) -> torch.Tensor:\n        h = self.encode_input(batch, options=options)\n        h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1)\n        h = self.ln_pre(h)\n        h = self.backbone(h)\n        h = h[:, -self.latent_ctx :]\n        h = self.ln_post(h)\n        h = self.output_proj(h)\n        return h\n\n\nclass PerceiverChannelsEncoder(ChannelsEncoder, ABC):\n    \"\"\"\n    Encode point clouds using a perceiver model with an extra output\n    token used to extract a latent vector.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        param_shapes: Dict[str, Tuple[int]],\n        params_proj: Dict[str, Any],\n        min_unrolls: int,\n        max_unrolls: int,\n        d_latent: int = 512,\n        latent_bottleneck: Optional[Dict[str, Any]] = None,\n        latent_warp: Optional[Dict[str, Any]] = None,\n        width: int = 512,\n        layers: int = 12,\n        xattn_layers: int = 1,\n        heads: int = 8,\n        init_scale: float = 0.25,\n        # Training hparams\n        inner_batch_size: Union[int, List[int]] = 1,\n        data_ctx: int = 1,\n    ):\n        super().__init__(\n            device=device,\n            param_shapes=param_shapes,\n            params_proj=params_proj,\n            d_latent=d_latent,\n            latent_bottleneck=latent_bottleneck,\n            latent_warp=latent_warp,\n        )\n        self.width = width\n        self.device = device\n        self.dtype = dtype\n\n        if isinstance(inner_batch_size, int):\n            inner_batch_size = [inner_batch_size]\n        self.inner_batch_size = inner_batch_size\n        self.data_ctx = data_ctx\n        self.min_unrolls = min_unrolls\n        self.max_unrolls = max_unrolls\n\n        encoder_fn = lambda inner_batch_size: SimplePerceiver(\n            device=device,\n            dtype=dtype,\n            n_ctx=self.data_ctx + self.latent_ctx,\n            n_data=inner_batch_size,\n            width=width,\n            layers=xattn_layers,\n            heads=heads,\n            init_scale=init_scale,\n        )\n        self.encoder = (\n            encoder_fn(self.inner_batch_size[0])\n            if len(self.inner_batch_size) == 1\n            else nn.ModuleList([encoder_fn(inner_bsz) for inner_bsz in self.inner_batch_size])\n        )\n        self.processor = Transformer(\n            device=device,\n            dtype=dtype,\n            n_ctx=self.data_ctx + self.latent_ctx,\n            layers=layers - xattn_layers,\n            width=width,\n            heads=heads,\n            init_scale=init_scale,\n        )\n        self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.register_parameter(\n            \"output_tokens\",\n            nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)),\n        )\n        self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype)\n\n    @abstractmethod\n    def get_h_and_iterator(\n        self, batch: AttrDict, options: Optional[AttrDict] = None\n    ) -> Tuple[torch.Tensor, Iterable[Union[torch.Tensor, Tuple]]]:\n        \"\"\"\n        :return: a tuple of (\n            the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],\n            an iterator over the given data\n        )\n        \"\"\"\n\n    def encode_to_channels(\n        self, batch: AttrDict, options: Optional[AttrDict] = None\n    ) -> torch.Tensor:\n        h, it = self.get_h_and_iterator(batch, options=options)\n        n_unrolls = self.get_n_unrolls()\n\n        for _ in range(n_unrolls):\n            data = next(it)\n            if isinstance(data, tuple):\n                for data_i, encoder_i in zip(data, self.encoder):\n                    h = encoder_i(h, data_i)\n            else:\n                h = self.encoder(h, data)\n            h = self.processor(h)\n\n        h = self.output_proj(self.ln_post(h[:, -self.latent_ctx :]))\n        return h\n\n    def get_n_unrolls(self):\n        if self.training:\n            n_unrolls = torch.randint(\n                self.min_unrolls, self.max_unrolls + 1, size=(), device=self.device\n            )\n            dist.broadcast(n_unrolls, 0)\n            n_unrolls = n_unrolls.item()\n        else:\n            n_unrolls = self.max_unrolls\n        return n_unrolls\n\n\n@dataclass\nclass DatasetIterator:\n\n    embs: torch.Tensor  # [batch_size, dataset_size, *shape]\n    batch_size: int\n\n    def __iter__(self):\n        self._reset()\n        return self\n\n    def __next__(self):\n        _outer_batch_size, dataset_size, *_shape = self.embs.shape\n\n        while True:\n            start = self.idx\n            self.idx += self.batch_size\n            end = self.idx\n            if end <= dataset_size:\n                break\n            self._reset()\n\n        return self.embs[:, start:end]\n\n    def _reset(self):\n        self._shuffle()\n        self.idx = 0  # pylint: disable=attribute-defined-outside-init\n\n    def _shuffle(self):\n        outer_batch_size, dataset_size, *shape = self.embs.shape\n        idx = torch.stack(\n            [\n                torch.randperm(dataset_size, device=self.embs.device)\n                for _ in range(outer_batch_size)\n            ],\n            dim=0,\n        )\n        idx = idx.view(outer_batch_size, dataset_size, *([1] * len(shape)))\n        idx = torch.broadcast_to(idx, self.embs.shape)\n        self.embs = torch.gather(self.embs, 1, idx)\n\n\nclass PointCloudTransformerChannelsEncoder(TransformerChannelsEncoder):\n    \"\"\"\n    Encode point clouds using a transformer model with an extra output\n    token used to extract a latent vector.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        input_channels: int = 6,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.input_channels = input_channels\n        self.input_proj = nn.Linear(\n            input_channels, self.width, device=self.device, dtype=self.dtype\n        )\n\n    def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:\n        _ = options\n        points = batch.points\n        h = self.input_proj(points.permute(0, 2, 1))  # NCL -> NLC\n        return h\n\n\nclass PointCloudPerceiverChannelsEncoder(PerceiverChannelsEncoder):\n    \"\"\"\n    Encode point clouds using a transformer model with an extra output\n    token used to extract a latent vector.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        cross_attention_dataset: str = \"pcl\",\n        fps_method: str = \"fps\",\n        # point cloud hyperparameters\n        input_channels: int = 6,\n        pos_emb: Optional[str] = None,\n        # multiview hyperparameters\n        image_size: int = 256,\n        patch_size: int = 32,\n        pose_dropout: float = 0.0,\n        use_depth: bool = False,\n        max_depth: float = 5.0,\n        # point conv hyperparameters\n        pointconv_radius: float = 0.5,\n        pointconv_samples: int = 32,\n        pointconv_hidden: Optional[List[int]] = None,\n        pointconv_patch_size: int = 1,\n        pointconv_stride: int = 1,\n        pointconv_padding_mode: str = \"zeros\",\n        use_pointconv: bool = False,\n        # other hyperparameters\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        assert cross_attention_dataset in (\n            \"pcl\",\n            \"multiview\",\n            \"dense_pose_multiview\",\n            \"multiview_pcl\",\n            \"pcl_and_multiview_pcl\",\n            \"incorrect_multiview_pcl\",\n            \"pcl_and_incorrect_multiview_pcl\",\n        )\n        assert fps_method in (\"fps\", \"first\")\n        self.cross_attention_dataset = cross_attention_dataset\n        self.fps_method = fps_method\n        self.input_channels = input_channels\n        self.input_proj = PosEmbLinear(\n            pos_emb,\n            input_channels,\n            self.width,\n            device=self.device,\n            dtype=self.dtype,\n        )\n        self.use_pointconv = use_pointconv\n        if use_pointconv:\n            if pointconv_hidden is None:\n                pointconv_hidden = [self.width]\n            self.point_conv = PointSetEmbedding(\n                n_point=self.data_ctx,\n                radius=pointconv_radius,\n                n_sample=pointconv_samples,\n                d_input=self.input_proj.weight.shape[0],\n                d_hidden=pointconv_hidden,\n                patch_size=pointconv_patch_size,\n                stride=pointconv_stride,\n                padding_mode=pointconv_padding_mode,\n                fps_method=fps_method,\n                device=self.device,\n                dtype=self.dtype,\n            )\n        if self.cross_attention_dataset == \"multiview\":\n            self.image_size = image_size\n            self.patch_size = patch_size\n            self.pose_dropout = pose_dropout\n            self.use_depth = use_depth\n            self.max_depth = max_depth\n            pos_ctx = (image_size // patch_size) ** 2\n            self.register_parameter(\n                \"pos_emb\",\n                nn.Parameter(\n                    torch.randn(\n                        pos_ctx * self.inner_batch_size,\n                        self.width,\n                        device=self.device,\n                        dtype=self.dtype,\n                    )\n                ),\n            )\n            self.patch_emb = nn.Conv2d(\n                in_channels=3 if not use_depth else 4,\n                out_channels=self.width,\n                kernel_size=patch_size,\n                stride=patch_size,\n                device=self.device,\n                dtype=self.dtype,\n            )\n            self.camera_emb = nn.Sequential(\n                nn.Linear(\n                    3 * 4 + 1, self.width, device=self.device, dtype=self.dtype\n                ),  # input size is for origin+x+y+z+fov\n                nn.GELU(),\n                nn.Linear(self.width, 2 * self.width, device=self.device, dtype=self.dtype),\n            )\n        elif self.cross_attention_dataset == \"dense_pose_multiview\":\n            # The number of output features is halved, because a patch_size of\n            # 32 ends up with a large patch_emb weight.\n            self.view_pose_width = self.width // 2\n            self.image_size = image_size\n            self.patch_size = patch_size\n            self.use_depth = use_depth\n            self.max_depth = max_depth\n            self.mv_pose_embed = MultiviewPoseEmbedding(\n                posemb_version=\"nerf\",\n                n_channels=4 if self.use_depth else 3,\n                out_features=self.view_pose_width,\n                device=self.device,\n                dtype=self.dtype,\n            )\n            pos_ctx = (image_size // patch_size) ** 2\n            # Positional embedding is unnecessary because pose information is baked into each pixel\n            self.patch_emb = nn.Conv2d(\n                in_channels=self.view_pose_width,\n                out_channels=self.width,\n                kernel_size=patch_size,\n                stride=patch_size,\n                device=self.device,\n                dtype=self.dtype,\n            )\n\n        elif (\n            self.cross_attention_dataset == \"multiview_pcl\"\n            or self.cross_attention_dataset == \"incorrect_multiview_pcl\"\n        ):\n            self.view_pose_width = self.width // 2\n            self.image_size = image_size\n            self.patch_size = patch_size\n            self.max_depth = max_depth\n            assert use_depth\n            self.mv_pcl_embed = MultiviewPointCloudEmbedding(\n                posemb_version=\"nerf\",\n                n_channels=3,\n                out_features=self.view_pose_width,\n                device=self.device,\n                dtype=self.dtype,\n            )\n            self.patch_emb = nn.Conv2d(\n                in_channels=self.view_pose_width,\n                out_channels=self.width,\n                kernel_size=patch_size,\n                stride=patch_size,\n                device=self.device,\n                dtype=self.dtype,\n            )\n\n        elif (\n            self.cross_attention_dataset == \"pcl_and_multiview_pcl\"\n            or self.cross_attention_dataset == \"pcl_and_incorrect_multiview_pcl\"\n        ):\n            self.view_pose_width = self.width // 2\n            self.image_size = image_size\n            self.patch_size = patch_size\n            self.max_depth = max_depth\n            assert use_depth\n            self.mv_pcl_embed = MultiviewPointCloudEmbedding(\n                posemb_version=\"nerf\",\n                n_channels=3,\n                out_features=self.view_pose_width,\n                device=self.device,\n                dtype=self.dtype,\n            )\n            self.patch_emb = nn.Conv2d(\n                in_channels=self.view_pose_width,\n                out_channels=self.width,\n                kernel_size=patch_size,\n                stride=patch_size,\n                device=self.device,\n                dtype=self.dtype,\n            )\n\n    def get_h_and_iterator(\n        self, batch: AttrDict, options: Optional[AttrDict] = None\n    ) -> Tuple[torch.Tensor, Iterable]:\n        \"\"\"\n        :return: a tuple of (\n            the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],\n            an iterator over the given data\n        )\n        \"\"\"\n        options = AttrDict() if options is None else options\n\n        # Build the initial query embeddings\n        points = batch.points.permute(0, 2, 1)  # NCL -> NLC\n        if self.use_pointconv:\n            points = self.input_proj(points).permute(0, 2, 1)  # NLC -> NCL\n            xyz = batch.points[:, :3]\n            data_tokens = self.point_conv(xyz, points).permute(0, 2, 1)  # NCL -> NLC\n        else:\n            fps_samples = self.sample_pcl_fps(points)\n            data_tokens = self.input_proj(fps_samples)\n        batch_size = points.shape[0]\n        latent_tokens = self.output_tokens.unsqueeze(0).repeat(batch_size, 1, 1)\n        h = self.ln_pre(torch.cat([data_tokens, latent_tokens], dim=1))\n        assert h.shape == (batch_size, self.data_ctx + self.latent_ctx, self.width)\n\n        # Build the dataset embedding iterator\n        dataset_fn = {\n            \"pcl\": self.get_pcl_dataset,\n            \"multiview\": self.get_multiview_dataset,\n            \"dense_pose_multiview\": self.get_dense_pose_multiview_dataset,\n            \"pcl_and_multiview_pcl\": self.get_pcl_and_multiview_pcl_dataset,\n            \"multiview_pcl\": self.get_multiview_pcl_dataset,\n        }[self.cross_attention_dataset]\n        it = dataset_fn(batch, options=options)\n\n        return h, it\n\n    def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor:\n        return sample_pcl_fps(points, data_ctx=self.data_ctx, method=self.fps_method)\n\n    def get_pcl_dataset(\n        self,\n        batch: AttrDict,\n        options: Optional[AttrDict[str, Any]] = None,\n        inner_batch_size: Optional[int] = None,\n    ) -> Iterable:\n        _ = options\n        if inner_batch_size is None:\n            inner_batch_size = self.inner_batch_size[0]\n        points = batch.points.permute(0, 2, 1)  # NCL -> NLC\n        dataset_emb = self.input_proj(points)\n        assert dataset_emb.shape[1] >= inner_batch_size\n        return iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size))\n\n    def get_multiview_dataset(\n        self,\n        batch: AttrDict,\n        options: Optional[AttrDict] = None,\n        inner_batch_size: Optional[int] = None,\n    ) -> Iterable:\n        _ = options\n\n        if inner_batch_size is None:\n            inner_batch_size = self.inner_batch_size[0]\n\n        dataset_emb = self.encode_views(batch)\n        batch_size, num_views, n_patches, width = dataset_emb.shape\n\n        assert num_views >= inner_batch_size\n\n        it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size))\n\n        def gen():\n            while True:\n                examples = next(it)\n                assert examples.shape == (batch_size, self.inner_batch_size, n_patches, self.width)\n                views = examples.reshape(batch_size, -1, width) + self.pos_emb\n                yield views\n\n        return gen()\n\n    def get_dense_pose_multiview_dataset(\n        self,\n        batch: AttrDict,\n        options: Optional[AttrDict] = None,\n        inner_batch_size: Optional[int] = None,\n    ) -> Iterable:\n        _ = options\n\n        if inner_batch_size is None:\n            inner_batch_size = self.inner_batch_size[0]\n\n        dataset_emb = self.encode_dense_pose_views(batch)\n        batch_size, num_views, n_patches, width = dataset_emb.shape\n\n        assert num_views >= inner_batch_size\n\n        it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size))\n\n        def gen():\n            while True:\n                examples = next(it)\n                assert examples.shape == (batch_size, inner_batch_size, n_patches, self.width)\n                views = examples.reshape(batch_size, -1, width)\n                yield views\n\n        return gen()\n\n    def get_pcl_and_multiview_pcl_dataset(\n        self,\n        batch: AttrDict,\n        options: Optional[AttrDict] = None,\n        use_distance: bool = True,\n    ) -> Iterable:\n        _ = options\n\n        pcl_it = self.get_pcl_dataset(\n            batch, options=options, inner_batch_size=self.inner_batch_size[0]\n        )\n        multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance)\n        batch_size, num_views, n_patches, width = multiview_pcl_emb.shape\n\n        assert num_views >= self.inner_batch_size[1]\n\n        multiview_pcl_it = iter(\n            DatasetIterator(multiview_pcl_emb, batch_size=self.inner_batch_size[1])\n        )\n\n        def gen():\n            while True:\n                pcl = next(pcl_it)\n                multiview_pcl = next(multiview_pcl_it)\n                assert multiview_pcl.shape == (\n                    batch_size,\n                    self.inner_batch_size[1],\n                    n_patches,\n                    self.width,\n                )\n                yield pcl, multiview_pcl.reshape(batch_size, -1, width)\n\n        return gen()\n\n    def get_multiview_pcl_dataset(\n        self,\n        batch: AttrDict,\n        options: Optional[AttrDict] = None,\n        inner_batch_size: Optional[int] = None,\n        use_distance: bool = True,\n    ) -> Iterable:\n        _ = options\n\n        if inner_batch_size is None:\n            inner_batch_size = self.inner_batch_size[0]\n\n        multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance)\n        batch_size, num_views, n_patches, width = multiview_pcl_emb.shape\n\n        assert num_views >= inner_batch_size\n\n        multiview_pcl_it = iter(DatasetIterator(multiview_pcl_emb, batch_size=inner_batch_size))\n\n        def gen():\n            while True:\n                multiview_pcl = next(multiview_pcl_it)\n                assert multiview_pcl.shape == (\n                    batch_size,\n                    inner_batch_size,\n                    n_patches,\n                    self.width,\n                )\n                yield multiview_pcl.reshape(batch_size, -1, width)\n\n        return gen()\n\n    def encode_views(self, batch: AttrDict) -> torch.Tensor:\n        \"\"\"\n        :return: [batch_size, num_views, n_patches, width]\n        \"\"\"\n        all_views = self.views_to_tensor(batch.views).to(self.device)\n        if self.use_depth:\n            all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)\n        all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)\n\n        batch_size, num_views, _, _, _ = all_views.shape\n\n        views_proj = self.patch_emb(\n            all_views.reshape([batch_size * num_views, *all_views.shape[2:]])\n        )\n        views_proj = (\n            views_proj.reshape([batch_size, num_views, self.width, -1])\n            .permute(0, 1, 3, 2)\n            .contiguous()\n        )  # [batch_size x num_views x n_patches x width]\n\n        # [batch_size, num_views, 1, 2 * width]\n        camera_proj = self.camera_emb(all_cameras).reshape(\n            [batch_size, num_views, 1, self.width * 2]\n        )\n        pose_dropout = self.pose_dropout if self.training else 0.0\n        mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout\n        camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj))\n        scale, shift = camera_proj.chunk(2, dim=3)\n        views_proj = views_proj * (scale + 1.0) + shift\n        return views_proj\n\n    def encode_dense_pose_views(self, batch: AttrDict) -> torch.Tensor:\n        \"\"\"\n        :return: [batch_size, num_views, n_patches, width]\n        \"\"\"\n        all_views = self.views_to_tensor(batch.views).to(self.device)\n        if self.use_depth:\n            depths = self.depths_to_tensor(batch.depths)\n            all_views = torch.cat([all_views, depths], dim=2)\n\n        dense_poses, _ = self.dense_pose_cameras_to_tensor(batch.cameras)\n        dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3)\n        position, direction = dense_poses[:, :, 0], dense_poses[:, :, 1]\n        all_view_poses = self.mv_pose_embed(all_views, position, direction)\n\n        batch_size, num_views, _, _, _ = all_view_poses.shape\n\n        views_proj = self.patch_emb(\n            all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]])\n        )\n        views_proj = (\n            views_proj.reshape([batch_size, num_views, self.width, -1])\n            .permute(0, 1, 3, 2)\n            .contiguous()\n        )  # [batch_size x num_views x n_patches x width]\n\n        return views_proj\n\n    def encode_multiview_pcl(self, batch: AttrDict, use_distance: bool = True) -> torch.Tensor:\n        \"\"\"\n        :return: [batch_size, num_views, n_patches, width]\n        \"\"\"\n        all_views = self.views_to_tensor(batch.views).to(self.device)\n        depths = self.raw_depths_to_tensor(batch.depths)\n        all_view_alphas = self.view_alphas_to_tensor(batch.view_alphas).to(self.device)\n        mask = all_view_alphas >= 0.999\n\n        dense_poses, camera_z = self.dense_pose_cameras_to_tensor(batch.cameras)\n        dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3)\n\n        origin, direction = dense_poses[:, :, 0], dense_poses[:, :, 1]\n        if use_distance:\n            ray_depth_factor = torch.sum(direction * camera_z[..., None, None], dim=2, keepdim=True)\n            depths = depths / ray_depth_factor\n        position = origin + depths * direction\n        all_view_poses = self.mv_pcl_embed(all_views, origin, position, mask)\n\n        batch_size, num_views, _, _, _ = all_view_poses.shape\n\n        views_proj = self.patch_emb(\n            all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]])\n        )\n        views_proj = (\n            views_proj.reshape([batch_size, num_views, self.width, -1])\n            .permute(0, 1, 3, 2)\n            .contiguous()\n        )  # [batch_size x num_views x n_patches x width]\n\n        return views_proj\n\n    def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor:\n        \"\"\"\n        Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1].\n        \"\"\"\n        if isinstance(views, torch.Tensor):\n            return views\n\n        tensor_batch = []\n        num_views = len(views[0])\n        for inner_list in views:\n            assert len(inner_list) == num_views\n            inner_batch = []\n            for img in inner_list:\n                img = img.resize((self.image_size,) * 2).convert(\"RGB\")\n                inner_batch.append(\n                    torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)\n                    / 127.5\n                    - 1\n                )\n            tensor_batch.append(torch.stack(inner_batch, dim=0))\n        return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3)\n\n    def depths_to_tensor(\n        self, depths: Union[torch.Tensor, List[List[Image.Image]]]\n    ) -> torch.Tensor:\n        \"\"\"\n        Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1].\n        \"\"\"\n        if isinstance(depths, torch.Tensor):\n            return depths\n\n        tensor_batch = []\n        num_views = len(depths[0])\n        for inner_list in depths:\n            assert len(inner_list) == num_views\n            inner_batch = []\n            for arr in inner_list:\n                tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth\n                tensor = tensor * 2 - 1\n                tensor = F.interpolate(\n                    tensor[None, None],\n                    (self.image_size,) * 2,\n                    mode=\"nearest\",\n                )\n                inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))\n            tensor_batch.append(torch.cat(inner_batch, dim=0))\n        return torch.stack(tensor_batch, dim=0)\n\n    def view_alphas_to_tensor(\n        self, view_alphas: Union[torch.Tensor, List[List[Image.Image]]]\n    ) -> torch.Tensor:\n        \"\"\"\n        Returns a [batch x num_views x 1 x size x size] tensor in the range [0, 1].\n        \"\"\"\n        if isinstance(view_alphas, torch.Tensor):\n            return view_alphas\n\n        tensor_batch = []\n        num_views = len(view_alphas[0])\n        for inner_list in view_alphas:\n            assert len(inner_list) == num_views\n            inner_batch = []\n            for img in inner_list:\n                tensor = (\n                    torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)\n                    / 255.0\n                )\n                tensor = F.interpolate(\n                    tensor[None, None],\n                    (self.image_size,) * 2,\n                    mode=\"nearest\",\n                )\n                inner_batch.append(tensor)\n            tensor_batch.append(torch.cat(inner_batch, dim=0))\n        return torch.stack(tensor_batch, dim=0)\n\n    def raw_depths_to_tensor(\n        self, depths: Union[torch.Tensor, List[List[Image.Image]]]\n    ) -> torch.Tensor:\n        \"\"\"\n        Returns a [batch x num_views x 1 x size x size] tensor\n        \"\"\"\n        if isinstance(depths, torch.Tensor):\n            return depths\n\n        tensor_batch = []\n        num_views = len(depths[0])\n        for inner_list in depths:\n            assert len(inner_list) == num_views\n            inner_batch = []\n            for arr in inner_list:\n                tensor = torch.from_numpy(arr).clamp(max=self.max_depth)\n                tensor = F.interpolate(\n                    tensor[None, None],\n                    (self.image_size,) * 2,\n                    mode=\"nearest\",\n                )\n                inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))\n            tensor_batch.append(torch.cat(inner_batch, dim=0))\n        return torch.stack(tensor_batch, dim=0)\n\n    def cameras_to_tensor(\n        self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]\n    ) -> torch.Tensor:\n        \"\"\"\n        Returns a [batch x num_views x 3*4+1] tensor of camera information.\n        \"\"\"\n        if isinstance(cameras, torch.Tensor):\n            return cameras\n        outer_batch = []\n        for inner_list in cameras:\n            inner_batch = []\n            for camera in inner_list:\n                inner_batch.append(\n                    np.array(\n                        [\n                            *camera.x,\n                            *camera.y,\n                            *camera.z,\n                            *camera.origin,\n                            camera.x_fov,\n                        ]\n                    )\n                )\n            outer_batch.append(np.stack(inner_batch, axis=0))\n        return torch.from_numpy(np.stack(outer_batch, axis=0)).float()\n\n    def dense_pose_cameras_to_tensor(\n        self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Returns a tuple of (rays, z_directions) where\n            - rays: [batch, num_views, height, width, 2, 3] tensor of camera information.\n            - z_directions: [batch, num_views, 3] tensor of camera z directions.\n        \"\"\"\n        if isinstance(cameras, torch.Tensor):\n            raise NotImplementedError\n\n        for inner_list in cameras:\n            assert len(inner_list) == len(cameras[0])\n\n        camera = cameras[0][0]\n        flat_camera = DifferentiableProjectiveCamera(\n            origin=torch.from_numpy(\n                np.stack(\n                    [cam.origin for inner_list in cameras for cam in inner_list],\n                    axis=0,\n                )\n            ).to(self.device),\n            x=torch.from_numpy(\n                np.stack(\n                    [cam.x for inner_list in cameras for cam in inner_list],\n                    axis=0,\n                )\n            ).to(self.device),\n            y=torch.from_numpy(\n                np.stack(\n                    [cam.y for inner_list in cameras for cam in inner_list],\n                    axis=0,\n                )\n            ).to(self.device),\n            z=torch.from_numpy(\n                np.stack(\n                    [cam.z for inner_list in cameras for cam in inner_list],\n                    axis=0,\n                )\n            ).to(self.device),\n            width=camera.width,\n            height=camera.height,\n            x_fov=camera.x_fov,\n            y_fov=camera.y_fov,\n        )\n        batch_size = len(cameras) * len(cameras[0])\n        coords = (\n            flat_camera.image_coords()\n            .to(flat_camera.origin.device)\n            .unsqueeze(0)\n            .repeat(batch_size, 1, 1)\n        )\n        rays = flat_camera.camera_rays(coords)\n        return (\n            rays.view(len(cameras), len(cameras[0]), camera.height, camera.width, 2, 3).to(\n                self.device\n            ),\n            flat_camera.z.view(len(cameras), len(cameras[0]), 3).to(self.device),\n        )\n\n\ndef sample_pcl_fps(points: torch.Tensor, data_ctx: int, method: str = \"fps\") -> torch.Tensor:\n    \"\"\"\n    Run farthest-point sampling on a batch of point clouds.\n\n    :param points: batch of shape [N x num_points].\n    :param data_ctx: subsample count.\n    :param method: either 'fps' or 'first'. Using 'first' assumes that the\n                   points are already sorted according to FPS sampling.\n    :return: batch of shape [N x min(num_points, data_ctx)].\n    \"\"\"\n    n_points = points.shape[1]\n    if n_points == data_ctx:\n        return points\n    if method == \"first\":\n        return points[:, :data_ctx]\n    elif method == \"fps\":\n        batch = points.cpu().split(1, dim=0)\n        fps = [sample_fps(x, n_samples=data_ctx) for x in batch]\n        return torch.cat(fps, dim=0).to(points.device)\n    else:\n        raise ValueError(f\"unsupported farthest-point sampling method: {method}\")\n\n\ndef sample_fps(example: torch.Tensor, n_samples: int) -> torch.Tensor:\n    \"\"\"\n    :param example: [1, n_points, 3 + n_channels]\n    :return: [1, n_samples, 3 + n_channels]\n    \"\"\"\n    points = example.cpu().squeeze(0).numpy()\n    coords, raw_channels = points[:, :3], points[:, 3:]\n    n_points, n_channels = raw_channels.shape\n    assert n_samples <= n_points\n    channels = {str(idx): raw_channels[:, idx] for idx in range(n_channels)}\n    max_points = min(32768, n_points)\n    fps_pcl = (\n        PointCloud(coords=coords, channels=channels)\n        .random_sample(max_points)\n        .farthest_point_sample(n_samples)\n    )\n    fps_channels = np.stack([fps_pcl.channels[str(idx)] for idx in range(n_channels)], axis=1)\n    fps = np.concatenate([fps_pcl.coords, fps_channels], axis=1)\n    fps = torch.from_numpy(fps).unsqueeze(0)\n    assert fps.shape == (1, n_samples, 3 + n_channels)\n    return fps\n"
  },
  {
    "path": "shap_e/models/transmitter/multiview_encoder.py",
    "content": "from typing import Any, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom PIL import Image\n\nfrom shap_e.models.generation.transformer import Transformer\nfrom shap_e.rendering.view_data import ProjectiveCamera\nfrom shap_e.util.collections import AttrDict\n\nfrom .base import VectorEncoder\n\n\nclass MultiviewTransformerEncoder(VectorEncoder):\n    \"\"\"\n    Encode cameras and views using a transformer model with extra output\n    token(s) used to extract a latent vector.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        param_shapes: Dict[str, Tuple[int]],\n        params_proj: Dict[str, Any],\n        latent_bottleneck: Optional[Dict[str, Any]] = None,\n        d_latent: int = 512,\n        latent_ctx: int = 1,\n        num_views: int = 20,\n        image_size: int = 256,\n        patch_size: int = 32,\n        use_depth: bool = False,\n        max_depth: float = 5.0,\n        width: int = 512,\n        layers: int = 12,\n        heads: int = 8,\n        init_scale: float = 0.25,\n        pos_emb_init_scale: float = 1.0,\n    ):\n        super().__init__(\n            device=device,\n            param_shapes=param_shapes,\n            params_proj=params_proj,\n            latent_bottleneck=latent_bottleneck,\n            d_latent=d_latent,\n        )\n        self.num_views = num_views\n        self.image_size = image_size\n        self.patch_size = patch_size\n        self.use_depth = use_depth\n        self.max_depth = max_depth\n        self.n_ctx = num_views * (1 + (image_size // patch_size) ** 2)\n        self.latent_ctx = latent_ctx\n        self.width = width\n\n        assert d_latent % latent_ctx == 0\n\n        self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.backbone = Transformer(\n            device=device,\n            dtype=dtype,\n            n_ctx=self.n_ctx + latent_ctx,\n            width=width,\n            layers=layers,\n            heads=heads,\n            init_scale=init_scale,\n        )\n        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.register_parameter(\n            \"output_tokens\",\n            nn.Parameter(torch.randn(latent_ctx, width, device=device, dtype=dtype)),\n        )\n        self.register_parameter(\n            \"pos_emb\",\n            nn.Parameter(\n                pos_emb_init_scale * torch.randn(self.n_ctx, width, device=device, dtype=dtype)\n            ),\n        )\n        self.patch_emb = nn.Conv2d(\n            in_channels=3 if not use_depth else 4,\n            out_channels=width,\n            kernel_size=patch_size,\n            stride=patch_size,\n            device=device,\n            dtype=dtype,\n        )\n        self.camera_emb = nn.Sequential(\n            nn.Linear(\n                3 * 4 + 1, width, device=device, dtype=dtype\n            ),  # input size is for origin+x+y+z+fov\n            nn.GELU(),\n            nn.Linear(width, width, device=device, dtype=dtype),\n        )\n        self.output_proj = nn.Linear(width, d_latent // latent_ctx, device=device, dtype=dtype)\n\n    def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:\n        _ = options\n\n        all_views = self.views_to_tensor(batch.views).to(self.device)\n        if self.use_depth:\n            all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)\n        all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)\n\n        batch_size, num_views, _, _, _ = all_views.shape\n\n        views_proj = self.patch_emb(\n            all_views.reshape([batch_size * num_views, *all_views.shape[2:]])\n        )\n        views_proj = (\n            views_proj.reshape([batch_size, num_views, self.width, -1])\n            .permute(0, 1, 3, 2)\n            .contiguous()\n        )  # [batch_size x num_views x n_patches x width]\n\n        cameras_proj = self.camera_emb(all_cameras).reshape([batch_size, num_views, 1, self.width])\n\n        h = torch.cat([views_proj, cameras_proj], dim=2).reshape([batch_size, -1, self.width])\n        h = h + self.pos_emb\n        h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1)\n        h = self.ln_pre(h)\n        h = self.backbone(h)\n        h = self.ln_post(h)\n        h = h[:, self.n_ctx :]\n        h = self.output_proj(h).flatten(1)\n\n        return h\n\n    def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor:\n        \"\"\"\n        Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1].\n        \"\"\"\n        if isinstance(views, torch.Tensor):\n            return views\n\n        tensor_batch = []\n        for inner_list in views:\n            assert len(inner_list) == self.num_views\n            inner_batch = []\n            for img in inner_list:\n                img = img.resize((self.image_size,) * 2).convert(\"RGB\")\n                inner_batch.append(\n                    torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)\n                    / 127.5\n                    - 1\n                )\n            tensor_batch.append(torch.stack(inner_batch, dim=0))\n        return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3)\n\n    def depths_to_tensor(\n        self, depths: Union[torch.Tensor, List[List[Image.Image]]]\n    ) -> torch.Tensor:\n        \"\"\"\n        Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1].\n        \"\"\"\n        if isinstance(depths, torch.Tensor):\n            return depths\n\n        tensor_batch = []\n        for inner_list in depths:\n            assert len(inner_list) == self.num_views\n            inner_batch = []\n            for arr in inner_list:\n                tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth\n                tensor = tensor * 2 - 1\n                tensor = F.interpolate(\n                    tensor[None, None],\n                    (self.image_size,) * 2,\n                    mode=\"nearest\",\n                )\n                inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))\n            tensor_batch.append(torch.cat(inner_batch, dim=0))\n        return torch.stack(tensor_batch, dim=0)\n\n    def cameras_to_tensor(\n        self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]\n    ) -> torch.Tensor:\n        \"\"\"\n        Returns a [batch x num_views x 3*4+1] tensor of camera information.\n        \"\"\"\n        if isinstance(cameras, torch.Tensor):\n            return cameras\n        outer_batch = []\n        for inner_list in cameras:\n            inner_batch = []\n            for camera in inner_list:\n                inner_batch.append(\n                    np.array(\n                        [\n                            *camera.x,\n                            *camera.y,\n                            *camera.z,\n                            *camera.origin,\n                            camera.x_fov,\n                        ]\n                    )\n                )\n            outer_batch.append(np.stack(inner_batch, axis=0))\n        return torch.from_numpy(np.stack(outer_batch, axis=0)).float()\n"
  },
  {
    "path": "shap_e/models/transmitter/params_proj.py",
    "content": "import math\nfrom abc import ABC, abstractmethod\nfrom collections import OrderedDict\nfrom typing import Any, Dict, Optional, Tuple\n\nimport numpy as np\nimport torch.nn as nn\nfrom torch import torch\n\nfrom shap_e.util.collections import AttrDict\n\n\ndef flatten_param_shapes(param_shapes: Dict[str, Tuple[int]]):\n    flat_shapes = OrderedDict(\n        (name, (int(np.prod(shape)) // shape[-1], shape[-1]))\n        for name, shape in param_shapes.items()\n    )\n    return flat_shapes\n\n\nclass ParamsProj(nn.Module, ABC):\n    def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int):\n        super().__init__()\n        self.device = device\n        self.param_shapes = param_shapes\n        self.d_latent = d_latent\n\n    @abstractmethod\n    def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        pass\n\n\nclass LinearParamsProj(ParamsProj):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        param_shapes: Dict[str, Tuple[int]],\n        d_latent: int,\n        init_scale: Optional[float] = None,\n    ):\n        super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent)\n        self.param_shapes = param_shapes\n        self.projections = nn.ModuleDict({})\n        for k, v in param_shapes.items():\n            self.projections[_sanitize_name(k)] = nn.Linear(\n                d_latent, int(np.prod(v)), device=device\n            )\n            if init_scale is not None:\n                scale = init_scale / math.sqrt(d_latent)\n                mod = self.projections[_sanitize_name(k)]\n                nn.init.normal_(mod.weight, std=scale)\n                nn.init.zeros_(mod.bias)\n\n    def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        out = AttrDict()\n        for k in self.param_shapes.keys():\n            proj = self.projections[_sanitize_name(k)]\n            out[k] = proj(x).reshape([len(x), *self.param_shapes[k]])\n        return out\n\n\nclass MLPParamsProj(ParamsProj):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        param_shapes: Dict[str, Tuple[int]],\n        d_latent: int,\n        hidden_size: Optional[int] = None,\n    ):\n        super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent)\n        if hidden_size is None:\n            hidden_size = d_latent\n        self.param_shapes = param_shapes\n        self.projections = nn.ModuleDict({})\n        for k, v in param_shapes.items():\n            self.projections[_sanitize_name(k)] = nn.Sequential(\n                nn.Linear(d_latent, hidden_size, device=device),\n                nn.GELU(),\n                nn.Linear(hidden_size, int(np.prod(v)), device=device),\n            )\n\n    def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        out = AttrDict()\n        for k in self.param_shapes.keys():\n            proj = self.projections[_sanitize_name(k)]\n            out[k] = proj(x).reshape([len(x), *self.param_shapes[k]])\n        return out\n\n\nclass ChannelsProj(nn.Module):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        vectors: int,\n        channels: int,\n        d_latent: int,\n        init_scale: float = 1.0,\n        learned_scale: Optional[float] = None,\n        use_ln: bool = False,\n    ):\n        super().__init__()\n        self.proj = nn.Linear(d_latent, vectors * channels, device=device)\n        self.use_ln = use_ln\n        self.learned_scale = learned_scale\n        if use_ln:\n            self.norm = nn.LayerNorm(normalized_shape=(channels,), device=device)\n            if learned_scale is not None:\n                self.norm.weight.data.fill_(learned_scale)\n            scale = init_scale / math.sqrt(d_latent)\n        elif learned_scale is not None:\n            gain = torch.ones((channels,), device=device) * learned_scale\n            self.register_parameter(\"gain\", nn.Parameter(gain))\n            scale = init_scale / math.sqrt(d_latent)\n        else:\n            scale = init_scale / math.sqrt(d_latent * channels)\n        nn.init.normal_(self.proj.weight, std=scale)\n        nn.init.zeros_(self.proj.bias)\n        self.d_latent = d_latent\n        self.vectors = vectors\n        self.channels = channels\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x_bvd = x\n        w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent)\n        b_vc = self.proj.bias.view(1, self.vectors, self.channels)\n        h = torch.einsum(\"bvd,vcd->bvc\", x_bvd, w_vcd)\n        if self.use_ln:\n            h = self.norm(h)\n        elif self.learned_scale is not None:\n            h = h * self.gain.view(1, 1, -1)\n        h = h + b_vc\n        return h\n\n\nclass ChannelsParamsProj(ParamsProj):\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        param_shapes: Dict[str, Tuple[int]],\n        d_latent: int,\n        init_scale: float = 1.0,\n        learned_scale: Optional[float] = None,\n        use_ln: bool = False,\n    ):\n        super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent)\n        self.param_shapes = param_shapes\n        self.projections = nn.ModuleDict({})\n        self.flat_shapes = flatten_param_shapes(param_shapes)\n        self.learned_scale = learned_scale\n        self.use_ln = use_ln\n        for k, (vectors, channels) in self.flat_shapes.items():\n            self.projections[_sanitize_name(k)] = ChannelsProj(\n                device=device,\n                vectors=vectors,\n                channels=channels,\n                d_latent=d_latent,\n                init_scale=init_scale,\n                learned_scale=learned_scale,\n                use_ln=use_ln,\n            )\n\n    def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:\n        out = AttrDict()\n        start = 0\n        for k, shape in self.param_shapes.items():\n            vectors, _ = self.flat_shapes[k]\n            end = start + vectors\n            x_bvd = x[:, start:end]\n            out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape)\n            start = end\n        return out\n\n\ndef params_proj_from_config(\n    config: Dict[str, Any], device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int\n):\n    name = config.pop(\"name\")\n    if name == \"linear\":\n        return LinearParamsProj(\n            **config, device=device, param_shapes=param_shapes, d_latent=d_latent\n        )\n    elif name == \"mlp\":\n        return MLPParamsProj(**config, device=device, param_shapes=param_shapes, d_latent=d_latent)\n    elif name == \"channels\":\n        return ChannelsParamsProj(\n            **config, device=device, param_shapes=param_shapes, d_latent=d_latent\n        )\n    else:\n        raise ValueError(f\"unknown params proj: {name}\")\n\n\ndef _sanitize_name(x: str) -> str:\n    return x.replace(\".\", \"__\")\n"
  },
  {
    "path": "shap_e/models/transmitter/pc_encoder.py",
    "content": "from abc import abstractmethod\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom PIL import Image\nfrom torch import torch\n\nfrom shap_e.models.generation.perceiver import SimplePerceiver\nfrom shap_e.models.generation.transformer import Transformer\nfrom shap_e.models.nn.encoding import PosEmbLinear\nfrom shap_e.rendering.view_data import ProjectiveCamera\nfrom shap_e.util.collections import AttrDict\n\nfrom .base import VectorEncoder\nfrom .channels_encoder import DatasetIterator, sample_pcl_fps\n\n\nclass PointCloudTransformerEncoder(VectorEncoder):\n    \"\"\"\n    Encode point clouds using a transformer model with an extra output\n    token used to extract a latent vector.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        param_shapes: Dict[str, Tuple[int]],\n        params_proj: Dict[str, Any],\n        latent_bottleneck: Optional[Dict[str, Any]] = None,\n        d_latent: int = 512,\n        latent_ctx: int = 1,\n        input_channels: int = 6,\n        n_ctx: int = 1024,\n        width: int = 512,\n        layers: int = 12,\n        heads: int = 8,\n        init_scale: float = 0.25,\n        pos_emb: Optional[str] = None,\n    ):\n        super().__init__(\n            device=device,\n            param_shapes=param_shapes,\n            params_proj=params_proj,\n            latent_bottleneck=latent_bottleneck,\n            d_latent=d_latent,\n        )\n        self.input_channels = input_channels\n        self.n_ctx = n_ctx\n        self.latent_ctx = latent_ctx\n\n        assert d_latent % latent_ctx == 0\n\n        self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.backbone = Transformer(\n            device=device,\n            dtype=dtype,\n            n_ctx=n_ctx + latent_ctx,\n            width=width,\n            layers=layers,\n            heads=heads,\n            init_scale=init_scale,\n        )\n        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.register_parameter(\n            \"output_tokens\",\n            nn.Parameter(torch.randn(latent_ctx, width, device=device, dtype=dtype)),\n        )\n\n        self.input_proj = PosEmbLinear(pos_emb, input_channels, width, device=device, dtype=dtype)\n        self.output_proj = nn.Linear(width, d_latent // latent_ctx, device=device, dtype=dtype)\n\n    def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:\n        _ = options\n        points = batch.points.permute(0, 2, 1)  # NCL -> NLC\n        h = self.input_proj(points)\n        h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1)\n        h = self.ln_pre(h)\n        h = self.backbone(h)\n        h = self.ln_post(h)\n        h = h[:, self.n_ctx :]\n        h = self.output_proj(h).flatten(1)\n        return h\n\n\nclass PerceiverEncoder(VectorEncoder):\n    \"\"\"\n    Encode point clouds using a perceiver model with an extra output\n    token used to extract a latent vector.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        device: torch.device,\n        dtype: torch.dtype,\n        param_shapes: Dict[str, Tuple[int]],\n        params_proj: Dict[str, Any],\n        latent_bottleneck: Optional[Dict[str, Any]] = None,\n        d_latent: int = 512,\n        latent_ctx: int = 1,\n        width: int = 512,\n        layers: int = 12,\n        xattn_layers: int = 1,\n        heads: int = 8,\n        init_scale: float = 0.25,\n        # Training hparams\n        inner_batch_size: int = 1,\n        data_ctx: int = 1,\n        min_unrolls: int,\n        max_unrolls: int,\n    ):\n        super().__init__(\n            device=device,\n            param_shapes=param_shapes,\n            params_proj=params_proj,\n            latent_bottleneck=latent_bottleneck,\n            d_latent=d_latent,\n        )\n        self.width = width\n        self.device = device\n        self.dtype = dtype\n        self.latent_ctx = latent_ctx\n\n        self.inner_batch_size = inner_batch_size\n        self.data_ctx = data_ctx\n        self.min_unrolls = min_unrolls\n        self.max_unrolls = max_unrolls\n\n        self.encoder = SimplePerceiver(\n            device=device,\n            dtype=dtype,\n            n_ctx=self.data_ctx + self.latent_ctx,\n            n_data=self.inner_batch_size,\n            width=width,\n            layers=xattn_layers,\n            heads=heads,\n            init_scale=init_scale,\n        )\n        self.processor = Transformer(\n            device=device,\n            dtype=dtype,\n            n_ctx=self.data_ctx + self.latent_ctx,\n            layers=layers - xattn_layers,\n            width=width,\n            heads=heads,\n            init_scale=init_scale,\n        )\n        self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)\n        self.register_parameter(\n            \"output_tokens\",\n            nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)),\n        )\n        self.output_proj = nn.Linear(width, d_latent // self.latent_ctx, device=device, dtype=dtype)\n\n    @abstractmethod\n    def get_h_and_iterator(\n        self, batch: AttrDict, options: Optional[AttrDict] = None\n    ) -> Tuple[torch.Tensor, Iterable]:\n        \"\"\"\n        :return: a tuple of (\n            the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],\n            an iterator over the given data\n        )\n        \"\"\"\n\n    def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:\n        h, it = self.get_h_and_iterator(batch, options=options)\n        n_unrolls = self.get_n_unrolls()\n\n        for _ in range(n_unrolls):\n            data = next(it)\n            h = self.encoder(h, data)\n            h = self.processor(h)\n\n        h = self.output_proj(self.ln_post(h[:, -self.latent_ctx :]))\n        return h.flatten(1)\n\n    def get_n_unrolls(self):\n        if self.training:\n            n_unrolls = torch.randint(\n                self.min_unrolls, self.max_unrolls + 1, size=(), device=self.device\n            )\n            dist.broadcast(n_unrolls, 0)\n            n_unrolls = n_unrolls.item()\n        else:\n            n_unrolls = self.max_unrolls\n        return n_unrolls\n\n\nclass PointCloudPerceiverEncoder(PerceiverEncoder):\n    \"\"\"\n    Encode point clouds using a transformer model with an extra output\n    token used to extract a latent vector.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        cross_attention_dataset: str = \"pcl\",\n        fps_method: str = \"fps\",\n        # point cloud hyperparameters\n        input_channels: int = 6,\n        pos_emb: Optional[str] = None,\n        # multiview hyperparameters\n        image_size: int = 256,\n        patch_size: int = 32,\n        pose_dropout: float = 0.0,\n        use_depth: bool = False,\n        max_depth: float = 5.0,\n        # other hyperparameters\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        assert cross_attention_dataset in (\"pcl\", \"multiview\")\n        assert fps_method in (\"fps\", \"first\")\n        self.cross_attention_dataset = cross_attention_dataset\n        self.fps_method = fps_method\n        self.input_channels = input_channels\n        self.input_proj = PosEmbLinear(\n            pos_emb, input_channels, self.width, device=self.device, dtype=self.dtype\n        )\n        if self.cross_attention_dataset == \"multiview\":\n            self.image_size = image_size\n            self.patch_size = patch_size\n            self.pose_dropout = pose_dropout\n            self.use_depth = use_depth\n            self.max_depth = max_depth\n            pos_ctx = (image_size // patch_size) ** 2\n            self.register_parameter(\n                \"pos_emb\",\n                nn.Parameter(\n                    torch.randn(\n                        pos_ctx * self.inner_batch_size,\n                        self.width,\n                        device=self.device,\n                        dtype=self.dtype,\n                    )\n                ),\n            )\n            self.patch_emb = nn.Conv2d(\n                in_channels=3 if not use_depth else 4,\n                out_channels=self.width,\n                kernel_size=patch_size,\n                stride=patch_size,\n                device=self.device,\n                dtype=self.dtype,\n            )\n            self.camera_emb = nn.Sequential(\n                nn.Linear(\n                    3 * 4 + 1, self.width, device=self.device, dtype=self.dtype\n                ),  # input size is for origin+x+y+z+fov\n                nn.GELU(),\n                nn.Linear(self.width, 2 * self.width, device=self.device, dtype=self.dtype),\n            )\n\n    def get_h_and_iterator(\n        self, batch: AttrDict, options: Optional[AttrDict] = None\n    ) -> Tuple[torch.Tensor, Iterable]:\n        \"\"\"\n        :return: a tuple of (\n            the initial output tokens of size [batch_size, data_ctx + latent_ctx, width],\n            an iterator over the given data\n        )\n        \"\"\"\n        options = AttrDict() if options is None else options\n\n        # Build the initial query embeddings\n        points = batch.points.permute(0, 2, 1)  # NCL -> NLC\n        fps_samples = self.sample_pcl_fps(points)\n        batch_size = points.shape[0]\n        data_tokens = self.input_proj(fps_samples)\n        latent_tokens = self.output_tokens.unsqueeze(0).repeat(batch_size, 1, 1)\n        h = self.ln_pre(torch.cat([data_tokens, latent_tokens], dim=1))\n        assert h.shape == (batch_size, self.data_ctx + self.latent_ctx, self.width)\n\n        # Build the dataset embedding iterator\n        dataset_fn = {\n            \"pcl\": self.get_pcl_dataset,\n            \"multiview\": self.get_multiview_dataset,\n        }[self.cross_attention_dataset]\n        it = dataset_fn(batch, options=options)\n\n        return h, it\n\n    def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor:\n        return sample_pcl_fps(points, data_ctx=self.data_ctx, method=self.fps_method)\n\n    def get_pcl_dataset(\n        self, batch: AttrDict, options: Optional[AttrDict[str, Any]] = None\n    ) -> Iterable:\n        _ = options\n        dataset_emb = self.input_proj(batch.points.permute(0, 2, 1))  # NCL -> NLC\n        assert dataset_emb.shape[1] >= self.inner_batch_size\n        return iter(DatasetIterator(dataset_emb, batch_size=self.inner_batch_size))\n\n    def get_multiview_dataset(\n        self, batch: AttrDict, options: Optional[AttrDict] = None\n    ) -> Iterable:\n        _ = options\n\n        dataset_emb = self.encode_views(batch)\n        batch_size, num_views, n_patches, width = dataset_emb.shape\n\n        assert num_views >= self.inner_batch_size\n\n        it = iter(DatasetIterator(dataset_emb, batch_size=self.inner_batch_size))\n\n        def gen():\n            while True:\n                examples = next(it)\n                assert examples.shape == (batch_size, self.inner_batch_size, n_patches, self.width)\n                views = examples.reshape(batch_size, -1, width) + self.pos_emb\n                yield views\n\n        return gen()\n\n    def encode_views(self, batch: AttrDict) -> torch.Tensor:\n        \"\"\"\n        :return: [batch_size, num_views, n_patches, width]\n        \"\"\"\n        all_views = self.views_to_tensor(batch.views).to(self.device)\n        if self.use_depth:\n            all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2)\n        all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device)\n\n        batch_size, num_views, _, _, _ = all_views.shape\n\n        views_proj = self.patch_emb(\n            all_views.reshape([batch_size * num_views, *all_views.shape[2:]])\n        )\n        views_proj = (\n            views_proj.reshape([batch_size, num_views, self.width, -1])\n            .permute(0, 1, 3, 2)\n            .contiguous()\n        )  # [batch_size x num_views x n_patches x width]\n\n        # [batch_size, num_views, 1, 2 * width]\n        camera_proj = self.camera_emb(all_cameras).reshape(\n            [batch_size, num_views, 1, self.width * 2]\n        )\n        pose_dropout = self.pose_dropout if self.training else 0.0\n        mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout\n        camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj))\n        scale, shift = camera_proj.chunk(2, dim=3)\n        views_proj = views_proj * (scale + 1.0) + shift\n        return views_proj\n\n    def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor:\n        \"\"\"\n        Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1].\n        \"\"\"\n        if isinstance(views, torch.Tensor):\n            return views\n\n        tensor_batch = []\n        num_views = len(views[0])\n        for inner_list in views:\n            assert len(inner_list) == num_views\n            inner_batch = []\n            for img in inner_list:\n                img = img.resize((self.image_size,) * 2).convert(\"RGB\")\n                inner_batch.append(\n                    torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32)\n                    / 127.5\n                    - 1\n                )\n            tensor_batch.append(torch.stack(inner_batch, dim=0))\n        return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3)\n\n    def depths_to_tensor(\n        self, depths: Union[torch.Tensor, List[List[Image.Image]]]\n    ) -> torch.Tensor:\n        \"\"\"\n        Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1].\n        \"\"\"\n        if isinstance(depths, torch.Tensor):\n            return depths\n\n        tensor_batch = []\n        num_views = len(depths[0])\n        for inner_list in depths:\n            assert len(inner_list) == num_views\n            inner_batch = []\n            for arr in inner_list:\n                tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth\n                tensor = tensor * 2 - 1\n                tensor = F.interpolate(\n                    tensor[None, None],\n                    (self.image_size,) * 2,\n                    mode=\"nearest\",\n                )\n                inner_batch.append(tensor.to(device=self.device, dtype=torch.float32))\n            tensor_batch.append(torch.cat(inner_batch, dim=0))\n        return torch.stack(tensor_batch, dim=0)\n\n    def cameras_to_tensor(\n        self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]]\n    ) -> torch.Tensor:\n        \"\"\"\n        Returns a [batch x num_views x 3*4+1] tensor of camera information.\n        \"\"\"\n        if isinstance(cameras, torch.Tensor):\n            return cameras\n        outer_batch = []\n        for inner_list in cameras:\n            inner_batch = []\n            for camera in inner_list:\n                inner_batch.append(\n                    np.array(\n                        [\n                            *camera.x,\n                            *camera.y,\n                            *camera.z,\n                            *camera.origin,\n                            camera.x_fov,\n                        ]\n                    )\n                )\n            outer_batch.append(np.stack(inner_batch, axis=0))\n        return torch.from_numpy(np.stack(outer_batch, axis=0)).float()\n"
  },
  {
    "path": "shap_e/models/volume.py",
    "content": "from abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Dict, Optional, Tuple\n\nimport torch\n\nfrom shap_e.models.nn.meta import MetaModule\nfrom shap_e.models.nn.utils import ArrayType, safe_divide, to_torch\n\n\n@dataclass\nclass VolumeRange:\n    t0: torch.Tensor\n    t1: torch.Tensor\n    intersected: torch.Tensor\n\n    def __post_init__(self):\n        assert self.t0.shape == self.t1.shape == self.intersected.shape\n\n    def next_t0(self):\n        \"\"\"\n        Given convex volume1 and volume2, where volume1 is contained in\n        volume2, this function returns the t0 at which rays leave volume1 and\n        intersect with volume2 \\\\ volume1.\n        \"\"\"\n        return self.t1 * self.intersected.float()\n\n    def extend(self, another: \"VolumeRange\") -> \"VolumeRange\":\n        \"\"\"\n        The ranges at which rays intersect with either one, or both, or none of\n        the self and another are merged together.\n        \"\"\"\n        return VolumeRange(\n            t0=torch.where(self.intersected, self.t0, another.t0),\n            t1=torch.where(another.intersected, another.t1, self.t1),\n            intersected=torch.logical_or(self.intersected, another.intersected),\n        )\n\n    def partition(self, ts) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Partitions t0 and t1 into n_samples intervals.\n\n        :param ts: [batch_size, *shape, n_samples, 1]\n        :return: a tuple of (\n            lower: [batch_size, *shape, n_samples, 1]\n            upper: [batch_size, *shape, n_samples, 1]\n            delta: [batch_size, *shape, n_samples, 1]\n        ) where\n\n            ts \\\\in [lower, upper]\n            deltas = upper - lower\n        \"\"\"\n        mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5\n        lower = torch.cat([self.t0[..., None, :], mids], dim=-2)\n        upper = torch.cat([mids, self.t1[..., None, :]], dim=-2)\n        delta = upper - lower\n        assert lower.shape == upper.shape == delta.shape == ts.shape\n        return lower, upper, delta\n\n\nclass Volume(ABC):\n    \"\"\"\n    An abstraction of rendering volume.\n    \"\"\"\n\n    @abstractmethod\n    def intersect(\n        self,\n        origin: torch.Tensor,\n        direction: torch.Tensor,\n        t0_lower: Optional[torch.Tensor] = None,\n        params: Optional[Dict] = None,\n        epsilon: float = 1e-6,\n    ) -> VolumeRange:\n        \"\"\"\n        :param origin: [batch_size, *shape, 3]\n        :param direction: [batch_size, *shape, 3]\n        :param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.\n        :param params: Optional meta parameters in case Volume is parametric\n        :param epsilon: to stabilize calculations\n\n        :return: A tuple of (t0, t1, intersected) where each has a shape\n            [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is\n            in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed\n            to be on the boundary of the volume.\n        \"\"\"\n\n\nclass BoundingBoxVolume(MetaModule, Volume):\n    \"\"\"\n    Axis-aligned bounding box defined by the two opposite corners.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        bbox_min: ArrayType,\n        bbox_max: ArrayType,\n        min_dist: float = 0.0,\n        min_t_range: float = 1e-3,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        \"\"\"\n        :param bbox_min: the left/bottommost corner of the bounding box\n        :param bbox_max: the other corner of the bounding box\n        :param min_dist: all rays should start at least this distance away from the origin.\n        \"\"\"\n        super().__init__()\n\n        self.bbox_min = to_torch(bbox_min).to(device)\n        self.bbox_max = to_torch(bbox_max).to(device)\n        self.min_dist = min_dist\n        self.min_t_range = min_t_range\n        self.bbox = torch.stack([self.bbox_min, self.bbox_max])\n        assert self.bbox.shape == (2, 3)\n        assert self.min_dist >= 0.0\n        assert self.min_t_range > 0.0\n        self.device = device\n\n    def intersect(\n        self,\n        origin: torch.Tensor,\n        direction: torch.Tensor,\n        t0_lower: Optional[torch.Tensor] = None,\n        params: Optional[Dict] = None,\n        epsilon=1e-6,\n    ) -> VolumeRange:\n        \"\"\"\n        :param origin: [batch_size, *shape, 3]\n        :param direction: [batch_size, *shape, 3]\n        :param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.\n        :param params: Optional meta parameters in case Volume is parametric\n        :param epsilon: to stabilize calculations\n\n        :return: A tuple of (t0, t1, intersected) where each has a shape\n            [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is\n            in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed\n            to be on the boundary of the volume.\n        \"\"\"\n\n        batch_size, *shape, _ = origin.shape\n        ones = [1] * len(shape)\n        bbox = self.bbox.view(1, *ones, 2, 3)\n        ts = safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon)\n\n        # Cases to think about:\n        #\n        #   1. t1 <= t0: the ray does not pass through the AABB.\n        #   2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin.\n        #   3. t0 <= 0 <= t1: the ray starts from inside the BB\n        #   4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice.\n        #\n        # 1 and 4 are clearly handled from t0 < t1 below.\n        # Making t0 at least min_dist (>= 0) takes care of 2 and 3.\n        t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist)\n        t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values\n        assert t0.shape == t1.shape == (batch_size, *shape, 1)\n        if t0_lower is not None:\n            assert t0.shape == t0_lower.shape\n            t0 = torch.maximum(t0, t0_lower)\n\n        intersected = t0 + self.min_t_range < t1\n        t0 = torch.where(intersected, t0, torch.zeros_like(t0))\n        t1 = torch.where(intersected, t1, torch.ones_like(t1))\n\n        return VolumeRange(t0=t0, t1=t1, intersected=intersected)\n\n\nclass UnboundedVolume(MetaModule, Volume):\n    \"\"\"\n    Originally used in NeRF. Unbounded volume but with a limited visibility\n    when rendering (e.g. objects that are farther away than the max_dist from\n    the ray origin are not considered)\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        max_dist: float,\n        min_dist: float = 0.0,\n        min_t_range: float = 1e-3,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        super().__init__()\n        self.max_dist = max_dist\n        self.min_dist = min_dist\n        self.min_t_range = min_t_range\n        assert self.min_dist >= 0.0\n        assert self.min_t_range > 0.0\n        self.device = device\n\n    def intersect(\n        self,\n        origin: torch.Tensor,\n        direction: torch.Tensor,\n        t0_lower: Optional[torch.Tensor] = None,\n        params: Optional[Dict] = None,\n    ) -> VolumeRange:\n        \"\"\"\n        :param origin: [batch_size, *shape, 3]\n        :param direction: [batch_size, *shape, 3]\n        :param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.\n        :param params: Optional meta parameters in case Volume is parametric\n        :param epsilon: to stabilize calculations\n\n        :return: A tuple of (t0, t1, intersected) where each has a shape\n            [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is\n            in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed\n            to be on the boundary of the volume.\n        \"\"\"\n\n        batch_size, *shape, _ = origin.shape\n        t0 = torch.zeros(batch_size, *shape, 1, dtype=origin.dtype, device=origin.device)\n        if t0_lower is not None:\n            t0 = torch.maximum(t0, t0_lower)\n        t1 = t0 + self.max_dist\n        t0 = t0.clamp(self.min_dist)\n        return VolumeRange(t0=t0, t1=t1, intersected=t0 + self.min_t_range < t1)\n\n\nclass SphericalVolume(MetaModule, Volume):\n    \"\"\"\n    Used in NeRF++ but will not be used probably unless we want to reproduce\n    their results.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        radius: float,\n        center: ArrayType = (0.0, 0.0, 0.0),\n        min_dist: float = 0.0,\n        min_t_range: float = 1e-3,\n        device: torch.device = torch.device(\"cuda\"),\n    ):\n        super().__init__()\n\n        self.radius = radius\n        self.center = to_torch(center).to(device)\n        self.min_dist = min_dist\n        self.min_t_range = min_t_range\n        assert self.min_dist >= 0.0\n        assert self.min_t_range > 0.0\n        self.device = device\n\n    def intersect(\n        self,\n        origin: torch.Tensor,\n        direction: torch.Tensor,\n        t0_lower: Optional[torch.Tensor] = None,\n        params: Optional[Dict] = None,\n        epsilon=1e-6,\n    ) -> VolumeRange:\n        raise NotImplementedError\n"
  },
  {
    "path": "shap_e/rendering/__init__.py",
    "content": ""
  },
  {
    "path": "shap_e/rendering/_mc_table.py",
    "content": "# Treat a cube as a bitmap, and create the index into this array in order of\n# ZYX (note Z is the most significant digit).\n# The resulting object is an array of triangles, where each triangle is 6\n# indices. Each consecutive pair of indices within this triangle represents an\n# edge spanning two corners (identified by the indices).\n#\n# The corners of a cube are indexed as follows\n#\n#    (0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0),\n#    (0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1)\n#\n# Here is a visualization of the cube indices:\n#\n#        6 + -----------------------+ 7\n#         /|                       /|\n#        / |                      / |\n#       /  |                     /  |\n#    4 +------------------------+ 5 |\n#      |   |                    |   |\n#      |   |                    |   |\n#      |   |                    |   |\n#      |   | 2                  |   | 3\n#      |   +--------------------|---+\n#      |  /                     |  /\n#      | /                      | /\n#      |/                       |/\n#      +------------------------+\n#     0                           1\n#\n# Derived using model3d, in particular this function:\n# https://github.com/unixpickle/model3d/blob/7a3adb982c154c80c1a22032b5a0695160a7f96d/model3d/mc.go#L434\n#\nMC_TABLE = [\n    [],\n    [[0, 1, 0, 2, 0, 4]],\n    [[1, 0, 1, 5, 1, 3]],\n    [[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2]],\n    [[2, 0, 2, 3, 2, 6]],\n    [[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4]],\n    [[1, 0, 1, 5, 1, 3], [2, 6, 0, 2, 3, 2]],\n    [[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4]],\n    [[3, 1, 3, 7, 3, 2]],\n    [[0, 2, 0, 4, 0, 1], [3, 7, 2, 3, 1, 3]],\n    [[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0]],\n    [[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5]],\n    [[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6]],\n    [[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6]],\n    [[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7]],\n    [[0, 4, 1, 5, 3, 7], [0, 4, 3, 7, 2, 6]],\n    [[4, 0, 4, 6, 4, 5]],\n    [[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1]],\n    [[1, 5, 1, 3, 1, 0], [4, 6, 5, 4, 0, 4]],\n    [[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2]],\n    [[2, 0, 2, 3, 2, 6], [4, 5, 0, 4, 6, 4]],\n    [[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1]],\n    [[2, 6, 2, 0, 3, 2], [1, 0, 1, 5, 3, 1], [6, 4, 5, 4, 0, 4]],\n    [[1, 3, 5, 4, 1, 5], [1, 3, 4, 6, 5, 4], [1, 3, 3, 2, 4, 6], [3, 2, 2, 6, 4, 6]],\n    [[3, 1, 3, 7, 3, 2], [6, 4, 5, 4, 0, 4]],\n    [[4, 5, 0, 1, 4, 6], [0, 1, 0, 2, 4, 6], [7, 3, 2, 3, 1, 3]],\n    [[3, 2, 1, 0, 3, 7], [1, 0, 1, 5, 3, 7], [6, 4, 5, 4, 0, 4]],\n    [[3, 7, 3, 2, 1, 5], [3, 2, 6, 4, 1, 5], [1, 5, 6, 4, 5, 4], [3, 2, 2, 0, 6, 4]],\n    [[3, 7, 2, 6, 3, 1], [2, 6, 2, 0, 3, 1], [5, 4, 0, 4, 6, 4]],\n    [[1, 0, 1, 3, 5, 4], [1, 3, 2, 6, 5, 4], [1, 3, 3, 7, 2, 6], [5, 4, 2, 6, 4, 6]],\n    [[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7], [4, 5, 0, 4, 4, 6]],\n    [[6, 2, 4, 6, 4, 5], [4, 5, 5, 1, 6, 2], [6, 2, 5, 1, 7, 3]],\n    [[5, 1, 5, 4, 5, 7]],\n    [[0, 1, 0, 2, 0, 4], [5, 7, 1, 5, 4, 5]],\n    [[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3]],\n    [[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3]],\n    [[2, 0, 2, 3, 2, 6], [7, 5, 1, 5, 4, 5]],\n    [[2, 6, 0, 4, 2, 3], [0, 4, 0, 1, 2, 3], [7, 5, 1, 5, 4, 5]],\n    [[5, 7, 1, 3, 5, 4], [1, 3, 1, 0, 5, 4], [6, 2, 0, 2, 3, 2]],\n    [[3, 1, 3, 2, 7, 5], [3, 2, 0, 4, 7, 5], [3, 2, 2, 6, 0, 4], [7, 5, 0, 4, 5, 4]],\n    [[3, 7, 3, 2, 3, 1], [5, 4, 7, 5, 1, 5]],\n    [[0, 4, 0, 1, 2, 0], [3, 1, 3, 7, 2, 3], [4, 5, 7, 5, 1, 5]],\n    [[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0]],\n    [[0, 4, 2, 3, 0, 2], [0, 4, 3, 7, 2, 3], [0, 4, 4, 5, 3, 7], [4, 5, 5, 7, 3, 7]],\n    [[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6], [4, 5, 7, 5, 1, 5]],\n    [[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6], [5, 7, 1, 5, 5, 4]],\n    [[2, 6, 2, 0, 3, 7], [2, 0, 4, 5, 3, 7], [3, 7, 4, 5, 7, 5], [2, 0, 0, 1, 4, 5]],\n    [[4, 0, 5, 4, 5, 7], [5, 7, 7, 3, 4, 0], [4, 0, 7, 3, 6, 2]],\n    [[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0]],\n    [[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6]],\n    [[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7]],\n    [[0, 2, 4, 6, 5, 7], [0, 2, 5, 7, 1, 3]],\n    [[5, 1, 4, 0, 5, 7], [4, 0, 4, 6, 5, 7], [3, 2, 6, 2, 0, 2]],\n    [[2, 3, 2, 6, 0, 1], [2, 6, 7, 5, 0, 1], [0, 1, 7, 5, 1, 5], [2, 6, 6, 4, 7, 5]],\n    [[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7], [2, 6, 0, 2, 2, 3]],\n    [[3, 1, 2, 3, 2, 6], [2, 6, 6, 4, 3, 1], [3, 1, 6, 4, 7, 5]],\n    [[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0], [2, 3, 1, 3, 7, 3]],\n    [[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6], [3, 2, 1, 3, 3, 7]],\n    [[0, 1, 0, 4, 2, 3], [0, 4, 5, 7, 2, 3], [0, 4, 4, 6, 5, 7], [2, 3, 5, 7, 3, 7]],\n    [[7, 5, 3, 7, 3, 2], [3, 2, 2, 0, 7, 5], [7, 5, 2, 0, 6, 4]],\n    [[0, 4, 4, 6, 5, 7], [0, 4, 5, 7, 1, 5], [0, 2, 1, 3, 3, 7], [3, 7, 2, 6, 0, 2]],\n    [\n        [3, 1, 7, 3, 6, 2],\n        [6, 2, 0, 1, 3, 1],\n        [6, 4, 0, 1, 6, 2],\n        [6, 4, 5, 1, 0, 1],\n        [6, 4, 7, 5, 5, 1],\n    ],\n    [\n        [4, 0, 6, 4, 7, 5],\n        [7, 5, 1, 0, 4, 0],\n        [7, 3, 1, 0, 7, 5],\n        [7, 3, 2, 0, 1, 0],\n        [7, 3, 6, 2, 2, 0],\n    ],\n    [[7, 3, 6, 2, 6, 4], [7, 5, 7, 3, 6, 4]],\n    [[6, 2, 6, 7, 6, 4]],\n    [[0, 4, 0, 1, 0, 2], [6, 7, 4, 6, 2, 6]],\n    [[1, 0, 1, 5, 1, 3], [7, 6, 4, 6, 2, 6]],\n    [[1, 3, 0, 2, 1, 5], [0, 2, 0, 4, 1, 5], [7, 6, 4, 6, 2, 6]],\n    [[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0]],\n    [[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3]],\n    [[6, 4, 2, 0, 6, 7], [2, 0, 2, 3, 6, 7], [5, 1, 3, 1, 0, 1]],\n    [[1, 5, 1, 3, 0, 4], [1, 3, 7, 6, 0, 4], [0, 4, 7, 6, 4, 6], [1, 3, 3, 2, 7, 6]],\n    [[3, 2, 3, 1, 3, 7], [6, 4, 2, 6, 7, 6]],\n    [[3, 7, 3, 2, 1, 3], [0, 2, 0, 4, 1, 0], [7, 6, 4, 6, 2, 6]],\n    [[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0], [4, 6, 2, 6, 7, 6]],\n    [[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5], [6, 4, 2, 6, 6, 7]],\n    [[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0]],\n    [[0, 1, 4, 6, 0, 4], [0, 1, 6, 7, 4, 6], [0, 1, 1, 3, 6, 7], [1, 3, 3, 7, 6, 7]],\n    [[0, 2, 0, 1, 4, 6], [0, 1, 3, 7, 4, 6], [0, 1, 1, 5, 3, 7], [4, 6, 3, 7, 6, 7]],\n    [[7, 3, 6, 7, 6, 4], [6, 4, 4, 0, 7, 3], [7, 3, 4, 0, 5, 1]],\n    [[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5]],\n    [[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5]],\n    [[6, 7, 4, 5, 6, 2], [4, 5, 4, 0, 6, 2], [3, 1, 0, 1, 5, 1]],\n    [[2, 0, 2, 6, 3, 1], [2, 6, 4, 5, 3, 1], [2, 6, 6, 7, 4, 5], [3, 1, 4, 5, 1, 5]],\n    [[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7]],\n    [[0, 1, 2, 3, 6, 7], [0, 1, 6, 7, 4, 5]],\n    [[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7], [1, 3, 0, 1, 1, 5]],\n    [[5, 4, 1, 5, 1, 3], [1, 3, 3, 2, 5, 4], [5, 4, 3, 2, 7, 6]],\n    [[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5], [1, 3, 7, 3, 2, 3]],\n    [[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5], [3, 7, 2, 3, 3, 1]],\n    [[0, 1, 1, 5, 3, 7], [0, 1, 3, 7, 2, 3], [0, 4, 2, 6, 6, 7], [6, 7, 4, 5, 0, 4]],\n    [\n        [6, 2, 7, 6, 5, 4],\n        [5, 4, 0, 2, 6, 2],\n        [5, 1, 0, 2, 5, 4],\n        [5, 1, 3, 2, 0, 2],\n        [5, 1, 7, 3, 3, 2],\n    ],\n    [[3, 1, 3, 7, 2, 0], [3, 7, 5, 4, 2, 0], [2, 0, 5, 4, 0, 4], [3, 7, 7, 6, 5, 4]],\n    [[1, 0, 3, 1, 3, 7], [3, 7, 7, 6, 1, 0], [1, 0, 7, 6, 5, 4]],\n    [\n        [1, 0, 5, 1, 7, 3],\n        [7, 3, 2, 0, 1, 0],\n        [7, 6, 2, 0, 7, 3],\n        [7, 6, 4, 0, 2, 0],\n        [7, 6, 5, 4, 4, 0],\n    ],\n    [[7, 6, 5, 4, 5, 1], [7, 3, 7, 6, 5, 1]],\n    [[5, 7, 5, 1, 5, 4], [6, 2, 7, 6, 4, 6]],\n    [[0, 2, 0, 4, 1, 0], [5, 4, 5, 7, 1, 5], [2, 6, 7, 6, 4, 6]],\n    [[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3], [2, 6, 7, 6, 4, 6]],\n    [[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3], [6, 7, 4, 6, 6, 2]],\n    [[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0], [1, 5, 4, 5, 7, 5]],\n    [[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3], [5, 1, 4, 5, 5, 7]],\n    [[0, 2, 2, 3, 6, 7], [0, 2, 6, 7, 4, 6], [0, 1, 4, 5, 5, 7], [5, 7, 1, 3, 0, 1]],\n    [\n        [5, 4, 7, 5, 3, 1],\n        [3, 1, 0, 4, 5, 4],\n        [3, 2, 0, 4, 3, 1],\n        [3, 2, 6, 4, 0, 4],\n        [3, 2, 7, 6, 6, 4],\n    ],\n    [[5, 4, 5, 7, 1, 5], [3, 7, 3, 2, 1, 3], [4, 6, 2, 6, 7, 6]],\n    [[1, 0, 0, 2, 0, 4], [1, 5, 5, 4, 5, 7], [3, 2, 1, 3, 3, 7], [2, 6, 7, 6, 4, 6]],\n    [[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0], [6, 2, 7, 6, 6, 4]],\n    [\n        [0, 4, 2, 3, 0, 2],\n        [0, 4, 3, 7, 2, 3],\n        [0, 4, 4, 5, 3, 7],\n        [4, 5, 5, 7, 3, 7],\n        [6, 7, 4, 6, 2, 6],\n    ],\n    [[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0], [5, 4, 7, 5, 5, 1]],\n    [\n        [0, 1, 4, 6, 0, 4],\n        [0, 1, 6, 7, 4, 6],\n        [0, 1, 1, 3, 6, 7],\n        [1, 3, 3, 7, 6, 7],\n        [5, 7, 1, 5, 4, 5],\n    ],\n    [\n        [6, 7, 4, 6, 0, 2],\n        [0, 2, 3, 7, 6, 7],\n        [0, 1, 3, 7, 0, 2],\n        [0, 1, 5, 7, 3, 7],\n        [0, 1, 4, 5, 5, 7],\n    ],\n    [[4, 0, 6, 7, 4, 6], [4, 0, 7, 3, 6, 7], [4, 0, 5, 7, 7, 3], [4, 5, 5, 7, 4, 0]],\n    [[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0]],\n    [[0, 2, 1, 5, 0, 1], [0, 2, 5, 7, 1, 5], [0, 2, 2, 6, 5, 7], [2, 6, 6, 7, 5, 7]],\n    [[1, 3, 1, 0, 5, 7], [1, 0, 2, 6, 5, 7], [5, 7, 2, 6, 7, 6], [1, 0, 0, 4, 2, 6]],\n    [[2, 0, 6, 2, 6, 7], [6, 7, 7, 5, 2, 0], [2, 0, 7, 5, 3, 1]],\n    [[0, 4, 0, 2, 1, 5], [0, 2, 6, 7, 1, 5], [0, 2, 2, 3, 6, 7], [1, 5, 6, 7, 5, 7]],\n    [[7, 6, 5, 7, 5, 1], [5, 1, 1, 0, 7, 6], [7, 6, 1, 0, 3, 2]],\n    [\n        [2, 0, 3, 2, 7, 6],\n        [7, 6, 4, 0, 2, 0],\n        [7, 5, 4, 0, 7, 6],\n        [7, 5, 1, 0, 4, 0],\n        [7, 5, 3, 1, 1, 0],\n    ],\n    [[7, 5, 3, 1, 3, 2], [7, 6, 7, 5, 3, 2]],\n    [[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0], [3, 1, 7, 3, 3, 2]],\n    [\n        [0, 2, 1, 5, 0, 1],\n        [0, 2, 5, 7, 1, 5],\n        [0, 2, 2, 6, 5, 7],\n        [2, 6, 6, 7, 5, 7],\n        [3, 7, 2, 3, 1, 3],\n    ],\n    [\n        [3, 7, 2, 3, 0, 1],\n        [0, 1, 5, 7, 3, 7],\n        [0, 4, 5, 7, 0, 1],\n        [0, 4, 6, 7, 5, 7],\n        [0, 4, 2, 6, 6, 7],\n    ],\n    [[2, 0, 3, 7, 2, 3], [2, 0, 7, 5, 3, 7], [2, 0, 6, 7, 7, 5], [2, 6, 6, 7, 2, 0]],\n    [\n        [5, 7, 1, 5, 0, 4],\n        [0, 4, 6, 7, 5, 7],\n        [0, 2, 6, 7, 0, 4],\n        [0, 2, 3, 7, 6, 7],\n        [0, 2, 1, 3, 3, 7],\n    ],\n    [[1, 0, 5, 7, 1, 5], [1, 0, 7, 6, 5, 7], [1, 0, 3, 7, 7, 6], [1, 3, 3, 7, 1, 0]],\n    [[0, 2, 0, 1, 0, 4], [3, 7, 6, 7, 5, 7]],\n    [[7, 5, 7, 3, 7, 6]],\n    [[7, 3, 7, 5, 7, 6]],\n    [[0, 1, 0, 2, 0, 4], [6, 7, 3, 7, 5, 7]],\n    [[1, 3, 1, 0, 1, 5], [7, 6, 3, 7, 5, 7]],\n    [[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2], [6, 7, 3, 7, 5, 7]],\n    [[2, 6, 2, 0, 2, 3], [7, 5, 6, 7, 3, 7]],\n    [[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4], [5, 7, 6, 7, 3, 7]],\n    [[1, 5, 1, 3, 0, 1], [2, 3, 2, 6, 0, 2], [5, 7, 6, 7, 3, 7]],\n    [[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4], [7, 6, 3, 7, 7, 5]],\n    [[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2]],\n    [[7, 6, 3, 2, 7, 5], [3, 2, 3, 1, 7, 5], [4, 0, 1, 0, 2, 0]],\n    [[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2]],\n    [[2, 3, 2, 0, 6, 7], [2, 0, 1, 5, 6, 7], [2, 0, 0, 4, 1, 5], [6, 7, 1, 5, 7, 5]],\n    [[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1]],\n    [[0, 4, 0, 1, 2, 6], [0, 1, 5, 7, 2, 6], [2, 6, 5, 7, 6, 7], [0, 1, 1, 3, 5, 7]],\n    [[1, 5, 0, 2, 1, 0], [1, 5, 2, 6, 0, 2], [1, 5, 5, 7, 2, 6], [5, 7, 7, 6, 2, 6]],\n    [[5, 1, 7, 5, 7, 6], [7, 6, 6, 2, 5, 1], [5, 1, 6, 2, 4, 0]],\n    [[4, 5, 4, 0, 4, 6], [7, 3, 5, 7, 6, 7]],\n    [[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1], [3, 7, 5, 7, 6, 7]],\n    [[4, 6, 4, 5, 0, 4], [1, 5, 1, 3, 0, 1], [6, 7, 3, 7, 5, 7]],\n    [[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2], [7, 3, 5, 7, 7, 6]],\n    [[2, 3, 2, 6, 0, 2], [4, 6, 4, 5, 0, 4], [3, 7, 5, 7, 6, 7]],\n    [[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1], [7, 5, 6, 7, 7, 3]],\n    [[0, 1, 1, 5, 1, 3], [0, 2, 2, 3, 2, 6], [4, 5, 0, 4, 4, 6], [5, 7, 6, 7, 3, 7]],\n    [\n        [1, 3, 5, 4, 1, 5],\n        [1, 3, 4, 6, 5, 4],\n        [1, 3, 3, 2, 4, 6],\n        [3, 2, 2, 6, 4, 6],\n        [7, 6, 3, 7, 5, 7],\n    ],\n    [[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2], [0, 4, 6, 4, 5, 4]],\n    [[1, 0, 0, 2, 4, 6], [1, 0, 4, 6, 5, 4], [1, 3, 5, 7, 7, 6], [7, 6, 3, 2, 1, 3]],\n    [[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2], [4, 6, 5, 4, 4, 0]],\n    [\n        [7, 5, 6, 7, 2, 3],\n        [2, 3, 1, 5, 7, 5],\n        [2, 0, 1, 5, 2, 3],\n        [2, 0, 4, 5, 1, 5],\n        [2, 0, 6, 4, 4, 5],\n    ],\n    [[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1], [4, 0, 6, 4, 4, 5]],\n    [\n        [4, 6, 5, 4, 1, 0],\n        [1, 0, 2, 6, 4, 6],\n        [1, 3, 2, 6, 1, 0],\n        [1, 3, 7, 6, 2, 6],\n        [1, 3, 5, 7, 7, 6],\n    ],\n    [\n        [1, 5, 0, 2, 1, 0],\n        [1, 5, 2, 6, 0, 2],\n        [1, 5, 5, 7, 2, 6],\n        [5, 7, 7, 6, 2, 6],\n        [4, 6, 5, 4, 0, 4],\n    ],\n    [[5, 1, 4, 6, 5, 4], [5, 1, 6, 2, 4, 6], [5, 1, 7, 6, 6, 2], [5, 7, 7, 6, 5, 1]],\n    [[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1]],\n    [[7, 3, 5, 1, 7, 6], [5, 1, 5, 4, 7, 6], [2, 0, 4, 0, 1, 0]],\n    [[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4]],\n    [[0, 2, 0, 4, 1, 3], [0, 4, 6, 7, 1, 3], [1, 3, 6, 7, 3, 7], [0, 4, 4, 5, 6, 7]],\n    [[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1], [0, 2, 3, 2, 6, 2]],\n    [[1, 5, 5, 4, 7, 6], [1, 5, 7, 6, 3, 7], [1, 0, 3, 2, 2, 6], [2, 6, 0, 4, 1, 0]],\n    [[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4], [2, 0, 3, 2, 2, 6]],\n    [\n        [2, 3, 6, 2, 4, 0],\n        [4, 0, 1, 3, 2, 3],\n        [4, 5, 1, 3, 4, 0],\n        [4, 5, 7, 3, 1, 3],\n        [4, 5, 6, 7, 7, 3],\n    ],\n    [[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6]],\n    [[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6], [0, 4, 1, 0, 0, 2]],\n    [[1, 0, 5, 4, 7, 6], [1, 0, 7, 6, 3, 2]],\n    [[2, 3, 0, 2, 0, 4], [0, 4, 4, 5, 2, 3], [2, 3, 4, 5, 6, 7]],\n    [[1, 3, 1, 5, 0, 2], [1, 5, 7, 6, 0, 2], [1, 5, 5, 4, 7, 6], [0, 2, 7, 6, 2, 6]],\n    [\n        [5, 1, 4, 5, 6, 7],\n        [6, 7, 3, 1, 5, 1],\n        [6, 2, 3, 1, 6, 7],\n        [6, 2, 0, 1, 3, 1],\n        [6, 2, 4, 0, 0, 1],\n    ],\n    [[6, 7, 2, 6, 2, 0], [2, 0, 0, 1, 6, 7], [6, 7, 0, 1, 4, 5]],\n    [[6, 2, 4, 0, 4, 5], [6, 7, 6, 2, 4, 5]],\n    [[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1]],\n    [[1, 5, 1, 0, 3, 7], [1, 0, 4, 6, 3, 7], [1, 0, 0, 2, 4, 6], [3, 7, 4, 6, 7, 6]],\n    [[1, 0, 3, 7, 1, 3], [1, 0, 7, 6, 3, 7], [1, 0, 0, 4, 7, 6], [0, 4, 4, 6, 7, 6]],\n    [[6, 4, 7, 6, 7, 3], [7, 3, 3, 1, 6, 4], [6, 4, 3, 1, 2, 0]],\n    [[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1], [2, 3, 6, 2, 2, 0]],\n    [\n        [7, 6, 3, 7, 1, 5],\n        [1, 5, 4, 6, 7, 6],\n        [1, 0, 4, 6, 1, 5],\n        [1, 0, 2, 6, 4, 6],\n        [1, 0, 3, 2, 2, 6],\n    ],\n    [\n        [1, 0, 3, 7, 1, 3],\n        [1, 0, 7, 6, 3, 7],\n        [1, 0, 0, 4, 7, 6],\n        [0, 4, 4, 6, 7, 6],\n        [2, 6, 0, 2, 3, 2],\n    ],\n    [[3, 1, 7, 6, 3, 7], [3, 1, 6, 4, 7, 6], [3, 1, 2, 6, 6, 4], [3, 2, 2, 6, 3, 1]],\n    [[3, 2, 3, 1, 7, 6], [3, 1, 0, 4, 7, 6], [7, 6, 0, 4, 6, 4], [3, 1, 1, 5, 0, 4]],\n    [\n        [0, 1, 2, 0, 6, 4],\n        [6, 4, 5, 1, 0, 1],\n        [6, 7, 5, 1, 6, 4],\n        [6, 7, 3, 1, 5, 1],\n        [6, 7, 2, 3, 3, 1],\n    ],\n    [[0, 1, 4, 0, 4, 6], [4, 6, 6, 7, 0, 1], [0, 1, 6, 7, 2, 3]],\n    [[6, 7, 2, 3, 2, 0], [6, 4, 6, 7, 2, 0]],\n    [\n        [2, 6, 0, 2, 1, 3],\n        [1, 3, 7, 6, 2, 6],\n        [1, 5, 7, 6, 1, 3],\n        [1, 5, 4, 6, 7, 6],\n        [1, 5, 0, 4, 4, 6],\n    ],\n    [[1, 5, 1, 0, 1, 3], [4, 6, 7, 6, 2, 6]],\n    [[0, 1, 2, 6, 0, 2], [0, 1, 6, 7, 2, 6], [0, 1, 4, 6, 6, 7], [0, 4, 4, 6, 0, 1]],\n    [[6, 7, 6, 2, 6, 4]],\n    [[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4]],\n    [[7, 5, 6, 4, 7, 3], [6, 4, 6, 2, 7, 3], [1, 0, 2, 0, 4, 0]],\n    [[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4], [0, 1, 5, 1, 3, 1]],\n    [[2, 0, 0, 4, 1, 5], [2, 0, 1, 5, 3, 1], [2, 6, 3, 7, 7, 5], [7, 5, 6, 4, 2, 6]],\n    [[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4]],\n    [[3, 2, 3, 7, 1, 0], [3, 7, 6, 4, 1, 0], [3, 7, 7, 5, 6, 4], [1, 0, 6, 4, 0, 4]],\n    [[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4], [1, 5, 3, 1, 1, 0]],\n    [\n        [7, 3, 5, 7, 4, 6],\n        [4, 6, 2, 3, 7, 3],\n        [4, 0, 2, 3, 4, 6],\n        [4, 0, 1, 3, 2, 3],\n        [4, 0, 5, 1, 1, 3],\n    ],\n    [[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5]],\n    [[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5], [0, 1, 2, 0, 0, 4]],\n    [[1, 0, 1, 5, 3, 2], [1, 5, 4, 6, 3, 2], [3, 2, 4, 6, 2, 6], [1, 5, 5, 7, 4, 6]],\n    [\n        [0, 2, 4, 0, 5, 1],\n        [5, 1, 3, 2, 0, 2],\n        [5, 7, 3, 2, 5, 1],\n        [5, 7, 6, 2, 3, 2],\n        [5, 7, 4, 6, 6, 2],\n    ],\n    [[2, 0, 3, 1, 7, 5], [2, 0, 7, 5, 6, 4]],\n    [[4, 6, 0, 4, 0, 1], [0, 1, 1, 3, 4, 6], [4, 6, 1, 3, 5, 7]],\n    [[0, 2, 1, 0, 1, 5], [1, 5, 5, 7, 0, 2], [0, 2, 5, 7, 4, 6]],\n    [[5, 7, 4, 6, 4, 0], [5, 1, 5, 7, 4, 0]],\n    [[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2]],\n    [[0, 1, 0, 2, 4, 5], [0, 2, 3, 7, 4, 5], [4, 5, 3, 7, 5, 7], [0, 2, 2, 6, 3, 7]],\n    [[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2], [1, 0, 5, 1, 1, 3]],\n    [\n        [1, 5, 3, 1, 2, 0],\n        [2, 0, 4, 5, 1, 5],\n        [2, 6, 4, 5, 2, 0],\n        [2, 6, 7, 5, 4, 5],\n        [2, 6, 3, 7, 7, 5],\n    ],\n    [[2, 3, 0, 4, 2, 0], [2, 3, 4, 5, 0, 4], [2, 3, 3, 7, 4, 5], [3, 7, 7, 5, 4, 5]],\n    [[3, 2, 7, 3, 7, 5], [7, 5, 5, 4, 3, 2], [3, 2, 5, 4, 1, 0]],\n    [\n        [2, 3, 0, 4, 2, 0],\n        [2, 3, 4, 5, 0, 4],\n        [2, 3, 3, 7, 4, 5],\n        [3, 7, 7, 5, 4, 5],\n        [1, 5, 3, 1, 0, 1],\n    ],\n    [[3, 2, 1, 5, 3, 1], [3, 2, 5, 4, 1, 5], [3, 2, 7, 5, 5, 4], [3, 7, 7, 5, 3, 2]],\n    [[2, 6, 2, 3, 0, 4], [2, 3, 7, 5, 0, 4], [2, 3, 3, 1, 7, 5], [0, 4, 7, 5, 4, 5]],\n    [\n        [3, 2, 1, 3, 5, 7],\n        [5, 7, 6, 2, 3, 2],\n        [5, 4, 6, 2, 5, 7],\n        [5, 4, 0, 2, 6, 2],\n        [5, 4, 1, 0, 0, 2],\n    ],\n    [\n        [4, 5, 0, 4, 2, 6],\n        [2, 6, 7, 5, 4, 5],\n        [2, 3, 7, 5, 2, 6],\n        [2, 3, 1, 5, 7, 5],\n        [2, 3, 0, 1, 1, 5],\n    ],\n    [[2, 3, 2, 0, 2, 6], [1, 5, 7, 5, 4, 5]],\n    [[5, 7, 4, 5, 4, 0], [4, 0, 0, 2, 5, 7], [5, 7, 0, 2, 1, 3]],\n    [[5, 4, 1, 0, 1, 3], [5, 7, 5, 4, 1, 3]],\n    [[0, 2, 4, 5, 0, 4], [0, 2, 5, 7, 4, 5], [0, 2, 1, 5, 5, 7], [0, 1, 1, 5, 0, 2]],\n    [[5, 4, 5, 1, 5, 7]],\n    [[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3]],\n    [[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3], [0, 2, 4, 0, 0, 1]],\n    [[3, 7, 3, 1, 2, 6], [3, 1, 5, 4, 2, 6], [3, 1, 1, 0, 5, 4], [2, 6, 5, 4, 6, 4]],\n    [\n        [6, 4, 2, 6, 3, 7],\n        [3, 7, 5, 4, 6, 4],\n        [3, 1, 5, 4, 3, 7],\n        [3, 1, 0, 4, 5, 4],\n        [3, 1, 2, 0, 0, 4],\n    ],\n    [[2, 0, 2, 3, 6, 4], [2, 3, 1, 5, 6, 4], [6, 4, 1, 5, 4, 5], [2, 3, 3, 7, 1, 5]],\n    [\n        [0, 4, 1, 0, 3, 2],\n        [3, 2, 6, 4, 0, 4],\n        [3, 7, 6, 4, 3, 2],\n        [3, 7, 5, 4, 6, 4],\n        [3, 7, 1, 5, 5, 4],\n    ],\n    [\n        [1, 3, 0, 1, 4, 5],\n        [4, 5, 7, 3, 1, 3],\n        [4, 6, 7, 3, 4, 5],\n        [4, 6, 2, 3, 7, 3],\n        [4, 6, 0, 2, 2, 3],\n    ],\n    [[3, 7, 3, 1, 3, 2], [5, 4, 6, 4, 0, 4]],\n    [[3, 1, 2, 6, 3, 2], [3, 1, 6, 4, 2, 6], [3, 1, 1, 5, 6, 4], [1, 5, 5, 4, 6, 4]],\n    [\n        [3, 1, 2, 6, 3, 2],\n        [3, 1, 6, 4, 2, 6],\n        [3, 1, 1, 5, 6, 4],\n        [1, 5, 5, 4, 6, 4],\n        [0, 4, 1, 0, 2, 0],\n    ],\n    [[4, 5, 6, 4, 6, 2], [6, 2, 2, 3, 4, 5], [4, 5, 2, 3, 0, 1]],\n    [[2, 3, 6, 4, 2, 6], [2, 3, 4, 5, 6, 4], [2, 3, 0, 4, 4, 5], [2, 0, 0, 4, 2, 3]],\n    [[1, 3, 5, 1, 5, 4], [5, 4, 4, 6, 1, 3], [1, 3, 4, 6, 0, 2]],\n    [[1, 3, 0, 4, 1, 0], [1, 3, 4, 6, 0, 4], [1, 3, 5, 4, 4, 6], [1, 5, 5, 4, 1, 3]],\n    [[4, 6, 0, 2, 0, 1], [4, 5, 4, 6, 0, 1]],\n    [[4, 6, 4, 0, 4, 5]],\n    [[4, 0, 6, 2, 7, 3], [4, 0, 7, 3, 5, 1]],\n    [[1, 5, 0, 1, 0, 2], [0, 2, 2, 6, 1, 5], [1, 5, 2, 6, 3, 7]],\n    [[3, 7, 1, 3, 1, 0], [1, 0, 0, 4, 3, 7], [3, 7, 0, 4, 2, 6]],\n    [[3, 1, 2, 0, 2, 6], [3, 7, 3, 1, 2, 6]],\n    [[0, 4, 2, 0, 2, 3], [2, 3, 3, 7, 0, 4], [0, 4, 3, 7, 1, 5]],\n    [[3, 7, 1, 5, 1, 0], [3, 2, 3, 7, 1, 0]],\n    [[0, 4, 1, 3, 0, 1], [0, 4, 3, 7, 1, 3], [0, 4, 2, 3, 3, 7], [0, 2, 2, 3, 0, 4]],\n    [[3, 7, 3, 1, 3, 2]],\n    [[2, 6, 3, 2, 3, 1], [3, 1, 1, 5, 2, 6], [2, 6, 1, 5, 0, 4]],\n    [[1, 5, 3, 2, 1, 3], [1, 5, 2, 6, 3, 2], [1, 5, 0, 2, 2, 6], [1, 0, 0, 2, 1, 5]],\n    [[2, 3, 0, 1, 0, 4], [2, 6, 2, 3, 0, 4]],\n    [[2, 3, 2, 0, 2, 6]],\n    [[1, 5, 0, 4, 0, 2], [1, 3, 1, 5, 0, 2]],\n    [[1, 5, 1, 0, 1, 3]],\n    [[0, 2, 0, 1, 0, 4]],\n    [],\n]\n"
  },
  {
    "path": "shap_e/rendering/blender/__init__.py",
    "content": "from .render import render_mesh, render_model\nfrom .view_data import BlenderViewData\n\n__all__ = [\"BlenderViewData\", \"render_model\"]\n"
  },
  {
    "path": "shap_e/rendering/blender/blender_script.py",
    "content": "\"\"\"\nScript to run within blender.\n\nProvide arguments after `--`.\nFor example: `blender -b -P blender_script.py -- --help`\n\"\"\"\n\nimport argparse\nimport json\nimport math\nimport os\nimport random\nimport sys\n\nimport bpy\nfrom mathutils import Vector\nfrom mathutils.noise import random_unit_vector\n\nMAX_DEPTH = 5.0\nFORMAT_VERSION = 6\n\n# Set by main(), these constants are passed to the script to avoid\n# duplicating them across multiple files.\nUNIFORM_LIGHT_DIRECTION = None\nBASIC_AMBIENT_COLOR = None\nBASIC_DIFFUSE_COLOR = None\n\n\ndef clear_scene():\n    bpy.ops.object.select_all(action=\"SELECT\")\n    bpy.ops.object.delete()\n\n\ndef clear_lights():\n    bpy.ops.object.select_all(action=\"DESELECT\")\n    for obj in bpy.context.scene.objects.values():\n        if isinstance(obj.data, bpy.types.Light):\n            obj.select_set(True)\n    bpy.ops.object.delete()\n\n\ndef import_model(path):\n    clear_scene()\n    _, ext = os.path.splitext(path)\n    ext = ext.lower()\n    if ext == \".obj\":\n        bpy.ops.import_scene.obj(filepath=path)\n    elif ext in [\".glb\", \".gltf\"]:\n        bpy.ops.import_scene.gltf(filepath=path)\n    elif ext == \".stl\":\n        bpy.ops.import_mesh.stl(filepath=path)\n    elif ext == \".fbx\":\n        bpy.ops.import_scene.fbx(filepath=path)\n    elif ext == \".dae\":\n        bpy.ops.wm.collada_import(filepath=path)\n    elif ext == \".ply\":\n        bpy.ops.import_mesh.ply(filepath=path)\n    else:\n        raise RuntimeError(f\"unexpected extension: {ext}\")\n\n\ndef scene_root_objects():\n    for obj in bpy.context.scene.objects.values():\n        if not obj.parent:\n            yield obj\n\n\ndef scene_bbox(single_obj=None, ignore_matrix=False):\n    bbox_min = (math.inf,) * 3\n    bbox_max = (-math.inf,) * 3\n    found = False\n    for obj in scene_meshes() if single_obj is None else [single_obj]:\n        found = True\n        for coord in obj.bound_box:\n            coord = Vector(coord)\n            if not ignore_matrix:\n                coord = obj.matrix_world @ coord\n            bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))\n            bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))\n    if not found:\n        raise RuntimeError(\"no objects in scene to compute bounding box for\")\n    return Vector(bbox_min), Vector(bbox_max)\n\n\ndef scene_meshes():\n    for obj in bpy.context.scene.objects.values():\n        if isinstance(obj.data, (bpy.types.Mesh)):\n            yield obj\n\n\ndef normalize_scene():\n    if len(list(scene_root_objects())) > 1:\n        # Create an empty object to be used as a parent for all root objects\n        parent_empty = bpy.data.objects.new(\"ParentEmpty\", None)\n        bpy.context.scene.collection.objects.link(parent_empty)\n\n        # Parent all root objects to the empty object\n        for obj in scene_root_objects():\n            if obj != parent_empty:\n                obj.parent = parent_empty\n\n    bbox_min, bbox_max = scene_bbox()\n    scale = 1 / max(bbox_max - bbox_min)\n\n    for obj in scene_root_objects():\n        obj.scale = obj.scale * scale\n\n    # Apply scale to matrix_world.\n    bpy.context.view_layer.update()\n\n    bbox_min, bbox_max = scene_bbox()\n    offset = -(bbox_min + bbox_max) / 2\n    for obj in scene_root_objects():\n        obj.matrix_world.translation += offset\n\n    bpy.ops.object.select_all(action=\"DESELECT\")\n\n\ndef create_camera():\n    # https://b3d.interplanety.org/en/how-to-create-camera-through-the-blender-python-api/\n    camera_data = bpy.data.cameras.new(name=\"Camera\")\n    camera_object = bpy.data.objects.new(\"Camera\", camera_data)\n    bpy.context.scene.collection.objects.link(camera_object)\n    bpy.context.scene.camera = camera_object\n\n\ndef set_camera(direction, camera_dist=2.0):\n    camera_pos = -camera_dist * direction\n    bpy.context.scene.camera.location = camera_pos\n\n    # https://blender.stackexchange.com/questions/5210/pointing-the-camera-in-a-particular-direction-programmatically\n    rot_quat = direction.to_track_quat(\"-Z\", \"Y\")\n    bpy.context.scene.camera.rotation_euler = rot_quat.to_euler()\n\n    bpy.context.view_layer.update()\n\n\ndef randomize_camera(camera_dist=2.0):\n    direction = random_unit_vector()\n    set_camera(direction, camera_dist=camera_dist)\n\n\ndef pan_camera(time, axis=\"Z\", camera_dist=2.0, elevation=0.1):\n    angle = time * math.pi * 2\n    direction = [-math.cos(angle), -math.sin(angle), elevation]\n    assert axis in [\"X\", \"Y\", \"Z\"]\n    if axis == \"X\":\n        direction = [direction[2], *direction[:2]]\n    elif axis == \"Y\":\n        direction = [direction[0], elevation, direction[1]]\n    direction = Vector(direction).normalized()\n    set_camera(direction, camera_dist=camera_dist)\n\n\ndef place_camera(time, camera_pose_mode=\"random\", camera_dist_min=2.0, camera_dist_max=2.0):\n    camera_dist = random.uniform(camera_dist_min, camera_dist_max)\n    if camera_pose_mode == \"random\":\n        randomize_camera(camera_dist=camera_dist)\n    elif camera_pose_mode == \"z-circular\":\n        pan_camera(time, axis=\"Z\", camera_dist=camera_dist)\n    elif camera_pose_mode == \"z-circular-elevated\":\n        pan_camera(time, axis=\"Z\", camera_dist=camera_dist, elevation=-0.2617993878)\n    else:\n        raise ValueError(f\"Unknown camera pose mode: {camera_pose_mode}\")\n\n\ndef create_light(location, energy=1.0, angle=0.5 * math.pi / 180):\n    # https://blender.stackexchange.com/questions/215624/how-to-create-a-light-with-the-python-api-in-blender-2-92\n    light_data = bpy.data.lights.new(name=\"Light\", type=\"SUN\")\n    light_data.energy = energy\n    light_data.angle = angle\n    light_object = bpy.data.objects.new(name=\"Light\", object_data=light_data)\n\n    direction = -location\n    rot_quat = direction.to_track_quat(\"-Z\", \"Y\")\n    light_object.rotation_euler = rot_quat.to_euler()\n    bpy.context.view_layer.update()\n\n    bpy.context.collection.objects.link(light_object)\n    light_object.location = location\n\n\ndef create_random_lights(count=4, distance=2.0, energy=1.5):\n    clear_lights()\n    for _ in range(count):\n        create_light(random_unit_vector() * distance, energy=energy)\n\n\ndef create_camera_light():\n    clear_lights()\n    create_light(bpy.context.scene.camera.location, energy=5.0)\n\n\ndef create_uniform_light(backend):\n    clear_lights()\n    # Random direction to decorrelate axis-aligned sides.\n    pos = Vector(UNIFORM_LIGHT_DIRECTION)\n    angle = 0.0092 if backend == \"CYCLES\" else math.pi\n    create_light(pos, energy=5.0, angle=angle)\n    create_light(-pos, energy=5.0, angle=angle)\n\n\ndef create_vertex_color_shaders():\n    # By default, Blender will ignore vertex colors in both the\n    # Eevee and Cycles backends, since these colors aren't\n    # associated with a material.\n    #\n    # What we do here is create a simple material shader and link\n    # the vertex color to the material color.\n    for obj in bpy.context.scene.objects.values():\n        if not isinstance(obj.data, (bpy.types.Mesh)):\n            continue\n\n        if len(obj.data.materials):\n            # We don't want to override any existing materials.\n            continue\n\n        color_keys = (obj.data.vertex_colors or {}).keys()\n        if not len(color_keys):\n            # Many objects will have no materials *or* vertex colors.\n            continue\n\n        mat = bpy.data.materials.new(name=\"VertexColored\")\n        mat.use_nodes = True\n\n        # There should be a Principled BSDF by default.\n        bsdf_node = None\n        for node in mat.node_tree.nodes:\n            if node.type == \"BSDF_PRINCIPLED\":\n                bsdf_node = node\n        assert bsdf_node is not None, \"material has no Principled BSDF node to modify\"\n\n        socket_map = {}\n        for input in bsdf_node.inputs:\n            socket_map[input.name] = input\n\n        # Make sure nothing lights the object except for the diffuse color.\n        socket_map[\"Specular\"].default_value = 0.0\n        socket_map[\"Roughness\"].default_value = 1.0\n\n        v_color = mat.node_tree.nodes.new(\"ShaderNodeVertexColor\")\n        v_color.layer_name = color_keys[0]\n\n        mat.node_tree.links.new(v_color.outputs[0], socket_map[\"Base Color\"])\n\n        obj.data.materials.append(mat)\n\n\ndef create_default_materials():\n    for obj in bpy.context.scene.objects.values():\n        if isinstance(obj.data, (bpy.types.Mesh)):\n            if not len(obj.data.materials):\n                mat = bpy.data.materials.new(name=\"DefaultMaterial\")\n                mat.use_nodes = True\n                obj.data.materials.append(mat)\n\n\ndef find_materials():\n    all_materials = set()\n    for obj in bpy.context.scene.objects.values():\n        if not isinstance(obj.data, bpy.types.Mesh):\n            continue\n        for mat in obj.data.materials:\n            all_materials.add(mat)\n    return all_materials\n\n\ndef delete_all_materials():\n    for obj in bpy.context.scene.objects.values():\n        if isinstance(obj.data, bpy.types.Mesh):\n            # https://blender.stackexchange.com/questions/146714/removing-all-material-slots-in-one-go\n            obj.data.materials.clear()\n\n\ndef setup_material_extraction_shaders(capturing_material_alpha: bool):\n    \"\"\"\n    Change every material to emit texture colors (or alpha) rather than having\n    an actual reflective color. Returns a function to undo the changes to the\n    materials.\n    \"\"\"\n    # Objects can share materials, so we first find all of the\n    # materials in the project, and then modify them each once.\n    undo_fns = []\n    for mat in find_materials():\n        undo_fn = setup_material_extraction_shader_for_material(mat, capturing_material_alpha)\n        if undo_fn is not None:\n            undo_fns.append(undo_fn)\n    return lambda: [undo_fn() for undo_fn in undo_fns]\n\n\ndef setup_material_extraction_shader_for_material(mat, capturing_material_alpha: bool):\n    mat.use_nodes = True\n\n    # By default, most imported models should use the regular\n    # \"Principled BSDF\" material, so we should always find this.\n    # If not, this shader manipulation logic won't work.\n    bsdf_node = None\n    for node in mat.node_tree.nodes:\n        if node.type == \"BSDF_PRINCIPLED\":\n            bsdf_node = node\n    assert bsdf_node is not None, \"material has no Principled BSDF node to modify\"\n\n    socket_map = {}\n    for input in bsdf_node.inputs:\n        socket_map[input.name] = input\n    for name in [\"Base Color\", \"Emission\", \"Emission Strength\", \"Alpha\", \"Specular\"]:\n        assert name in socket_map.keys(), f\"{name} not in {list(socket_map.keys())}\"\n\n    old_base_color = get_socket_value(mat.node_tree, socket_map[\"Base Color\"])\n    old_alpha = get_socket_value(mat.node_tree, socket_map[\"Alpha\"])\n    old_emission = get_socket_value(mat.node_tree, socket_map[\"Emission\"])\n    old_emission_strength = get_socket_value(mat.node_tree, socket_map[\"Emission Strength\"])\n    old_specular = get_socket_value(mat.node_tree, socket_map[\"Specular\"])\n\n    # Make sure the base color of all objects is black and the opacity\n    # is 1, so that we are effectively just telling the shader what color\n    # to make the pixels.\n    clear_socket_input(mat.node_tree, socket_map[\"Base Color\"])\n    socket_map[\"Base Color\"].default_value = [0, 0, 0, 1]\n    clear_socket_input(mat.node_tree, socket_map[\"Alpha\"])\n    socket_map[\"Alpha\"].default_value = 1\n    clear_socket_input(mat.node_tree, socket_map[\"Specular\"])\n    socket_map[\"Specular\"].default_value = 0.0\n\n    old_blend_method = mat.blend_method\n    mat.blend_method = \"OPAQUE\"\n\n    if capturing_material_alpha:\n        set_socket_value(mat.node_tree, socket_map[\"Emission\"], old_alpha)\n    else:\n        set_socket_value(mat.node_tree, socket_map[\"Emission\"], old_base_color)\n    clear_socket_input(mat.node_tree, socket_map[\"Emission Strength\"])\n    socket_map[\"Emission Strength\"].default_value = 1.0\n\n    def undo_fn():\n        mat.blend_method = old_blend_method\n        set_socket_value(mat.node_tree, socket_map[\"Base Color\"], old_base_color)\n        set_socket_value(mat.node_tree, socket_map[\"Alpha\"], old_alpha)\n        set_socket_value(mat.node_tree, socket_map[\"Emission\"], old_emission)\n        set_socket_value(mat.node_tree, socket_map[\"Emission Strength\"], old_emission_strength)\n        set_socket_value(mat.node_tree, socket_map[\"Specular\"], old_specular)\n\n    return undo_fn\n\n\ndef get_socket_value(tree, socket):\n    default = socket.default_value\n    if not isinstance(default, float):\n        default = list(default)\n    for link in tree.links:\n        if link.to_socket == socket:\n            return (link.from_socket, default)\n    return (None, default)\n\n\ndef clear_socket_input(tree, socket):\n    for link in list(tree.links):\n        if link.to_socket == socket:\n            tree.links.remove(link)\n\n\ndef set_socket_value(tree, socket, socket_and_default):\n    clear_socket_input(tree, socket)\n    old_source_socket, default = socket_and_default\n    if isinstance(default, float) and not isinstance(socket.default_value, float):\n        # Codepath for setting Emission to a previous alpha value.\n        socket.default_value = [default] * 3 + [1.0]\n    else:\n        socket.default_value = default\n    if old_source_socket is not None:\n        tree.links.new(old_source_socket, socket)\n\n\ndef setup_nodes(output_path, capturing_material_alpha: bool = False, basic_lighting: bool = False):\n    tree = bpy.context.scene.node_tree\n    links = tree.links\n\n    for node in tree.nodes:\n        tree.nodes.remove(node)\n\n    # Helpers to perform math on links and constants.\n    def node_op(op: str, *args, clamp=False):\n        node = tree.nodes.new(type=\"CompositorNodeMath\")\n        node.operation = op\n        if clamp:\n            node.use_clamp = True\n        for i, arg in enumerate(args):\n            if isinstance(arg, (int, float)):\n                node.inputs[i].default_value = arg\n            else:\n                links.new(arg, node.inputs[i])\n        return node.outputs[0]\n\n    def node_clamp(x, maximum=1.0):\n        return node_op(\"MINIMUM\", x, maximum)\n\n    def node_mul(x, y, **kwargs):\n        return node_op(\"MULTIPLY\", x, y, **kwargs)\n\n    def node_add(x, y, **kwargs):\n        return node_op(\"ADD\", x, y, **kwargs)\n\n    def node_abs(x, **kwargs):\n        return node_op(\"ABSOLUTE\", x, **kwargs)\n\n    input_node = tree.nodes.new(type=\"CompositorNodeRLayers\")\n    input_node.scene = bpy.context.scene\n\n    input_sockets = {}\n    for output in input_node.outputs:\n        input_sockets[output.name] = output\n\n    if capturing_material_alpha:\n        color_socket = input_sockets[\"Image\"]\n    else:\n        raw_color_socket = input_sockets[\"Image\"]\n        if basic_lighting:\n            # Compute diffuse lighting\n            normal_xyz = tree.nodes.new(type=\"CompositorNodeSeparateXYZ\")\n            tree.links.new(input_sockets[\"Normal\"], normal_xyz.inputs[0])\n            normal_x, normal_y, normal_z = [normal_xyz.outputs[i] for i in range(3)]\n            dot = node_add(\n                node_mul(UNIFORM_LIGHT_DIRECTION[0], normal_x),\n                node_add(\n                    node_mul(UNIFORM_LIGHT_DIRECTION[1], normal_y),\n                    node_mul(UNIFORM_LIGHT_DIRECTION[2], normal_z),\n                ),\n            )\n            diffuse = node_abs(dot)\n            # Compute ambient + diffuse lighting\n            brightness = node_add(BASIC_AMBIENT_COLOR, node_mul(BASIC_DIFFUSE_COLOR, diffuse))\n            # Modulate the RGB channels using the total brightness.\n            rgba_node = tree.nodes.new(type=\"CompositorNodeSepRGBA\")\n            tree.links.new(raw_color_socket, rgba_node.inputs[0])\n            combine_node = tree.nodes.new(type=\"CompositorNodeCombRGBA\")\n            for i in range(3):\n                tree.links.new(node_mul(rgba_node.outputs[i], brightness), combine_node.inputs[i])\n            tree.links.new(rgba_node.outputs[3], combine_node.inputs[3])\n            raw_color_socket = combine_node.outputs[0]\n\n        # We apply sRGB here so that our fixed-point depth map and material\n        # alpha values are not sRGB, and so that we perform ambient+diffuse\n        # lighting in linear RGB space.\n        color_node = tree.nodes.new(type=\"CompositorNodeConvertColorSpace\")\n        color_node.from_color_space = \"Linear\"\n        color_node.to_color_space = \"sRGB\"\n        tree.links.new(raw_color_socket, color_node.inputs[0])\n        color_socket = color_node.outputs[0]\n    split_node = tree.nodes.new(type=\"CompositorNodeSepRGBA\")\n    tree.links.new(color_socket, split_node.inputs[0])\n    # Create separate file output nodes for every channel we care about.\n    # The process calling this script must decide how to recombine these\n    # channels, possibly into a single image.\n    for i, channel in enumerate(\"rgba\") if not capturing_material_alpha else [(0, \"MatAlpha\")]:\n        output_node = tree.nodes.new(type=\"CompositorNodeOutputFile\")\n        output_node.base_path = f\"{output_path}_{channel}\"\n        links.new(split_node.outputs[i], output_node.inputs[0])\n\n    if capturing_material_alpha:\n        # No need to re-write depth here.\n        return\n\n    depth_out = node_clamp(node_mul(input_sockets[\"Depth\"], 1 / MAX_DEPTH))\n    output_node = tree.nodes.new(type=\"CompositorNodeOutputFile\")\n    output_node.base_path = f\"{output_path}_depth\"\n    links.new(depth_out, output_node.inputs[0])\n\n\ndef render_scene(output_path, fast_mode: bool, extract_material: bool, basic_lighting: bool):\n    use_workbench = bpy.context.scene.render.engine == \"BLENDER_WORKBENCH\"\n    if use_workbench:\n        # We must use a different engine to compute depth maps.\n        bpy.context.scene.render.engine = \"BLENDER_EEVEE\"\n        bpy.context.scene.eevee.taa_render_samples = 1  # faster, since we discard image.\n    if fast_mode:\n        if bpy.context.scene.render.engine == \"BLENDER_EEVEE\":\n            bpy.context.scene.eevee.taa_render_samples = 1\n        elif bpy.context.scene.render.engine == \"CYCLES\":\n            bpy.context.scene.cycles.samples = 256\n    else:\n        if bpy.context.scene.render.engine == \"CYCLES\":\n            # We should still impose a per-frame time limit\n            # so that we don't timeout completely.\n            bpy.context.scene.cycles.time_limit = 40\n    bpy.context.view_layer.update()\n    bpy.context.scene.use_nodes = True\n    bpy.context.scene.view_layers[\"ViewLayer\"].use_pass_z = True\n    if basic_lighting:\n        bpy.context.scene.view_layers[\"ViewLayer\"].use_pass_normal = True\n    bpy.context.scene.view_settings.view_transform = \"Raw\"  # sRGB done in graph nodes\n    bpy.context.scene.render.film_transparent = True\n    bpy.context.scene.render.resolution_x = 512\n    bpy.context.scene.render.resolution_y = 512\n    bpy.context.scene.render.image_settings.file_format = \"PNG\"\n    bpy.context.scene.render.image_settings.color_mode = \"BW\"\n    bpy.context.scene.render.image_settings.color_depth = \"16\"\n    bpy.context.scene.render.filepath = output_path\n    if extract_material:\n        for do_alpha in [False, True]:\n            undo_fn = setup_material_extraction_shaders(capturing_material_alpha=do_alpha)\n            setup_nodes(output_path, capturing_material_alpha=do_alpha)\n            bpy.ops.render.render(write_still=True)\n            undo_fn()\n    else:\n        setup_nodes(output_path, basic_lighting=basic_lighting)\n        bpy.ops.render.render(write_still=True)\n\n    # The output images must be moved from their own sub-directories, or\n    # discarded if we are using workbench for the color.\n    for channel_name in [\"r\", \"g\", \"b\", \"a\", \"depth\", *([\"MatAlpha\"] if extract_material else [])]:\n        sub_dir = f\"{output_path}_{channel_name}\"\n        image_path = os.path.join(sub_dir, os.listdir(sub_dir)[0])\n        name, ext = os.path.splitext(output_path)\n        if channel_name == \"depth\" or not use_workbench:\n            os.rename(image_path, f\"{name}_{channel_name}{ext}\")\n        else:\n            os.remove(image_path)\n        os.removedirs(sub_dir)\n\n    if use_workbench:\n        # Re-render RGBA using workbench with texture mode, since this seems\n        # to show the most reasonable colors when lighting is broken.\n        bpy.context.scene.use_nodes = False\n        bpy.context.scene.render.engine = \"BLENDER_WORKBENCH\"\n        bpy.context.scene.render.image_settings.color_mode = \"RGBA\"\n        bpy.context.scene.render.image_settings.color_depth = \"8\"\n        bpy.context.scene.display.shading.color_type = \"TEXTURE\"\n        bpy.context.scene.display.shading.light = \"FLAT\"\n        if fast_mode:\n            # Single pass anti-aliasing.\n            bpy.context.scene.display.render_aa = \"FXAA\"\n        os.remove(output_path)\n        bpy.ops.render.render(write_still=True)\n        bpy.context.scene.render.image_settings.color_mode = \"BW\"\n        bpy.context.scene.render.image_settings.color_depth = \"16\"\n\n\ndef scene_fov():\n    x_fov = bpy.context.scene.camera.data.angle_x\n    y_fov = bpy.context.scene.camera.data.angle_y\n    width = bpy.context.scene.render.resolution_x\n    height = bpy.context.scene.render.resolution_y\n    if bpy.context.scene.camera.data.angle == x_fov:\n        y_fov = 2 * math.atan(math.tan(x_fov / 2) * height / width)\n    else:\n        x_fov = 2 * math.atan(math.tan(y_fov / 2) * width / height)\n    return x_fov, y_fov\n\n\ndef write_camera_metadata(path):\n    x_fov, y_fov = scene_fov()\n    bbox_min, bbox_max = scene_bbox()\n    matrix = bpy.context.scene.camera.matrix_world\n    with open(path, \"w\") as f:\n        json.dump(\n            dict(\n                format_version=FORMAT_VERSION,\n                max_depth=MAX_DEPTH,\n                bbox=[list(bbox_min), list(bbox_max)],\n                origin=list(matrix.col[3])[:3],\n                x_fov=x_fov,\n                y_fov=y_fov,\n                x=list(matrix.col[0])[:3],\n                y=list(-matrix.col[1])[:3],\n                z=list(-matrix.col[2])[:3],\n            ),\n            f,\n        )\n\n\ndef save_rendering_dataset(\n    input_path: str,\n    output_path: str,\n    num_images: int,\n    backend: str,\n    light_mode: str,\n    camera_pose: str,\n    camera_dist_min: float,\n    camera_dist_max: float,\n    fast_mode: bool,\n    extract_material: bool,\n    delete_material: bool,\n):\n    assert light_mode in [\"random\", \"uniform\", \"camera\", \"basic\"]\n    assert camera_pose in [\"random\", \"z-circular\", \"z-circular-elevated\"]\n\n    basic_lighting = light_mode == \"basic\"\n    assert not (basic_lighting and extract_material), \"cannot extract material with basic lighting\"\n    assert not (delete_material and extract_material), \"cannot extract material and delete it\"\n\n    import_model(input_path)\n    bpy.context.scene.render.engine = backend\n    normalize_scene()\n    if light_mode == \"random\":\n        create_random_lights()\n    elif light_mode == \"uniform\":\n        create_uniform_light(backend)\n    create_camera()\n    create_vertex_color_shaders()\n    if delete_material:\n        delete_all_materials()\n    if extract_material or basic_lighting:\n        create_default_materials()\n    if basic_lighting:\n        # Make sure materials are uniformly lit, so that we can light\n        # them in the output shader.\n        setup_material_extraction_shaders(capturing_material_alpha=False)\n    for i in range(num_images):\n        t = i / max(num_images - 1, 1)  # same as np.linspace(0, 1, num_images)\n        place_camera(\n            t,\n            camera_pose_mode=camera_pose,\n            camera_dist_min=camera_dist_min,\n            camera_dist_max=camera_dist_max,\n        )\n        if light_mode == \"camera\":\n            create_camera_light()\n        render_scene(\n            os.path.join(output_path, f\"{i:05}.png\"),\n            fast_mode=fast_mode,\n            extract_material=extract_material,\n            basic_lighting=basic_lighting,\n        )\n        write_camera_metadata(os.path.join(output_path, f\"{i:05}.json\"))\n    with open(os.path.join(output_path, \"info.json\"), \"w\") as f:\n        info = dict(\n            backend=backend,\n            light_mode=light_mode,\n            fast_mode=fast_mode,\n            extract_material=extract_material,\n            format_version=FORMAT_VERSION,\n            channels=[\"R\", \"G\", \"B\", \"A\", \"D\", *([\"MatAlpha\"] if extract_material else [])],\n            scale=0.5,  # The scene is bounded by [-scale, scale].\n        )\n        json.dump(info, f)\n\n\ndef main():\n    global UNIFORM_LIGHT_DIRECTION, BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR\n\n    try:\n        dash_index = sys.argv.index(\"--\")\n    except ValueError as exc:\n        raise ValueError(\"arguments must be preceded by '--'\") from exc\n\n    raw_args = sys.argv[dash_index + 1 :]\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--input_path\", required=True, type=str)\n    parser.add_argument(\"--output_path\", required=True, type=str)\n    parser.add_argument(\"--num_images\", required=True, type=int)\n    parser.add_argument(\"--backend\", type=str, default=\"BLENDER_EEVEE\")\n    parser.add_argument(\"--light_mode\", type=str, default=\"random\")\n    parser.add_argument(\"--camera_pose\", type=str, default=\"random\")\n    parser.add_argument(\"--camera_dist_min\", type=float, default=2.0)\n    parser.add_argument(\"--camera_dist_max\", type=float, default=2.0)\n    parser.add_argument(\"--fast_mode\", action=\"store_true\")\n    parser.add_argument(\"--extract_material\", action=\"store_true\")\n    parser.add_argument(\"--delete_material\", action=\"store_true\")\n\n    # Prevent constants from being repeated.\n    parser.add_argument(\"--uniform_light_direction\", required=True, type=float, nargs=\"+\")\n    parser.add_argument(\"--basic_ambient\", required=True, type=float)\n    parser.add_argument(\"--basic_diffuse\", required=True, type=float)\n    args = parser.parse_args(raw_args)\n\n    UNIFORM_LIGHT_DIRECTION = args.uniform_light_direction\n    BASIC_AMBIENT_COLOR = args.basic_ambient\n    BASIC_DIFFUSE_COLOR = args.basic_diffuse\n\n    save_rendering_dataset(\n        input_path=args.input_path,\n        output_path=args.output_path,\n        num_images=args.num_images,\n        backend=args.backend,\n        light_mode=args.light_mode,\n        camera_pose=args.camera_pose,\n        camera_dist_min=args.camera_dist_min,\n        camera_dist_max=args.camera_dist_max,\n        fast_mode=args.fast_mode,\n        extract_material=args.extract_material,\n        delete_material=args.delete_material,\n    )\n\n\nmain()\n"
  },
  {
    "path": "shap_e/rendering/blender/constants.py",
    "content": "UNIFORM_LIGHT_DIRECTION = [0.09387503, -0.63953443, -0.7630093]\nBASIC_AMBIENT_COLOR = 0.3\nBASIC_DIFFUSE_COLOR = 0.7\n"
  },
  {
    "path": "shap_e/rendering/blender/render.py",
    "content": "import os\nimport platform\nimport subprocess\nimport tempfile\nimport zipfile\n\nimport blobfile as bf\nimport numpy as np\nfrom PIL import Image\n\nfrom shap_e.rendering.mesh import TriMesh\n\nfrom .constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR, UNIFORM_LIGHT_DIRECTION\n\nSCRIPT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"blender_script.py\")\n\n\ndef render_model(\n    model_path: str,\n    output_path: str,\n    num_images: int,\n    backend: str = \"BLENDER_EEVEE\",\n    light_mode: str = \"random\",\n    camera_pose: str = \"random\",\n    camera_dist_min: float = 2.0,\n    camera_dist_max: float = 2.0,\n    fast_mode: bool = False,\n    extract_material: bool = False,\n    delete_material: bool = False,\n    verbose: bool = False,\n    timeout: float = 15 * 60,\n):\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        tmp_in = model_path\n        tmp_out = os.path.join(tmp_dir, \"out\")\n        zip_out = tmp_out + \".zip\"\n        os.mkdir(tmp_out)\n        args = []\n        if platform.system() == \"Linux\":\n            # Needed to enable Eevee backend on headless linux.\n            args = [\"xvfb-run\", \"-a\"]\n        args.extend(\n            [\n                _blender_binary_path(),\n                \"-b\",\n                \"-P\",\n                SCRIPT_PATH,\n                \"--\",\n                \"--input_path\",\n                tmp_in,\n                \"--output_path\",\n                tmp_out,\n                \"--num_images\",\n                str(num_images),\n                \"--backend\",\n                backend,\n                \"--light_mode\",\n                light_mode,\n                \"--camera_pose\",\n                camera_pose,\n                \"--camera_dist_min\",\n                str(camera_dist_min),\n                \"--camera_dist_max\",\n                str(camera_dist_max),\n                \"--uniform_light_direction\",\n                *[str(x) for x in UNIFORM_LIGHT_DIRECTION],\n                \"--basic_ambient\",\n                str(BASIC_AMBIENT_COLOR),\n                \"--basic_diffuse\",\n                str(BASIC_DIFFUSE_COLOR),\n            ]\n        )\n        if fast_mode:\n            args.append(\"--fast_mode\")\n        if extract_material:\n            args.append(\"--extract_material\")\n        if delete_material:\n            args.append(\"--delete_material\")\n        if verbose:\n            subprocess.check_call(args)\n        else:\n            try:\n                output = subprocess.check_output(args, stderr=subprocess.STDOUT, timeout=timeout)\n            except subprocess.CalledProcessError as exc:\n                raise RuntimeError(f\"{exc}: {exc.output}\") from exc\n        if not os.path.exists(os.path.join(tmp_out, \"info.json\")):\n            if verbose:\n                # There is no output available, since it was\n                # logged directly to stdout/stderr.\n                raise RuntimeError(f\"render failed: output file missing\")\n            else:\n                raise RuntimeError(f\"render failed: output file missing. Output: {output}\")\n        _combine_rgba(tmp_out)\n        with zipfile.ZipFile(zip_out, mode=\"w\") as zf:\n            for name in os.listdir(tmp_out):\n                zf.write(os.path.join(tmp_out, name), name)\n        bf.copy(zip_out, output_path, overwrite=True)\n\n\ndef render_mesh(\n    mesh: TriMesh,\n    output_path: str,\n    num_images: int,\n    backend: str = \"BLENDER_EEVEE\",\n    **kwargs,\n):\n    if mesh.has_vertex_colors() and backend not in [\"BLENDER_EEVEE\", \"CYCLES\"]:\n        raise ValueError(f\"backend does not support vertex colors: {backend}\")\n\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        ply_path = os.path.join(tmp_dir, \"out.ply\")\n        with open(ply_path, \"wb\") as f:\n            mesh.write_ply(f)\n        render_model(\n            ply_path, output_path=output_path, num_images=num_images, backend=backend, **kwargs\n        )\n\n\ndef _combine_rgba(out_dir: str):\n    i = 0\n    while True:\n        paths = [os.path.join(out_dir, f\"{i:05}_{ch}.png\") for ch in \"rgba\"]\n        if not os.path.exists(paths[0]):\n            break\n        joined = np.stack(\n            [(np.array(Image.open(path)) >> 8).astype(np.uint8) for path in paths], axis=-1\n        )\n        Image.fromarray(joined).save(os.path.join(out_dir, f\"{i:05}.png\"))\n        for path in paths:\n            os.remove(path)\n        i += 1\n\n\ndef _blender_binary_path() -> str:\n    path = os.getenv(\"BLENDER_PATH\", None)\n    if path is not None:\n        return path\n\n    if os.path.exists(\"/Applications/Blender.app/Contents/MacOS/Blender\"):\n        return \"/Applications/Blender.app/Contents/MacOS/Blender\"\n\n    raise EnvironmentError(\n        \"To render 3D models, install Blender version 3.3.1 or higher and \"\n        \"set the environment variable `BLENDER_PATH` to the path of the Blender executable.\"\n    )\n"
  },
  {
    "path": "shap_e/rendering/blender/view_data.py",
    "content": "import itertools\nimport json\nimport zipfile\nfrom typing import BinaryIO, List, Tuple\n\nimport numpy as np\nfrom PIL import Image\n\nfrom shap_e.rendering.view_data import Camera, ProjectiveCamera, ViewData\n\n\nclass BlenderViewData(ViewData):\n    \"\"\"\n    Interact with a dataset zipfile exported by view_data.py.\n    \"\"\"\n\n    def __init__(self, f_obj: BinaryIO):\n        self.zipfile = zipfile.ZipFile(f_obj, mode=\"r\")\n        self.infos = []\n        with self.zipfile.open(\"info.json\", \"r\") as f:\n            self.info = json.load(f)\n        self.channels = list(self.info.get(\"channels\", \"RGBAD\"))\n        assert set(\"RGBA\").issubset(\n            set(self.channels)\n        ), \"The blender output should at least have RGBA images.\"\n        names = set(x.filename for x in self.zipfile.infolist())\n        for i in itertools.count():\n            name = f\"{i:05}.json\"\n            if name not in names:\n                break\n            with self.zipfile.open(name, \"r\") as f:\n                self.infos.append(json.load(f))\n\n    @property\n    def num_views(self) -> int:\n        return len(self.infos)\n\n    @property\n    def channel_names(self) -> List[str]:\n        return list(self.channels)\n\n    def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]:\n        for ch in channels:\n            if ch not in self.channel_names:\n                raise ValueError(f\"unsupported channel: {ch}\")\n\n        # Gather (a superset of) the requested channels.\n        channel_map = {}\n        if any(x in channels for x in \"RGBA\"):\n            with self.zipfile.open(f\"{index:05}.png\", \"r\") as f:\n                rgba = np.array(Image.open(f)).astype(np.float32) / 255.0\n                channel_map.update(zip(\"RGBA\", rgba.transpose([2, 0, 1])))\n        if \"D\" in channels:\n            with self.zipfile.open(f\"{index:05}_depth.png\", \"r\") as f:\n                # Decode a 16-bit fixed-point number.\n                fp = np.array(Image.open(f))\n                inf_dist = fp == 0xFFFF\n                channel_map[\"D\"] = np.where(\n                    inf_dist,\n                    np.inf,\n                    self.infos[index][\"max_depth\"] * (fp.astype(np.float32) / 65536),\n                )\n        if \"MatAlpha\" in channels:\n            with self.zipfile.open(f\"{index:05}_MatAlpha.png\", \"r\") as f:\n                channel_map[\"MatAlpha\"] = np.array(Image.open(f)).astype(np.float32) / 65536\n\n        # The order of channels is user-specified.\n        combined = np.stack([channel_map[k] for k in channels], axis=-1)\n\n        h, w, _ = combined.shape\n        return self.camera(index, w, h), combined\n\n    def camera(self, index: int, width: int, height: int) -> ProjectiveCamera:\n        info = self.infos[index]\n        return ProjectiveCamera(\n            origin=np.array(info[\"origin\"], dtype=np.float32),\n            x=np.array(info[\"x\"], dtype=np.float32),\n            y=np.array(info[\"y\"], dtype=np.float32),\n            z=np.array(info[\"z\"], dtype=np.float32),\n            width=width,\n            height=height,\n            x_fov=info[\"x_fov\"],\n            y_fov=info[\"y_fov\"],\n        )\n"
  },
  {
    "path": "shap_e/rendering/mc.py",
    "content": "from dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import Tuple\n\nimport torch\n\nfrom ._mc_table import MC_TABLE\nfrom .torch_mesh import TorchMesh\n\n\ndef marching_cubes(\n    field: torch.Tensor,\n    min_point: torch.Tensor,\n    size: torch.Tensor,\n) -> TorchMesh:\n    \"\"\"\n    For a signed distance field, produce a mesh using marching cubes.\n\n    :param field: a 3D tensor of field values, where negative values correspond\n                  to the outside of the shape. The dimensions correspond to the\n                  x, y, and z directions, respectively.\n    :param min_point: a tensor of shape [3] containing the point corresponding\n                      to (0, 0, 0) in the field.\n    :param size: a tensor of shape [3] containing the per-axis distance from the\n                 (0, 0, 0) field corner and the (-1, -1, -1) field corner.\n    \"\"\"\n    assert len(field.shape) == 3, \"input must be a 3D scalar field\"\n    dev = field.device\n\n    grid_size = field.shape\n    grid_size_tensor = torch.tensor(grid_size).to(size)\n    lut = _lookup_table(dev)\n\n    # Create bitmasks between 0 and 255 (inclusive) indicating the state\n    # of the eight corners of each cube.\n    bitmasks = (field > 0).to(torch.uint8)\n    bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1)\n    bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2)\n    bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4)\n\n    # Compute corner coordinates across the entire grid.\n    corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype)\n    corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(\n        grid_size[0], device=dev, dtype=field.dtype\n    )[:, None, None]\n    corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(\n        grid_size[1], device=dev, dtype=field.dtype\n    )[:, None]\n    corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(\n        grid_size[2], device=dev, dtype=field.dtype\n    )\n\n    # Compute all vertices across all edges in the grid, even though we will\n    # throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices.\n    # These are all midpoints, and don't account for interpolation (which is\n    # done later based on the used edge midpoints).\n    edge_midpoints = torch.cat(\n        [\n            ((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3),\n            ((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3),\n            ((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3),\n        ],\n        dim=0,\n    )\n\n    # Create a flat array of [X, Y, Z] indices for each cube.\n    cube_indices = torch.zeros(\n        grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long\n    )\n    cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[\n        :, None, None\n    ]\n    cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[\n        :, None\n    ]\n    cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev)\n    flat_cube_indices = cube_indices.reshape(-1, 3)\n\n    # Create a flat array mapping each cube to 12 global edge indices.\n    edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size)\n\n    # Apply the LUT to figure out the triangles.\n    flat_bitmasks = bitmasks.reshape(\n        -1\n    ).long()  # must cast to long for indexing to believe this not a mask\n    local_tris = lut.cases[flat_bitmasks]\n    local_masks = lut.masks[flat_bitmasks]\n    # Compute the global edge indices for the triangles.\n    global_tris = torch.gather(\n        edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)\n    ).reshape(local_tris.shape)\n    # Select the used triangles for each cube.\n    selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)]\n\n    # Now we have a bunch of indices into the full list of possible vertices,\n    # but we want to reduce this list to only the used vertices.\n    used_vertex_indices = torch.unique(selected_tris.view(-1))\n    used_edge_midpoints = edge_midpoints[used_vertex_indices]\n    old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long)\n    old_index_to_new_index[used_vertex_indices] = torch.arange(\n        len(used_vertex_indices), device=dev, dtype=torch.long\n    )\n\n    # Rewrite the triangles to use the new indices\n    selected_tris = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(\n        selected_tris.shape\n    )\n\n    # Compute the actual interpolated coordinates corresponding to edge midpoints.\n    v1 = torch.floor(used_edge_midpoints).to(torch.long)\n    v2 = torch.ceil(used_edge_midpoints).to(torch.long)\n    s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]]\n    s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]]\n    p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point\n    p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point\n    # The signs of s1 and s2 should be different. We want to find\n    # t such that t*s2 + (1-t)*s1 = 0.\n    t = (s1 / (s1 - s2))[:, None]\n    verts = t * p2 + (1 - t) * p1\n\n    return TorchMesh(verts=verts, faces=selected_tris)\n\n\ndef _create_flat_edge_indices(\n    flat_cube_indices: torch.Tensor, grid_size: Tuple[int, int, int]\n) -> torch.Tensor:\n    num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2]\n    y_offset = num_xs\n    num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2]\n    z_offset = num_xs + num_ys\n    return torch.stack(\n        [\n            # Edges spanning x-axis.\n            flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]\n            + flat_cube_indices[:, 1] * grid_size[2]\n            + flat_cube_indices[:, 2],\n            flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]\n            + (flat_cube_indices[:, 1] + 1) * grid_size[2]\n            + flat_cube_indices[:, 2],\n            flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]\n            + flat_cube_indices[:, 1] * grid_size[2]\n            + flat_cube_indices[:, 2]\n            + 1,\n            flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]\n            + (flat_cube_indices[:, 1] + 1) * grid_size[2]\n            + flat_cube_indices[:, 2]\n            + 1,\n            # Edges spanning y-axis.\n            (\n                y_offset\n                + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]\n                + flat_cube_indices[:, 1] * grid_size[2]\n                + flat_cube_indices[:, 2]\n            ),\n            (\n                y_offset\n                + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]\n                + flat_cube_indices[:, 1] * grid_size[2]\n                + flat_cube_indices[:, 2]\n            ),\n            (\n                y_offset\n                + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]\n                + flat_cube_indices[:, 1] * grid_size[2]\n                + flat_cube_indices[:, 2]\n                + 1\n            ),\n            (\n                y_offset\n                + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]\n                + flat_cube_indices[:, 1] * grid_size[2]\n                + flat_cube_indices[:, 2]\n                + 1\n            ),\n            # Edges spanning z-axis.\n            (\n                z_offset\n                + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)\n                + flat_cube_indices[:, 1] * (grid_size[2] - 1)\n                + flat_cube_indices[:, 2]\n            ),\n            (\n                z_offset\n                + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)\n                + flat_cube_indices[:, 1] * (grid_size[2] - 1)\n                + flat_cube_indices[:, 2]\n            ),\n            (\n                z_offset\n                + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)\n                + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)\n                + flat_cube_indices[:, 2]\n            ),\n            (\n                z_offset\n                + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)\n                + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)\n                + flat_cube_indices[:, 2]\n            ),\n        ],\n        dim=-1,\n    )\n\n\n@dataclass\nclass McLookupTable:\n    # Coordinates in triangles are represented as edge indices from 0-12\n    # Here is an MC cell with both corner and edge indices marked.\n    #        6 + ---------- 3 ----------+ 7\n    #         /|                       /|\n    #        6 |                      7 |\n    #       /  |                     /  |\n    #    4 +--------- 2 ------------+ 5 |\n    #      |   10                   |   |\n    #      |   |                    |   11\n    #      |   |                    |   |\n    #      8   | 2                  9   | 3\n    #      |   +--------- 1 --------|---+\n    #      |  /                     |  /\n    #      | 4                      | 5\n    #      |/                       |/\n    #      +---------- 0 -----------+\n    #     0                           1\n    cases: torch.Tensor  # [256 x 5 x 3] long tensor\n    masks: torch.Tensor  # [256 x 5] bool tensor\n\n\n@lru_cache(maxsize=9)  # if there's more than 8 GPUs and a CPU, don't bother caching\ndef _lookup_table(device: torch.device) -> McLookupTable:\n    cases = torch.zeros(256, 5, 3, device=device, dtype=torch.long)\n    masks = torch.zeros(256, 5, device=device, dtype=torch.bool)\n\n    edge_to_index = {\n        (0, 1): 0,\n        (2, 3): 1,\n        (4, 5): 2,\n        (6, 7): 3,\n        (0, 2): 4,\n        (1, 3): 5,\n        (4, 6): 6,\n        (5, 7): 7,\n        (0, 4): 8,\n        (1, 5): 9,\n        (2, 6): 10,\n        (3, 7): 11,\n    }\n\n    for i, case in enumerate(MC_TABLE):\n        for j, tri in enumerate(case):\n            for k, (c1, c2) in enumerate(zip(tri[::2], tri[1::2])):\n                cases[i, j, k] = edge_to_index[(c1, c2) if c1 < c2 else (c2, c1)]\n            masks[i, j] = True\n    return McLookupTable(cases=cases, masks=masks)\n"
  },
  {
    "path": "shap_e/rendering/mesh.py",
    "content": "from dataclasses import dataclass, field\nfrom typing import BinaryIO, Dict, Optional, Union\n\nimport blobfile as bf\nimport numpy as np\n\nfrom .ply_util import write_ply\n\n\n@dataclass\nclass TriMesh:\n    \"\"\"\n    A 3D triangle mesh with optional data at the vertices and faces.\n    \"\"\"\n\n    # [N x 3] array of vertex coordinates.\n    verts: np.ndarray\n\n    # [M x 3] array of triangles, pointing to indices in verts.\n    faces: np.ndarray\n\n    # [P x 3] array of normal vectors per face.\n    normals: Optional[np.ndarray] = None\n\n    # Extra data per vertex and face.\n    vertex_channels: Optional[Dict[str, np.ndarray]] = field(default_factory=dict)\n    face_channels: Optional[Dict[str, np.ndarray]] = field(default_factory=dict)\n\n    @classmethod\n    def load(cls, f: Union[str, BinaryIO]) -> \"TriMesh\":\n        \"\"\"\n        Load the mesh from a .npz file.\n        \"\"\"\n        if isinstance(f, str):\n            with bf.BlobFile(f, \"rb\") as reader:\n                return cls.load(reader)\n        else:\n            obj = np.load(f)\n            keys = list(obj.keys())\n            verts = obj[\"verts\"]\n            faces = obj[\"faces\"]\n            normals = obj[\"normals\"] if \"normals\" in keys else None\n            vertex_channels = {}\n            face_channels = {}\n            for key in keys:\n                if key.startswith(\"v_\"):\n                    vertex_channels[key[2:]] = obj[key]\n                elif key.startswith(\"f_\"):\n                    face_channels[key[2:]] = obj[key]\n            return cls(\n                verts=verts,\n                faces=faces,\n                normals=normals,\n                vertex_channels=vertex_channels,\n                face_channels=face_channels,\n            )\n\n    def save(self, f: Union[str, BinaryIO]):\n        \"\"\"\n        Save the mesh to a .npz file.\n        \"\"\"\n        if isinstance(f, str):\n            with bf.BlobFile(f, \"wb\") as writer:\n                self.save(writer)\n        else:\n            obj_dict = dict(verts=self.verts, faces=self.faces)\n            if self.normals is not None:\n                obj_dict[\"normals\"] = self.normals\n            for k, v in self.vertex_channels.items():\n                obj_dict[f\"v_{k}\"] = v\n            for k, v in self.face_channels.items():\n                obj_dict[f\"f_{k}\"] = v\n            np.savez(f, **obj_dict)\n\n    def has_vertex_colors(self) -> bool:\n        return self.vertex_channels is not None and all(x in self.vertex_channels for x in \"RGB\")\n\n    def write_ply(self, raw_f: BinaryIO):\n        write_ply(\n            raw_f,\n            coords=self.verts,\n            rgb=(\n                np.stack([self.vertex_channels[x] for x in \"RGB\"], axis=1)\n                if self.has_vertex_colors()\n                else None\n            ),\n            faces=self.faces,\n        )\n\n    def write_obj(self, raw_f: BinaryIO):\n        if self.has_vertex_colors():\n            vertex_colors = np.stack([self.vertex_channels[x] for x in \"RGB\"], axis=1)\n            vertices = [\n                \"{} {} {} {} {} {}\".format(*coord, *color)\n                for coord, color in zip(self.verts.tolist(), vertex_colors.tolist())\n            ]\n        else:\n            vertices = [\"{} {} {}\".format(*coord) for coord in self.verts.tolist()]\n\n        faces = [\n            \"f {} {} {}\".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1))\n            for tri in self.faces.tolist()\n        ]\n\n        combined_data = [\"v \" + vertex for vertex in vertices] + faces\n\n        raw_f.writelines(\"\\n\".join(combined_data))\n"
  },
  {
    "path": "shap_e/rendering/ply_util.py",
    "content": "import struct\nfrom typing import BinaryIO, Optional\n\nimport numpy as np\n\nfrom shap_e.util.io import buffered_writer\n\n\ndef write_ply(\n    raw_f: BinaryIO,\n    coords: np.ndarray,\n    rgb: Optional[np.ndarray] = None,\n    faces: Optional[np.ndarray] = None,\n):\n    \"\"\"\n    Write a PLY file for a mesh or a point cloud.\n\n    :param coords: an [N x 3] array of floating point coordinates.\n    :param rgb: an [N x 3] array of vertex colors, in the range [0.0, 1.0].\n    :param faces: an [N x 3] array of triangles encoded as integer indices.\n    \"\"\"\n    with buffered_writer(raw_f) as f:\n        f.write(b\"ply\\n\")\n        f.write(b\"format binary_little_endian 1.0\\n\")\n        f.write(bytes(f\"element vertex {len(coords)}\\n\", \"ascii\"))\n        f.write(b\"property float x\\n\")\n        f.write(b\"property float y\\n\")\n        f.write(b\"property float z\\n\")\n        if rgb is not None:\n            f.write(b\"property uchar red\\n\")\n            f.write(b\"property uchar green\\n\")\n            f.write(b\"property uchar blue\\n\")\n        if faces is not None:\n            f.write(bytes(f\"element face {len(faces)}\\n\", \"ascii\"))\n            f.write(b\"property list uchar int vertex_index\\n\")\n        f.write(b\"end_header\\n\")\n\n        if rgb is not None:\n            rgb = (rgb * 255.499).round().astype(int)\n            vertices = [\n                (*coord, *rgb)\n                for coord, rgb in zip(\n                    coords.tolist(),\n                    rgb.tolist(),\n                )\n            ]\n            format = struct.Struct(\"<3f3B\")\n            for item in vertices:\n                f.write(format.pack(*item))\n        else:\n            format = struct.Struct(\"<3f\")\n            for vertex in coords.tolist():\n                f.write(format.pack(*vertex))\n\n        if faces is not None:\n            format = struct.Struct(\"<B3I\")\n            for tri in faces.tolist():\n                f.write(format.pack(len(tri), *tri))\n"
  },
  {
    "path": "shap_e/rendering/point_cloud.py",
    "content": "import random\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import BinaryIO, Dict, List, Optional, Union\n\nimport blobfile as bf\nimport numpy as np\n\nfrom shap_e.rendering.view_data import ViewData\n\nfrom .ply_util import write_ply\n\nCOLORS = frozenset([\"R\", \"G\", \"B\", \"A\"])\n\n\ndef preprocess(data, channel):\n    if channel in COLORS:\n        return np.round(data * 255.0)\n    return data\n\n\n@dataclass\nclass PointCloud:\n    \"\"\"\n    An array of points sampled on a surface. Each point may have zero or more\n    channel attributes.\n\n    :param coords: an [N x 3] array of point coordinates.\n    :param channels: a dict mapping names to [N] arrays of channel values.\n    \"\"\"\n\n    coords: np.ndarray\n    channels: Dict[str, np.ndarray]\n\n    @classmethod\n    def from_rgbd(cls, vd: ViewData, num_views: Optional[int] = None) -> \"PointCloud\":\n        \"\"\"\n        Construct a point cloud from the given view data.\n\n        The data must have a depth channel. All other channels will be stored\n        in the `channels` attribute of the result.\n\n        Pixels in the rendered views are not converted into points in the cloud\n        if they have infinite depth or less than 1.0 alpha.\n        \"\"\"\n        channel_names = vd.channel_names\n        if \"D\" not in channel_names:\n            raise ValueError(f\"view data must have depth channel\")\n        depth_index = channel_names.index(\"D\")\n\n        all_coords = []\n        all_channels = defaultdict(list)\n\n        if num_views is None:\n            num_views = vd.num_views\n        for i in range(num_views):\n            camera, channel_values = vd.load_view(i, channel_names)\n            flat_values = channel_values.reshape([-1, len(channel_names)])\n\n            # Create an array of integer (x, y) image coordinates for Camera methods.\n            image_coords = camera.image_coords()\n\n            # Select subset of pixels that have meaningful depth/color.\n            image_mask = np.isfinite(flat_values[:, depth_index])\n            if \"A\" in channel_names:\n                image_mask = image_mask & (flat_values[:, channel_names.index(\"A\")] >= 1 - 1e-5)\n            image_coords = image_coords[image_mask]\n            flat_values = flat_values[image_mask]\n\n            # Use the depth and camera information to compute the coordinates\n            # corresponding to every visible pixel.\n            camera_rays = camera.camera_rays(image_coords)\n            camera_origins = camera_rays[:, 0]\n            camera_directions = camera_rays[:, 1]\n            depth_dirs = camera.depth_directions(image_coords)\n            ray_scales = flat_values[:, depth_index] / np.sum(\n                camera_directions * depth_dirs, axis=-1\n            )\n            coords = camera_origins + camera_directions * ray_scales[:, None]\n\n            all_coords.append(coords)\n            for j, name in enumerate(channel_names):\n                if name != \"D\":\n                    all_channels[name].append(flat_values[:, j])\n\n        if len(all_coords) == 0:\n            return cls(coords=np.zeros([0, 3], dtype=np.float32), channels={})\n\n        return cls(\n            coords=np.concatenate(all_coords, axis=0),\n            channels={k: np.concatenate(v, axis=0) for k, v in all_channels.items()},\n        )\n\n    @classmethod\n    def load(cls, f: Union[str, BinaryIO]) -> \"PointCloud\":\n        \"\"\"\n        Load the point cloud from a .npz file.\n        \"\"\"\n        if isinstance(f, str):\n            with bf.BlobFile(f, \"rb\") as reader:\n                return cls.load(reader)\n        else:\n            obj = np.load(f)\n            keys = list(obj.keys())\n            return PointCloud(\n                coords=obj[\"coords\"],\n                channels={k: obj[k] for k in keys if k != \"coords\"},\n            )\n\n    def save(self, f: Union[str, BinaryIO]):\n        \"\"\"\n        Save the point cloud to a .npz file.\n        \"\"\"\n        if isinstance(f, str):\n            with bf.BlobFile(f, \"wb\") as writer:\n                self.save(writer)\n        else:\n            np.savez(f, coords=self.coords, **self.channels)\n\n    def write_ply(self, raw_f: BinaryIO):\n        write_ply(\n            raw_f,\n            coords=self.coords,\n            rgb=(\n                np.stack([self.channels[x] for x in \"RGB\"], axis=1)\n                if all(x in self.channels for x in \"RGB\")\n                else None\n            ),\n        )\n\n    def random_sample(self, num_points: int, **subsample_kwargs) -> \"PointCloud\":\n        \"\"\"\n        Sample a random subset of this PointCloud.\n\n        :param num_points: maximum number of points to sample.\n        :param subsample_kwargs: arguments to self.subsample().\n        :return: a reduced PointCloud, or self if num_points is not less than\n                 the current number of points.\n        \"\"\"\n        if len(self.coords) <= num_points:\n            return self\n        indices = np.random.choice(len(self.coords), size=(num_points,), replace=False)\n        return self.subsample(indices, **subsample_kwargs)\n\n    def farthest_point_sample(\n        self, num_points: int, init_idx: Optional[int] = None, **subsample_kwargs\n    ) -> \"PointCloud\":\n        \"\"\"\n        Sample a subset of the point cloud that is evenly distributed in space.\n\n        First, a random point is selected. Then each successive point is chosen\n        such that it is furthest from the currently selected points.\n\n        The time complexity of this operation is O(NM), where N is the original\n        number of points and M is the reduced number. Therefore, performance\n        can be improved by randomly subsampling points with random_sample()\n        before running farthest_point_sample().\n\n        :param num_points: maximum number of points to sample.\n        :param init_idx: if specified, the first point to sample.\n        :param subsample_kwargs: arguments to self.subsample().\n        :return: a reduced PointCloud, or self if num_points is not less than\n                 the current number of points.\n        \"\"\"\n        if len(self.coords) <= num_points:\n            return self\n        init_idx = random.randrange(len(self.coords)) if init_idx is None else init_idx\n        indices = np.zeros([num_points], dtype=np.int64)\n        indices[0] = init_idx\n        sq_norms = np.sum(self.coords**2, axis=-1)\n\n        def compute_dists(idx: int):\n            # Utilize equality: ||A-B||^2 = ||A||^2 + ||B||^2 - 2*(A @ B).\n            return sq_norms + sq_norms[idx] - 2 * (self.coords @ self.coords[idx])\n\n        cur_dists = compute_dists(init_idx)\n        for i in range(1, num_points):\n            idx = np.argmax(cur_dists)\n            indices[i] = idx\n\n            # Without this line, we may duplicate an index more than once if\n            # there are duplicate points, due to rounding errors.\n            cur_dists[idx] = -1\n\n            cur_dists = np.minimum(cur_dists, compute_dists(idx))\n\n        return self.subsample(indices, **subsample_kwargs)\n\n    def subsample(self, indices: np.ndarray, average_neighbors: bool = False) -> \"PointCloud\":\n        if not average_neighbors:\n            return PointCloud(\n                coords=self.coords[indices],\n                channels={k: v[indices] for k, v in self.channels.items()},\n            )\n\n        new_coords = self.coords[indices]\n        neighbor_indices = PointCloud(coords=new_coords, channels={}).nearest_points(self.coords)\n\n        # Make sure every point points to itself, which might not\n        # be the case if points are duplicated or there is rounding\n        # error.\n        neighbor_indices[indices] = np.arange(len(indices))\n\n        new_channels = {}\n        for k, v in self.channels.items():\n            v_sum = np.zeros_like(v[: len(indices)])\n            v_count = np.zeros_like(v[: len(indices)])\n            np.add.at(v_sum, neighbor_indices, v)\n            np.add.at(v_count, neighbor_indices, 1)\n            new_channels[k] = v_sum / v_count\n        return PointCloud(coords=new_coords, channels=new_channels)\n\n    def select_channels(self, channel_names: List[str]) -> np.ndarray:\n        data = np.stack([preprocess(self.channels[name], name) for name in channel_names], axis=-1)\n        return data\n\n    def nearest_points(self, points: np.ndarray, batch_size: int = 16384) -> np.ndarray:\n        \"\"\"\n        For each point in another set of points, compute the point in this\n        pointcloud which is closest.\n\n        :param points: an [N x 3] array of points.\n        :param batch_size: the number of neighbor distances to compute at once.\n                           Smaller values save memory, while larger values may\n                           make the computation faster.\n        :return: an [N] array of indices into self.coords.\n        \"\"\"\n        norms = np.sum(self.coords**2, axis=-1)\n        all_indices = []\n        for i in range(0, len(points), batch_size):\n            batch = points[i : i + batch_size]\n            dists = norms + np.sum(batch**2, axis=-1)[:, None] - 2 * (batch @ self.coords.T)\n            all_indices.append(np.argmin(dists, axis=-1))\n        return np.concatenate(all_indices, axis=0)\n\n    def combine(self, other: \"PointCloud\") -> \"PointCloud\":\n        assert self.channels.keys() == other.channels.keys()\n        return PointCloud(\n            coords=np.concatenate([self.coords, other.coords], axis=0),\n            channels={\n                k: np.concatenate([v, other.channels[k]], axis=0) for k, v in self.channels.items()\n            },\n        )\n"
  },
  {
    "path": "shap_e/rendering/pytorch3d_util.py",
    "content": "import copy\nimport inspect\nfrom typing import Any, Callable, List, Sequence, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom pytorch3d.renderer import (\n    BlendParams,\n    DirectionalLights,\n    FoVPerspectiveCameras,\n    MeshRasterizer,\n    MeshRenderer,\n    RasterizationSettings,\n    SoftPhongShader,\n    TexturesVertex,\n)\nfrom pytorch3d.renderer.utils import TensorProperties\nfrom pytorch3d.structures import Meshes\n\nfrom shap_e.models.nn.checkpoint import checkpoint\n\nfrom .blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR, UNIFORM_LIGHT_DIRECTION\nfrom .torch_mesh import TorchMesh\nfrom .view_data import ProjectiveCamera\n\n# Using a lower value like 1e-4 seems to result in weird issues\n# for our high-poly meshes.\nDEFAULT_RENDER_SIGMA = 1e-5\n\nDEFAULT_RENDER_GAMMA = 1e-4\n\n\ndef render_images(\n    image_size: int,\n    meshes: Meshes,\n    cameras: Any,\n    lights: Any,\n    sigma: float = DEFAULT_RENDER_SIGMA,\n    gamma: float = DEFAULT_RENDER_GAMMA,\n    max_faces_per_bin=100000,\n    faces_per_pixel=50,\n    bin_size=None,\n    use_checkpoint: bool = False,\n) -> torch.Tensor:\n    if use_checkpoint:\n        # Decompose all of our arguments into a bunch of tensor lists\n        # so that autograd can keep track of what the op depends on.\n        verts_list = meshes.verts_list()\n        faces_list = meshes.faces_list()\n        assert isinstance(meshes.textures, TexturesVertex)\n        assert isinstance(lights, BidirectionalLights)\n        textures = meshes.textures.verts_features_padded()\n        light_vecs, light_fn = _deconstruct_tensor_props(lights)\n        camera_vecs, camera_fn = _deconstruct_tensor_props(cameras)\n\n        def ckpt_fn(\n            *args: torch.Tensor,\n            num_verts=len(verts_list),\n            num_light_vecs=len(light_vecs),\n            num_camera_vecs=len(camera_vecs),\n            light_fn=light_fn,\n            camera_fn=camera_fn,\n            faces_list=faces_list\n        ):\n            args = list(args)\n            verts_list = args[:num_verts]\n            del args[:num_verts]\n            light_vecs = args[:num_light_vecs]\n            del args[:num_light_vecs]\n            camera_vecs = args[:num_camera_vecs]\n            del args[:num_camera_vecs]\n            textures = args.pop(0)\n\n            meshes = Meshes(verts=verts_list, faces=faces_list, textures=TexturesVertex(textures))\n            lights = light_fn(light_vecs)\n            cameras = camera_fn(camera_vecs)\n            return render_images(\n                image_size=image_size,\n                meshes=meshes,\n                cameras=cameras,\n                lights=lights,\n                sigma=sigma,\n                gamma=gamma,\n                max_faces_per_bin=max_faces_per_bin,\n                faces_per_pixel=faces_per_pixel,\n                bin_size=bin_size,\n                use_checkpoint=False,\n            )\n\n        result = checkpoint(ckpt_fn, (*verts_list, *light_vecs, *camera_vecs, textures), (), True)\n    else:\n        raster_settings_soft = RasterizationSettings(\n            image_size=image_size,\n            blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma,\n            faces_per_pixel=faces_per_pixel,\n            max_faces_per_bin=max_faces_per_bin,\n            bin_size=bin_size,\n            perspective_correct=False,\n        )\n        renderer = MeshRenderer(\n            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings_soft),\n            shader=SoftPhongShader(\n                device=meshes.device,\n                cameras=cameras,\n                lights=lights,\n                blend_params=BlendParams(sigma=sigma, gamma=gamma, background_color=(0, 0, 0)),\n            ),\n        )\n        result = renderer(meshes)\n\n    return result\n\n\ndef _deconstruct_tensor_props(\n    props: TensorProperties,\n) -> Tuple[List[torch.Tensor], Callable[[List[torch.Tensor]], TensorProperties]]:\n    vecs = []\n    names = []\n    other_props = {}\n    for k in dir(props):\n        if k.startswith(\"__\"):\n            continue\n        v = getattr(props, k)\n        if inspect.ismethod(v):\n            continue\n        if torch.is_tensor(v):\n            vecs.append(v)\n            names.append(k)\n        else:\n            other_props[k] = v\n\n    def recreate_fn(vecs_arg):\n        other = type(props)(device=props.device)\n        for k, v in other_props.items():\n            setattr(other, k, copy.deepcopy(v))\n        for name, vec in zip(names, vecs_arg):\n            setattr(other, name, vec)\n        return other\n\n    return vecs, recreate_fn\n\n\n\ndef convert_meshes(raw_meshes: Sequence[TorchMesh], default_brightness=0.8) -> Meshes:\n    meshes = Meshes(\n        verts=[mesh.verts for mesh in raw_meshes], faces=[mesh.faces for mesh in raw_meshes]\n    )\n    rgbs = []\n    for mesh in raw_meshes:\n        if mesh.vertex_channels and all(k in mesh.vertex_channels for k in \"RGB\"):\n            rgbs.append(torch.stack([mesh.vertex_channels[k] for k in \"RGB\"], axis=-1))\n        else:\n            rgbs.append(\n                torch.ones(\n                    len(mesh.verts) * default_brightness,\n                    3,\n                    device=mesh.verts.device,\n                    dtype=mesh.verts.dtype,\n                )\n            )\n    meshes.textures = TexturesVertex(verts_features=rgbs)\n    return meshes\n\n\ndef convert_cameras(\n    cameras: Sequence[ProjectiveCamera], device: torch.device\n) -> FoVPerspectiveCameras:\n    Rs = []\n    Ts = []\n    for camera in cameras:\n        assert (\n            camera.width == camera.height and camera.x_fov == camera.y_fov\n        ), \"viewports must be square\"\n        assert camera.x_fov == cameras[0].x_fov, \"all cameras must have same field-of-view\"\n        R = np.stack([-camera.x, -camera.y, camera.z], axis=0).T\n        T = -R.T @ camera.origin\n        Rs.append(R)\n        Ts.append(T)\n    return FoVPerspectiveCameras(\n        R=np.stack(Rs, axis=0),\n        T=np.stack(Ts, axis=0),\n        fov=cameras[0].x_fov,\n        degrees=False,\n        device=device,\n    )\n\n\ndef convert_cameras_torch(\n    origins: torch.Tensor, xs: torch.Tensor, ys: torch.Tensor, zs: torch.Tensor, fov: float\n) -> FoVPerspectiveCameras:\n    Rs = []\n    Ts = []\n    for origin, x, y, z in zip(origins, xs, ys, zs):\n        R = torch.stack([-x, -y, z], axis=0).T\n        T = -R.T @ origin\n        Rs.append(R)\n        Ts.append(T)\n    return FoVPerspectiveCameras(\n        R=torch.stack(Rs, dim=0),\n        T=torch.stack(Ts, dim=0),\n        fov=fov,\n        degrees=False,\n        device=origins.device,\n    )\n\n\ndef blender_uniform_lights(\n    batch_size: int,\n    device: torch.device,\n    ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR,\n    diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR,\n    specular_color: Union[float, Tuple[float]] = 0.0,\n) -> \"BidirectionalLights\":\n    \"\"\"\n    Create a light that attempts to match the light used by the Blender\n    renderer when run with `--light_mode basic`.\n    \"\"\"\n    if isinstance(ambient_color, float):\n        ambient_color = (ambient_color,) * 3\n    if isinstance(diffuse_color, float):\n        diffuse_color = (diffuse_color,) * 3\n    if isinstance(specular_color, float):\n        specular_color = (specular_color,) * 3\n    return BidirectionalLights(\n        ambient_color=(ambient_color,) * batch_size,\n        diffuse_color=(diffuse_color,) * batch_size,\n        specular_color=(specular_color,) * batch_size,\n        direction=(UNIFORM_LIGHT_DIRECTION,) * batch_size,\n        device=device,\n    )\n\n\nclass BidirectionalLights(DirectionalLights):\n    \"\"\"\n    Adapted from here, but effectively shines the light in both positive and negative directions:\n    https://github.com/facebookresearch/pytorch3d/blob/efea540bbcab56fccde6f4bc729d640a403dac56/pytorch3d/renderer/lighting.py#L159\n    \"\"\"\n\n    def diffuse(self, normals, points=None) -> torch.Tensor:\n        return torch.maximum(\n            super().diffuse(normals, points=points), super().diffuse(-normals, points=points)\n        )\n\n    def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:\n        return torch.maximum(\n            super().specular(normals, points, camera_position, shininess),\n            super().specular(-normals, points, camera_position, shininess),\n        )\n"
  },
  {
    "path": "shap_e/rendering/raycast/__init__.py",
    "content": ""
  },
  {
    "path": "shap_e/rendering/raycast/_utils.py",
    "content": "import torch\n\n\ndef normalize(v: torch.Tensor) -> torch.Tensor:\n    return v / torch.linalg.norm(v, dim=-1, keepdim=True)\n\n\ndef cross_product(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:\n    return torch.stack(\n        [\n            v1[..., 1] * v2[..., 2] - v2[..., 1] * v1[..., 2],\n            -(v1[..., 0] * v2[..., 2] - v2[..., 0] * v1[..., 2]),\n            v1[..., 0] * v2[..., 1] - v2[..., 0] * v1[..., 1],\n        ],\n        dim=-1,\n    )\n"
  },
  {
    "path": "shap_e/rendering/raycast/cast.py",
    "content": "from typing import Iterator, Optional, Tuple\n\nimport numpy as np\nimport torch\n\nfrom shap_e.rendering.view_data import ProjectiveCamera\n\nfrom ._utils import cross_product\nfrom .types import RayCollisions, Rays, TriMesh\n\n\ndef cast_camera(\n    camera: ProjectiveCamera,\n    mesh: TriMesh,\n    ray_batch_size: Optional[int] = None,\n    checkpoint: Optional[bool] = None,\n) -> Iterator[RayCollisions]:\n    pixel_indices = np.arange(camera.width * camera.height)\n    image_coords = np.stack([pixel_indices % camera.width, pixel_indices // camera.width], axis=1)\n    rays = camera.camera_rays(image_coords)\n    batch_size = ray_batch_size or len(rays)\n    checkpoint = checkpoint if checkpoint is not None else batch_size < len(rays)\n    for i in range(0, len(rays), batch_size):\n        sub_rays = rays[i : i + batch_size]\n        origins = torch.from_numpy(sub_rays[:, 0]).to(mesh.vertices)\n        directions = torch.from_numpy(sub_rays[:, 1]).to(mesh.vertices)\n        yield cast_rays(Rays(origins=origins, directions=directions), mesh, checkpoint=checkpoint)\n\n\ndef cast_rays(rays: Rays, mesh: TriMesh, checkpoint: bool = False) -> RayCollisions:\n    \"\"\"\n    Cast a batch of rays onto a mesh.\n    \"\"\"\n    if checkpoint:\n        collides, ray_dists, tri_indices, barycentric, normals = RayCollisionFunction.apply(\n            rays.origins, rays.directions, mesh.faces, mesh.vertices\n        )\n        return RayCollisions(\n            collides=collides,\n            ray_dists=ray_dists,\n            tri_indices=tri_indices,\n            barycentric=barycentric,\n            normals=normals,\n        )\n\n    # https://github.com/unixpickle/vae-textures/blob/2968549ddd4a3487f9437d4db00793324453cd59/vae_textures/render.py#L98\n    normals = mesh.normals()  # [N x 3]\n    directions = rays.directions  # [M x 3]\n    collides = (directions @ normals.T).abs() > 1e-8  # [N x M]\n\n    tris = mesh.vertices[mesh.faces]  # [N x 3 x 3]\n    v1 = tris[:, 1] - tris[:, 0]\n    v2 = tris[:, 2] - tris[:, 0]\n\n    cross1 = cross_product(directions[:, None], v2[None])  # [N x M x 3]\n    det = torch.sum(cross1 * v1[None], dim=-1)  # [N x M]\n    collides = torch.logical_and(collides, det.abs() > 1e-8)\n\n    invDet = 1 / det  # [N x M]\n    o = rays.origins[:, None] - tris[None, :, 0]  # [N x M x 3]\n    bary1 = invDet * torch.sum(o * cross1, dim=-1)  # [N x M]\n    collides = torch.logical_and(collides, torch.logical_and(bary1 >= 0, bary1 <= 1))\n\n    cross2 = cross_product(o, v1[None])  # [N x M x 3]\n    bary2 = invDet * torch.sum(directions[:, None] * cross2, dim=-1)  # [N x M]\n    collides = torch.logical_and(collides, torch.logical_and(bary2 >= 0, bary2 <= 1))\n\n    bary0 = 1 - (bary1 + bary2)\n\n    # Make sure this is in the positive part of the ray.\n    scale = invDet * torch.sum(v2 * cross2, dim=-1)\n    collides = torch.logical_and(collides, scale > 0)\n\n    # Select the nearest collision\n    ray_dists, tri_indices = torch.min(\n        torch.where(collides, scale, torch.tensor(torch.inf).to(scale)), dim=-1\n    )  # [N]\n    nearest_bary = torch.stack(\n        [\n            bary0[range(len(tri_indices)), tri_indices],\n            bary1[range(len(tri_indices)), tri_indices],\n            bary2[range(len(tri_indices)), tri_indices],\n        ],\n        dim=-1,\n    )\n\n    return RayCollisions(\n        collides=torch.any(collides, dim=-1),\n        ray_dists=ray_dists,\n        tri_indices=tri_indices,\n        barycentric=nearest_bary,\n        normals=normals[tri_indices],\n    )\n\n\nclass RayCollisionFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx, origins, directions, faces, vertices\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        ctx.save_for_backward(origins, directions, faces, vertices)\n        with torch.no_grad():\n            res = cast_rays(\n                Rays(origins=origins, directions=directions),\n                TriMesh(faces=faces, vertices=vertices),\n                checkpoint=False,\n            )\n        return (res.collides, res.ray_dists, res.tri_indices, res.barycentric, res.normals)\n\n    @staticmethod\n    def backward(\n        ctx, _collides_grad, ray_dists_grad, _tri_indices_grad, barycentric_grad, normals_grad\n    ):\n        origins, directions, faces, vertices = ctx.input_tensors\n\n        origins = origins.detach().requires_grad_(True)\n        directions = directions.detach().requires_grad_(True)\n        vertices = vertices.detach().requires_grad_(True)\n\n        with torch.enable_grad():\n            outputs = cast_rays(\n                Rays(origins=origins, directions=directions),\n                TriMesh(faces=faces, vertices=vertices),\n                checkpoint=False,\n            )\n\n        origins_grad, directions_grad, vertices_grad = torch.autograd.grad(\n            (outputs.ray_dists, outputs.barycentric, outputs.normals),\n            (origins, directions, vertices),\n            (ray_dists_grad, barycentric_grad, normals_grad),\n        )\n        return (origins_grad, directions_grad, None, vertices_grad)\n"
  },
  {
    "path": "shap_e/rendering/raycast/render.py",
    "content": "from typing import Optional, Sequence\n\nimport torch\n\nfrom shap_e.rendering.blender.constants import (\n    BASIC_AMBIENT_COLOR,\n    BASIC_DIFFUSE_COLOR,\n    UNIFORM_LIGHT_DIRECTION,\n)\nfrom shap_e.rendering.view_data import ProjectiveCamera\n\nfrom .cast import cast_camera\nfrom .types import RayCollisions, TriMesh\n\n\ndef render_diffuse_mesh(\n    camera: ProjectiveCamera,\n    mesh: TriMesh,\n    light_direction: Sequence[float] = tuple(UNIFORM_LIGHT_DIRECTION),\n    diffuse: float = BASIC_DIFFUSE_COLOR,\n    ambient: float = BASIC_AMBIENT_COLOR,\n    ray_batch_size: Optional[int] = None,\n    checkpoint: Optional[bool] = None,\n) -> torch.Tensor:\n    \"\"\"\n    Return an [H x W x 4] RGBA tensor of the rendered image.\n    The pixels are floating points, with alpha in the range [0, 1] and the\n    other colors matching the scale used by the mesh's vertex colors.\n    \"\"\"\n    light_direction = torch.tensor(\n        light_direction, device=mesh.vertices.device, dtype=mesh.vertices.dtype\n    )\n\n    all_collisions = RayCollisions.collect(\n        cast_camera(\n            camera=camera,\n            mesh=mesh,\n            ray_batch_size=ray_batch_size,\n            checkpoint=checkpoint,\n        )\n    )\n    num_rays = len(all_collisions.normals)\n    if mesh.vertex_colors is None:\n        vertex_colors = torch.tensor([[0.8, 0.8, 0.8]]).to(mesh.vertices).repeat(num_rays, 1)\n    else:\n        vertex_colors = mesh.vertex_colors\n\n    light_coeffs = ambient + (\n        diffuse * torch.sum(all_collisions.normals * light_direction, dim=-1).abs()\n    )\n    vertex_colors = mesh.vertex_colors[mesh.faces[all_collisions.tri_indices]]\n    bary_products = torch.sum(vertex_colors * all_collisions.barycentric[..., None], axis=-2)\n    out_colors = bary_products * light_coeffs[..., None]\n    res = torch.where(all_collisions.collides[:, None], out_colors, torch.zeros_like(out_colors))\n    return torch.cat([res, all_collisions.collides[:, None].float()], dim=-1).view(\n        camera.height, camera.width, 4\n    )\n"
  },
  {
    "path": "shap_e/rendering/raycast/types.py",
    "content": "from dataclasses import dataclass\nfrom typing import Iterable, Optional\n\nimport numpy as np\nimport torch\n\nimport shap_e.rendering.mesh\n\nfrom ._utils import cross_product, normalize\n\n\n@dataclass\nclass Rays:\n    \"\"\"\n    A ray in ray casting.\n    \"\"\"\n\n    origins: torch.Tensor  # [N x 3] float tensor\n    directions: torch.Tensor  # [N x 3] float tensor\n\n    def normalized_directions(self) -> torch.Tensor:\n        return normalize(self.directions)\n\n\n@dataclass\nclass RayCollisions:\n    \"\"\"\n    The result of casting N rays onto a mesh.\n    \"\"\"\n\n    collides: torch.Tensor  # [N] boolean tensor\n    ray_dists: torch.Tensor  # [N] float tensor\n    tri_indices: torch.Tensor  # [N] long tensor\n    barycentric: torch.Tensor  # [N x 3] float tensor\n    normals: torch.Tensor  # [N x 3] float tensor\n\n    @classmethod\n    def collect(cls, it: Iterable[\"RayCollisions\"]) -> \"RayCollisions\":\n        res = None\n        for x in it:\n            if res is None:\n                res = x\n            else:\n                res = cls(\n                    collides=torch.cat([res.collides, x.collides]),\n                    ray_dists=torch.cat([res.ray_dists, x.ray_dists]),\n                    tri_indices=torch.cat([res.tri_indices, x.tri_indices]),\n                    barycentric=torch.cat([res.barycentric, x.barycentric]),\n                    normals=torch.cat([res.normals, x.normals]),\n                )\n        if res is None:\n            raise ValueError(\"cannot collect an empty iterable of RayCollisions\")\n        return res\n\n\n@dataclass\nclass TriMesh:\n    faces: torch.Tensor  # [N x 3] long tensor\n    vertices: torch.Tensor  # [N x 3] float tensor\n\n    vertex_colors: Optional[torch.Tensor] = None\n\n    def normals(self) -> torch.Tensor:\n        \"\"\"\n        Returns an [N x 3] batch of normal vectors per triangle assuming the\n        right-hand rule.\n        \"\"\"\n        tris = self.vertices[self.faces]\n        v1 = tris[:, 1] - tris[:, 0]\n        v2 = tris[:, 2] - tris[:, 0]\n        return normalize(cross_product(v1, v2))\n\n    @classmethod\n    def from_numpy(cls, x: shap_e.rendering.mesh.TriMesh) -> \"TriMesh\":\n        vertex_colors = None\n        if all(ch in x.vertex_channels for ch in \"RGB\"):\n            vertex_colors = torch.from_numpy(\n                np.stack([x.vertex_channels[ch] for ch in \"RGB\"], axis=-1)\n            )\n        return cls(\n            faces=torch.from_numpy(x.faces),\n            vertices=torch.from_numpy(x.verts),\n            vertex_colors=vertex_colors,\n        )\n\n    def to(self, *args, **kwargs) -> \"TriMesh\":\n        return TriMesh(\n            faces=self.faces.to(*args, **kwargs),\n            vertices=self.vertices.to(*args, **kwargs),\n            vertex_colors=None\n            if self.vertex_colors is None\n            else self.vertex_colors.to(*args, **kwargs),\n        )\n"
  },
  {
    "path": "shap_e/rendering/torch_mesh.py",
    "content": "from dataclasses import dataclass, field\nfrom typing import Dict, Optional\n\nimport torch\n\nfrom .mesh import TriMesh\n\n\n@dataclass\nclass TorchMesh:\n    \"\"\"\n    A 3D triangle mesh with optional data at the vertices and faces.\n    \"\"\"\n\n    # [N x 3] array of vertex coordinates.\n    verts: torch.Tensor\n\n    # [M x 3] array of triangles, pointing to indices in verts.\n    faces: torch.Tensor\n\n    # Extra data per vertex and face.\n    vertex_channels: Optional[Dict[str, torch.Tensor]] = field(default_factory=dict)\n    face_channels: Optional[Dict[str, torch.Tensor]] = field(default_factory=dict)\n\n    def tri_mesh(self) -> TriMesh:\n        \"\"\"\n        Create a CPU version of the mesh.\n        \"\"\"\n        return TriMesh(\n            verts=self.verts.detach().cpu().numpy(),\n            faces=self.faces.cpu().numpy(),\n            vertex_channels=(\n                {k: v.detach().cpu().numpy() for k, v in self.vertex_channels.items()}\n                if self.vertex_channels is not None\n                else None\n            ),\n            face_channels=(\n                {k: v.detach().cpu().numpy() for k, v in self.face_channels.items()}\n                if self.face_channels is not None\n                else None\n            ),\n        )\n"
  },
  {
    "path": "shap_e/rendering/view_data.py",
    "content": "from abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Tuple\n\nimport numpy as np\n\n\n@dataclass\nclass Camera(ABC):\n    \"\"\"\n    An object describing how a camera corresponds to pixels in an image.\n    \"\"\"\n\n    @abstractmethod\n    def image_coords(self) -> np.ndarray:\n        \"\"\"\n        :return: ([self.height, self.width, 2]).reshape(self.height * self.width, 2) image coordinates\n        \"\"\"\n\n    @abstractmethod\n    def camera_rays(self, coords: np.ndarray) -> np.ndarray:\n        \"\"\"\n        For every (x, y) coordinate in a rendered image, compute the ray of the\n        corresponding pixel.\n\n        :param coords: an [N x 2] integer array of 2D image coordinates.\n        :return: an [N x 2 x 3] array of [2 x 3] (origin, direction) tuples.\n                 The direction should always be unit length.\n        \"\"\"\n\n    def depth_directions(self, coords: np.ndarray) -> np.ndarray:\n        \"\"\"\n        For every (x, y) coordinate in a rendered image, get the direction that\n        corresponds to \"depth\" in an RGBD rendering.\n\n        This may raise an exception if there is no \"D\" channel in the\n        corresponding ViewData.\n\n        :param coords: an [N x 2] integer array of 2D image coordinates.\n        :return: an [N x 3] array of normalized depth directions.\n        \"\"\"\n        _ = coords\n        raise NotImplementedError\n\n    @abstractmethod\n    def center_crop(self) -> \"Camera\":\n        \"\"\"\n        Creates a new camera with the same intrinsics and direction as this one,\n        but with a center crop to a square of the smaller dimension.\n        \"\"\"\n\n    @abstractmethod\n    def resize_image(self, width: int, height: int) -> \"Camera\":\n        \"\"\"\n        Creates a new camera with the same intrinsics and direction as this one,\n        but with resized image dimensions.\n        \"\"\"\n\n    @abstractmethod\n    def scale_scene(self, factor: float) -> \"Camera\":\n        \"\"\"\n        Creates a new camera with the same intrinsics and direction as this one,\n        but with the scene rescaled by the given factor.\n        \"\"\"\n\n\n@dataclass\nclass ProjectiveCamera(Camera):\n    \"\"\"\n    A Camera implementation for a standard pinhole camera.\n\n    The camera rays shoot away from the origin in the z direction, with the x\n    and y directions corresponding to the positive horizontal and vertical axes\n    in image space.\n    \"\"\"\n\n    origin: np.ndarray\n    x: np.ndarray\n    y: np.ndarray\n    z: np.ndarray\n    width: int\n    height: int\n    x_fov: float\n    y_fov: float\n\n    def image_coords(self) -> np.ndarray:\n        ind = np.arange(self.width * self.height)\n        coords = np.stack([ind % self.width, ind // self.width], axis=1).astype(np.float32)\n        return coords\n\n    def camera_rays(self, coords: np.ndarray) -> np.ndarray:\n        fracs = (coords / (np.array([self.width, self.height], dtype=np.float32) - 1)) * 2 - 1\n        fracs = fracs * np.tan(np.array([self.x_fov, self.y_fov]) / 2)\n        directions = self.z + self.x * fracs[:, :1] + self.y * fracs[:, 1:]\n        directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True)\n        return np.stack([np.broadcast_to(self.origin, directions.shape), directions], axis=1)\n\n    def depth_directions(self, coords: np.ndarray) -> np.ndarray:\n        return np.tile((self.z / np.linalg.norm(self.z))[None], [len(coords), 1])\n\n    def resize_image(self, width: int, height: int) -> \"ProjectiveCamera\":\n        \"\"\"\n        Creates a new camera for the resized view assuming the aspect ratio does not change.\n        \"\"\"\n        assert width * self.height == height * self.width, \"The aspect ratio should not change.\"\n        return ProjectiveCamera(\n            origin=self.origin,\n            x=self.x,\n            y=self.y,\n            z=self.z,\n            width=width,\n            height=height,\n            x_fov=self.x_fov,\n            y_fov=self.y_fov,\n        )\n\n    def center_crop(self) -> \"ProjectiveCamera\":\n        \"\"\"\n        Creates a new camera for the center-cropped view\n        \"\"\"\n        size = min(self.width, self.height)\n        fov = min(self.x_fov, self.y_fov)\n        return ProjectiveCamera(\n            origin=self.origin,\n            x=self.x,\n            y=self.y,\n            z=self.z,\n            width=size,\n            height=size,\n            x_fov=fov,\n            y_fov=fov,\n        )\n\n    def scale_scene(self, factor: float) -> \"ProjectiveCamera\":\n        \"\"\"\n        Creates a new camera with the same intrinsics and direction as this one,\n        but with the camera frame rescaled by the given factor.\n        \"\"\"\n        return ProjectiveCamera(\n            origin=self.origin * factor,\n            x=self.x,\n            y=self.y,\n            z=self.z,\n            width=self.width,\n            height=self.height,\n            x_fov=self.x_fov,\n            y_fov=self.y_fov,\n        )\n\n\nclass ViewData(ABC):\n    \"\"\"\n    A collection of rendered camera views of a scene or object.\n\n    This is a generalization of a NeRF dataset, since NeRF datasets only encode\n    RGB or RGBA data, whereas this dataset supports arbitrary channels.\n    \"\"\"\n\n    @property\n    @abstractmethod\n    def num_views(self) -> int:\n        \"\"\"\n        The number of rendered views.\n        \"\"\"\n\n    @property\n    @abstractmethod\n    def channel_names(self) -> List[str]:\n        \"\"\"\n        Get all of the supported channels available for the views.\n\n        This can be arbitrary, but there are some standard names:\n        \"R\", \"G\", \"B\", \"A\" (alpha), and \"D\" (depth).\n        \"\"\"\n\n    @abstractmethod\n    def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]:\n        \"\"\"\n        Load the given channels from the view at the given index.\n\n        :return: a tuple (camera_view, data), where data is a float array of\n                 shape [height x width x num_channels].\n        \"\"\"\n\n\nclass MemoryViewData(ViewData):\n    \"\"\"\n    A ViewData that is implemented in memory.\n    \"\"\"\n\n    def __init__(self, channels: Dict[str, np.ndarray], cameras: List[Camera]):\n        assert all(v.shape[0] == len(cameras) for v in channels.values())\n        self.channels = channels\n        self.cameras = cameras\n\n    @property\n    def num_views(self) -> int:\n        return len(self.cameras)\n\n    @property\n    def channel_names(self) -> List[str]:\n        return list(self.channels.keys())\n\n    def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]:\n        outputs = [self.channels[channel][index] for channel in channels]\n        return self.cameras[index], np.stack(outputs, axis=-1)\n"
  },
  {
    "path": "shap_e/util/__init__.py",
    "content": ""
  },
  {
    "path": "shap_e/util/collections.py",
    "content": "from collections import OrderedDict\nfrom typing import Any, Callable, Dict, List, Optional\nfrom typing import OrderedDict, Generic, TypeVar\n\nK = TypeVar('K')\nV = TypeVar('V')\n\nclass AttrDict(OrderedDict[K, V], Generic[K, V]):\n    \"\"\"\n    An attribute dictionary that automatically handles nested keys joined by \"/\".\n\n    Originally copied from: https://stackoverflow.com/questions/3031219/recursively-access-dict-via-attributes-as-well-as-index-access\n    \"\"\"\n\n    MARKER = object()\n\n    # pylint: disable=super-init-not-called\n    def __init__(self, *args, **kwargs):\n        if len(args) == 0:\n            for key, value in kwargs.items():\n                self.__setitem__(key, value)\n        else:\n            assert len(args) == 1\n            assert isinstance(args[0], (dict, AttrDict))\n            for key, value in args[0].items():\n                self.__setitem__(key, value)\n\n    def __contains__(self, key):\n        if \"/\" in key:\n            keys = key.split(\"/\")\n            key, next_key = keys[0], \"/\".join(keys[1:])\n            return key in self and next_key in self[key]\n        return super(AttrDict, self).__contains__(key)\n\n    def __setitem__(self, key, value):\n        if \"/\" in key:\n            keys = key.split(\"/\")\n            key, next_key = keys[0], \"/\".join(keys[1:])\n            if key not in self:\n                self[key] = AttrDict()\n            self[key].__setitem__(next_key, value)\n            return\n\n        if isinstance(value, dict) and not isinstance(value, AttrDict):\n            value = AttrDict(**value)\n        if isinstance(value, list):\n            value = [AttrDict(val) if isinstance(val, dict) else val for val in value]\n        super(AttrDict, self).__setitem__(key, value)\n\n    def __getitem__(self, key):\n        if \"/\" in key:\n            keys = key.split(\"/\")\n            key, next_key = keys[0], \"/\".join(keys[1:])\n            val = self[key]\n            if not isinstance(val, AttrDict):\n                raise ValueError\n            return val.__getitem__(next_key)\n\n        return self.get(key, None)\n\n    def all_keys(\n        self,\n        leaves_only: bool = False,\n        parent: Optional[str] = None,\n    ) -> List[str]:\n        keys = []\n        for key in self.keys():\n            cur = key if parent is None else f\"{parent}/{key}\"\n            if not leaves_only or not isinstance(self[key], dict):\n                keys.append(cur)\n            if isinstance(self[key], dict):\n                keys.extend(self[key].all_keys(leaves_only=leaves_only, parent=cur))\n        return keys\n\n    def dumpable(self, strip=True):\n        \"\"\"\n        Casts into OrderedDict and removes internal attributes\n        \"\"\"\n\n        def _dump(val):\n            if isinstance(val, AttrDict):\n                return val.dumpable()\n            elif isinstance(val, list):\n                return [_dump(v) for v in val]\n            return val\n\n        if strip:\n            return {k: _dump(v) for k, v in self.items() if not k.startswith(\"_\")}\n        return {k: _dump(v if not k.startswith(\"_\") else repr(v)) for k, v in self.items()}\n\n    def map(\n        self,\n        map_fn: Callable[[Any, Any], Any],\n        should_map: Optional[Callable[[Any, Any], bool]] = None,\n    ) -> \"AttrDict\":\n        \"\"\"\n        Creates a copy of self where some or all values are transformed by\n        map_fn.\n\n        :param should_map: If provided, only those values that evaluate to true\n            are converted; otherwise, all values are mapped.\n        \"\"\"\n\n        def _apply(key, val):\n            if isinstance(val, AttrDict):\n                return val.map(map_fn, should_map)\n            elif should_map is None or should_map(key, val):\n                return map_fn(key, val)\n            return val\n\n        return AttrDict({k: _apply(k, v) for k, v in self.items()})\n\n    def __eq__(self, other):\n        return self.keys() == other.keys() and all(self[k] == other[k] for k in self.keys())\n\n    def combine(\n        self,\n        other: Dict[str, Any],\n        combine_fn: Callable[[Optional[Any], Optional[Any]], Any],\n    ) -> \"AttrDict\":\n        \"\"\"\n        Some values may be missing, but the dictionary structures must be the\n        same.\n\n        :param combine_fn: a (possibly non-commutative) function to combine the\n            values\n        \"\"\"\n\n        def _apply(val, other_val):\n            if val is not None and isinstance(val, AttrDict):\n                assert isinstance(other_val, AttrDict)\n                return val.combine(other_val, combine_fn)\n            return combine_fn(val, other_val)\n\n        # TODO nit: this changes the ordering..\n        keys = self.keys() | other.keys()\n        return AttrDict({k: _apply(self[k], other[k]) for k in keys})\n\n    __setattr__, __getattr__ = __setitem__, __getitem__\n"
  },
  {
    "path": "shap_e/util/data_util.py",
    "content": "import tempfile\nfrom contextlib import contextmanager\nfrom typing import Iterator, Optional, Union\n\nimport blobfile as bf\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom shap_e.rendering.blender.render import render_mesh, render_model\nfrom shap_e.rendering.blender.view_data import BlenderViewData\nfrom shap_e.rendering.mesh import TriMesh\nfrom shap_e.rendering.point_cloud import PointCloud\nfrom shap_e.rendering.view_data import ViewData\nfrom shap_e.util.collections import AttrDict\nfrom shap_e.util.image_util import center_crop, get_alpha, remove_alpha, resize\n\n\ndef load_or_create_multimodal_batch(\n    device: torch.device,\n    *,\n    mesh_path: Optional[str] = None,\n    model_path: Optional[str] = None,\n    cache_dir: Optional[str] = None,\n    point_count: int = 2**14,\n    random_sample_count: int = 2**19,\n    pc_num_views: int = 40,\n    mv_light_mode: Optional[str] = None,\n    mv_num_views: int = 20,\n    mv_image_size: int = 512,\n    mv_alpha_removal: str = \"black\",\n    verbose: bool = False,\n) -> AttrDict:\n    if verbose:\n        print(\"creating point cloud...\")\n    pc = load_or_create_pc(\n        mesh_path=mesh_path,\n        model_path=model_path,\n        cache_dir=cache_dir,\n        random_sample_count=random_sample_count,\n        point_count=point_count,\n        num_views=pc_num_views,\n        verbose=verbose,\n    )\n    raw_pc = np.concatenate([pc.coords, pc.select_channels([\"R\", \"G\", \"B\"])], axis=-1)\n    encode_me = torch.from_numpy(raw_pc).float().to(device)\n    batch = AttrDict(points=encode_me.t()[None])\n    if mv_light_mode:\n        if verbose:\n            print(\"creating multiview...\")\n        with load_or_create_multiview(\n            mesh_path=mesh_path,\n            model_path=model_path,\n            cache_dir=cache_dir,\n            num_views=mv_num_views,\n            extract_material=False,\n            light_mode=mv_light_mode,\n            verbose=verbose,\n        ) as mv:\n            cameras, views, view_alphas, depths = [], [], [], []\n            for view_idx in range(mv.num_views):\n                camera, view = mv.load_view(\n                    view_idx,\n                    [\"R\", \"G\", \"B\", \"A\"] if \"A\" in mv.channel_names else [\"R\", \"G\", \"B\"],\n                )\n                depth = None\n                if \"D\" in mv.channel_names:\n                    _, depth = mv.load_view(view_idx, [\"D\"])\n                    depth = process_depth(depth, mv_image_size)\n                view, alpha = process_image(\n                    np.round(view * 255.0).astype(np.uint8), mv_alpha_removal, mv_image_size\n                )\n                camera = camera.center_crop().resize_image(mv_image_size, mv_image_size)\n                cameras.append(camera)\n                views.append(view)\n                view_alphas.append(alpha)\n                depths.append(depth)\n            batch.depths = [depths]\n            batch.views = [views]\n            batch.view_alphas = [view_alphas]\n            batch.cameras = [cameras]\n    return normalize_input_batch(batch, pc_scale=2.0, color_scale=1.0 / 255.0)\n\n\ndef load_or_create_pc(\n    *,\n    mesh_path: Optional[str],\n    model_path: Optional[str],\n    cache_dir: Optional[str],\n    random_sample_count: int,\n    point_count: int,\n    num_views: int,\n    verbose: bool = False,\n) -> PointCloud:\n\n    assert (model_path is not None) ^ (\n        mesh_path is not None\n    ), \"must specify exactly one of model_path or mesh_path\"\n    path = model_path if model_path is not None else mesh_path\n\n    if cache_dir is not None:\n        cache_path = bf.join(\n            cache_dir,\n            f\"pc_{bf.basename(path)}_mat_{num_views}_{random_sample_count}_{point_count}.npz\",\n        )\n        if bf.exists(cache_path):\n            return PointCloud.load(cache_path)\n    else:\n        cache_path = None\n\n    with load_or_create_multiview(\n        mesh_path=mesh_path,\n        model_path=model_path,\n        cache_dir=cache_dir,\n        num_views=num_views,\n        verbose=verbose,\n    ) as mv:\n        if verbose:\n            print(\"extracting point cloud from multiview...\")\n        pc = mv_to_pc(\n            multiview=mv, random_sample_count=random_sample_count, point_count=point_count\n        )\n        if cache_path is not None:\n            pc.save(cache_path)\n        return pc\n\n\n@contextmanager\ndef load_or_create_multiview(\n    *,\n    mesh_path: Optional[str],\n    model_path: Optional[str],\n    cache_dir: Optional[str],\n    num_views: int = 20,\n    extract_material: bool = True,\n    light_mode: Optional[str] = None,\n    verbose: bool = False,\n) -> Iterator[BlenderViewData]:\n\n    assert (model_path is not None) ^ (\n        mesh_path is not None\n    ), \"must specify exactly one of model_path or mesh_path\"\n    path = model_path if model_path is not None else mesh_path\n\n    if extract_material:\n        assert light_mode is None, \"light_mode is ignored when extract_material=True\"\n    else:\n        assert light_mode is not None, \"must specify light_mode when extract_material=False\"\n\n    if cache_dir is not None:\n        if extract_material:\n            cache_path = bf.join(cache_dir, f\"mv_{bf.basename(path)}_mat_{num_views}.zip\")\n        else:\n            cache_path = bf.join(cache_dir, f\"mv_{bf.basename(path)}_{light_mode}_{num_views}.zip\")\n        if bf.exists(cache_path):\n            with bf.BlobFile(cache_path, \"rb\") as f:\n                yield BlenderViewData(f)\n                return\n    else:\n        cache_path = None\n\n    common_kwargs = dict(\n        fast_mode=True,\n        extract_material=extract_material,\n        camera_pose=\"random\",\n        light_mode=light_mode or \"uniform\",\n        verbose=verbose,\n    )\n\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        tmp_path = bf.join(tmp_dir, \"out.zip\")\n        if mesh_path is not None:\n            mesh = TriMesh.load(mesh_path)\n            render_mesh(\n                mesh=mesh,\n                output_path=tmp_path,\n                num_images=num_views,\n                backend=\"BLENDER_EEVEE\",\n                **common_kwargs,\n            )\n        elif model_path is not None:\n            render_model(\n                model_path,\n                output_path=tmp_path,\n                num_images=num_views,\n                backend=\"BLENDER_EEVEE\",\n                **common_kwargs,\n            )\n        if cache_path is not None:\n            bf.copy(tmp_path, cache_path)\n        with bf.BlobFile(tmp_path, \"rb\") as f:\n            yield BlenderViewData(f)\n\n\ndef mv_to_pc(multiview: ViewData, random_sample_count: int, point_count: int) -> PointCloud:\n    pc = PointCloud.from_rgbd(multiview)\n\n    # Handle empty samples.\n    if len(pc.coords) == 0:\n        pc = PointCloud(\n            coords=np.zeros([1, 3]),\n            channels=dict(zip(\"RGB\", np.zeros([3, 1]))),\n        )\n    while len(pc.coords) < point_count:\n        pc = pc.combine(pc)\n        # Prevent duplicate points; some models may not like it.\n        pc.coords += np.random.normal(size=pc.coords.shape) * 1e-4\n\n    pc = pc.random_sample(random_sample_count)\n    pc = pc.farthest_point_sample(point_count, average_neighbors=True)\n\n    return pc\n\n\ndef normalize_input_batch(batch: AttrDict, *, pc_scale: float, color_scale: float) -> AttrDict:\n    res = batch.copy()\n    scale_vec = torch.tensor([*([pc_scale] * 3), *([color_scale] * 3)], device=batch.points.device)\n    res.points = res.points * scale_vec[:, None]\n\n    if \"cameras\" in res:\n        res.cameras = [[cam.scale_scene(pc_scale) for cam in cams] for cams in res.cameras]\n\n    if \"depths\" in res:\n        res.depths = [[depth * pc_scale for depth in depths] for depths in res.depths]\n\n    return res\n\n\ndef process_depth(depth_img: np.ndarray, image_size: int) -> np.ndarray:\n    depth_img = center_crop(depth_img)\n    depth_img = resize(depth_img, width=image_size, height=image_size)\n    return np.squeeze(depth_img)\n\n\ndef process_image(\n    img_or_img_arr: Union[Image.Image, np.ndarray], alpha_removal: str, image_size: int\n):\n    if isinstance(img_or_img_arr, np.ndarray):\n        img = Image.fromarray(img_or_img_arr)\n        img_arr = img_or_img_arr\n    else:\n        img = img_or_img_arr\n        img_arr = np.array(img)\n        if len(img_arr.shape) == 2:\n            # Grayscale\n            rgb = Image.new(\"RGB\", img.size)\n            rgb.paste(img)\n            img = rgb\n            img_arr = np.array(img)\n\n    img = center_crop(img)\n    alpha = get_alpha(img)\n    img = remove_alpha(img, mode=alpha_removal)\n    alpha = alpha.resize((image_size,) * 2, resample=Image.BILINEAR)\n    img = img.resize((image_size,) * 2, resample=Image.BILINEAR)\n    return img, alpha\n"
  },
  {
    "path": "shap_e/util/image_util.py",
    "content": "import random\nfrom typing import Any, List, Optional, Union\n\nimport blobfile as bf\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom PIL import Image\n\n\ndef center_crop(\n    img: Union[Image.Image, torch.Tensor, np.ndarray]\n) -> Union[Image.Image, torch.Tensor, np.ndarray]:\n    \"\"\"\n    Center crops an image.\n    \"\"\"\n    if isinstance(img, (np.ndarray, torch.Tensor)):\n        height, width = img.shape[:2]\n    else:\n        width, height = img.size\n    size = min(width, height)\n    left, top = (width - size) // 2, (height - size) // 2\n    right, bottom = left + size, top + size\n    if isinstance(img, (np.ndarray, torch.Tensor)):\n        img = img[top:bottom, left:right]\n    else:\n        img = img.crop((left, top, right, bottom))\n    return img\n\n\ndef resize(\n    img: Union[Image.Image, torch.Tensor, np.ndarray],\n    *,\n    height: int,\n    width: int,\n    min_value: Optional[Any] = None,\n    max_value: Optional[Any] = None,\n) -> Union[Image.Image, torch.Tensor, np.ndarray]:\n    \"\"\"\n    :param: img: image in HWC order\n    :return: currently written for downsampling\n    \"\"\"\n\n    orig, cls = img, type(img)\n    if isinstance(img, Image.Image):\n        img = np.array(img)\n    dtype = img.dtype\n    if isinstance(img, np.ndarray):\n        img = torch.from_numpy(img)\n    ndim = img.ndim\n    if img.ndim == 2:\n        img = img.unsqueeze(-1)\n\n    if min_value is None and max_value is None:\n        # .clamp throws an error when both are None\n        min_value = -np.inf\n\n    img = img.permute(2, 0, 1)\n    size = (height, width)\n    img = (\n        F.interpolate(img[None].float(), size=size, mode=\"area\")[0]\n        .clamp(min_value, max_value)\n        .to(img.dtype)\n        .permute(1, 2, 0)\n    )\n\n    if ndim < img.ndim:\n        img = img.squeeze(-1)\n    if not isinstance(orig, torch.Tensor):\n        img = img.numpy()\n    img = img.astype(dtype)\n    if isinstance(orig, Image.Image):\n        img = Image.fromarray(img)\n\n    return img\n\n\ndef get_alpha(img: Image.Image) -> Image.Image:\n    \"\"\"\n    :return: the alpha channel separated out as a grayscale image\n    \"\"\"\n    img_arr = np.asarray(img)\n    if img_arr.shape[2] == 4:\n        alpha = img_arr[:, :, 3]\n    else:\n        alpha = np.full(img_arr.shape[:2], 255, dtype=np.uint8)\n    alpha = Image.fromarray(alpha)\n    return alpha\n\n\ndef remove_alpha(img: Image.Image, mode: str = \"random\") -> Image.Image:\n    \"\"\"\n    No op if the image doesn't have an alpha channel.\n\n    :param: mode: Defaults to \"random\" but has an option to use a \"black\" or\n        \"white\" background\n\n    :return: image with alpha removed\n    \"\"\"\n    img_arr = np.asarray(img)\n    if img_arr.shape[2] == 4:\n        # Add bg to get rid of alpha channel\n        if mode == \"random\":\n            height, width = img_arr.shape[:2]\n            bg = Image.fromarray(\n                random.choice([_black_bg, _gray_bg, _checker_bg, _noise_bg])(height, width)\n            )\n            bg.paste(img, mask=img)\n            img = bg\n        elif mode == \"black\" or mode == \"white\":\n            img_arr = img_arr.astype(float)\n            rgb, alpha = img_arr[:, :, :3], img_arr[:, :, -1:] / 255\n            background = np.zeros((1, 1, 3)) if mode == \"black\" else np.full((1, 1, 3), 255)\n            rgb = rgb * alpha + background * (1 - alpha)\n            img = Image.fromarray(np.round(rgb).astype(np.uint8))\n    return img\n\n\ndef _black_bg(h: int, w: int) -> np.ndarray:\n    return np.zeros([h, w, 3], dtype=np.uint8)\n\n\ndef _gray_bg(h: int, w: int) -> np.ndarray:\n    return (np.zeros([h, w, 3]) + np.random.randint(low=0, high=256)).astype(np.uint8)\n\n\ndef _checker_bg(h: int, w: int) -> np.ndarray:\n    checker_size = np.ceil(np.exp(np.random.uniform() * np.log(min(h, w))))\n    c1 = np.random.randint(low=0, high=256)\n    c2 = np.random.randint(low=0, high=256)\n\n    xs = np.arange(w)[None, :, None] + np.random.randint(low=0, high=checker_size + 1)\n    ys = np.arange(h)[:, None, None] + np.random.randint(low=0, high=checker_size + 1)\n\n    fields = np.logical_xor((xs // checker_size) % 2 == 0, (ys // checker_size) % 2 == 0)\n    return np.where(fields, np.array([c1] * 3), np.array([c2] * 3)).astype(np.uint8)\n\n\ndef _noise_bg(h: int, w: int) -> np.ndarray:\n    return np.random.randint(low=0, high=256, size=[h, w, 3]).astype(np.uint8)\n\n\ndef load_image(image_path: str) -> Image.Image:\n    with bf.BlobFile(image_path, \"rb\") as thefile:\n        img = Image.open(thefile)\n        img.load()\n    return img\n\n\ndef make_tile(images: List[Union[np.ndarray, Image.Image]], columns=8) -> Image.Image:\n    \"\"\"\n    to test, run\n        >>> display(make_tile([(np.zeros((128, 128, 3)) + c).astype(np.uint8) for c in np.linspace(0, 255, 15)]))\n    \"\"\"\n    images = list(map(np.array, images))\n    size = images[0].shape[0]\n    n = round_up(len(images), columns)\n    n_blanks = n - len(images)\n    images.extend([np.zeros((size, size, 3), dtype=np.uint8)] * n_blanks)\n    images = (\n        np.array(images)\n        .reshape(n // columns, columns, size, size, 3)\n        .transpose([0, 2, 1, 3, 4])\n        .reshape(n // columns * size, columns * size, 3)\n    )\n    return Image.fromarray(images)\n\n\ndef round_up(n: int, b: int) -> int:\n    return (n + b - 1) // b * b\n"
  },
  {
    "path": "shap_e/util/io.py",
    "content": "import io\nfrom contextlib import contextmanager\nfrom typing import Any, BinaryIO, Iterator, Union\n\nimport blobfile as bf\nimport yaml\n\nfrom shap_e.util.collections import AttrDict\n\n\ndef read_config(path_or_file: Union[str, io.IOBase]) -> Any:\n    if isinstance(path_or_file, io.IOBase):\n        obj = yaml.load(path_or_file, Loader=yaml.SafeLoader)\n    else:\n        with bf.BlobFile(path_or_file, \"rb\") as f:\n            try:\n                obj = yaml.load(f, Loader=yaml.SafeLoader)\n            except Exception as exc:\n                with bf.BlobFile(path_or_file, \"rb\") as f:\n                    print(f.read())\n                raise exc\n    if isinstance(obj, dict):\n        return AttrDict(obj)\n    return obj\n\n\n@contextmanager\ndef buffered_writer(raw_f: BinaryIO) -> Iterator[io.BufferedIOBase]:\n    if isinstance(raw_f, io.BufferedIOBase):\n        yield raw_f\n    else:\n        f = io.BufferedWriter(raw_f)\n        yield f\n        f.flush()\n"
  },
  {
    "path": "shap_e/util/notebooks.py",
    "content": "import base64\nimport io\nfrom typing import Union\n\nimport ipywidgets as widgets\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera\nfrom shap_e.models.transmitter.base import Transmitter, VectorDecoder\nfrom shap_e.rendering.torch_mesh import TorchMesh\nfrom shap_e.util.collections import AttrDict\n\n\ndef create_pan_cameras(size: int, device: torch.device) -> DifferentiableCameraBatch:\n    origins = []\n    xs = []\n    ys = []\n    zs = []\n    for theta in np.linspace(0, 2 * np.pi, num=20):\n        z = np.array([np.sin(theta), np.cos(theta), -0.5])\n        z /= np.sqrt(np.sum(z**2))\n        origin = -z * 4\n        x = np.array([np.cos(theta), -np.sin(theta), 0.0])\n        y = np.cross(z, x)\n        origins.append(origin)\n        xs.append(x)\n        ys.append(y)\n        zs.append(z)\n    return DifferentiableCameraBatch(\n        shape=(1, len(xs)),\n        flat_camera=DifferentiableProjectiveCamera(\n            origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device),\n            x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device),\n            y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device),\n            z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device),\n            width=size,\n            height=size,\n            x_fov=0.7,\n            y_fov=0.7,\n        ),\n    )\n\n\n@torch.no_grad()\ndef decode_latent_images(\n    xm: Union[Transmitter, VectorDecoder],\n    latent: torch.Tensor,\n    cameras: DifferentiableCameraBatch,\n    rendering_mode: str = \"stf\",\n):\n    decoded = xm.renderer.render_views(\n        AttrDict(cameras=cameras),\n        params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(\n            latent[None]\n        ),\n        options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False),\n    )\n    arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy()\n    return [Image.fromarray(x) for x in arr]\n\n\n@torch.no_grad()\ndef decode_latent_mesh(\n    xm: Union[Transmitter, VectorDecoder],\n    latent: torch.Tensor,\n) -> TorchMesh:\n    decoded = xm.renderer.render_views(\n        AttrDict(cameras=create_pan_cameras(2, latent.device)),  # lowest resolution possible\n        params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(\n            latent[None]\n        ),\n        options=AttrDict(rendering_mode=\"stf\", render_with_direction=False),\n    )\n    return decoded.raw_meshes[0]\n\n\ndef gif_widget(images):\n    writer = io.BytesIO()\n    images[0].save(\n        writer, format=\"GIF\", save_all=True, append_images=images[1:], duration=100, loop=0\n    )\n    writer.seek(0)\n    data = base64.b64encode(writer.read()).decode(\"ascii\")\n    return widgets.HTML(f'<img src=\"data:image/gif;base64,{data}\" />')\n"
  }
]