UltraShape 1.0: High-Fidelity 3D Shape Generation via Scalable Geometric Refinement
## 📖 Abstract
In this report, we introduce **UltraShape 1.0**, a scalable 3D diffusion framework for high-fidelity 3D geometry generation. The proposed approach adopts a **two-stage generation pipeline**: a coarse global structure is first synthesized and then refined to produce detailed, high-quality geometry.
To support reliable 3D generation, we develop a comprehensive data processing pipeline that includes a novel **watertight processing method** and **high-quality data filtering**. This pipeline improves the geometric quality of publicly available 3D datasets by removing low-quality samples, filling holes, and thickening thin structures, while preserving fine-grained geometric details.
To enable fine-grained geometry refinement, we decouple spatial localization from geometric detail synthesis in the diffusion process. We achieve this by performing **voxel-based refinement** at fixed spatial locations, where voxel queries derived from coarse geometry provide explicit positional anchors encoded via **RoPE**, allowing the diffusion model to focus on synthesizing local geometric details within a reduced, structured solution space.
Extensive evaluations demonstrate that UltraShape 1.0 performs competitively with existing open-source methods in both data processing quality and geometry generation.
## 🔥 News
* **[2025-12-25]** 📄 We released the technical report of **UltraShape 1.0** on arXiv.
* **[2025-12-26]** 🚀 We released the inference code and pre-trained models.
* **[2025-12-31]** 🚀 We released the training code.
## 🗓️ To-Do List
- [x] Release inference code
- [x] Release pre-trained weights (Hugging Face)
- [x] Release training code
- [ ] Release data processing scripts
## 🛠️ Installation & Usage
### 1. Environment Setup
```bash
git clone https://github.com/PKU-YuanGroup/UltraShape-1.0.git
cd UltraShape-1.0
# 1. Create and activate the environment
conda create -n ultrashape python=3.10
conda activate ultrashape
# 2. Install PyTorch (CUDA 12.1 recommended)
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
# 3. Install dependencies
pip install -r requirements.txt
# 4. Install cubvh (Required for MC acceleration)
pip install git+https://github.com/ashawkey/cubvh --no-build-isolation
# For Training & Sampling (Optional)
pip install --no-build-isolation "git+https://github.com/facebookresearch/pytorch3d.git@stable"
pip install https://data.pyg.org/whl/torch-2.5.0%2Bcu121/torch_cluster-1.6.3%2Bpt25cu121-cp310-cp310-linux_x86_64.whl
```
⬇️ Model Weights
Please download the pre-trained weights from Hugging Face [ [infinith/UltraShape](https://huggingface.co/infinith/UltraShape/tree/main) ] and place them in your checkpoint directory (e.g., ./checkpoints/).
### 2. Generate Coarse Mesh
First, use Hunyuan3D-2.1 to generate a coarse mesh from your input image.
Repository: [Tencent-Hunyuan/Hunyuan3D-2.1](https://github.com/Tencent-Hunyuan/Hunyuan3D-2.1)
Follow the instructions in the Hunyuan3D-2.1 repository to obtain the initial mesh file (e.g., .glb or .obj).
### 3. Generate Refined Mesh
Once you have the coarse mesh, use the provided script to run the refinement stage.
Run the inference script:
```bash
sh scripts/run.sh
```
**image**: Path to the reference image.
**mesh**: Path to the coarse mesh.
**output_dir**: Directory to save the refined result.
**ckpt**: Path to the downloaded UltraShape checkpoint.
**step**: the number of DiT inference sampling steps. The default is 50, and it can be reduced to 12 to speed up generation.
*Alternatively, you can run the gradio app for interactive inference:*
```bash
python scripts/gradio_app.py --ckpt
```
#### Low VRAM
1. Use a low value for num_latents (Try 8192)
2. Use a low chunk_size (Try 2048)
3. Try the --low_vram arg in gradio_app.py and infer_dit_refine.py
### 4. Data Preparation & Training
First, prepare the data, including watertight meshes and rendered images.
Then, run the sampling script as follows:
```
python scripts/sampling.py \
--mesh_json data/mesh_paths.json \
--output_dir data/sample
```
Here, mesh_json is a list containing the file paths of the watertight meshes.
The multi-node training script is:
```
sh train.sh [node_idx]
```
**training_data_list**: the folder containing train.json and val.json, which store the ID lists for datasets.
**sample_pcd_dir**: the directory containing the sampled .npz files.
**image_data_json**: the file paths of the rendered images.
You can switch between VAE and DiT training in train.sh, and specify the output directory and configuration file there as well.
## 🔗 BibTeX
If you found this repository helpful, please cite our reports:
```bibtex
@article{jia2025ultrashape,
title={UltraShape 1.0: High-Fidelity 3D Shape Generation via Scalable Geometric Refinement},
author={Jia, Tanghui and Yan, Dongyu and Hao, Dehao and Li, Yang and Zhang, Kaiyi and He, Xianyi and Li, Lanjiong and Chen, Jinnan and Jiang, Lutao and Yin, Qishen and Quan, Long and Chen, Ying-Cong and Yuan, Li},
journal={arxiv preprint arXiv:2512.21185},
year={2025}
}
```
## Acknowledgements
Our code is built upon the excellent work of [Hunyuan3D-2.1](https://github.com/Tencent-Hunyuan/Hunyuan3D-2.1). The core idea of our method is greatly inspired by [LATTICE](https://arxiv.org/abs/2512.03052). We deeply appreciate the contributions of these works to the 3D generation community. Please also consider citing **Hunyuan3D 2.1** and **LATTICE**:
- **[Hunyuan3D-2.1](https://github.com/Tencent-Hunyuan/Hunyuan3D-2.1)**
- **[Lattice3D](https://lattice3d.github.io/)**
================================================
FILE: configs/infer_dit_refine.yaml
================================================
model:
target: ultrashape.pipelines.UltraShapePipeline
params:
# 1. VAE Config
vae_config:
target: ultrashape.models.autoencoders.ShapeVAE
params:
num_latents: &token_num 32768 # infer token_num
embed_dim: 64
num_freqs: 8
include_pi: false
heads: 16
width: 1024
point_feats: 4
num_encoder_layers: 8
num_decoder_layers: 16
pc_size: 409600 # num_s (204800) + num_u (204800)
pc_sharpedge_size: 0
downsample_ratio: 20
qkv_bias: false
qk_norm: true
scale_factor: 1.0039506158752403
geo_decoder_mlp_expand_ratio: 4
# geo_decoder_downsample_ratio: 1
geo_decoder_ln_post: true
enable_flashvdm: true
jitter_query: false
voxel_query: true
voxel_query_res: &voxel_query_res 128
# 2. DiT Denoiser Config
dit_cfg:
target: ultrashape.models.denoisers.dit_mask.RefineDiT
params:
input_size: *token_num
in_channels: 64
hidden_size: 2048
context_dim: 1024
depth: 21
num_heads: 16
qk_norm: true
text_len: 1370
qk_norm_type: 'rms'
qkv_bias: false
num_moe_layers: 6
num_experts: 8
moe_top_k: 2
voxel_query_res: *voxel_query_res
# 3. Image Encoder Config
conditioner_config:
target: ultrashape.models.conditioner_mask.SingleImageEncoder
params:
drop_ratio: 0.0
main_image_encoder:
type: DinoImageEncoder
kwargs:
version: 'facebook/dinov2-large'
image_size: 1022
use_cls_token: true
# 4. Scheduler Config
scheduler_cfg:
target: ultrashape.schedulers.FlowMatchEulerDiscreteScheduler
params:
num_train_timesteps: 1000
# 5. Image Processor
image_processor_cfg:
target: ultrashape.preprocessors.ImageProcessorV2
params:
size: 1024
================================================
FILE: configs/train_dit_refine.yaml
================================================
name: "UltraShape Refine DiT"
training:
# ckpt_path:
steps: 10_0000_0000
use_amp: true
amp_type: "bf16"
base_lr: 1e-5
gradient_clip_val: 1.0
gradient_clip_algorithm: "norm"
every_n_train_steps: 2500
val_check_interval: 1000
limit_val_batches: 16
accumulate_grad_batches: 4
dataset:
target: ultrashape.data.objaverse_dit.ObjaverseDataModule
params:
batch_size: 1
num_workers: 4
val_num_workers: 4
# data
training_data_list: data/data_list
sample_pcd_dir: data/sample
image_data_json: data/render.json
# image
image_size: &image_size 1022 # 518
mean: &mean [0.5, 0.5, 0.5]
std: &std [0.5, 0.5, 0.5]
padding: true
# input_pcd
pc_size: &pc_size 163840
pc_sharpedge_size: &pc_sharpedge_size 0
sharpedge_label: &sharpedge_label true
return_normal: true
model:
target: ultrashape.models.diffusion.flow_matching_dit_trainer.Diffuser
params:
ckpt_path: ckpt/dit_step=XXX.ckpt
scale_by_std: false
z_scale_factor: &z_scale_factor 1.0039506158752403
torch_compile: false
vae_config:
target: ultrashape.models.autoencoders.ShapeVAE
from_pretrained: ckpt/vae_step=XXX.ckpt
params:
num_latents: &num_latents 8192 # 4096
embed_dim: 64
num_freqs: 8
include_pi: false
heads: 16
width: 1024
point_feats: 4
num_encoder_layers: 8
num_decoder_layers: 16
pc_size: *pc_size
pc_sharpedge_size: *pc_sharpedge_size
downsample_ratio: 20
qkv_bias: false
qk_norm: true
scale_factor: *z_scale_factor
geo_decoder_mlp_expand_ratio: 4
geo_decoder_downsample_ratio: 1
geo_decoder_ln_post: true
enable_flashvdm: true
jitter_query: false
voxel_query: true
voxel_query_res: 128
cond_config:
target: ultrashape.models.conditioner_mask.SingleImageEncoder
params:
drop_ratio: 0.1
# disable_drop: false
main_image_encoder:
type: DinoImageEncoder
kwargs:
version: 'facebook/dinov2-large'
image_size: *image_size
use_cls_token: true
dit_cfg:
target: ultrashape.models.denoisers.dit_mask.RefineDiT
params:
input_size: *num_latents
in_channels: 64
hidden_size: 2048
context_dim: 1024
depth: 21
num_heads: 16
qk_norm: true
text_len: 5330 # 1370
qk_norm_type: 'rms'
qkv_bias: false
num_moe_layers: 6
num_experts: 8
moe_top_k: 2
scheduler_cfg:
transport:
target: ultrashape.models.diffusion.transport.create_transport
params:
path_type: Linear
prediction: velocity
sampler:
target: ultrashape.models.diffusion.transport.Sampler
params: {}
ode_params:
sampling_method: euler
num_steps: &num_steps 50
optimizer_cfg:
optimizer:
target: torch.optim.AdamW
params:
betas: [0.9, 0.99]
eps: 1.e-6
weight_decay: 1.e-2
scheduler:
target: ultrashape.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
params:
warm_up_steps: 500 # 5000
f_start: 1.e-6
f_min: 1.e-3
f_max: 1.0
pipeline_cfg:
target: ultrashape.pipelines.UltraShapePipeline
image_processor_cfg:
target: ultrashape.preprocessors.ImageProcessorV2
params: {}
================================================
FILE: configs/train_vae_refine.yaml
================================================
name: "UltraShape Refine VAE"
training:
# ckpt_path:
steps: 10_0000_0000
use_amp: true
amp_type: "bf16"
base_lr: 1e-5
gradient_clip_val: 1.0
gradient_clip_algorithm: "norm"
every_n_train_steps: 2500
val_check_interval: 1000
limit_val_batches: 16
dataset:
target: ultrashape.data.objaverse_vae.ObjaverseDataModule
params:
batch_size: 4
num_workers: 4
val_num_workers: 4
# data
training_data_list: data/data_list
sample_pcd_dir: data/sample
# input_pcd
pc_size: &pc_size 163840
pc_sharpedge_size: &pc_sharpedge_size 0
sharpedge_label: &sharpedge_label true
return_normal: true
# sup_pcd
sup_near_uni_size: 100000
sup_near_sharp_size: 100000
sup_space_size: 100000
tsdf_threshold: 0.01
model:
target: ultrashape.models.autoencoders.VAETrainer
params:
ckpt_path: ckpt/vae_step=15000.ckpt
torch_compile: false
save_dir: outputs/vae_recon
mc_res: 512
vae_config:
target: ultrashape.models.autoencoders.ShapeVAE
params:
num_latents: &num_latents 8192 # 4096
embed_dim: 64
num_freqs: 8
include_pi: false
heads: 16
width: 1024
point_feats: 4
num_encoder_layers: 8
num_decoder_layers: 16
pc_size: *pc_size
pc_sharpedge_size: *pc_sharpedge_size
downsample_ratio: 20
qkv_bias: false
qk_norm: true
geo_decoder_mlp_expand_ratio: 4
geo_decoder_downsample_ratio: 1
geo_decoder_ln_post: true
enable_flashvdm: true
jitter_query: true
optimizer_cfg:
optimizer:
target: torch.optim.AdamW
params:
betas: [0.9, 0.99]
eps: 1.e-6
weight_decay: 1.e-2
scheduler:
target: ultrashape.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
params:
warm_up_steps: 500 # 5000
f_start: 1.e-6
f_min: 1.e-3
f_max: 1.0
loss_cfg:
lambda_logits: 1.
lambda_kl: 0.001
# lambda_eik: -1.
# lambda_sn: -1.
# lambda_sign: -1.
================================================
FILE: docs/carousel.css
================================================
.x-carousel-tags {
width: 100%;
display: flex;
align-items: center;
justify-content: left;
flex-wrap: wrap;
}
.x-carousel-tag {
background-color: rgba(255, 255, 255, 0.9);
box-shadow: rgba(0, 0, 0, 0.1) 0px 2px 4px;
border: 2px solid transparent;
border-radius: 8px;
cursor: pointer;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
transition: all 0.3s ease;
color: #2a2a2a;
margin: 4px;
padding: 8px 16px;
text-align: center;
}
.x-carousel-tag:hover {
background-color: rgba(255, 255, 255, 1);
transform: translateY(-2px);
box-shadow: rgba(0, 0, 0, 0.15) 0px 4px 6px;
}
.x-carousel-tag.active { border-color: #666; background-color: rgba(255, 255, 255, 1); }
.x-carousel-slider {
width: 100%;
display: flex;
align-items: center;
justify-content: center;
flex-wrap: wrap;
}
.x-carousel-slider-item {
max-width: 100%;
width: 100%;
flex: 1 0 0;
}
.x-carousel-nav {
width: 100%;
height: 40px;
display: flex;
align-items: center;
justify-content: space-between;
}
.x-carousel-switch {
width: 50px;
height: 25px;
margin: 8px;
border-radius: 25px;
cursor: pointer;
user-select: none;
color: rgba(42, 42, 42, 0.4);
font-size: 24px;
font-weight: 500;
transition: all 0.25s ease;
display: flex;
align-items: center;
justify-content: center;
}
.x-carousel-switch:hover {
color: rgba(42, 42, 42, 0.8);
transform: scale(1.05);
}
.x-carousel-pages {
display: flex;
align-items: center;
justify-content: center;
flex-wrap: wrap;
}
.x-carousel-page {
width: 10px;
height: 10px;
border-radius: 10px;
background: rgba(42, 42, 42, 0.2);
/* background: linear-gradient(107.54deg, #0078d4 .39%, #8661c5 51.23%, #ff9349 100%) fixed; */
margin: 0 3px;
cursor: pointer;
transition: all 0.25s ease;
}
.x-carousel-page:hover {
background: rgba(42, 42, 42, 0.5);
}
.x-carousel-page.x-carousel-page-active {
width: 12px;
height: 13px;
}
================================================
FILE: docs/carousel.js
================================================
/**
* Carousel functionality for research project page
* Handles navigation, filtering, and page indicators for carousel components
*/
(function() {
'use strict';
/**
* Initialize carousel functionality
* @param {string} carouselId - ID of the carousel container
*/
function initCarousel(carouselId) {
const carousel = document.getElementById(carouselId);
if (!carousel) return;
const slider = carousel.querySelector('.x-carousel-slider');
const allItems = carousel.querySelectorAll('.x-carousel-slider-item');
const prevBtn = carousel.querySelector('.x-carousel-nav .x-carousel-switch:first-child');
const nextBtn = carousel.querySelector('.x-carousel-nav .x-carousel-switch:last-child');
const pages = carousel.querySelectorAll('.x-carousel-page');
const tags = carousel.querySelectorAll('.x-carousel-tag');
if (!slider || !allItems.length) return;
let currentIndex = 0;
let currentFilter = 'all'; // Current filter: 'all', 'class1', 'class2', 'class3'
let filteredItems = Array.from(allItems); // Currently visible items
/**
* Filter items by tag
* @param {string} filter - Filter value: 'all', 'class1', 'class2', 'class3'
*/
function filterItems(filter) {
currentFilter = filter;
// Filter items based on data-tag attribute
if (filter === 'all') {
filteredItems = Array.from(allItems);
} else {
filteredItems = Array.from(allItems).filter(item => {
return item.getAttribute('data-tag') === filter;
});
}
// Reset to first item after filtering
currentIndex = 0;
// Update visibility of all items
allItems.forEach(item => {
if (filteredItems.includes(item)) {
item.style.display = 'block';
} else {
item.style.display = 'none';
}
});
// Show first filtered item
goToSlide(0);
updatePages();
}
/**
* Navigate to a specific slide (within filtered items)
* @param {number} index - Index of the slide to show
*/
function goToSlide(index) {
const totalItems = filteredItems.length;
currentIndex = Math.max(0, Math.min(index, totalItems - 1));
// Hide all filtered items
filteredItems.forEach(item => {
item.style.display = 'none';
});
// Show current item - use block instead of flex to preserve card's flex layout
if (filteredItems[currentIndex]) {
filteredItems[currentIndex].style.display = 'block';
}
updatePages();
updateButtons();
}
/**
* Update page indicators to reflect current slide
*/
function updatePages() {
const totalItems = filteredItems.length;
pages.forEach((page, index) => {
if (index === currentIndex && index < totalItems) {
page.classList.add('x-carousel-page-active');
} else {
page.classList.remove('x-carousel-page-active');
}
});
// Hide unused page indicators
pages.forEach((page, index) => {
if (index >= totalItems) {
page.style.display = 'none';
} else {
page.style.display = '';
}
});
}
/**
* Update navigation buttons state (enable/disable at boundaries)
*/
function updateButtons() {
const totalItems = filteredItems.length;
if (prevBtn) {
prevBtn.style.opacity = currentIndex === 0 ? '0.3' : '1';
prevBtn.style.cursor = currentIndex === 0 ? 'not-allowed' : 'pointer';
}
if (nextBtn) {
nextBtn.style.opacity = currentIndex === totalItems - 1 ? '0.3' : '1';
nextBtn.style.cursor = currentIndex === totalItems - 1 ? 'not-allowed' : 'pointer';
}
}
/**
* Tag filtering (only for results-gen carousel)
*/
if (tags.length && carouselId === 'results-gen') {
tags.forEach((tag) => {
tag.addEventListener('click', function() {
const filter = tag.getAttribute('data-filter');
if (!filter) return;
// Remove active class from all tags
tags.forEach(t => t.classList.remove('active'));
// Add active class to clicked tag
tag.classList.add('active');
// Filter items
filterItems(filter);
});
});
}
// Previous/Next buttons
if (prevBtn) {
prevBtn.addEventListener('click', function() {
if (currentIndex > 0) {
goToSlide(currentIndex - 1);
}
});
}
if (nextBtn) {
nextBtn.addEventListener('click', function() {
const totalItems = filteredItems.length;
if (currentIndex < totalItems - 1) {
goToSlide(currentIndex + 1);
}
});
}
// Page indicators - click to jump to specific slide
pages.forEach((page, index) => {
page.addEventListener('click', function() {
goToSlide(index);
});
});
// Initialize - filter to 'all' and show first item
filterItems('all');
}
// Initialize all carousels when DOM is ready
document.addEventListener('DOMContentLoaded', function() {
initCarousel('results-gen');
initCarousel('results-recon');
});
})();
document.addEventListener('DOMContentLoaded', function() {
// 1. 预先创建一个用于显示图片的容器(一开始隐藏)
const promptImgContainer = document.createElement('div');
promptImgContainer.id = 'glb-prompt-image-container';
promptImgContainer.style.cssText = `
position: fixed;
bottom: 20px;
right: 20px;
width: 200px;
height: 200px;
z-index: 10000; /* 保证在最上层 */
display: none; /* 默认隐藏 */
background-color: white;
padding: 5px;
border-radius: 8px;
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
cursor: pointer; /* 提示可点击关闭 */
`;
// 创建图片元素
const promptImg = document.createElement('img');
promptImg.style.cssText = `
width: 100%;
height: 100%;
object-fit: contain;
display: block;
`;
promptImgContainer.appendChild(promptImg);
// 添加关闭提示文字(可选)
const closeTip = document.createElement('div');
closeTip.innerText = "Click to close";
closeTip.style.cssText = "position:absolute; top:-25px; right:0; color:white; font-size:12px; background:rgba(0,0,0,0.5); padding:2px 5px; border-radius:4px;";
promptImgContainer.appendChild(closeTip);
document.body.appendChild(promptImgContainer);
// 点击图片容器时,自己隐藏
promptImgContainer.addEventListener('click', function() {
this.style.display = 'none';
});
// 2. 为所有的 View GLB 按钮添加点击事件
const buttons = document.querySelectorAll('.x-button');
buttons.forEach(btn => {
btn.addEventListener('click', function(e) {
// 获取 HTML 中定义的 data-prompt 属性 (assets/images/1.png)
const imgUrl = this.getAttribute('data-prompt');
if (imgUrl) {
promptImg.src = imgUrl;
promptImgContainer.style.display = 'block'; // 显示图片
}
});
});
// 3. (可选) 如果你的 GLB 查看器有“关闭”按钮(例如 class 为 .close-viewer),
// 你需要在这里添加逻辑,让点击关闭查看器时,图片也跟着消失。
// 假设关闭按钮的类名是 .close-btn (你需要确认实际类名)
/*
const closeGlbBtn = document.querySelector('.close-btn-class-name');
if(closeGlbBtn) {
closeGlbBtn.addEventListener('click', () => {
promptImgContainer.style.display = 'none';
});
}
*/
});
================================================
FILE: docs/index copy.html
================================================
UltraShape 1.0: High-Fidelity 3D Shape Generation via Scalable Geometric Refinement
UltraShape 1.0
High-Fidelity 3D Shape Generation via Scalable Geometric Refinement
We introduce UltraShape-1.0, a scalable two-stage diffusion framework for high-quality 3D geometry generation, enhanced by an advanced data processing pipeline that ensures geometric details through watertight processing and quality filtering.
title={UltraShape 1.0: High-Fidelity 3D Shape Generation via Scalable Geometric Refinement},
author={Jia, Tanghui and Yan, Dongyu and Hao, Dehao and Li, Yang and Zhang, Kaiyi and He, Xianyi and Li, Lanjiong and Chen, Jinnan and Jiang, Lutao and Yin, Qishen and Quan, Long and Chen, Ying-Cong and Yuan, Li},
journal={arxiv preprint arXiv:2512.21185},
year={2025}
}
✕
================================================
FILE: docs/index.html
================================================
UltraShape 1.0: High-Fidelity 3D Shape Generation via Scalable Geometric Refinement
UltraShape 1.0
High-Fidelity 3D Shape Generation via Scalable Geometric Refinement
We introduce UltraShape-1.0, a scalable two-stage diffusion framework for high-quality 3D geometry generation, enhanced by an advanced data processing pipeline that ensures geometric details through watertight processing and quality filtering.
title={UltraShape 1.0: High-Fidelity 3D Shape Generation via Scalable Geometric Refinement},
author={Jia, Tanghui and Yan, Dongyu and Hao, Dehao and Li, Yang and Zhang, Kaiyi and He, Xianyi and Li, Lanjiong and Chen, Jinnan and Jiang, Lutao and Yin, Qishen and Quan, Long and Chen, Ying-Cong and Yuan, Li},
journal={arxiv preprint arXiv:2512.21185},
year={2025}
}
✕
================================================
FILE: docs/main.css
================================================
* {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
a {
color: #5b6acf;
text-decoration: none;
}
a.link:focus, a.link:hover {
color: #8b5cf6;
text-decoration: none;
}
html {
scroll-behavior: smooth;
scroll-snap-type: y proximity;
}
body {
background: #faf9f7;
position: relative;
margin: 0px;
padding: 0px;
color: #2a2a2a;
overflow-x: hidden;
}
p {
position: relative;
margin: 16px;
font-size: 16px;
font-weight: 300;
text-align: justify;
}
p span {
font-weight: 500;
}
.x-row {
width: 100%;
display: flex;
align-items: center;
justify-content: center;
flex-wrap: nowrap;
}
.x-column {
display: flex;
align-items: center;
justify-content: center;
flex-wrap: nowrap;
flex-direction: column;
}
.x-center-text {
margin: 16px 32px;
text-align: center;
}
.x-left-align {
display: flex;
align-items: center;
justify-content: left;
flex-wrap: nowrap;
}
.x-right-align {
display: flex;
align-items: center;
justify-content: right;
flex-wrap: nowrap;
}
.x-flex-spacer {
flex: 1;
}
.x-labels {
position: absolute;
top: 8px;
right: 6px;
display: flex;
align-items: center;
justify-content: left;
flex-direction: row-reverse;
}
.x-label {
height: 20px;
padding: 0px 6px;
margin: 0px 2px;
color: #2a2a2a;
font-size: 12px;
font-weight: 600;
background: rgba(45, 45, 45, 0.1);
border-radius: 16px;
display: flex;
align-items: center;
justify-content: center;
}
.x-button {
height: 36px;
padding: 0px 14px;
background: rgba(45, 45, 45, 0.08);
color: #2a2a2a;
border-radius: 50px;
box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.1);
font-size: 16px;
font-weight: 300;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
transition: all 0.2s ease;
}
.x-button.small {
height: 32px;
padding: 0px 12px;
border-radius: 50px;
font-size: 14px;
font-weight: 600;
}
.x-button:hover {
background: rgba(45, 45, 45, 0.15);
transform: translateY(-2px);
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.15);
}
.x-button.disabled {
background: rgba(45, 45, 45, 0.05);
color: rgba(42, 42, 42, 0.4);
cursor: default;
}
.x-button.disabled:hover {
background: rgba(45, 45, 45, 0.05);
transform: none;
}
.x-gradient-font {
background: linear-gradient(270deg, #845ade 0%, #2e6ed6 25%, #ff7d4b 75%, #ec9b0b 100%);
background-clip: text;
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
.x-gradient-block {
color: #3f3f3f;
background: linear-gradient(270deg, #845ade2f 0%, #2e6ed62f 25%, #ff7d4b2f 75%, #ec9b0b2f 100%);
border-radius: 16px;
}
.x-gradient-border {
position: relative;
padding: 1px;
margin: 3px;
border: 3px;
background: white;
background-clip: padding-box;
border: solid border transparent;
border-radius: 16px;
}
.x-gradient-border::before {
content: '';
position: absolute;
top: 0; right: 0; bottom: 0; left: 0;
z-index: -1;
margin: -3px;
border-radius: 16px;
background: linear-gradient(270deg, #845ade 0%, #2e6ed6 25%, #ff7d4b 75%, #ec9b0b 100%);
}
.x-section-title {
text-align: center;
margin: 100px 0px 48px 0px;
font-size: 36px;
font-weight: 600;
letter-spacing: 2px;
text-transform: uppercase;
color: #333;
position: relative;
padding-bottom: 20px;
}
.x-section-title::after {
content: '';
position: absolute;
bottom: 0;
left: 50%;
transform: translateX(-50%);
width: 80px;
height: 3px;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
border-radius: 2px;
}
.x-note {
color: rgba(42, 42, 42, 0.7);
font-size: 14px;
font-weight: 300;
}
.x-card {
position: relative;
display: flex;
align-items: center;
justify-content: center;
flex-wrap: wrap;
}
.x-card .caption {
height: 200px;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
color: #2a2a2a;
font-size: 16px;
font-weight: 600;
width: 100%;
}
.x-handwriting {
width: 100%;
font-family: 'Segoe Print';
font-size: 12px;
font-weight: 600;
line-height: 1.5;
color: black;
text-align: justify;
}
.x-image-prompt {
position: relative;
height: calc(100% - 2px);
aspect-ratio: 1/1;
display: flex;
align-items: center;
justify-content: center;
}
.x-image-prompt img {
max-width: 100%;
max-height: 100%;
}
.x-small-header {
text-align: center;
margin-top: 64px;
margin-bottom: 24px;
margin-left: 4px;
font-size: 16px;
font-weight: 500;
letter-spacing: 4px;
text-transform: uppercase;
color: #666;
}
.x-dot-card {
background: rgba(255, 255, 255, 0.8);
border-radius: 16px;
padding: 24px;
display: flex;
flex-direction: column;
box-shadow: 0px 2px 8px rgba(0, 0, 0, 0.08);
}
.x-dot-card-title {
margin-top: 0;
margin-bottom: 0px;
font-size: 20px;
color: #2a2a2a;
display: flex;
align-items: center;
gap: 10px;
}
#main {
max-width: 1000px;
margin: 0px auto;
padding-bottom: 200px;
}
.author-info {
display: flex;
justify-content: center;
align-items: center;
gap: 32px;
padding: 8px;
}
.author-link {
color: #2a2a2a;
text-decoration: none;
font-weight: 500;
}
.author-link:focus, .author-link:hover {
text-decoration: underline;
}
.affiliation-link {
font-size: 14px;
color: rgba(42, 42, 42, 0.7);
text-decoration: none;
font-weight: 300;
}
#links {
margin: 16px 0;
display: flex;
align-items: center;
justify-content: center;
flex-wrap: wrap;
}
#links div {
margin: 4px 8px;
height: 38px;
display: flex;
align-items: center;
justify-content: center;
}
#links a {
height: 20px;
padding: 8px 16px;
color: #2a2a2a;
font-size: 16px;
font-weight: 300;
background: rgba(255, 255, 255, 0.9);
border-radius: 50px;
box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.1);
display: flex;
align-items: center;
justify-content: center;
transition: all 0.2s ease;
}
#links a:hover {
background: rgba(255, 255, 255, 1);
transform: translateY(-2px);
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.15);
}
#links a.disabled {
background-color: rgba(200, 200, 200, 0.5);
color: rgba(42, 42, 42, 0.5);
}
#links a.disabled:hover {
background-color: rgba(200, 200, 200, 0.5);
}
#links a::before {
/* Material Icons */
font-family: 'Material Icons' !important;
font-style: normal;
font-weight: normal;
font-variant: normal;
text-transform: none;
line-height: 1;
letter-spacing: normal;
word-wrap: normal;
white-space: nowrap;
direction: ltr;
/* Better Font Rendering =========== */
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
text-rendering: optimizeLegibility;
font-feature-settings: 'liga';
margin-right: 8px;
font-size: 20px;
}
#links #paper::before {
content: "description";
}
#links #arxiv::before {
content: "article";
}
#links #code::before {
content: "code";
}
#links #poster::before {
content: "picture_as_pdf";
}
#links #video::before {
content: "play_circle";
}
#links #demo::before {
content: "rocket_launch";
}
.feature-container {
max-width: 1000px;
margin: 32px auto;
font-family: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
}
.feature-tabs {
display: grid;
grid-template-columns: repeat(4, 1fr);
gap: 16px;
}
.feature-tab {
aspect-ratio: 1 / 1;
background-color: rgba(255, 255, 255, 0.8);
border: 2px solid transparent;
border-radius: 12px;
cursor: pointer;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
transition: all 0.3s ease;
color: #666;
padding: 10px;
text-align: center;
box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.08);
}
.feature-tab:hover {
background-color: rgba(255, 255, 255, 1);
transform: translateY(-2px);
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.12);
}
.feature-tab.active-yellow { border-color: #facd5c; color: #facd5c; background-color: rgba(251, 191, 36, 0.1); }
.feature-tab.active-red { border-color: #d17969; color: #d17969; background-color: rgba(209, 121, 105, 0.1); }
.feature-tab.active-blue { border-color: #60a5fa; color: #60a5fa; background-color: rgba(96, 165, 250, 0.1); }
.feature-tab.active-purple { border-color: #b7a5ff; color: #b7a5ff; background-color: rgba(167, 139, 250, 0.1); }
.feature-tab svg {
width: 64px;
height: 64px;
margin-bottom: 8px;
fill: currentColor;
}
.feature-tab span {
font-size: 15px;
font-weight: 400;
line-height: 1.2;
}
.feature-panel {
display: none;
padding: 32px;
animation: fadeIn 0.4s ease;
}
.feature-panel.active {
display: block;
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(10px); }
to { opacity: 1; transform: translateY(0); }
}
@media (max-width: 600px) {
.feature-tabs {
grid-template-columns: repeat(2, 1fr);
}
.feature-panel {
padding: 16px;
}
}
.bibtex-entry {
margin: 32px auto;
max-width: 900px;
padding: 0px 24px;
color: #2a2a2a;
font-family: consolas, monospace;
white-space: pre;
text-wrap: wrap;
font-size: 14px;
font-weight: 300;
display: flex;
flex-direction: column;
align-items: left;
justify-content: center;
}
.line {
display: grid;
grid-template-columns: max-content 1fr;
gap: 0;
}
.key {
font-family: consolas, monospace;
text-align: right;
padding-right: 0;
}
.value {
font-family: consolas, monospace;
word-break: break-word;
}
#bottombar {
position: absolute;
bottom: 0px;
height: 100px;
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
justify-content: space-around;
user-select: none;
}
#bottombar .row {
width: 90%;
padding: 0px 5%;
display: flex;
align-items: center;
justify-content: space-between;
user-select: none;
}
#bottombar div {
color: rgba(42, 42, 42, 0.7);
font-size: 12px;
font-weight: 500;
}
#bottombar div a {
color: rgba(42, 42, 42, 0.7);
font-size: 12px;
font-weight: 500;
}
#bottombar div a:hover {
color: rgba(42, 42, 42, 1);
font-size: 12px;
}
#bottombar div span {
font-weight: 700;
}
.scroll-indicator {
position: fixed;
bottom: 4px;
left: 50%;
transform: translateX(-50%);
display: flex;
flex-direction: column;
align-items: center;
color: #666;
cursor: pointer;
z-index: 1000;
animation: scroll-bounce 2s infinite;
transition: opacity 0.5s ease, visibility 0.5s;
}
.scroll-indicator.hidden {
opacity: 0;
visibility: hidden;
}
@keyframes scroll-bounce {
0%, 20%, 50%, 80%, 100% {
transform: translateX(-50%) translateY(0);
}
40% {
transform: translateX(-50%) translateY(-10px);
}
60% {
transform: translateX(-50%) translateY(-5px);
}
}
/* Animated Gradient Background with Mouse Follow */
.animated-gradient {
--mouse-x: 50%;
--mouse-y: 50%;
background:
radial-gradient(circle at var(--mouse-x) var(--mouse-y), rgba(102, 126, 234, 0.8) 0%, transparent 50%),
radial-gradient(circle at calc(100% - var(--mouse-x)) calc(100% - var(--mouse-y)), rgba(118, 75, 162, 0.8) 0%, transparent 50%),
radial-gradient(circle at var(--mouse-x) calc(100% - var(--mouse-y)), rgba(240, 147, 251, 0.6) 0%, transparent 50%),
linear-gradient(135deg, #667eea 0%, #764ba2 50%, #f093fb 100%);
background-size: 200% 200%;
animation: gradient-shift 20s ease infinite;
transition: background 0.1s ease-out;
}
/* Hero expandable effect */
.hero-expandable {
position: sticky;
top: 0;
z-index: 10;
transition: height 0.6s cubic-bezier(0.4, 0, 0.2, 1);
margin-bottom: 0;
}
@keyframes gradient-shift {
0% {
background-position: 0% 50%;
}
50% {
background-position: 100% 50%;
}
100% {
background-position: 0% 50%;
}
}
================================================
FILE: docs/pv.css
================================================
.pv-video-wrapper {
position: relative;
width: 100%;
aspect-ratio: 16 / 9;
margin: 0 auto;
background-color: #2a2a2a;
overflow: hidden;
border-radius: 8px;
}
.pv-video-element {
width: 100%;
height: 100%;
display: block;
object-fit: contain;
}
.pv-poster-overlay {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
cursor: pointer;
z-index: 10;
display: flex;
justify-content: center;
align-items: center;
background-size: cover;
background-position: center;
transition: opacity 0.3s ease;
}
.pv-play-btn {
width: 64px;
height: 64px;
background-color: rgba(255, 255, 255, 0.9);
border: 2px solid #2a2a2a;
border-radius: 50%;
display: flex;
justify-content: center;
align-items: center;
transition: transform 0.2s ease, background-color 0.2s;
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2);
}
.pv-play-btn::after {
content: '';
display: block;
width: 0;
height: 0;
border-top: 10px solid transparent;
border-bottom: 10px solid transparent;
border-left: 16px solid #2a2a2a;
margin-left: 4px;
}
.pv-poster-overlay:hover .pv-play-btn {
transform: scale(1.1);
background-color: rgba(255, 255, 255, 1);
}
.pv-video-wrapper.is-playing .pv-poster-overlay {
opacity: 0;
pointer-events: none;
}
================================================
FILE: docs/style.css
================================================
/* Icomoon font removed - now using Material Icons via CDN */
/* Material Icons are loaded in index.html via Google Fonts CDN */
================================================
FILE: docs/stylesheet.css
================================================
/* Font definitions removed - using system Segoe UI font */
/* All font-face declarations for Avenir Next Cyr have been removed */
/* The site now uses 'Segoe UI' as defined in main.css and style.css */
================================================
FILE: docs/window.css
================================================
#fullscreen {
position: fixed;
top: 0;
left: 0;
width: 100vw;
height: 100vh;
background: transparent;
display: none;
align-items: center;
justify-content: center;
z-index: 1000;
user-select: none;
backdrop-filter: blur(10px);
opacity: 0;
transition: opacity 0.25s ease;
}
#fullscreen #window {
position: relative;
min-width: 25vw;
min-height: 25vh;
max-width: 100vw;
max-height: 90vh;
background: #ffffff;
border-radius: 16px;
box-shadow: 0px 4px 16px rgba(0, 0, 0, 0.2);
padding: 8px;
}
#fullscreen #window #close {
position: absolute;
top: 0px;
right: 0px;
width: 31px;
height: 30px;
padding: 0px 0px 2px 1px;
color: black;
font-size: 16px;
font-weight: 700;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
transition: all 0.2s ease;
z-index: 100;
}
#fullscreen #window #close {
color: #2a2a2a;
}
#fullscreen #window #close:hover {
color: #d32f2f;
}
#fullscreen #window #content {
max-width: calc(100vw - 16px);
max-height: calc(90vh - 16px);
overflow-x: hidden;
overflow-y: auto;
}
.modelviewer-container {
width: 500px;
height: 500px;
margin: 8px;
border-radius: 8px;
background: white;
box-shadow: inset 0px 0px 4px rgba(0, 0, 0, 0.25);
overflow: hidden;
position: relative;
}
.modelviewer-container model-viewer {
width: 100% !important;
height: 100% !important;
display: block !important;
background-color: #f5f5f5;
}
.modelviewer-container model-viewer button {
height: 16px;
padding: 0px 6px;
background: rgba(255, 255, 255, 0.75);
border-radius: 50px;
box-shadow: 0px 0px 4px rgba(0, 0, 0, 0.25);
border: none;
font-size: 12px;
font-weight: 300;
display: none;
opacity: 0;
align-items: center;
justify-content: center;
pointer-events: none;
}
.modelviewer-panel {
width: 300px;
margin: 8px;
margin-top: 0px;
display: flex;
flex-direction: column;
align-items: start;
justify-content: start;
}
.modelviewer-panel-desc {
width: 100%;
}
.modelviewer-panel-desc div {
font-size: 16px;
font-weight: 500;
margin: 4px;
}
.modelviewer-panel-prompt {
width: calc(100% - 16px);
height: 250px;
padding: 8px;
background: #f5f5f5;
border-radius: 8px;
box-shadow: inset 0px 0px 4px rgba(0, 0, 0, 0.1);
display: flex;
align-items: start;
justify-content: center;
overflow-y: auto;
user-select: text;
}
.modelviewer-panel-button {
height: 40px;
margin: 4px 4px;
padding: 0px 14px;
background: rgba(45, 45, 45, 0.08);
color: #2a2a2a;
border-radius: 50px;
box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.1);
font-size: 16px;
font-weight: 300;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
transition: all 0.2s ease;
}
.modelviewer-panel-button.small {
height: 32px;
padding: 0px 12px;
border-radius: 50px;
font-size: 14px;
font-weight: 300;
}
.modelviewer-panel-button.tiny {
height: 24px;
padding: 0px 10px;
border-radius: 50px;
font-size: 12px;
font-weight: 300;
}
.modelviewer-panel-button:hover {
background: rgba(45, 45, 45, 0.15);
transform: translateY(-2px);
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.15);
}
.modelviewer-panel-button.checked {
border: 2px solid #2a2a2a;
background: rgba(45, 45, 45, 0.1);
color: #2a2a2a;
}
================================================
FILE: inputs/coarse_mesh/1.glb
================================================
[File too large to display: 25.1 MB]
================================================
FILE: main.py
================================================
# ==============================================================================
# Original work Copyright (c) 2025 Tencent.
# Modified work Copyright (c) 2025 UltraShape Team.
#
# Modified by UltraShape on 2025.12.25
# ==============================================================================
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
# import warnings
# warnings.filterwarnings("ignore")
import os
import torch
import argparse
from pathlib import Path
from typing import Tuple, List
from omegaconf import OmegaConf, DictConfig
from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
allow_ops_in_compiled_graph()
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from pytorch_lightning.loggers import Logger, TensorBoardLogger
from pytorch_lightning.utilities import rank_zero_info
from ultrashape.utils import get_config_from_file, instantiate_from_config
class SetupCallback(Callback):
def __init__(self, config: DictConfig, basedir: Path, logdir: str = "log", ckptdir: str = "ckpt") -> None:
super().__init__()
self.logdir = basedir / logdir
self.ckptdir = basedir / ckptdir
self.config = config
def on_fit_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:
if trainer.global_rank == 0:
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
def setup_callbacks(config: DictConfig) -> Tuple[List[Callback], Logger]:
training_cfg = config.training
basedir = Path(training_cfg.output_dir)
os.makedirs(basedir, exist_ok=True)
all_callbacks = []
setup_callback = SetupCallback(config, basedir)
all_callbacks.append(setup_callback)
checkpoint_callback = ModelCheckpoint(
dirpath=setup_callback.ckptdir,
filename="ckpt-{step:08d}",
save_top_k=-1,
verbose=False,
every_n_train_steps=training_cfg.every_n_train_steps)
all_callbacks.append(checkpoint_callback)
if "callbacks" in config:
for key, value in config['callbacks'].items():
custom_callback = instantiate_from_config(value)
all_callbacks.append(custom_callback)
logger = TensorBoardLogger(save_dir=str(setup_callback.logdir), name="tensorboard")
return all_callbacks, logger
def merge_cfg(cfg, arg_cfg):
for key in arg_cfg.keys():
if key in cfg.training:
arg_cfg[key] = cfg.training[key]
cfg.training = DictConfig(arg_cfg)
return cfg
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--fast", action='store_true')
parser.add_argument("-c", "--config", type=str, required=True)
parser.add_argument("-s", "--seed", type=int, default=0)
parser.add_argument("-nn", "--num_nodes", type=int, default=1)
parser.add_argument("-ng", "--num_gpus", type=int, default=1)
parser.add_argument("-u", "--update_every", type=int, default=1)
parser.add_argument("-st", "--steps", type=int, default=50000000)
parser.add_argument("-lr", "--base_lr", type=float, default=4.5e-6)
parser.add_argument("-a", "--use_amp", default=False, action="store_true")
parser.add_argument("--amp_type", type=str, default="16")
parser.add_argument("--gradient_clip_val", type=float, default=None)
parser.add_argument("--gradient_clip_algorithm", type=str, default=None)
parser.add_argument("--every_n_train_steps", type=int, default=50000)
parser.add_argument("--log_every_n_steps", type=int, default=50)
parser.add_argument("--val_check_interval", type=int, default=1024)
parser.add_argument("--limit_val_batches", type=int, default=64)
parser.add_argument("--monitor", type=str, default="val/total_loss")
parser.add_argument("--output_dir", type=str, help="the output directory to save everything.")
parser.add_argument("--ckpt_path", type=str, default="", help="the restore checkpoints.")
parser.add_argument("--deepspeed", default=False, action="store_true")
parser.add_argument("--deepspeed2", default=False, action="store_true")
parser.add_argument("--scale_lr", type=bool, nargs="?", const=True, default=False,
help="scale base-lr by ngpu * batch_size * n_accumulate")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
if args.fast:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('medium')
torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 0.05
# Set random seed
pl.seed_everything(args.seed, workers=True)
# Load configuration
config = get_config_from_file(args.config)
config = merge_cfg(config, vars(args))
training_cfg = config.training
# print config
rank_zero_info("Begin to print configuration ...")
rank_zero_info(OmegaConf.to_yaml(config))
rank_zero_info("Finish print ...")
# Setup callbacks
callbacks, loggers = setup_callbacks(config)
# Build data modules
data: pl.LightningDataModule = instantiate_from_config(config.dataset)
# Build model
model: pl.LightningModule = instantiate_from_config(config.model)
nodes = args.num_nodes
ngpus = args.num_gpus
base_lr = training_cfg.base_lr
accumulate_grad_batches = training_cfg.update_every
batch_size = config.dataset.params.batch_size
if 'NNODES' in os.environ:
nodes = int(os.environ['NNODES'])
training_cfg.num_nodes = nodes
args.num_nodes = nodes
if args.scale_lr:
model.learning_rate = accumulate_grad_batches * nodes * ngpus * batch_size * base_lr
info = f"Setting learning rate to {model.learning_rate:.2e} = {accumulate_grad_batches} (accumulate)"
info += f" * {nodes} (nodes) * {ngpus} (num_gpus) * {batch_size} (batchsize) * {base_lr:.2e} (base_lr)"
rank_zero_info(info)
else:
model.learning_rate = base_lr
rank_zero_info("++++ NOT USING LR SCALING ++++")
rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
# Build trainer
if args.num_nodes > 1 or args.num_gpus > 1:
if args.deepspeed:
ddp_strategy = DeepSpeedStrategy(stage=1)
elif args.deepspeed2:
ddp_strategy = 'deepspeed_stage_2'
else:
ddp_strategy = DDPStrategy(find_unused_parameters=False, bucket_cap_mb=1500)
else:
ddp_strategy = 'ddp'
rank_zero_info(f'*' * 100)
if training_cfg.use_amp:
amp_type = training_cfg.amp_type
assert amp_type in ['bf16', '16', '32'], f"Invalid amp_type: {amp_type}"
rank_zero_info(f'Using {amp_type} precision')
else:
amp_type = 32
rank_zero_info(f'Using 32 bit precision')
rank_zero_info(f'*' * 100)
trainer = pl.Trainer(
max_steps=training_cfg.steps,
precision=amp_type,
callbacks=callbacks,
accelerator="gpu",
devices=args.num_gpus,
num_nodes=training_cfg.num_nodes,
strategy=ddp_strategy,
gradient_clip_val=training_cfg.get('gradient_clip_val'),
gradient_clip_algorithm=training_cfg.get('gradient_clip_algorithm'),
accumulate_grad_batches=args.update_every,
logger=loggers,
log_every_n_steps=training_cfg.log_every_n_steps,
val_check_interval=training_cfg.val_check_interval,
limit_val_batches=training_cfg.limit_val_batches,
check_val_every_n_epoch=None
)
# Train
if training_cfg.ckpt_path == '':
training_cfg.ckpt_path = None
trainer.fit(model, datamodule=data, ckpt_path=training_cfg.ckpt_path)
================================================
FILE: requirements.txt
================================================
accelerate==1.1.1
diffusers==0.30.0
deepspeed
diso==0.1.4
einops==0.8.1
flash_attn==2.8.3
huggingface_hub==0.36.0
imageio==2.36.0
ipywidgets==8.1.7
jaxtyping==0.3.4
matplotlib==3.10.8
numpy==1.24.4
omegaconf==2.3.0
opencv_python==4.10.0.84
opencv_python_headless==4.11.0.86
pandas==2.3.3
Pillow==12.0.0
pymeshlab==2022.2.post3
pythreejs==2.4.2
pytorch_lightning==1.9.5
PyYAML==6.0.2
safetensors==0.7.0
sageattention==1.0.6
scikit-image==0.24.0
onnxruntime
rembg
tensorboard
timm==1.0.22
torchdiffeq==0.2.5
tqdm==4.66.5
transformers==4.37.2
trimesh==4.4.7
typeguard==4.3.0
wandb==0.23.1
================================================
FILE: scripts/gradio_app.py
================================================
import argparse
import gc
import os
import sys
import gradio as gr
import torch
from omegaconf import OmegaConf
# Add project root to path
sys.path.append(os.getcwd())
from ultrashape.rembg import BackgroundRemover
from ultrashape.utils.misc import instantiate_from_config
from ultrashape.surface_loaders import SharpEdgeSurfaceLoader
from ultrashape.utils import voxelize_from_point
from ultrashape.pipelines import UltraShapePipeline
# Global variables to cache the model
MODEL_CACHE = {}
def get_pipeline_cached(config_path, ckpt_path, device='cuda', low_vram=False):
# Check if we have a valid cached pipeline for this checkpoint
if "pipeline" in MODEL_CACHE and MODEL_CACHE.get("ckpt_path") == ckpt_path:
print("Using cached pipeline...")
return MODEL_CACHE["pipeline"], MODEL_CACHE["config"]
# Clear old cache if it exists (e.g. different checkpoint)
if MODEL_CACHE:
print("Clearing old model cache...")
MODEL_CACHE.clear()
gc.collect()
torch.cuda.empty_cache()
print(f"Loading config from {config_path}...")
config = OmegaConf.load(config_path)
print("Instantiating VAE...")
vae = instantiate_from_config(config.model.params.vae_config)
print("Instantiating DiT...")
dit = instantiate_from_config(config.model.params.dit_cfg)
print("Instantiating Conditioner...")
conditioner = instantiate_from_config(config.model.params.conditioner_config)
print("Instantiating Scheduler & Processor...")
scheduler = instantiate_from_config(config.model.params.scheduler_cfg)
image_processor = instantiate_from_config(config.model.params.image_processor_cfg)
print(f"Loading weights from {ckpt_path}...")
weights = torch.load(ckpt_path, map_location='cpu')
vae.load_state_dict(weights['vae'], strict=True)
dit.load_state_dict(weights['dit'], strict=True)
conditioner.load_state_dict(weights['conditioner'], strict=True)
vae.eval().to(device)
dit.eval().to(device)
conditioner.eval().to(device)
if hasattr(vae, 'enable_flashvdm_decoder'):
vae.enable_flashvdm_decoder()
print("Creating Pipeline...")
pipeline = UltraShapePipeline(
vae=vae,
model=dit,
scheduler=scheduler,
conditioner=conditioner,
image_processor=image_processor
)
if low_vram:
pipeline.enable_model_cpu_offload()
MODEL_CACHE["pipeline"] = pipeline
MODEL_CACHE["config"] = config
MODEL_CACHE["ckpt_path"] = ckpt_path
return pipeline, config
def predict(
image_input,
mesh_input,
steps,
scale,
octree_res,
num_latents,
chunk_size,
seed,
remove_bg,
ckpt_path,
low_vram
):
# Aggressive memory cleanup at start
gc.collect()
torch.cuda.empty_cache()
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config_path = "configs/infer_dit_refine.yaml"
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config not found at {config_path}")
pipeline, config = get_pipeline_cached(config_path, ckpt_path, device, low_vram)
voxel_res = config.model.params.vae_config.params.voxel_query_res
print(f"Initializing Surface Loader (Token Num: {num_latents})...")
loader = SharpEdgeSurfaceLoader(
num_sharp_points=204800,
num_uniform_points=204800,
)
print(f"Processing inputs...")
if image_input is None:
raise ValueError("Image input is required")
if mesh_input is None:
raise ValueError("Mesh input is required")
# Handle image input
if isinstance(image_input, dict):
# In newer gradio versions Image component might return a dict for mask etc, but usually just PIL/numpy
# if type='pil' it is PIL.Image
pass
image = image_input.convert("RGBA")
if remove_bg or image.mode != 'RGBA':
rembg = BackgroundRemover()
image = rembg(image)
# Handle mesh input - Gradio Model3D returns path to file
surface = loader(mesh_input, normalize_scale=scale).to(device, dtype=torch.float16)
pc = surface[:, :, :3] # [B, N, 3]
# Voxelize
_, voxel_idx = voxelize_from_point(pc, num_latents, resolution=voxel_res)
print("Running diffusion process...")
gen_device = "cpu" if low_vram else device
generator = torch.Generator(gen_device).manual_seed(int(seed))
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
mesh_out_list, _ = pipeline(
image=image,
voxel_cond=voxel_idx,
generator=generator,
box_v=1.0,
mc_level=0.0,
octree_resolution=int(octree_res),
num_chunks=int(chunk_size),
num_inference_steps=int(steps)
)
# Save output
output_dir = "outputs_gradio"
os.makedirs(output_dir, exist_ok=True)
base_name = "output"
save_path = os.path.join(output_dir, f"{base_name}_refined.glb")
mesh_out = mesh_out_list[0]
mesh_out.export(save_path)
print(f"Successfully saved to {save_path}")
return save_path
finally:
# Aggressive memory cleanup at end
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="UltraShape Gradio App")
parser.add_argument("--ckpt", type=str, required=True, help="Path to split checkpoint (.pt)")
parser.add_argument("--share", action="store_true", help="Share the gradio app")
parser.add_argument("--low_vram", action="store_true", help="Optimize for low VRAM usage")
args = parser.parse_args()
# Define Gradio Interface
with gr.Blocks(title="UltraShape Inference") as demo:
gr.Markdown("# UltraShape Inference: Mesh & Image Refinement")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Input Image", image_mode="RGBA")
mesh_input = gr.Model3D(label="Input Coarse Mesh (.glb, .obj)")
with gr.Accordion("Advanced Parameters", open=True):
steps = gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Inference Steps")
scale = gr.Slider(minimum=0.1, maximum=2.0, value=0.99, label="Mesh Normalization Scale")
octree_res = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, label="Octree Resolution")
num_latents = gr.Slider(minimum=1024, maximum=32768, value=32768, step=128,
label="Number of Latent Tokens (Use 8192 if OOM)")
chunk_size = gr.Slider(minimum=512, maximum=10000, value=2048, step=512,
label="Chunk Size (Use 2000 if OOM)")
seed = gr.Number(value=42, label="Random Seed")
remove_bg = gr.Checkbox(label="Remove Background", value=False)
run_btn = gr.Button("Run Inference", variant="primary")
with gr.Column():
output_model = gr.Model3D(label="Refined Output Mesh")
run_btn.click(
fn=lambda img, mesh, s, sc, oct, nml, chk, sd, rm: predict(img, mesh, s, sc, oct, nml, chk, sd, rm, args.ckpt,
args.low_vram),
inputs=[image_input, mesh_input, steps, scale, octree_res, num_latents, chunk_size, seed, remove_bg],
outputs=[output_model]
)
demo.launch(share=args.share, server_name='0.0.0.0', server_port=7860)
================================================
FILE: scripts/infer_dit_refine.py
================================================
import os
import sys
import argparse
import torch
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
# project_root = '[your_project_root_path]' # Replace with your project root path
# sys.path.insert(0, project_root)
from ultrashape.rembg import BackgroundRemover
from ultrashape.utils.misc import instantiate_from_config
from ultrashape.surface_loaders import SharpEdgeSurfaceLoader
from ultrashape.utils import voxelize_from_point
from ultrashape.pipelines import UltraShapePipeline
def load_models(config_path, ckpt_path, device='cuda'):
print(f"Loading config from {config_path}...")
config = OmegaConf.load(config_path)
print("Instantiating VAE...")
vae = instantiate_from_config(config.model.params.vae_config)
print("Instantiating DiT...")
dit = instantiate_from_config(config.model.params.dit_cfg)
print("Instantiating Conditioner...")
conditioner = instantiate_from_config(config.model.params.conditioner_config)
print("Instantiating Scheduler & Processor...")
scheduler = instantiate_from_config(config.model.params.scheduler_cfg)
image_processor = instantiate_from_config(config.model.params.image_processor_cfg)
print(f"Loading weights from {ckpt_path}...")
weights = torch.load(ckpt_path, map_location='cpu')
vae.load_state_dict(weights['vae'], strict=True)
dit.load_state_dict(weights['dit'], strict=True)
conditioner.load_state_dict(weights['conditioner'], strict=True)
vae.eval().to(device)
dit.eval().to(device)
conditioner.eval().to(device)
if hasattr(vae, 'enable_flashvdm_decoder'):
vae.enable_flashvdm_decoder()
components = {
"vae": vae,
"dit": dit,
"conditioner": conditioner,
"scheduler": scheduler,
"image_processor": image_processor,
}
return components, config
def run_inference(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
components, config = load_models(args.config, args.ckpt, device)
pipeline = UltraShapePipeline(
vae=components['vae'],
model=components['dit'],
scheduler=components['scheduler'],
conditioner=components['conditioner'],
image_processor=components['image_processor']
)
if args.low_vram:
pipeline.enable_model_cpu_offload()
token_num = args.num_latents
voxel_res = config.model.params.vae_config.params.voxel_query_res
print(f"Initializing Surface Loader (Token Num: {token_num})...")
loader = SharpEdgeSurfaceLoader(
num_sharp_points=204800,
num_uniform_points=204800,
)
print(f"Processing inputs: {args.image} & {args.mesh}")
image = Image.open(args.image)
if args.remove_bg or image.mode != 'RGBA':
rembg = BackgroundRemover()
image = rembg(image)
surface = loader(args.mesh, normalize_scale=args.scale).to(device, dtype=torch.float16)
pc = surface[:, :, :3] # [B, N, 3]
# Voxelize
_, voxel_idx = voxelize_from_point(pc, token_num, resolution=voxel_res)
print("Running diffusion process...")
generator = torch.Generator(device).manual_seed(args.seed)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
mesh, _ = pipeline(
image=image,
voxel_cond=voxel_idx,
generator=generator,
box_v=1.0,
mc_level=0.0,
octree_resolution=args.octree_res,
num_inference_steps=args.steps,
num_chunks=args.chunk_size,
)
os.makedirs(args.output_dir, exist_ok=True)
base_name = os.path.splitext(os.path.basename(args.image))[0]
save_path = os.path.join(args.output_dir, f"{base_name}_refined.glb")
mesh = mesh[0]
mesh.export(save_path)
print(f"Successfully saved to {save_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="UltraShape Inference Script")
parser.add_argument("--config", type=str, default="configs/infer_dit2.yaml", help="Path to inference config")
parser.add_argument("--ckpt", type=str, required=True, help="Path to split checkpoint (.pt)")
parser.add_argument("--low_vram", action="store_true", help="Optimize for low VRAM usage")
parser.add_argument("--image", type=str, required=True, help="Input image path")
parser.add_argument("--mesh", type=str, required=True, help="Input coarse mesh (.glb/.obj)")
parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory")
parser.add_argument("--steps", type=int, default=50, help="Inference steps")
parser.add_argument("--scale", type=float, default=0.99, help="Mesh normalization scale")
parser.add_argument("--num_latents", type=int, default=32768, help="Number of latents")
parser.add_argument("--chunk_size", type=int, default=8000, help="Chunk size for inference")
parser.add_argument("--octree_res", type=int, default=1024, help="Marching Cubes resolution")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--remove_bg", action="store_true", help="Force remove background")
args = parser.parse_args()
run_inference(args)
================================================
FILE: scripts/install_env.sh
================================================
conda create -n ultrashape python=3.10
conda activate ultrashape
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt
pip install git+https://github.com/ashawkey/cubvh --no-build-isolation
pip install --no-build-isolation "git+https://github.com/facebookresearch/pytorch3d.git@stable"
pip install https://data.pyg.org/whl/torch-2.5.0%2Bcu121/torch_cluster-1.6.3%2Bpt25cu121-cp310-cp310-linux_x86_64.whl
================================================
FILE: scripts/run.sh
================================================
# sampling
# python scripts/sampling.py \
# --mesh_json data/mesh_paths.json \
# --output_dir data/sample
# inference refine_dit
python scripts/infer_dit_refine.py \
--ckpt checkpoints/ultrashape_v1.pt \
--image inputs/image/1.png \
--mesh inputs/coarse_mesh/1.glb \
--config configs/infer_dit_refine.yaml
# --steps 12
================================================
FILE: scripts/sampling.py
================================================
import os
import trimesh
import numpy as np
from typing import List, Optional, Any, Tuple, Union
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import STEP_OUTPUT
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch3d.structures
import pytorch3d.ops
from scipy.stats import truncnorm
import json
import argparse
import cubvh
# import logging
# from tools.logger import init_log, set_all_log
# sys_logger = init_log("sampler", logging.DEBUG)
# set_all_log(level=logging.DEBUG, path='./debug/logs')
def load_mesh(mesh_path: str, device: str = "cuda") -> Tuple[torch.Tensor, torch.Tensor]:
if mesh_path.endswith(".npz"):
mesh_np = np.load(mesh_path)
vertices, faces = torch.tensor(mesh_np["vertices"], device=device), torch.tensor(mesh_np["faces"].astype('i8'), device=device)
else:
mesh = trimesh.load(mesh_path, force='mesh')
vertices = torch.tensor(mesh.vertices, dtype=torch.float32, device=device)
faces = torch.tensor(mesh.faces, dtype=torch.long, device=device)
if faces.shape[0] > 2 * 1e8:
raise ValueError(f"too many faces {faces.shape}")
return vertices, faces
def compute_mesh_features(vertices: torch.Tensor, faces: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
device = vertices.device
v0 = vertices[faces[:, 0]]
v1 = vertices[faces[:, 1]]
v2 = vertices[faces[:, 2]]
face_normals = torch.cross(v1 - v0, v2 - v0)
face_areas = torch.norm(face_normals, dim=1) * 0.5
face_normals = face_normals / (face_areas.unsqueeze(1) * 2 + 1e-12)
vertex_normals = torch.zeros_like(vertices)
face_normals_weighted = face_normals * face_areas.unsqueeze(1)
vertex_normals.scatter_add_(0, faces[:, 0:1].expand(-1, 3), face_normals_weighted)
vertex_normals.scatter_add_(0, faces[:, 1:2].expand(-1, 3), face_normals_weighted)
vertex_normals.scatter_add_(0, faces[:, 2:3].expand(-1, 3), face_normals_weighted)
vertex_normals = vertex_normals / (torch.norm(vertex_normals, dim=1, keepdim=True) + 1e-12)
edges = torch.cat([
faces[:, [0, 1]],
faces[:, [1, 2]],
faces[:, [2, 0]]
], dim=0)
edges_unique, edges_inverse = torch.unique(torch.sort(edges, dim=1)[0], dim=0, return_inverse=True)
edge_normals_diff = torch.norm(
vertex_normals[edges[:, 0]] - vertex_normals[edges[:, 1]],
dim=1
)
vertex_curvatures = torch.zeros(len(vertices), device=device)
vertex_curvatures.scatter_add_(0, edges[:, 0], edge_normals_diff)
vertex_curvatures.scatter_add_(0, edges[:, 1], edge_normals_diff)
vertex_degrees = torch.zeros(len(vertices), device=device)
vertex_degrees.scatter_add_(0, edges[:, 0], torch.ones_like(edge_normals_diff))
vertex_degrees.scatter_add_(0, edges[:, 1], torch.ones_like(edge_normals_diff))
vertex_curvatures = vertex_curvatures / (vertex_degrees + 1e-12)
vertex_curvatures = (vertex_curvatures - vertex_curvatures.min()) / (
vertex_curvatures.max() - vertex_curvatures.min() + 1e-12)
return face_areas, vertex_curvatures
def sample_uniform_points(
vertices: torch.Tensor,
faces: torch.Tensor,
num_samples: int,
random_seed: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if random_seed is not None:
torch.manual_seed(random_seed)
mesh = pytorch3d.structures.Meshes(verts=[vertices], faces=[faces])
points, normals = pytorch3d.ops.sample_points_from_meshes(
mesh, num_samples=num_samples, return_normals=True)
return points[0], normals[0]
def sample_surface_points(
vertices: torch.Tensor,
faces: torch.Tensor,
num_samples: int,
min_samples_per_face: int = 0,
use_curvature: bool = True,
random_seed: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Curvature-based surface sampling"""
device = vertices.device
if random_seed is not None:
torch.manual_seed(random_seed)
# Compute face areas and vertex curvatures
face_areas, vertex_curvatures = compute_mesh_features(vertices, faces)
# Compute average curvature of faces
face_curvatures = torch.mean(vertex_curvatures[faces], dim=1)
sampling_weights = face_curvatures # Use only curvature as weights
# Calculate number of sample points per face
num_faces = len(faces)
# Chunk forward
if min_samples_per_face > 0:
base_samples = torch.full((num_faces,), min_samples_per_face, device=device)
remaining_samples = num_samples - torch.sum(base_samples).item()
if remaining_samples > 0:
# Block sampling to avoid large mesh issues
if num_faces > 2**24:
chunk_size = 1000000 # Process 1 million faces at a time
additional_counts = torch.zeros(num_faces, device=device)
for start in range(0, num_faces, chunk_size):
end = min(start + chunk_size, num_faces)
chunk_weights = sampling_weights[start:end]
chunk_probs = chunk_weights / chunk_weights.sum()
# Proportinally allocate remaining samples
chunk_samples = int(remaining_samples * (end - start) / num_faces)
samples = torch.multinomial(chunk_probs, chunk_samples, replacement=True)
chunk_counts = torch.bincount(samples, minlength=chunk_size)
additional_counts[start:end] += chunk_counts[:end-start]
sample_counts = additional_counts + base_samples
else:
probs = sampling_weights / sampling_weights.sum()
additional_samples = torch.multinomial(probs, remaining_samples, replacement=True)
sample_counts = torch.bincount(additional_samples, minlength=num_faces) + base_samples
else:
sample_counts = base_samples
else:
if num_faces > 2**24:
# Chunk sampling strategy
sample_counts = torch.zeros(num_faces, device=device)
chunk_size = 1000000 # Process 1 million faces at a time
chunk_samples = num_samples // ((num_faces + chunk_size - 1) // chunk_size)
for start in range(0, num_faces, chunk_size):
end = min(start + chunk_size, num_faces)
chunk_weights = sampling_weights[start:end]
chunk_probs = chunk_weights / chunk_weights.sum()
samples = torch.multinomial(chunk_probs, chunk_samples, replacement=True)
chunk_counts = torch.bincount(samples, minlength=chunk_size)
sample_counts[start:end] += chunk_counts[:end-start]
else:
probs = sampling_weights / sampling_weights.sum()
samples = torch.multinomial(probs, num_samples, replacement=True)
sample_counts = torch.bincount(samples, minlength=num_faces)
# Generate barycentric coordinates for sampled points
total_samples = sample_counts.sum().item()
r1 = torch.sqrt(torch.rand(total_samples, device=device))
r2 = torch.rand(total_samples, device=device)
barycentric_coords = torch.stack([
1 - r1,
r1 * (1 - r2),
r1 * r2
], dim=1)
# Generate face indices
face_indices = torch.repeat_interleave(
torch.arange(num_faces, device=device),
sample_counts
)
# Get vertices of corresponding faces
face_vertices = vertices[faces[face_indices]]
# Compute 3D coordinates of sampled points
points = (barycentric_coords.unsqueeze(1) @ face_vertices).squeeze(1)
# Compute normal vectors of sampled points
v0, v1, v2 = face_vertices[:, 0], face_vertices[:, 1], face_vertices[:, 2]
face_normals = torch.cross(v1 - v0, v2 - v0)
normals = face_normals / (torch.norm(face_normals, dim=1, keepdim=True) + 1e-12)
return points, face_indices, normals
def normalize_points_and_mesh(vertices: torch.Tensor, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Normalize mesh and point cloud to unit cube"""
device = vertices.device
vmin = vertices.min(dim=0)[0]
vmax = vertices.max(dim=0)[0]
center = (vmax + vmin) / 2
scale = (vmax - vmin).max()
margin = 0.01
scale = scale * (1 + 2 * margin)
vertices_normalized = (vertices - center) / scale + 0.5
points_normalized = (points - center) / scale + 0.5
return vertices_normalized, points_normalized, center, scale
def add_gaussian_noise(uniform_surface_points: torch.Tensor, curvature_surface_points: torch.Tensor, sigma: float = 0.01) -> torch.Tensor:
"""Add Gaussian noise to point cloud"""
# noise = torch.randn_like(points) * sigma
# print("u_num:",uniform_surface_points.shape)
# print("c_num:",curvature_surface_points.shape)
idx1 = torch.randperm(uniform_surface_points.shape[0])
idx2 = torch.randperm(curvature_surface_points.shape[0])
uniform_surface_points = uniform_surface_points[idx1]
curvature_surface_points = curvature_surface_points[idx2]
a, b = -0.25, 0.25
mu = 0
# get near points (add offset on surface points)
offset1 = torch.tensor(truncnorm.rvs((a - mu) / 0.005, (b - mu) / 0.005, loc=mu, scale=0.005, size=(len(uniform_surface_points), 3)),
dtype=uniform_surface_points.dtype, device=uniform_surface_points.device)
offset2 = torch.tensor(truncnorm.rvs((a - mu) / 0.05, (b - mu) / 0.05, loc=mu, scale=0.05, size=(len(uniform_surface_points), 3)),
dtype=uniform_surface_points.dtype, device=uniform_surface_points.device)
uniform_near_points = torch.cat([
uniform_surface_points + offset1,
uniform_surface_points + offset2
], dim=0)
# Generate multi-scale noise for curvature sample points
unit_num = curvature_surface_points.shape[0] // 6
scales = [0.001, 0.003, 0.006, 0.01, 0.02, 0.04]
curvature_near_points = []
for i in range(6):
start = i * unit_num
end = (i + 1) * unit_num if i < 5 else curvature_surface_points.shape[0]
noise = torch.randn((end - start, 3), dtype=curvature_surface_points.dtype,
device=curvature_surface_points.device) * scales[i]
curvature_near_points.append(curvature_surface_points[start:end] + noise)
curvature_near_points = torch.cat(curvature_near_points, dim=0)
return uniform_near_points, curvature_near_points
def compute_points_value_bvh(
vertices: torch.Tensor,
faces: torch.Tensor,
points: torch.Tensor,
use_sdf: bool = True,
batch_size: int = 100_00000
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute SDF or occupancy values for sampled points"""
device = vertices.device
# Normalize mesh and point cloud
vertices_norm, points_norm, center, scale = normalize_points_and_mesh(vertices, points)
BVH = cubvh.cuBVH(vertices_norm, faces)
distances, face_id, uvw = BVH.signed_distance(points, return_uvw=True, mode='watertight')
values = distances
return values, points_norm, center, scale
def save_point_cloud(
points: torch.Tensor,
output_path: str,
normals: Optional[torch.Tensor] = None,
colors: Optional[torch.Tensor] = None
) -> None:
"""Save point cloud to file"""
points_np = points.cpu().numpy()
normals_np = normals.cpu().numpy() if normals is not None else None
colors_np = None
if colors is not None:
colors_np = colors.cpu().numpy()
if colors_np.max() <= 1.0:
colors_np = (colors_np * 255).astype(np.uint8)
ext = os.path.splitext(output_path)[1].lower()
if ext == '.txt':
data_list = [points_np]
if normals_np is not None:
data_list.append(normals_np)
if colors_np is not None:
data_list.append(colors_np)
combined_data = np.hstack(data_list)
np.savetxt(output_path, combined_data, fmt='%.6f')
elif ext == '.ply':
cloud = trimesh.PointCloud(points_np, colors=colors_np)
if normals_np is not None:
cloud.metadata['normals'] = normals_np
cloud.export(output_path)
else:
raise ValueError(f"Unsupported file format: {ext}. Please use .txt or .ply")
def sample_points_in_bbox(
bbox_min: torch.Tensor,
bbox_max: torch.Tensor,
num_samples: int,
device: str = "cuda"
) -> torch.Tensor:
"""Uniformly sample points within bounding box"""
points = torch.rand(num_samples, 3, device=device)
points = points * (bbox_max - bbox_min) + bbox_min
return points
def process_single_mesh(
mesh_name:str,
mesh_path: str,
output_dir: str,
data_type:str = 'mesh',
surface_uniform_samples: int = 100000, # surface上均匀采样点数
surface_curvature_samples: int = 200000, # surface上曲率采样点数
space_samples: int = 300000, # 空间中采样点数
noise_sigma: float = 0.01,
device: str = "cuda"
) -> None:
"""Process a single mesh file
Args:
mesh_path: Input mesh path
output_dir: Output directory
surface_uniform_samples: Number of uniform sample points on surface
surface_curvature_samples: Number of curvature-based sample points on surface
space_samples: Number of sample points in space
noise_sigma: Gaussian noise standard deviation
device: Computation device
"""
os.makedirs(output_dir, exist_ok=True)
if data_type == "mesh":
vertices, faces = load_mesh(mesh_path, device)
elif data_type == "sparse_voxel":
pass
vertices_normalized, _, center, scale = normalize_points_and_mesh(vertices, vertices)
space_points = torch.rand(space_samples, 3, device=device)
uniform_surface_points, uniform_surface_normals = sample_uniform_points(
vertices=vertices_normalized,
faces=faces,
num_samples=surface_uniform_samples
)
curvature_surface_points, _, curvature_surface_normals = sample_surface_points(
vertices=vertices_normalized,
faces=faces,
num_samples=surface_curvature_samples,
use_curvature=True
)
clean_surface_points = torch.cat([uniform_surface_points, curvature_surface_points], dim=0)
clean_surface_normals = torch.cat([uniform_surface_normals, curvature_surface_normals], dim=0)
surface_uni_save_path = os.path.join(output_dir, f"{mesh_name}_uni_surface")
save_point_cloud(
points=uniform_surface_points,
output_path=f"{surface_uni_save_path}.ply",
normals=uniform_surface_normals
)
surface_cur_save_path = os.path.join(output_dir, f"{mesh_name}_cur_surface")
save_point_cloud(
points=curvature_surface_points,
output_path=f"{surface_cur_save_path}.ply",
normals=curvature_surface_normals
)
uniform_near_points, curvature_near_points = add_gaussian_noise(uniform_surface_points = uniform_surface_points.clone(),
curvature_surface_points = curvature_surface_points.clone(), sigma=noise_sigma)
space_sdf, _, _, _ = compute_points_value_bvh(
vertices=vertices_normalized,
faces=faces,
points=space_points,
use_sdf=True,
batch_size=1000_00000
)
# clean_surface_sdf = torch.zeros(len(clean_surface_points), device=device)
uniform_near_sdf, _, _, _ = compute_points_value_bvh(
vertices=vertices_normalized,
faces=faces,
points=uniform_near_points,
use_sdf=True,
batch_size=1000_00000
)
curvature_near_sdf, _, _, _ = compute_points_value_bvh(
vertices=vertices_normalized,
faces=faces,
points=curvature_near_points,
use_sdf=True,
batch_size=1000_00000
)
print("sdf:",uniform_near_sdf.shape, curvature_near_sdf.shape)
base_save_path = os.path.join(output_dir, mesh_name)
np.savez(f"{base_save_path}.npz",
space_points=space_points.cpu().numpy(),
space_sdf=space_sdf.cpu().numpy(),
clean_surface_points=clean_surface_points.cpu().numpy(),
clean_surface_normals=clean_surface_normals.cpu().numpy(),
uniform_near_points=uniform_near_points.cpu().numpy(),
curvature_near_points=curvature_near_points.cpu().numpy(),
uniform_near_sdf=uniform_near_sdf.cpu().numpy(),
curvature_near_sdf=curvature_near_sdf.cpu().numpy(),
center=center.cpu().numpy(),
scale=scale.cpu().numpy())
class MeshDataset(Dataset):
def __init__(self, mesh_json: str):
with open(mesh_json, "r") as f:
self.mesh_paths = json.load(f)
# print(len(self.mesh_paths))
def __len__(self) -> int:
return len(self.mesh_paths)
def __getitem__(self, idx: int) -> dict:
mesh_path = self.mesh_paths[idx]
mesh_name = os.path.basename(mesh_path)[:-4]
mesh = {
"mesh_path": mesh_path,
"mesh_name": mesh_name,
}
return mesh
class MeshProcessor(pl.LightningModule):
def __init__(
self,
mesh_json: str,
output_dir: str,
data_type:str,
surface_uniform_samples: int = 20000,
surface_curvature_samples: int = 40000,
space_samples: int = 300000,
noise_sigma: float = 0.01,
batch_size: int = 1,
num_workers: int = 4
):
super().__init__()
self.save_hyperparameters()
os.makedirs(output_dir, exist_ok=True)
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT:
mesh_path = batch["mesh_path"][0]
mesh_name = batch["mesh_name"][0]
# sys_logger.info(f"Processing {batch_idx}/{len(self.trainer.predict_dataloaders)}: {mesh_name} from {mesh_path}")
output_subdir = self.hparams.output_dir
try:
filename = os.path.splitext(os.path.basename(mesh_path))[0]
if os.path.exists(os.path.join(output_subdir, f"{filename}.npz")):
# sys_logger.info(f"Skipping {mesh_name} as it already exists.")
return {
"status": "success",
"mesh_name": mesh_name
}
process_single_mesh(
mesh_name=mesh_name,
mesh_path=mesh_path,
output_dir=output_subdir,
data_type = self.hparams.data_type,
surface_uniform_samples=self.hparams.surface_uniform_samples,
surface_curvature_samples=self.hparams.surface_curvature_samples,
space_samples=self.hparams.space_samples,
noise_sigma=self.hparams.noise_sigma,
device=self.device
)
return {
"status": "success",
"mesh_name": mesh_name
}
except Exception as e:
print(f"Error processing {mesh_name}: {str(e)}")
return {
"status": "error",
"mesh_name": mesh_name,
"error": str(e)
}
def predict_dataloader(self) -> DataLoader:
dataset = MeshDataset(
self.hparams.mesh_json)
return DataLoader(
dataset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
persistent_workers=True,
shuffle=False
)
def process_mesh_directory(
mesh_json: str,
output_dir: str,
data_type: str,
surface_uniform_samples: int = 100000,
surface_curvature_samples: int = 200000,
space_samples: int = 300000,
noise_sigma: float = 0.01,
num_gpus: int = -1,
batch_size: int = 1,
num_workers: int = 4
) -> None:
model = MeshProcessor(
mesh_json=mesh_json,
output_dir=output_dir,
data_type=data_type,
surface_uniform_samples=surface_uniform_samples,
surface_curvature_samples=surface_curvature_samples,
space_samples=space_samples,
noise_sigma=noise_sigma,
batch_size=batch_size,
num_workers=num_workers
)
trainer = pl.Trainer(
accelerator="gpu",
devices=num_gpus,
strategy="ddp",
precision=32,
logger=False,
enable_progress_bar=True
)
predictions = trainer.predict(model)
success_count = sum(1 for p in predictions if p["status"] == "success")
error_count = sum(1 for p in predictions if p["status"] == "error")
print(f"\nProcessing completed:")
print(f"Successfully processed: {success_count} files")
print(f"Failed to process: {error_count} files")
if error_count > 0:
print("\nFailed files:")
for p in predictions:
if p["status"] == "error":
print(f"- {p['mesh_name']}: {p['error']}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process Mesh Directory for Sampling")
parser.add_argument("--mesh_json", type=str, default="test_mesh.json", help="Path to the mesh json file")
parser.add_argument("--output_dir", type=str, default="ultrashape_test1", help="Directory to save outputs")
parser.add_argument("--surface_uniform_samples", type=int, default=300000, help="Number of uniform samples on surface")
parser.add_argument("--surface_curvature_samples", type=int, default=300000, help="Number of curvature-based samples on surface")
parser.add_argument("--space_samples", type=int, default=400000, help="Number of samples in space")
parser.add_argument("--noise_sigma", type=float, default=0.01, help="Sigma for Gaussian noise")
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use")
parser.add_argument("--num_workers", type=int, default=16, help="Number of data loading workers")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU")
args = parser.parse_args()
# print(f"Arguments: {args}")
process_mesh_directory(
mesh_json=args.mesh_json,
output_dir=args.output_dir,
data_type='mesh',
surface_uniform_samples=args.surface_uniform_samples,
surface_curvature_samples=args.surface_curvature_samples,
space_samples=args.space_samples,
noise_sigma=args.noise_sigma,
num_gpus=args.num_gpus,
num_workers=args.num_workers,
batch_size=args.batch_size
)
================================================
FILE: scripts/train_deepspeed.sh
================================================
export NCCL_IB_TIMEOUT=24
export NCCL_NVLS_ENABLE=0
NET_TYPE="high"
if [[ "${NET_TYPE}" = "low" ]]; then
export NCCL_SOCKET_IFNAME=eth1
export NCCL_IB_GID_INDEX=3
export NCCL_IB_HCA=mlx5_2:1,mlx5_2:1
export NCCL_IB_SL=3
export NCCL_CHECKS_DISABLE=1
export NCCL_P2P_DISABLE=0
export NCCL_LL_THRESHOLD=16384
export NCCL_IB_CUDA_SUPPORT=1
else
export NCCL_IB_GID_INDEX=3
export NCCL_IB_SL=3
export NCCL_CHECKS_DISABLE=1
export NCCL_P2P_DISABLE=0
export NCCL_IB_DISABLE=0
export NCCL_LL_THRESHOLD=16384
export NCCL_IB_CUDA_SUPPORT=1
export NCCL_SOCKET_IFNAME=bond1
export NCCL_COLLNET_ENABLE=0
export SHARP_COLL_ENABLE_SAT=0
export NCCL_NET_GDR_LEVEL=2
export NCCL_IB_QPS_PER_CONNECTION=4
export NCCL_IB_TC=160
export NCCL_PXN_DISABLE=1
fi
# export NCCL_DEBUG=INFO
node_num=$1
node_rank=$2
num_gpu_per_node=$3
master_ip=$4
config=$5
output_dir=$6
echo node_num $node_num
echo node_rank $node_rank
echo master_ip $master_ip
echo config $config
echo output_dir $output_dir
if test -d "$output_dir"; then
cp $config $output_dir
else
mkdir -p "$output_dir"
cp $config $output_dir
fi
NODE_RANK=$node_rank \
HF_HUB_OFFLINE=0 \
MASTER_PORT=12348 \
MASTER_ADDR=$master_ip \
NCCL_SOCKET_IFNAME=bond1 \
NCCL_IB_GID_INDEX=3 \
NCCL_NVLS_ENABLE=0 \
python3 main.py \
--num_nodes $node_num \
--num_gpus $num_gpu_per_node \
--config $config \
--output_dir $output_dir \
--deepspeed
================================================
FILE: train.sh
================================================
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export num_gpu_per_node=8
export node_num=1
export node_rank=$1
export master_ip= # [your master ip here]
############## vae ##############
# export config=configs/train_vae_refine.yaml
# export output_dir=outputs/vae_ultrashape/exp1_token8192
# bash scripts/train_deepspeed.sh $node_num $node_rank $num_gpu_per_node $master_ip $config $output_dir
############## dit ##############
export config=configs/train_dit_refine.yaml
export output_dir=outputs/dit_ultrashape/exp1_token8192
bash scripts/train_deepspeed.sh $node_num $node_rank $num_gpu_per_node $master_ip $config $output_dir
================================================
FILE: ultrashape/__init__.py
================================================
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
from .pipelines import UltraShapePipeline
from .postprocessors import FaceReducer, FloaterRemover, DegenerateFaceRemover, MeshSimplifier
from .preprocessors import ImageProcessorV2, IMAGE_PROCESSORS, DEFAULT_IMAGEPROCESSOR
================================================
FILE: ultrashape/data/objaverse_dit.py
================================================
# -*- coding: utf-8 -*-
# ==============================================================================
# Original work Copyright (c) 2025 Tencent.
# Modified work Copyright (c) 2025 UltraShape Team.
#
# Modified by UltraShape on 2025.12.25
# ==============================================================================
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import math
import os
import json
from dataclasses import dataclass, field
import random
import imageio
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import pickle
from ultrashape.utils.typing import *
import pandas as pd
import cv2
import torchvision.transforms as transforms
from pytorch_lightning.utilities import rank_zero_info
def padding(image, mask, center=True, padding_ratio_range=[1.15, 1.15]):
"""
Pad the input image and mask to a square shape with padding ratio.
Args:
image (np.ndarray): Input image array of shape (H, W, C).
mask (np.ndarray): Corresponding mask array of shape (H, W).
center (bool): Whether to center the original image in the padded output.
padding_ratio_range (list): Range [min, max] to randomly select padding ratio.
Returns:
newimg (np.ndarray): Padded image of shape (resize_side, resize_side, 3).
newmask (np.ndarray): Padded mask of shape (resize_side, resize_side).
"""
h, w = image.shape[:2]
max_side = max(h, w)
# Select padding ratio either fixed or randomly within the given range
if padding_ratio_range[0] == padding_ratio_range[1]:
padding_ratio = padding_ratio_range[0]
else:
padding_ratio = random.uniform(padding_ratio_range[0], padding_ratio_range[1])
resize_side = int(max_side * padding_ratio)
pad_h = resize_side - h
pad_w = resize_side - w
if center:
start_h = pad_h // 2
else:
start_h = pad_h - resize_side // 20
start_w = pad_w // 2
# Create new white image and black mask with padded size
newimg = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255
newmask = np.zeros((resize_side, resize_side), dtype=np.uint8)
# Place original image and mask into the padded canvas
newimg[start_h:start_h + h, start_w:start_w + w] = image
newmask[start_h:start_h + h, start_w:start_w + w] = mask
return newimg, newmask
class ObjaverseDataset(Dataset):
def __init__(
self,
data_json,
sample_root,
image_path,
image_transform = None,
pc_size: int = 2048,
pc_sharpedge_size: int = 2048,
sharpedge_label: bool = False,
return_normal: bool = False,
padding = True,
padding_ratio_range=[1.15, 1.15],
):
super().__init__()
self.uids = json.load(open(data_json))
self.sample_root = sample_root
self.image_paths = json.load(open(image_path))
self.image_transform = image_transform
self.pc_size = pc_size
self.pc_sharpedge_size = pc_sharpedge_size
self.sharpedge_label = sharpedge_label
self.return_normal = return_normal
self.padding = padding
self.padding_ratio_range = padding_ratio_range
print(f"Loaded {len(self.uids)} uids from {data_json}.")
rank_zero_info(f'*' * 50)
rank_zero_info(f'Dataset Infos:')
rank_zero_info(f'# of 3D file: {len(self.uids)}')
rank_zero_info(f'# of Surface Points: {self.pc_size}')
rank_zero_info(f'# of Sharpedge Surface Points: {self.pc_sharpedge_size}')
rank_zero_info(f'Using sharp edge label: {self.sharpedge_label}')
rank_zero_info(f'*' * 50)
def __len__(self):
return len(self.uids)
def _load_shape(self, index: int) -> Dict[str, Any]:
data = np.load(f'{self.sample_root}/{self.uids[index]}.npz')
surface_og = (np.asarray(data['clean_surface_points'])-0.5) * 2
normal = np.asarray(data['clean_surface_normals'])
surface_og_n = np.concatenate([surface_og, normal], axis=1)
rng = np.random.default_rng()
# hard code: first 300k are uniform, last 300k are sharp
assert surface_og_n.shape[0] == 600000, f"assume that suface points = 30w uniform + 30w curvature, but {len(surface_og_n)=}"
coarse_surface = surface_og_n[:300000]
sharp_surface = surface_og_n[300000:]
surface_normal = []
rng = np.random.default_rng()
if self.pc_size > 0:
ind = rng.choice(coarse_surface.shape[0], self.pc_size // 2, replace=False)
coarse_surface = coarse_surface[ind]
if self.sharpedge_label:
sharpedge_label = np.zeros((self.pc_size // 2, 1))
coarse_surface = np.concatenate((coarse_surface, sharpedge_label), axis=1)
surface_normal.append(coarse_surface)
ind_sharpedge = rng.choice(sharp_surface.shape[0], self.pc_size // 2, replace=False)
sharp_surface = sharp_surface[ind_sharpedge]
if self.sharpedge_label:
sharpedge_label = np.ones((self.pc_size // 2, 1))
sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)
surface_normal.append(sharp_surface)
surface_normal = np.concatenate(surface_normal, axis=0)
surface_normal = torch.FloatTensor(surface_normal)
surface = surface_normal[:, 0:3]
normal = surface_normal[:, 3:6]
assert surface.shape[0] == self.pc_size + self.pc_sharpedge_size
geo_points = 0.0
normal = torch.nn.functional.normalize(normal, p=2, dim=1)
if self.return_normal:
surface = torch.cat([surface, normal], dim=-1)
if self.sharpedge_label:
surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1)
ret = {
"uid": self.uids[index],
"surface": surface,
"geo_points": geo_points
}
return ret
def _load_image(self, index: int) -> Dict[str, Any]:
ret = {}
sel_idx = random.randint(0, 15)
ret["sel_image_idx"] = sel_idx
obj_name = self.uids[index]
img_path = f'{self.image_paths[obj_name]}/{os.path.basename(self.image_paths[obj_name])}/rgba/' + f"{sel_idx:03d}.png"
images, masks = [], []
image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
assert image.shape[2] == 4
alpha = image[:, :, 3:4].astype(np.float32) / 255
forground = image[:, :, :3]
background = np.ones_like(forground) * 255
img_new = forground * alpha + background * (1 - alpha)
image = img_new.astype(np.uint8)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = (alpha[:, :, 0] * 255).astype(np.uint8)
if self.padding:
h, w = image.shape[:2]
binary = mask > 0.3
non_zero_coords = np.argwhere(binary)
x_min, y_min = non_zero_coords.min(axis=0)
x_max, y_max = non_zero_coords.max(axis=0)
image, mask = padding(
image[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
mask[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
center=True, padding_ratio_range=self.padding_ratio_range)
if self.image_transform:
image = self.image_transform(image)
mask = np.stack((mask, mask, mask), axis=-1)
mask = self.image_transform(mask)
images.append(image)
masks.append(mask)
ret["image"] = torch.cat(images, dim=0)
ret["mask"] = torch.cat(masks, dim=0)[:1, ...]
return ret
def get_data(self, index):
ret = self._load_shape(index)
ret.update(self._load_image(index))
return ret
def __getitem__(self, index):
try:
return self.get_data(index)
except Exception as e:
print(f"Error in {self.uids[index]}: {e}")
return self.__getitem__(np.random.randint(len(self)))
def collate(self, batch):
batch = torch.utils.data.default_collate(batch)
return batch
class ObjaverseDataModule(pl.LightningDataModule):
def __init__(
self,
batch_size: int = 1,
num_workers: int = 4,
val_num_workers: int = 2,
training_data_list: str = None,
sample_pcd_dir: str = None,
image_data_json: str = None,
image_size: int = 224,
mean: Union[List[float], Tuple[float]] = (0.485, 0.456, 0.406),
std: Union[List[float], Tuple[float]] = (0.229, 0.224, 0.225),
pc_size: int = 2048,
pc_sharpedge_size: int = 2048,
sharpedge_label: bool = False,
return_normal: bool = False,
padding = True,
padding_ratio_range=[1.15, 1.15]
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.val_num_workers = val_num_workers
self.training_data_list = training_data_list
self.sample_pcd_dir = sample_pcd_dir
self.image_data_json = image_data_json
self.image_size = image_size
self.mean = mean
self.std = std
self.train_image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(self.image_size),
transforms.Normalize(mean=self.mean, std=self.std)])
self.val_image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(self.image_size),
transforms.Normalize(mean=self.mean, std=self.std)])
self.pc_size = pc_size
self.pc_sharpedge_size = pc_sharpedge_size
self.sharpedge_label = sharpedge_label
self.return_normal = return_normal
self.padding = padding
self.padding_ratio_range = padding_ratio_range
def train_dataloader(self):
asl_params = {
"data_json": f'{self.training_data_list}/train.json',
"sample_root": self.sample_pcd_dir,
"image_path": self.image_data_json,
"image_transform": self.train_image_transform,
"pc_size": self.pc_size,
"pc_sharpedge_size": self.pc_sharpedge_size,
"sharpedge_label": self.sharpedge_label,
"return_normal": self.return_normal,
"padding": self.padding,
"padding_ratio_range": self.padding_ratio_range,
}
dataset = ObjaverseDataset(**asl_params)
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
)
def val_dataloader(self):
asl_params = {
"data_json": f'{self.training_data_list}/val.json',
"sample_root": self.sample_pcd_dir,
"image_path": self.image_data_json,
"image_transform": self.val_image_transform,
"pc_size": self.pc_size,
"pc_sharpedge_size": self.pc_sharpedge_size,
"sharpedge_label": self.sharpedge_label,
"return_normal": self.return_normal,
"padding": self.padding,
"padding_ratio_range": self.padding_ratio_range,
}
dataset = ObjaverseDataset(**asl_params)
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.val_num_workers,
pin_memory=True,
drop_last=True,
)
================================================
FILE: ultrashape/data/objaverse_vae.py
================================================
# -*- coding: utf-8 -*-
# ==============================================================================
# Original work Copyright (c) 2025 Tencent.
# Modified work Copyright (c) 2025 UltraShape Team.
#
# Modified by UltraShape on 2025.12.25
# ==============================================================================
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import os
import cv2
import json
import math
import random
import imageio
import pickle
import numpy as np
from PIL import Image
import pandas as pd
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from pytorch_lightning.utilities import rank_zero_info
from ultrashape.utils.typing import *
class ObjaverseDataset(Dataset):
def __init__(
self,
data_json,
sample_root,
pc_size: int = 2048,
pc_sharpedge_size: int = 2048,
sup_near_uni_size: int = 4096,
sup_near_sharp_size: int = 4096,
sup_space_size: int = 4096,
tsdf_threshold: float = 0.05,
sharpedge_label: bool = False,
return_normal: bool = False,
):
super().__init__()
self.uids = json.load(open(data_json))
self.sample_root = sample_root
self.pc_size = pc_size
self.pc_sharpedge_size = pc_sharpedge_size
self.sharpedge_label = sharpedge_label
self.return_normal = return_normal
self.sup_near_uni_size = sup_near_uni_size
self.sup_near_sharp_size = sup_near_sharp_size
self.sup_space_size = sup_space_size
self.tsdf_threshold = tsdf_threshold
print(f"Loaded {len(self.uids)} uids from {data_json}.")
rank_zero_info(f'*' * 50)
rank_zero_info(f'Dataset Infos:')
rank_zero_info(f'# of 3D file: {len(self.uids)}')
rank_zero_info(f'# of Surface Points: {self.pc_size}')
rank_zero_info(f'# of Sharpedge Surface Points: {self.pc_sharpedge_size}')
rank_zero_info(f'# of Uniform Near-Surface Sup-Points: {self.sup_near_uni_size}')
rank_zero_info(f'# of Sharpedge Near-Surface Sup-Points: {self.sup_near_sharp_size}')
rank_zero_info(f'# of Random Space Sup-Points: {self.sup_space_size}')
rank_zero_info(f'Using sharp edge label: {self.sharpedge_label}')
rank_zero_info(f'*' * 50)
def __len__(self):
return len(self.uids)
def _load_shape(self, index: int) -> Dict[str, Any]:
rng = np.random.default_rng()
data = np.load(f'{self.sample_root}/{self.uids[index]}.npz')
##################### sup pcd&sdf ######################
uniform_near_points = (np.asarray(data['uniform_near_points'])-0.5) * 2
curvature_near_points = (np.asarray(data['curvature_near_points'])-0.5) * 2
space_points = (np.asarray(data['space_points'])-0.5) * 2
uniform_near_sdf = np.asarray(data['uniform_near_sdf']) * 2
curvature_near_sdf = np.asarray(data['curvature_near_sdf']) * 2
space_sdf = np.asarray(data['space_sdf']) * 2
uni_noisy_idx = rng.choice(uniform_near_points.shape[0], self.sup_near_uni_size, replace=False)
cur_noisy_idx = rng.choice(curvature_near_points.shape[0], self.sup_near_sharp_size, replace=False)
space_idx = rng.choice(space_points.shape[0], self.sup_space_size, replace=False)
uniform_near_points = uniform_near_points[uni_noisy_idx]
curvature_near_points = curvature_near_points[cur_noisy_idx]
space_points = space_points[space_idx]
uniform_near_sdf = uniform_near_sdf[uni_noisy_idx]
curvature_near_sdf = curvature_near_sdf[cur_noisy_idx]
space_sdf = space_sdf[space_idx]
uniform_near_sdf, curvature_near_sdf, space_sdf = map(self._clip_to_tsdf, (uniform_near_sdf, curvature_near_sdf, space_sdf))
surface_og = (np.asarray(data['clean_surface_points'])-0.5) * 2
normal = np.asarray(data['clean_surface_normals'])
surface_og_n = np.concatenate([surface_og, normal], axis=1)
rng = np.random.default_rng()
# hard code: first 300k are uniform, last 300k are sharp
assert surface_og_n.shape[0] == 600000, f"assume that suface points = 30w uniform + 30w curvature, but {len(surface_og_n)=}"
coarse_surface = surface_og_n[:300000]
sharp_surface = surface_og_n[300000:]
surface_normal = []
if self.pc_size > 0:
ind = rng.choice(coarse_surface.shape[0], self.pc_size // 2, replace=False)
coarse_surface = coarse_surface[ind]
if self.sharpedge_label:
sharpedge_label = np.zeros((self.pc_size // 2, 1))
coarse_surface = np.concatenate((coarse_surface, sharpedge_label), axis=1)
surface_normal.append(coarse_surface)
ind_sharpedge = rng.choice(sharp_surface.shape[0], self.pc_size // 2, replace=False)
sharp_surface = sharp_surface[ind_sharpedge]
if self.sharpedge_label:
sharpedge_label = np.ones((self.pc_size // 2, 1))
sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)
surface_normal.append(sharp_surface)
surface_normal = np.concatenate(surface_normal, axis=0)
surface_normal = torch.FloatTensor(surface_normal)
surface = surface_normal[:, 0:3]
normal = surface_normal[:, 3:6]
assert surface.shape[0] == self.pc_size + self.pc_sharpedge_size
geo_points = 0.0
normal = torch.nn.functional.normalize(normal, p=2, dim=1)
if self.return_normal:
surface = torch.cat([surface, normal], dim=-1)
if self.sharpedge_label:
surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1)
ret = {
"uid": self.uids[index],
"surface": surface,
"sup_near_uniform": np.concatenate([uniform_near_points, uniform_near_sdf[...,None]], axis=1),
"sup_near_sharp": np.concatenate([curvature_near_points, curvature_near_sdf[...,None]], axis=1),
"sup_space": np.concatenate([space_points, space_sdf[...,None]], axis=1),
"geo_points": geo_points
}
return ret
def _clip_to_tsdf(self, sdf: np.array):
nan_mask = np.isnan(sdf)
if np.any(nan_mask):
sdf=np.nan_to_num(sdf, nan=1.0, posinf=1.0, neginf=-1.0)
return sdf.flatten().astype(np.float32).clip(-self.tsdf_threshold, self.tsdf_threshold) / self.tsdf_threshold
def get_data(self, index):
ret = self._load_shape(index)
return ret
def __getitem__(self, index):
return self.get_data(index)
def collate(self, batch):
batch = torch.utils.data.default_collate(batch)
return batch
class ObjaverseDataModule(pl.LightningDataModule):
def __init__(
self,
batch_size: int = 1,
num_workers: int = 4,
val_num_workers: int = 2,
training_data_list: str = None,
sample_pcd_dir: str = None,
pc_size: int = 2048,
pc_sharpedge_size: int = 2048,
sup_near_uni_size: int = 4096,
sup_near_sharp_size: int = 4096,
sup_space_size: int = 4096,
tsdf_threshold: float = 0.05,
sharpedge_label: bool = False,
return_normal: bool = False,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.val_num_workers = val_num_workers
self.training_data_list = training_data_list
self.sample_pcd_dir = sample_pcd_dir
self.pc_size = pc_size
self.pc_sharpedge_size = pc_sharpedge_size
self.sharpedge_label = sharpedge_label
self.return_normal = return_normal
self.sup_near_uni_size = sup_near_uni_size
self.sup_near_sharp_size = sup_near_sharp_size
self.sup_space_size = sup_space_size
self.tsdf_threshold = tsdf_threshold
def train_dataloader(self):
asl_params = {
"data_json": f'{self.training_data_list}/train.json',
"sample_root": self.sample_pcd_dir,
"pc_size": self.pc_size,
"pc_sharpedge_size": self.pc_sharpedge_size,
"sup_near_uni_size": self.sup_near_uni_size,
"sup_near_sharp_size": self.sup_near_sharp_size,
"sup_space_size": self.sup_space_size,
"tsdf_threshold": self.tsdf_threshold,
"sharpedge_label": self.sharpedge_label,
"return_normal": self.return_normal,
}
dataset = ObjaverseDataset(**asl_params)
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
)
def val_dataloader(self):
asl_params = {
"data_json": f'{self.training_data_list}/val.json',
"sample_root": self.sample_pcd_dir,
"pc_size": self.pc_size,
"pc_sharpedge_size": self.pc_sharpedge_size,
"sup_near_uni_size": self.sup_near_uni_size,
"sup_near_sharp_size": self.sup_near_sharp_size,
"sup_space_size": self.sup_space_size,
"tsdf_threshold": self.tsdf_threshold,
"sharpedge_label": self.sharpedge_label,
"return_normal": self.return_normal,
}
dataset = ObjaverseDataset(**asl_params)
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.val_num_workers,
pin_memory=True,
drop_last=True,
)
================================================
FILE: ultrashape/data/utils.py
================================================
# -*- coding: utf-8 -*-
# ==============================================================================
# Original work Copyright (c) 2025 Tencent.
# Modified work Copyright (c) 2025 UltraShape Team.
#
# Modified by UltraShape on 2025.12.25
# ==============================================================================
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
"""Miscellaneous utility functions."""
import importlib
import itertools as itt
import os
import re
import sys
from typing import Any, Callable, Iterator, Union
import torch
import numpy as np
def make_seed(*args):
seed = 0
for arg in args:
seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
return seed
class PipelineStage:
def invoke(self, *args, **kw):
raise NotImplementedError
def identity(x: Any) -> Any:
"""Return the argument as is."""
return x
def safe_eval(s: str, expr: str = "{}"):
"""Evaluate the given expression more safely."""
if re.sub("[^A-Za-z0-9_]", "", s) != s:
raise ValueError(f"safe_eval: illegal characters in: '{s}'")
return eval(expr.format(s))
def lookup_sym(sym: str, modules: list):
"""Look up a symbol in a list of modules."""
for mname in modules:
module = importlib.import_module(mname, package="webdataset")
result = getattr(module, sym, None)
if result is not None:
return result
return None
def repeatedly0(
loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
):
"""Repeatedly returns batches from a DataLoader."""
for _ in range(nepochs):
yield from itt.islice(loader, nbatches)
def guess_batchsize(batch: Union[tuple, list]):
"""Guess the batch size by looking at the length of the first element in a tuple."""
return len(batch[0])
def repeatedly(
source: Iterator,
nepochs: int = None,
nbatches: int = None,
nsamples: int = None,
batchsize: Callable[..., int] = guess_batchsize,
):
"""Repeatedly yield samples from an iterator."""
epoch = 0
batch = 0
total = 0
while True:
for sample in source:
yield sample
batch += 1
if nbatches is not None and batch >= nbatches:
return
if nsamples is not None:
total += guess_batchsize(sample)
if total >= nsamples:
return
epoch += 1
if nepochs is not None and epoch >= nepochs:
return
def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress
"""Return node and worker info for PyTorch and some distributed environments."""
rank = 0
world_size = 1
worker = 0
num_workers = 1
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
else:
try:
import torch.distributed
if torch.distributed.is_available() and torch.distributed.is_initialized():
group = group or torch.distributed.group.WORLD
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)
except ModuleNotFoundError:
pass
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
worker = int(os.environ["WORKER"])
num_workers = int(os.environ["NUM_WORKERS"])
else:
try:
import torch.utils.data
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker = worker_info.id
num_workers = worker_info.num_workers
except ModuleNotFoundError:
pass
return rank, world_size, worker, num_workers
def pytorch_worker_seed(group=None):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
return rank * 1000 + worker
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id
# dataset = worker_info.dataset
# split_size = dataset.num_records // worker_info.num_workers
# # reset num_records to the true number to retain reliable length information
# dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
# current_id = np.random.choice(len(np.random.get_state()[1]), 1)
# return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
return np.random.seed(np.random.get_state()[1][0] + worker_id)
def collation_fn(samples, combine_tensors=True, combine_scalars=True):
"""
Args:
samples (list[dict]):
combine_tensors:
combine_scalars:
Returns:
"""
result = {}
keys = samples[0].keys()
for key in keys:
result[key] = []
for sample in samples:
for key in keys:
val = sample[key]
result[key].append(val)
for key in keys:
val_list = result[key]
if isinstance(val_list[0], (int, float)):
if combine_scalars:
result[key] = np.array(result[key])
elif isinstance(val_list[0], torch.Tensor):
if combine_tensors:
result[key] = torch.stack(val_list)
elif isinstance(val_list[0], np.ndarray):
if combine_tensors:
result[key] = np.stack(val_list)
return result
================================================
FILE: ultrashape/models/__init__.py
================================================
# Open Source Model Licensed under the Apache License Version 2.0
# and Other Licenses of the Third-Party Components therein:
# The below Model in this distribution may have been modified by THL A29 Limited
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
# The below software and/or models in this distribution may have been
# modified by THL A29 Limited ("Tencent Modifications").
# All Tencent Modifications are Copyright (C) THL A29 Limited.
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
from .autoencoders import ShapeVAE
from .conditioner_mask import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder
from .denoisers import RefineDiT
================================================
FILE: ultrashape/models/autoencoders/__init__.py
================================================
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
from .attention_blocks import CrossAttentionDecoder
from .attention_processors import FlashVDMCrossAttentionProcessor, CrossAttentionProcessor, \
FlashVDMTopMCrossAttentionProcessor
from .model import ShapeVAE, VectsetVAE
from .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor, Latent2MeshOutput
from .volume_decoders import HierarchicalVolumeDecoding, FlashVDMVolumeDecoding, VanillaVolumeDecoder
from .vae_trainer import VAETrainer
================================================
FILE: ultrashape/models/autoencoders/attention_blocks.py
================================================
# ==============================================================================
# Original work Copyright (c) 2025 Tencent.
# Modified work Copyright (c) 2025 UltraShape Team.
#
# Modified by UltraShape on 2025.12.25
# ==============================================================================
# Open Source Model Licensed under the Apache License Version 2.0
# and Other Licenses of the Third-Party Components therein:
# The below Model in this distribution may have been modified by THL A29 Limited
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
# The below software and/or models in this distribution may have been
# modified by THL A29 Limited ("Tencent Modifications").
# All Tencent Modifications are Copyright (C) THL A29 Limited.
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import os
from typing import Optional, Union, List
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor
from .attention_processors import CrossAttentionProcessor
from ...utils import logger
from ultrashape.utils import voxelize_from_point
scaled_dot_product_attention = nn.functional.scaled_dot_product_attention
if os.environ.get('USE_SAGEATTN', '0') == '1':
try:
from sageattention import sageattn
except ImportError:
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
scaled_dot_product_attention = sageattn
class FourierEmbedder(nn.Module):
""" The sin/cosine positional embedding. """
def __init__(self,
num_freqs: int = 6,
logspace: bool = True,
input_dim: int = 3,
include_input: bool = True,
include_pi: bool = True) -> None:
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(
num_freqs,
dtype=torch.float32
)
else:
frequencies = torch.linspace(
1.0,
2.0 ** (num_freqs - 1),
num_freqs,
dtype=torch.float32
)
if include_pi:
frequencies *= torch.pi
self.register_buffer("frequencies", frequencies, persistent=False)
self.include_input = include_input
self.num_freqs = num_freqs
self.out_dim = self.get_dims(input_dim)
def get_dims(self, input_dim):
temp = 1 if self.include_input or self.num_freqs == 0 else 0
out_dim = input_dim * (self.num_freqs * 2 + temp)
return out_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Forward process.
Args:
x: tensor of shape [..., dim]
Returns:
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
where temp is 1 if include_input is True and 0 otherwise.
"""
if self.num_freqs > 0:
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
if self.include_input:
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
else:
return torch.cat((embed.sin(), embed.cos()), dim=-1)
else:
return x
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if self.drop_prob == 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and self.scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
def extra_repr(self):
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
class MLP(nn.Module):
def __init__(
self, *,
width: int,
expand_ratio: int = 4,
output_width: int = None,
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.c_fc = nn.Linear(width, width * expand_ratio)
self.c_proj = nn.Linear(width * expand_ratio, output_width if output_width is not None else width)
self.gelu = nn.GELU()
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
heads: int,
n_data: Optional[int] = None,
width=None,
qk_norm=False,
norm_layer=nn.LayerNorm
):
super().__init__()
self.heads = heads
self.n_data = n_data
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.attn_processor = CrossAttentionProcessor()
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
attn_ch = width // self.heads // 2
q = q.view(bs, n_ctx, self.heads, -1)
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = self.attn_processor(self, q, k, v)
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
qkv_bias: bool = True,
n_data: Optional[int] = None,
data_width: Optional[int] = None,
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
kv_cache: bool = False,
):
super().__init__()
self.n_data = n_data
self.width = width
self.heads = heads
self.data_width = width if data_width is None else data_width
self.c_q = nn.Linear(width, width, bias=qkv_bias)
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
self.c_proj = nn.Linear(width, width)
self.attention = QKVMultiheadCrossAttention(
heads=heads,
n_data=n_data,
width=width,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.kv_cache = kv_cache
self.data = None
def forward(self, x, data):
x = self.c_q(x)
if self.kv_cache:
if self.data is None:
self.data = self.c_kv(data)
logger.info('Save kv cache,this should be called only once for one mesh')
data = self.data
else:
data = self.c_kv(data)
x = self.attention(x, data)
x = self.c_proj(x)
return x
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
n_data: Optional[int] = None,
width: int,
heads: int,
mlp_expand_ratio: int = 4,
data_width: Optional[int] = None,
qkv_bias: bool = True,
norm_layer=nn.LayerNorm,
qk_norm: bool = False
):
super().__init__()
if data_width is None:
data_width = width
self.attn = MultiheadCrossAttention(
n_data=n_data,
width=width,
heads=heads,
data_width=data_width,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
def forward(self, x: torch.Tensor, data: torch.Tensor):
x = x + self.attn(self.ln_1(x), self.ln_2(data))
x = x + self.mlp(self.ln_3(x))
return x
class QKVMultiheadAttention(nn.Module):
def __init__(
self,
*,
heads: int,
n_ctx: int,
width=None,
qk_norm=False,
norm_layer=nn.LayerNorm
):
super().__init__()
self.heads = heads
self.n_ctx = n_ctx
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.heads // 3
qkv = qkv.view(bs, n_ctx, self.heads, -1)
q, k, v = torch.split(qkv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadAttention(nn.Module):
def __init__(
self,
*,
n_ctx: int,
width: int,
heads: int,
qkv_bias: bool,
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.heads = heads
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
self.c_proj = nn.Linear(width, width)
self.attention = QKVMultiheadAttention(
heads=heads,
n_ctx=n_ctx,
width=width,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
x = self.c_qkv(x)
x = self.attention(x)
x = self.drop_path(self.c_proj(x))
return x
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
*,
n_ctx: int,
width: int,
heads: int,
qkv_bias: bool = True,
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0,
):
super().__init__()
self.attn = MultiheadAttention(
n_ctx=n_ctx,
width=width,
heads=heads,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
def forward(self, x: torch.Tensor):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
*,
n_ctx: int,
width: int,
layers: int,
heads: int,
qkv_bias: bool = True,
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
n_ctx=n_ctx,
width=width,
heads=heads,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor):
for block in self.resblocks:
x = block(x)
return x
class CrossAttentionDecoder(nn.Module):
def __init__(
self,
*,
num_latents: int,
out_channels: int,
fourier_embedder: FourierEmbedder,
width: int,
heads: int,
mlp_expand_ratio: int = 4,
downsample_ratio: int = 1,
enable_ln_post: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary"
):
super().__init__()
self.enable_ln_post = enable_ln_post
self.fourier_embedder = fourier_embedder
self.downsample_ratio = downsample_ratio
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
if self.downsample_ratio != 1:
self.latents_proj = nn.Linear(width * downsample_ratio, width)
if self.enable_ln_post == False:
qk_norm = False
self.cross_attn_decoder = ResidualCrossAttentionBlock(
n_data=num_latents,
width=width,
mlp_expand_ratio=mlp_expand_ratio,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm
)
if self.enable_ln_post:
self.ln_post = nn.LayerNorm(width)
self.output_proj = nn.Linear(width, out_channels)
self.label_type = label_type
self.count = 0
def set_cross_attention_processor(self, processor):
self.cross_attn_decoder.attn.attention.attn_processor = processor
def set_default_cross_attention_processor(self):
self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor
def forward(self, queries=None, query_embeddings=None, latents=None):
if query_embeddings is None:
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
self.count += query_embeddings.shape[1]
if self.downsample_ratio != 1:
latents = self.latents_proj(latents)
x = self.cross_attn_decoder(query_embeddings, latents)
if self.enable_ln_post:
x = self.ln_post(x)
occ = self.output_proj(x)
return occ
def fps(
src: torch.Tensor,
batch: Optional[Tensor] = None,
ratio: Optional[Union[Tensor, float]] = None,
random_start: bool = True,
batch_size: Optional[int] = None,
ptr: Optional[Union[Tensor, List[int]]] = None,
):
src = src.float()
from torch_cluster import fps as fps_fn
output = fps_fn(src, batch, ratio, random_start, batch_size, ptr)
return output
class PointCrossAttentionEncoder(nn.Module):
def __init__(
self, *,
num_latents: int,
downsample_ratio: float,
pc_size: int,
pc_sharpedge_size: int,
fourier_embedder: FourierEmbedder,
point_feats: int,
width: int,
heads: int,
layers: int,
voxel_query_res: int,
normal_pe: bool = False,
qkv_bias: bool = True,
use_ln_post: bool = False,
use_checkpoint: bool = False,
qk_norm: bool = False,
jitter_query: bool = False,
voxel_query: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.num_latents = num_latents
self.downsample_ratio = downsample_ratio
self.point_feats = point_feats
self.normal_pe = normal_pe
self.jitter_query = jitter_query
self.voxel_query = voxel_query
self.voxel_query_res = voxel_query_res
if pc_sharpedge_size == 0:
print(
f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is zero')
else:
print(
f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is given, using pc_size={pc_size}, pc_sharpedge_size={pc_sharpedge_size}')
self.pc_size = pc_size
self.pc_sharpedge_size = pc_sharpedge_size
self.fourier_embedder = fourier_embedder
if self.jitter_query or self.voxel_query:
self.input_proj_q = nn.Linear(self.fourier_embedder.out_dim, width)
self.input_proj_kv = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
else:
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
self.cross_attn = ResidualCrossAttentionBlock(
width=width,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm
)
self.self_attn = None
if layers > 0:
self.self_attn = Transformer(
n_ctx=num_latents,
width=width,
layers=layers,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm
)
if use_ln_post:
self.ln_post = nn.LayerNorm(width)
else:
self.ln_post = None
def sample_points_and_latents(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
B, N, D = pc.shape
num_pts = self.num_latents * self.downsample_ratio
# Compute number of latents
num_latents = int(num_pts / self.downsample_ratio)
# Compute the number of random and sharpedge latents
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
num_sharpedge_query = num_latents - num_random_query
# Split random and sharpedge surface points
random_pc, sharpedge_pc = torch.split(pc, [self.pc_size, self.pc_sharpedge_size], dim=1)
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
# Randomly select random surface points and random query points
input_random_pc_size = int(num_random_query * self.downsample_ratio)
random_query_ratio = num_random_query / input_random_pc_size
idx_random_pc = torch.randperm(random_pc.shape[1], device=random_pc.device)[:input_random_pc_size]
input_random_pc = random_pc[:, idx_random_pc, :]
if self.voxel_query:
query_random_pc, query_voxel_indices = voxelize_from_point(pc, num_latents, resolution=self.voxel_query_res)
else:
flatten_input_random_pc = input_random_pc.view(B * input_random_pc_size, D)
N_down = int(flatten_input_random_pc.shape[0] / B)
batch_down = torch.arange(B).to(pc.device)
batch_down = torch.repeat_interleave(batch_down, N_down)
idx_query_random = fps(flatten_input_random_pc, batch_down, ratio=random_query_ratio)
query_random_pc = flatten_input_random_pc[idx_query_random].view(B, -1, D)
# Randomly select sharpedge surface points and sharpedge query points
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
if input_sharpedge_pc_size == 0 or self.voxel_query:
input_sharpedge_pc = torch.zeros(B, 0, D, dtype=input_random_pc.dtype).to(pc.device)
query_sharpedge_pc = torch.zeros(B, 0, D, dtype=query_random_pc.dtype).to(pc.device)
else:
sharpedge_query_ratio = num_sharpedge_query / input_sharpedge_pc_size
idx_sharpedge_pc = torch.randperm(sharpedge_pc.shape[1], device=sharpedge_pc.device)[
:input_sharpedge_pc_size]
input_sharpedge_pc = sharpedge_pc[:, idx_sharpedge_pc, :]
flatten_input_sharpedge_surface_points = input_sharpedge_pc.view(B * input_sharpedge_pc_size, D)
N_down = int(flatten_input_sharpedge_surface_points.shape[0] / B)
batch_down = torch.arange(B).to(pc.device)
batch_down = torch.repeat_interleave(batch_down, N_down)
idx_query_sharpedge = fps(flatten_input_sharpedge_surface_points, batch_down, ratio=sharpedge_query_ratio)
query_sharpedge_pc = flatten_input_sharpedge_surface_points[idx_query_sharpedge].view(B, -1, D)
# Concatenate random and sharpedge surface points and query points
query_pc = torch.cat([query_random_pc, query_sharpedge_pc], dim=1)
input_pc = torch.cat([input_random_pc, input_sharpedge_pc], dim=1)
if self.jitter_query:
R = self.voxel_query_res // 2
noise = torch.rand_like(query_pc)
query_pc += (noise - 0.5) / R
# PE
query = self.fourier_embedder(query_pc)
data = self.fourier_embedder(input_pc)
# Concat normal if given
if self.point_feats != 0:
random_surface_feats, sharpedge_surface_feats = torch.split(feats, [self.pc_size, self.pc_sharpedge_size],
dim=1)
input_random_surface_feats = random_surface_feats[:, idx_random_pc, :]
if not self.voxel_query and not self.jitter_query:
flatten_input_random_surface_feats = input_random_surface_feats.view(B * input_random_pc_size, -1)
query_random_feats = flatten_input_random_surface_feats[idx_query_random].view(B, -1,
flatten_input_random_surface_feats.shape[
-1])
if input_sharpedge_pc_size == 0:
input_sharpedge_surface_feats = torch.zeros(B, 0, self.point_feats,
dtype=input_random_surface_feats.dtype).to(pc.device)
if not self.voxel_query and not self.jitter_query:
query_sharpedge_feats = torch.zeros(B, 0, self.point_feats, dtype=query_random_feats.dtype).to(
pc.device)
else:
input_sharpedge_surface_feats = sharpedge_surface_feats[:, idx_sharpedge_pc, :]
if not self.voxel_query and not self.jitter_query:
flatten_input_sharpedge_surface_feats = input_sharpedge_surface_feats.view(B * input_sharpedge_pc_size,
-1)
query_sharpedge_feats = flatten_input_sharpedge_surface_feats[idx_query_sharpedge].view(B, -1,
flatten_input_sharpedge_surface_feats.shape[
-1])
if not self.voxel_query and not self.jitter_query:
query_feats = torch.cat([query_random_feats, query_sharpedge_feats], dim=1)
input_feats = torch.cat([input_random_surface_feats, input_sharpedge_surface_feats], dim=1)
if self.normal_pe:
if not self.voxel_query and not self.jitter_query:
query_normal_pe = self.fourier_embedder(query_feats[..., :3])
query_feats = torch.cat([query_normal_pe, query_feats[..., 3:]], dim=-1)
input_normal_pe = self.fourier_embedder(input_feats[..., :3])
input_feats = torch.cat([input_normal_pe, input_feats[..., 3:]], dim=-1)
if not self.voxel_query and not self.jitter_query:
query = torch.cat([query, query_feats], dim=-1)
data = torch.cat([data, input_feats], dim=-1)
if input_sharpedge_pc_size == 0:
query_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
input_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
if self.voxel_query:
pc_infos = [query_voxel_indices, query_random_pc]
else:
pc_infos = [query_pc, input_pc, query_random_pc, input_random_pc, query_sharpedge_pc, input_sharpedge_pc]
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1]), pc_infos
def forward(self, pc, feats):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, C]
Returns:
"""
query, data, pc_infos = self.sample_points_and_latents(pc, feats)
if self.jitter_query or self.voxel_query:
query = self.input_proj_q(query)
query = query
data = self.input_proj_kv(data)
data = data
else:
query = self.input_proj(query)
query = query
data = self.input_proj(data)
data = data
latents = self.cross_attn(query, data)
if self.self_attn is not None:
latents = self.self_attn(latents)
if self.ln_post is not None:
latents = self.ln_post(latents)
return latents, pc_infos
================================================
FILE: ultrashape/models/autoencoders/attention_processors.py
================================================
# ==============================================================================
# Original work Copyright (c) 2025 Tencent.
# Modified work Copyright (c) 2025 UltraShape Team.
#
# Modified by UltraShape on 2025.12.25
# ==============================================================================
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import os
import torch
import torch.nn.functional as F
scaled_dot_product_attention = F.scaled_dot_product_attention
if os.environ.get('CA_USE_SAGEATTN', '0') == '1':
try:
from sageattention import sageattn
except ImportError:
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
scaled_dot_product_attention = sageattn
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = scaled_dot_product_attention(q, k, v)
return out
class FlashVDMCrossAttentionProcessor:
def __init__(self, topk=None):
self.topk = topk
def __call__(self, attn, q, k, v):
if k.shape[-2] == 3072:
topk = 1024
elif k.shape[-2] == 512:
topk = 256
else:
topk = k.shape[-2] // 3
if self.topk is True:
q1 = q[:, :, ::100, :]
sim = q1 @ k.transpose(-1, -2)
sim = torch.mean(sim, -2)
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
v0 = torch.gather(v, dim=-2, index=topk_ind)
k0 = torch.gather(k, dim=-2, index=topk_ind)
out = scaled_dot_product_attention(q, k0, v0)
elif self.topk is False:
out = scaled_dot_product_attention(q, k, v)
else:
idx, counts = self.topk
start = 0
outs = []
for grid_coord, count in zip(idx, counts):
end = start + count
q_chunk = q[:, :, start:end, :]
k0, v0 = self.select_topkv(q_chunk, k, v, topk)
out = scaled_dot_product_attention(q_chunk, k0, v0)
outs.append(out)
start += count
out = torch.cat(outs, dim=-2)
self.topk = False
return out
def select_topkv(self, q_chunk, k, v, topk):
q1 = q_chunk[:, :, ::50, :]
sim = q1 @ k.transpose(-1, -2)
sim = torch.mean(sim, -2)
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
v0 = torch.gather(v, dim=-2, index=topk_ind)
k0 = torch.gather(k, dim=-2, index=topk_ind)
return k0, v0
class FlashVDMTopMCrossAttentionProcessor(FlashVDMCrossAttentionProcessor):
def select_topkv(self, q_chunk, k, v, topk):
q1 = q_chunk[:, :, ::30, :]
sim = q1 @ k.transpose(-1, -2)
# sim = sim.to(torch.float32)
sim = sim.softmax(-1)
sim = torch.mean(sim, 1)
activated_token = torch.where(sim > 1e-6)[2]
index = torch.unique(activated_token, return_counts=True)[0].unsqueeze(0).unsqueeze(0).unsqueeze(-1)
index = index.expand(-1, v.shape[1], -1, v.shape[-1])
v0 = torch.gather(v, dim=-2, index=index)
k0 = torch.gather(k, dim=-2, index=index)
return k0, v0
================================================
FILE: ultrashape/models/autoencoders/model.py
================================================
# ==============================================================================
# Original work Copyright (c) 2025 Tencent.
# Modified work Copyright (c) 2025 UltraShape Team.
#
# Modified by UltraShape on 2025.12.25
# ==============================================================================
# Open Source Model Licensed under the Apache License Version 2.0
# and Other Licenses of the Third-Party Components therein:
# The below Model in this distribution may have been modified by THL A29 Limited
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
# The below software and/or models in this distribution may have been
# modified by THL A29 Limited ("Tencent Modifications").
# All Tencent Modifications are Copyright (C) THL A29 Limited.
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import os
from typing import Union, List
import numpy as np
import torch
import torch.nn as nn
import yaml
from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder, PointCrossAttentionEncoder
from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors
from .volume_decoders import VanillaVolumeDecoder, FlashVDMVolumeDecoding, HierarchicalVolumeDecoding
from ...utils import logger, synchronize_timer, smart_load_model
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
"""
Initialize a diagonal Gaussian distribution with mean and log-variance parameters.
Args:
parameters (Union[torch.Tensor, List[torch.Tensor]]):
Either a single tensor containing concatenated mean and log-variance along `feat_dim`,
or a list of two tensors [mean, logvar].
deterministic (bool, optional): If True, the distribution is deterministic (zero variance).
Default is False. feat_dim (int, optional): Dimension along which mean and logvar are
concatenated if parameters is a single tensor. Default is 1.
"""
self.feat_dim = feat_dim
self.parameters = parameters
if isinstance(parameters, list):
self.mean = parameters[0]
self.logvar = parameters[1]
else:
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean)
def sample(self):
"""
Sample from the diagonal Gaussian distribution.
Returns:
torch.Tensor: A sample tensor with the same shape as the mean.
"""
x = self.mean + self.std * torch.randn_like(self.mean)
return x
def kl(self, other=None, dims=(1, 2)):
"""
Compute the Kullback-Leibler (KL) divergence between this distribution and another.
If `other` is None, compute KL divergence to a standard normal distribution N(0, I).
Args:
other (DiagonalGaussianDistribution, optional): Another diagonal Gaussian distribution.
dims (tuple, optional): Dimensions along which to compute the mean KL divergence.
Default is (1, 2, 3).
Returns:
torch.Tensor: The mean KL divergence value.
"""
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims)
else:
return 0.5 * torch.mean(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=dims)
def nll(self, sample, dims=(1, 2, 3)):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class VectsetVAE(nn.Module):
@classmethod
@synchronize_timer('VectsetVAE Model Loading')
def from_single_file(
cls,
ckpt_path,
config_path=None,
params=None,
device='cuda',
dtype=torch.float16,
use_safetensors=None,
**kwargs,
):
# load config
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# load ckpt
if use_safetensors:
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Model file {ckpt_path} not found")
logger.info(f"Loading model from {ckpt_path}")
if use_safetensors:
import safetensors.torch
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
else:
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
if params is not None:
model_kwargs = params
else:
model_kwargs = config['params']
model_kwargs.update(kwargs)
model = cls(**model_kwargs)
model.load_state_dict(ckpt)
model.to(device=device, dtype=dtype)
return model
@classmethod
def from_pretrained(
cls,
model_path,
device='cuda',
params=None,
dtype=torch.float16,
use_safetensors=False,
variant='fp16',
subfolder='hunyuan3d-vae-v2-1',
**kwargs,
):
config_path, ckpt_path = smart_load_model(
model_path,
subfolder=subfolder,
use_safetensors=use_safetensors,
variant=variant
)
return cls.from_single_file(
ckpt_path,
config_path=config_path,
params=params,
device=device,
dtype=dtype,
use_safetensors=use_safetensors,
**kwargs
)
def init_from_ckpt(self, path, ignore_keys=()):
state_dict = torch.load(path, map_location="cpu")
state_dict = state_dict.get("state_dict", state_dict)
keys = list(state_dict.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def __init__(
self,
volume_decoder=None,
surface_extractor=None
):
super().__init__()
if volume_decoder is None:
volume_decoder = VanillaVolumeDecoder()
if surface_extractor is None:
surface_extractor = MCSurfaceExtractor()
self.volume_decoder = volume_decoder
self.surface_extractor = surface_extractor
def latents2mesh(self, latents: torch.FloatTensor, **kwargs):
with synchronize_timer('Volume decoding'):
grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs)
with synchronize_timer('Surface extraction'):
outputs = self.surface_extractor(grid_logits, **kwargs)
return outputs, grid_logits
def enable_flashvdm_decoder(
self,
enabled: bool = True,
adaptive_kv_selection=True,
topk_mode='mean',
mc_algo='mc',
):
if enabled:
if adaptive_kv_selection:
self.volume_decoder = FlashVDMVolumeDecoding(topk_mode)
else:
self.volume_decoder = HierarchicalVolumeDecoding()
if mc_algo not in SurfaceExtractors.keys():
raise ValueError(f'Unsupported mc_algo {mc_algo}, available:{list(SurfaceExtractors.keys())}')
self.surface_extractor = SurfaceExtractors[mc_algo]()
else:
self.volume_decoder = VanillaVolumeDecoder()
self.surface_extractor = MCSurfaceExtractor()
class ShapeVAE(VectsetVAE):
def __init__(
self,
*,
num_latents: int,
embed_dim: int,
width: int,
heads: int,
num_decoder_layers: int,
num_encoder_layers: int = 8,
pc_size: int = 5120,
pc_sharpedge_size: int = 5120,
point_feats: int = 3,
downsample_ratio: int = 20,
geo_decoder_downsample_ratio: int = 1,
geo_decoder_mlp_expand_ratio: int = 4,
geo_decoder_ln_post: bool = True,
num_freqs: int = 8,
include_pi: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary",
drop_path_rate: float = 0.0,
scale_factor: float = 1.0,
use_ln_post: bool = True,
enable_flashvdm: bool = False,
ckpt_path = None,
jitter_query: bool = False,
voxel_query: bool = False,
voxel_query_res: int = 128,
):
super().__init__()
self.geo_decoder_ln_post = geo_decoder_ln_post
self.downsample_ratio = downsample_ratio
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
self.encoder = PointCrossAttentionEncoder(
fourier_embedder=self.fourier_embedder,
num_latents=num_latents,
downsample_ratio=self.downsample_ratio,
pc_size=pc_size,
pc_sharpedge_size=pc_sharpedge_size,
point_feats=point_feats,
width=width,
heads=heads,
layers=num_encoder_layers,
qkv_bias=qkv_bias,
use_ln_post=use_ln_post,
qk_norm=qk_norm,
jitter_query=jitter_query,
voxel_query=voxel_query,
voxel_query_res=voxel_query_res
)
self.pre_kl = nn.Linear(width, embed_dim * 2)
self.post_kl = nn.Linear(embed_dim, width)
self.transformer = Transformer(
n_ctx=num_latents,
width=width,
layers=num_decoder_layers,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
self.geo_decoder = CrossAttentionDecoder(
fourier_embedder=self.fourier_embedder,
out_channels=1,
num_latents=num_latents,
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
downsample_ratio=geo_decoder_downsample_ratio,
enable_ln_post=self.geo_decoder_ln_post,
width=width // geo_decoder_downsample_ratio,
heads=heads // geo_decoder_downsample_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
label_type=label_type,
)
self.scale_factor = scale_factor
self.latent_shape = (num_latents, embed_dim)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path)
if enable_flashvdm:
self.enable_flashvdm_decoder()
def forward(self, latents):
latents = self.post_kl(latents)
latents = self.transformer(latents)
return latents
def encode(self, surface, sample_posterior=True, need_kl=False, need_voxel=False):
pc, feats = surface[:, :, :3], surface[:, :, 3:]
latents, pc_infos = self.encoder(pc, feats)
# print(latents.shape, self.pre_kl.weight.shape)
moments = self.pre_kl(latents)
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
if sample_posterior:
latents = posterior.sample()
else:
latents = posterior.mode()
if need_kl:
return latents, posterior
if need_voxel:
return latents, pc_infos[0]
return latents
def decode(self, latents, voxel_idx=None):
latents = self.post_kl(latents)
latents = self.transformer(latents)
return latents
def query(self, latents, queries, voxel_idx=None):
"""
Args:
queries (torch.FloatTensor): [B, N, 3]
latents (torch.FloatTensor): [B, embed_dim]
Returns:
logits (torch.FloatTensor): [B, N], occupancy logits
"""
logits = self.geo_decoder(queries=queries, latents=latents).squeeze(-1)
return logits
================================================
FILE: ultrashape/models/autoencoders/surface_extractors.py
================================================
# ==============================================================================
# Original work Copyright (c) 2025 Tencent.
# Modified work Copyright (c) 2025 UltraShape Team.
#
# Modified by UltraShape on 2025.12.25
# ==============================================================================
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
from typing import Union, Tuple, List
import numpy as np
import torch
from skimage import measure
import cubvh
class Latent2MeshOutput:
def __init__(self, mesh_v=None, mesh_f=None):
self.mesh_v = mesh_v
self.mesh_f = mesh_f
def center_vertices(vertices):
"""Translate the vertices so that bounding box is centered at zero."""
vert_min = vertices.min(dim=0)[0]
vert_max = vertices.max(dim=0)[0]
vert_center = 0.5 * (vert_min + vert_max)
return vertices - vert_center
class SurfaceExtractor:
def _compute_box_stat(self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int):
"""
Compute grid size, bounding box minimum coordinates, and bounding box size based on input
bounds and resolution.
Args:
bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or a single
float representing half side length.
If float, bounds are assumed symmetric around zero in all axes.
Expected format if list/tuple: [xmin, ymin, zmin, xmax, ymax, zmax].
octree_resolution (int): Resolution of the octree grid.
Returns:
grid_size (List[int]): Grid size along each axis (x, y, z), each equal to octree_resolution + 1.
bbox_min (np.ndarray): Minimum coordinates of the bounding box (xmin, ymin, zmin).
bbox_size (np.ndarray): Size of the bounding box along each axis (xmax - xmin, etc.).
"""
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
bbox_size = bbox_max - bbox_min
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
return grid_size, bbox_min, bbox_size
def run(self, *args, **kwargs):
"""
Abstract method to extract surface mesh from grid logits.
This method should be implemented by subclasses.
Raises:
NotImplementedError: Always, since this is an abstract method.
"""
return NotImplementedError
def __call__(self, grid_logits, **kwargs):
"""
Process a batch of grid logits to extract surface meshes.
Args:
grid_logits (torch.Tensor): Batch of grid logits with shape (batch_size, ...).
**kwargs: Additional keyword arguments passed to the `run` method.
Returns:
List[Optional[Latent2MeshOutput]]: List of mesh outputs for each grid in the batch.
If extraction fails for a grid, None is appended at that position.
"""
outputs = []
for i in range(grid_logits.shape[0]):
try:
vertices, faces = self.run(grid_logits[i], **kwargs)
vertices = vertices.astype(np.float32)
faces = np.ascontiguousarray(faces)
outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces))
except Exception:
import traceback
traceback.print_exc()
outputs.append(None)
return outputs
def get_sparse_valid_voxels(grid_logit: torch.Tensor):
if not isinstance(grid_logit, torch.Tensor):
raise TypeError("Input must be a PyTorch tensor.")
if grid_logit.dim() != 3 or grid_logit.shape[0] != grid_logit.shape[1] or grid_logit.shape[0] != grid_logit.shape[2]:
raise ValueError("Input tensor must have shape (N, N, N)")
N = grid_logit.shape[0]
device = grid_logit.device
# Chunk processing to save memory
chunk_size = 128
all_sparse_coords = []
all_sparse_logits = []
# Process in chunks along x-axis
for start_x in range(0, N - 1, chunk_size):
end_x = min(start_x + chunk_size, N - 1)
# Determine slice range including +1 for neighbor checks
# slice_end needs to be end_x + 1 to include the neighbors for the last voxel in chunk
slice_end = end_x + 1
chunk = grid_logit[start_x:slice_end, :, :]
nan_mask = torch.isnan(chunk)
# Compute mask for this chunk (valid voxels are 0 to end_x - start_x)
# Note: chunk shape is [D_chunk, N, N].
# We want to check validity for [0..D_chunk-1, :-1, :-1]
sub_nan_mask = nan_mask
# Validity check requires looking at i and i+1
# Invalid if ANY corner is NaN
invalid_voxel_mask = (
sub_nan_mask[:-1, :-1, :-1] |
sub_nan_mask[1:, :-1, :-1] |
sub_nan_mask[:-1, 1:, :-1] |
sub_nan_mask[:-1, :-1, 1:] |
sub_nan_mask[:-1, 1:, 1:] |
sub_nan_mask[1:, :-1, 1:] |
sub_nan_mask[1:, 1:, :-1] |
sub_nan_mask[1:, 1:, 1:]
)
valid_voxel_mask = ~invalid_voxel_mask
# Get local coordinates
local_coords = valid_voxel_mask.nonzero(as_tuple=False)
if local_coords.shape[0] > 0:
lx, ly, lz = local_coords[:, 0], local_coords[:, 1], local_coords[:, 2]
# Extract logits using local indices on the chunk
# v0 is at lx, v1 is at lx+1, etc.
sparse_vertex_logits = torch.stack([
chunk[lx, ly, lz], # v0
chunk[lx + 1, ly, lz], # v1
chunk[lx + 1, ly + 1, lz], # v2
chunk[lx, ly + 1, lz], # v3
chunk[lx, ly, lz + 1], # v4
chunk[lx + 1, ly, lz + 1], # v5
chunk[lx + 1, ly + 1, lz + 1], # v6
chunk[lx, ly + 1, lz + 1] # v7
], dim=1)
# Convert local coords to global coords
# x coordinate needs offset added
global_coords = local_coords.clone()
global_coords[:, 0] += start_x
all_sparse_coords.append(global_coords)
all_sparse_logits.append(sparse_vertex_logits)
# Free memory
del chunk, nan_mask, invalid_voxel_mask, valid_voxel_mask, local_coords
if not all_sparse_coords:
return torch.empty((0, 3), dtype=torch.long, device=device), torch.empty((0, 8), dtype=grid_logit.dtype, device=device)
sparse_coords = torch.cat(all_sparse_coords, dim=0)
sparse_vertex_logits = torch.cat(all_sparse_logits, dim=0)
return sparse_coords, sparse_vertex_logits
class MCSurfaceExtractor(SurfaceExtractor):
def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
"""
Extract surface mesh using the Marching Cubes algorithm.
Args:
grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
mc_level (float): The level (iso-value) at which to extract the surface.
bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or half side length.
octree_resolution (int): Resolution of the octree grid.
**kwargs: Additional keyword arguments (ignored).
Returns:
Tuple[np.ndarray, np.ndarray]: Tuple containing:
- vertices (np.ndarray): Extracted mesh vertices, scaled and translated to bounding
box coordinates.
- faces (np.ndarray): Extracted mesh faces (triangles).
"""
grid_logit = grid_logit.detach()
sparse_coords, sparse_logits = get_sparse_valid_voxels(grid_logit)
# Convert to float32 only for the sparse set
vertices, faces = cubvh.sparse_marching_cubes(sparse_coords, sparse_logits.float(), mc_level)
vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()
# vertices, faces, normals, _ = measure.marching_cubes(grid_logit,
# mc_level, method="lewiner", mask=(~np.isnan(grid_logit)))
grid_size, bbox_min, bbox_size = self._compute_box_stat(bounds, octree_resolution)
vertices = vertices / grid_size * bbox_size + bbox_min
return vertices, faces
class DMCSurfaceExtractor(SurfaceExtractor):
def run(self, grid_logit, *, octree_resolution, **kwargs):
"""
Extract surface mesh using Differentiable Marching Cubes (DMC) algorithm.
Args:
grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
octree_resolution (int): Resolution of the octree grid.
**kwargs: Additional keyword arguments (ignored).
Returns:
Tuple[np.ndarray, np.ndarray]: Tuple containing:
- vertices (np.ndarray): Extracted mesh vertices, centered and converted to numpy.
- faces (np.ndarray): Extracted mesh faces (triangles), with reversed vertex order.
Raises:
ImportError: If the 'diso' package is not installed.
"""
device = grid_logit.device
if not hasattr(self, 'dmc'):
try:
from diso import DiffDMC
self.dmc = DiffDMC(dtype=torch.float32).to(device)
except:
raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'")
sdf = -grid_logit / octree_resolution
sdf = sdf.to(torch.float32).contiguous()
verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
verts = center_vertices(verts)
vertices = verts.detach().cpu().numpy()
faces = faces.detach().cpu().numpy()[:, ::-1]
return vertices, faces
SurfaceExtractors = {
'mc': MCSurfaceExtractor,
'dmc': DMCSurfaceExtractor,
}
================================================
FILE: ultrashape/models/autoencoders/vae_trainer.py
================================================
import os
from contextlib import contextmanager
from typing import List, Tuple, Optional, Union
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities import rank_zero_only
import trimesh
from ...utils.misc import instantiate_from_config, instantiate_non_trainable_model, instantiate_vae_model
def export_to_trimesh(mesh_output):
if isinstance(mesh_output, list):
outputs = []
for mesh in mesh_output:
if mesh is None:
outputs.append(None)
else:
mesh.mesh_f = mesh.mesh_f[:, ::-1]
mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
outputs.append(mesh_output)
return outputs
else:
mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1]
mesh_output = trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f)
return mesh_output
class VAETrainer(pl.LightningModule):
def __init__(
self,
*,
vae_config,
optimizer_cfg,
loss_cfg,
save_dir,
mc_res,
ckpt_path: Optional[str] = None,
ignore_keys: Union[Tuple[str], List[str]] = (),
torch_compile: bool = False,
):
super().__init__()
# ========= init optimizer config ========= #
self.optimizer_cfg = optimizer_cfg
self.loss_cfg = loss_cfg
self.ckpt_path = ckpt_path
self.vae_model = instantiate_vae_model(vae_config, requires_grad=True)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.mc_res = mc_res
self.save_root = save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# ========= torch compile to accelerate ========= #
self.torch_compile = torch_compile
if self.torch_compile:
torch.nn.Module.compile(self.vae_model)
print(f'*' * 100)
print(f'Compile model for acceleration')
print(f'*' * 100)
def init_from_ckpt(self, path, ignore_keys=()):
ckpt = torch.load(path, map_location="cpu")
if 'state_dict' not in ckpt:
# deepspeed ckpt
state_dict = {}
for k in ckpt.keys():
new_k = k.replace('_forward_module.', '')
state_dict[new_k] = ckpt[k]
else:
state_dict = ckpt["state_dict"]
keys = list(state_dict.keys())
for k in keys:
for ik in ignore_keys:
if ik in k:
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
# # ==================== Weight Surgery Start ====================
# old_key_base = "vae_model.encoder.input_proj"
# old_weight_key = f"{old_key_base}.weight"
# old_bias_key = f"{old_key_base}.bias"
# if old_weight_key in state_dict:
# print(f"[*] Detected legacy '{old_key_base}' in checkpoint. Performing weight surgery...")
# src_weight = state_dict[old_weight_key]
# src_bias = state_dict[old_bias_key]
# encoder = self.vae_model.encoder
# fourier_dim = encoder.fourier_embedder.out_dim
# # --- A. input_proj_kv ---
# # shape: [width, fourier_dim + point_feats]
# encoder.input_proj_kv.weight.data.copy_(src_weight)
# encoder.input_proj_kv.bias.data.copy_(src_bias)
# print(f" -> Loaded input_proj_kv from {old_key_base}")
# # --- B. input_proj_q ---
# # shape: [width, fourier_dim]
# sliced_weight = src_weight[:, :fourier_dim]
# encoder.input_proj_q.weight.data.copy_(sliced_weight)
# encoder.input_proj_q.bias.data.copy_(src_bias)
# print(f" -> Loaded input_proj_q (sliced) from {old_key_base}")
# del state_dict[old_weight_key]
# if old_bias_key in state_dict:
# del state_dict[old_bias_key]
# # ==================== Weight Surgery End ====================
missing, unexpected = self.load_state_dict(state_dict, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def configure_optimizers(self) -> Tuple[List, List]:
lr = self.learning_rate
params_list = []
trainable_parameters = list(self.vae_model.parameters())
params_list.append({'params': trainable_parameters, 'lr': lr})
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=params_list, lr=lr)
if hasattr(self.optimizer_cfg, 'scheduler'):
scheduler_func = instantiate_from_config(
self.optimizer_cfg.scheduler,
max_decay_steps=self.trainer.max_steps,
lr_max=lr
)
scheduler = {
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
"interval": "step",
"frequency": 1
}
schedulers = [scheduler]
else:
schedulers = []
optimizers = [optimizer]
return optimizers, schedulers
def on_train_epoch_start(self) -> None:
pl.seed_everything(self.trainer.global_rank)
def forward(self, batch):
sup_pc_s_list = [batch["sup_near_uniform"], batch["sup_near_sharp"], batch["sup_space"]]
rand_points = [sup_pc_s[:,:,:3] for sup_pc_s in sup_pc_s_list]
rand_points_val = [sup_pc_s[:,:,3:] for sup_pc_s in sup_pc_s_list]
rand_points = torch.cat(rand_points, dim=1)
target = torch.cat(rand_points_val, dim=1)[...,0]
target = -target
latents, posterior = self.vae_model.encode(
batch['surface'], sample_posterior=True, need_kl=True)
latents = self.vae_model.decode(latents)
logits = self.vae_model.query(latents, rand_points)
loss_kl = posterior.kl()
loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]
criteria = torch.nn.MSELoss()
criteria2 = torch.nn.L1Loss()
loss_logits = criteria(logits, target).mean() + criteria2(logits, target).mean()
loss = self.loss_cfg.lambda_logits * loss_logits + self.loss_cfg.lambda_kl * loss_kl
loss_dict = {
"loss": loss,
"loss_logits": loss_logits,
"loss_kl": loss_kl
}
return loss_dict, latents
def training_step(self, batch, batch_idx, optimizer_idx=0):
loss, latents = self.forward(batch)
split = 'train'
loss_dict = {
f"{split}/total_loss": loss["loss"].detach(),
f"{split}/loss_logits": loss["loss_logits"].detach(),
f"{split}/loss_kl": loss["loss_kl"].detach(),
f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
}
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
return loss
def validation_step(self, batch, batch_idx, optimizer_idx=0):
loss, latents = self.forward(batch)
split = 'val'
loss_dict = {
f"{split}/total_loss": loss["loss"].detach(),
f"{split}/loss_logits": loss["loss_logits"].detach(),
f"{split}/loss_kl": loss["loss_kl"].detach(),
f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
}
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
if self.trainer.global_rank < 2:
with torch.no_grad():
save_dir = f"{self.save_root}/gs{self.global_step:010d}_rank{self.trainer.global_rank}"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
uids = batch.get('uid')
for i, latent in enumerate(latents[:5]):
mesh, grid_logits = self.vae_model.latents2mesh(
latent[None],
output_type='trimesh',
bounds=1.01,
mc_level=0.0,
num_chunks=20000,
octree_resolution=self.mc_res,
mc_algo='mc',
enable_pbar=True
)
mesh = export_to_trimesh(mesh[0])
save_path = f"{save_dir}/recon_{os.path.splitext(os.path.basename(uids[i]))[0]}_mc{self.mc_res}.obj"
mesh.export(save_path)
return loss
================================================
FILE: ultrashape/models/autoencoders/volume_decoders.py
================================================
# ==============================================================================
# Original work Copyright (c) 2025 Tencent.
# Modified work Copyright (c) 2025 UltraShape Team.
#
# Modified by UltraShape on 2025.12.25
# ==============================================================================
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
from typing import Union, Tuple, List, Callable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat
from tqdm import tqdm
from .attention_blocks import CrossAttentionDecoder
from .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor
from ...utils import logger
def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
val = input_tensor + alpha
valid_mask = val > -9000
mask = torch.ones_like(val, dtype=torch.int32)
sign = torch.sign(val.to(torch.float32))
# Helper to compute neighbor for a single direction
def check_neighbor_sign(shift, axis):
if shift == 0:
return
pad_dims = [0, 0, 0, 0, 0, 0]
if axis == 0:
pad_idx = 0 if shift > 0 else 1
pad_dims[pad_idx] = abs(shift)
elif axis == 1:
pad_idx = 2 if shift > 0 else 3
pad_dims[pad_idx] = abs(shift)
elif axis == 2:
pad_idx = 4 if shift > 0 else 5
pad_dims[pad_idx] = abs(shift)
padded = F.pad(val.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate')
slice_dims = [slice(None)] * 3
if axis == 0:
if shift > 0: slice_dims[0] = slice(shift, None)
else: slice_dims[0] = slice(None, shift)
elif axis == 1:
if shift > 0: slice_dims[1] = slice(shift, None)
else: slice_dims[1] = slice(None, shift)
elif axis == 2:
if shift > 0: slice_dims[2] = slice(shift, None)
else: slice_dims[2] = slice(None, shift)
padded = padded.squeeze(0).squeeze(0)
neighbor = padded[slice_dims]
neighbor = torch.where(neighbor > -9000, neighbor, val)
# Check sign consistency
neighbor_sign = torch.sign(neighbor.to(torch.float32))
return (neighbor_sign == sign)
# Iteratively check neighbors and update mask
# directions: (shift, axis)
directions = [(1, 0), (-1, 0), (1, 1), (-1, 1), (1, 2), (-1, 2)]
for shift, axis in directions:
is_same = check_neighbor_sign(shift, axis)
mask = mask & is_same.to(torch.int32)
# Invert mask: we want 1 where ANY neighbor has different sign
mask = (~(mask.bool())).to(torch.int32)
return mask * valid_mask.to(torch.int32)
def generate_dense_grid_points(
bbox_min: np.ndarray,
bbox_max: np.ndarray,
octree_resolution: int,
indexing: str = "ij",
):
length = bbox_max - bbox_min
num_cells = octree_resolution
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
xyz = np.stack((xs, ys, zs), axis=-1)
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
return xyz, grid_size, length
class VanillaVolumeDecoder:
@torch.no_grad()
def __call__(
self,
latents: torch.FloatTensor,
geo_decoder: Callable,
bounds: Union[Tuple[float], List[float], float] = 1.01,
num_chunks: int = 10000,
octree_resolution: int = None,
enable_pbar: bool = True,
**kwargs,
):
device = latents.device
dtype = latents.dtype
batch_size = latents.shape[0]
# 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_resolution=octree_resolution,
indexing="ij"
)
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
# 2. latents to 3d volume
batch_logits = []
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f"Volume Decoding",
disable=not enable_pbar):
chunk_queries = xyz_samples[start: start + num_chunks, :]
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
logits = geo_decoder(queries=chunk_queries, latents=latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim=1)
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
return grid_logits
class HierarchicalVolumeDecoding:
@torch.no_grad()
def __call__(
self,
latents: torch.FloatTensor,
geo_decoder: Callable,
bounds: Union[Tuple[float], List[float], float] = 1.01,
num_chunks: int = 10000,
mc_level: float = 0.0,
octree_resolution: int = None,
min_resolution: int = 63,
enable_pbar: bool = True,
**kwargs,
):
device = latents.device
dtype = latents.dtype
resolutions = []
if octree_resolution < min_resolution:
resolutions.append(octree_resolution)
while octree_resolution >= min_resolution:
resolutions.append(octree_resolution)
octree_resolution = octree_resolution // 2
resolutions.reverse()
# 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min = np.array(bounds[0:3])
bbox_max = np.array(bounds[3:6])
bbox_size = bbox_max - bbox_min
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_resolution=resolutions[0],
indexing="ij"
)
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
grid_size = np.array(grid_size)
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
# 2. latents to 3d volume
batch_logits = []
batch_size = latents.shape[0]
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]"):
queries = xyz_samples[start: start + num_chunks, :]
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
logits = geo_decoder(queries=batch_queries, latents=latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2]))
for octree_depth_now in resolutions[1:]:
grid_size = np.array([octree_depth_now + 1] * 3)
resolution = bbox_size / octree_depth_now
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
curr_points += grid_logits.squeeze(0).abs() < 0.95
if octree_depth_now == resolutions[-1]:
expand_num = 0
else:
expand_num = 1
for i in range(expand_num):
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
for i in range(2 - expand_num):
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
nidx = torch.where(next_index > 0)
# Store shape before deleting
next_index_shape = next_index.shape
del next_index
torch.cuda.empty_cache()
next_points = torch.stack(nidx, dim=1)
next_points = (next_points * torch.tensor(resolution, dtype=next_points.dtype, device=device) +
torch.tensor(bbox_min, dtype=next_points.dtype, device=device))
batch_logits = []
for start in tqdm(range(0, next_points.shape[0], num_chunks),
desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]"):
queries = next_points[start: start + num_chunks, :]
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
logits = geo_decoder(queries=batch_queries.to(latents.dtype), latents=latents)
batch_logits.append(logits)
# Delayed allocation of next_logits
next_logits = torch.full(next_index_shape, -10000., dtype=dtype, device=device)
grid_logits = torch.cat(batch_logits, dim=1)
next_logits[nidx] = grid_logits[0, ..., 0]
grid_logits = next_logits.unsqueeze(0)
grid_logits[grid_logits == -10000.] = float('nan')
return grid_logits
class FlashVDMVolumeDecoding:
def __init__(self, topk_mode='mean'):
if topk_mode not in ['mean', 'merge']:
raise ValueError(f'Unsupported topk_mode {topk_mode}, available: {["mean", "merge"]}')
if topk_mode == 'mean':
self.processor = FlashVDMCrossAttentionProcessor()
else:
self.processor = FlashVDMTopMCrossAttentionProcessor()
@torch.no_grad()
def __call__(
self,
latents: torch.FloatTensor,
geo_decoder: CrossAttentionDecoder,
bounds: Union[Tuple[float], List[float], float] = 1.01,
num_chunks: int = 10000,
mc_level: float = 0.0,
octree_resolution: int = None,
min_resolution: int = 63,
mini_grid_num: int = 4,
enable_pbar: bool = True,
**kwargs,
):
processor = self.processor
geo_decoder.set_cross_attention_processor(processor)
device = latents.device
dtype = latents.dtype
resolutions = []
if octree_resolution < min_resolution:
resolutions.append(octree_resolution)
while octree_resolution >= min_resolution:
resolutions.append(octree_resolution)
octree_resolution = octree_resolution // 2
resolutions.reverse()
resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
for i, resolution in enumerate(resolutions[1:]):
resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}")
# 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min = np.array(bounds[0:3])
bbox_max = np.array(bounds[3:6])
bbox_size = bbox_max - bbox_min
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_resolution=resolutions[0],
indexing="ij"
)
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
grid_size = np.array(grid_size)
# 2. latents to 3d volume
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
batch_size = latents.shape[0]
mini_grid_size = xyz_samples.shape[0] // mini_grid_num
xyz_samples = xyz_samples.view(
mini_grid_num, mini_grid_size,
mini_grid_num, mini_grid_size,
mini_grid_num, mini_grid_size, 3
).permute(
0, 2, 4, 1, 3, 5, 6
).reshape(
-1, mini_grid_size * mini_grid_size * mini_grid_size, 3
)
batch_logits = []
num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
for start in tqdm(range(0, xyz_samples.shape[0], num_batchs),
desc=f"FlashVDM Volume Decoding", disable=not enable_pbar):
queries = xyz_samples[start: start + num_batchs, :]
batch = queries.shape[0]
batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
processor.topk = True
# Chunk queries along dim 1 if too large
if queries.shape[1] > num_chunks:
batch_logits_sub = []
for sub_start in range(0, queries.shape[1], num_chunks):
sub_queries = queries[:, sub_start: sub_start + num_chunks, :]
logits = geo_decoder(queries=sub_queries, latents=batch_latents)
batch_logits_sub.append(logits)
logits = torch.cat(batch_logits_sub, dim=1)
else:
logits = geo_decoder(queries=queries, latents=batch_latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim=0).reshape(
mini_grid_num, mini_grid_num, mini_grid_num,
mini_grid_size, mini_grid_size,
mini_grid_size
).permute(0, 3, 1, 4, 2, 5).contiguous().view(
(batch_size, grid_size[0], grid_size[1], grid_size[2])
)
for octree_depth_now in resolutions[1:]:
grid_size = np.array([octree_depth_now + 1] * 3)
resolution = bbox_size / octree_depth_now
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
curr_points += grid_logits.squeeze(0).abs() < 0.95
if octree_depth_now == resolutions[-1]:
expand_num = 0
else:
expand_num = 1
for i in range(expand_num):
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
for i in range(2 - expand_num):
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
nidx = torch.where(next_index > 0)
# Store shape before deleting
next_index_shape = next_index.shape
del next_index
torch.cuda.empty_cache()
next_points = torch.stack(nidx, dim=1)
next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
torch.tensor(bbox_min, dtype=torch.float32, device=device))
query_grid_num = 6
min_val = next_points.min(axis=0).values
max_val = next_points.max(axis=0).values
vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
index = torch.floor(vol_queries_index).long()
index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
index = index.sort()
next_points = next_points[index.indices].unsqueeze(0).contiguous()
unique_values = torch.unique(index.values, return_counts=True)
grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
input_grid = [[], []]
logits_grid_list = []
start_num = 0
sum_num = 0
for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
remaining_count = count
while remaining_count > 0:
space_left = num_chunks - sum_num
# If buffer is full, flush it
if space_left <= 0:
processor.topk = input_grid
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
start_num = start_num + sum_num
logits_grid_list.append(logits_grid)
input_grid = [[], []]
sum_num = 0
space_left = num_chunks
take = min(remaining_count, space_left)
input_grid[0].append(grid_index)
input_grid[1].append(take)
sum_num += take
remaining_count -= take
if sum_num > 0:
processor.topk = input_grid
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
logits_grid_list.append(logits_grid)
logits_grid = torch.cat(logits_grid_list, dim=1)
grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
# Delayed allocation of next_logits
next_logits = torch.full(next_index_shape, -10000., dtype=dtype, device=device)
next_logits[nidx] = grid_logits
grid_logits = next_logits.unsqueeze(0)
grid_logits[grid_logits == -10000.] = float('nan')
return grid_logits
================================================
FILE: ultrashape/models/conditioner_mask.py
================================================
# ==============================================================================
# Original work Copyright (c) 2025 Tencent.
# Modified work Copyright (c) 2025 UltraShape Team.
#
# Modified by UltraShape on 2025.12.25
# ==============================================================================
# Open Source Model Licensed under the Apache License Version 2.0
# and Other Licenses of the Third-Party Components therein:
# The below Model in this distribution may have been modified by THL A29 Limited
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
# The below software and/or models in this distribution may have been
# modified by THL A29 Limited ("Tencent Modifications").
# All Tencent Modifications are Copyright (C) THL A29 Limited.
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import (
CLIPVisionModelWithProjection,
CLIPVisionConfig,
Dinov2Model,
Dinov2Config,
)
from transformers import AutoImageProcessor, AutoModel
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
return np.concatenate([emb_sin, emb_cos], axis=1)
class ImageEncoder(nn.Module):
def __init__(
self,
version=None,
config=None,
use_cls_token=True,
image_size=224,
**kwargs,
):
super().__init__()
if config is None:
self.model = AutoModel.from_pretrained(version)
else:
self.model = self.MODEL_CLASS(self.MODEL_CONFIG_CLASS.from_dict(config))
self.model.eval()
self.model.requires_grad_(False)
self.use_cls_token = use_cls_token
self.size = image_size // 14
self.num_patches = (image_size // 14) ** 2
if self.use_cls_token:
self.num_patches += 1
self.transform = transforms.Compose(
[
transforms.Resize(image_size, transforms.InterpolationMode.BILINEAR, antialias=True),
transforms.CenterCrop(image_size),
transforms.Normalize(
mean=self.mean,
std=self.std,
),
]
)
self.mask_transform = transforms.Compose(
[
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.NEAREST),
transforms.CenterCrop(image_size),
]
)
def forward(self, image, mask=None, value_range=(-1, 1), **kwargs):
if value_range is not None:
low, high = value_range
image = (image - low) / (high - low)
image = image.to(self.model.device, dtype=self.model.dtype)
inputs = self.transform(image)
outputs = self.model(inputs)
last_hidden_state = outputs.last_hidden_state
if not self.use_cls_token:
last_hidden_state = last_hidden_state[:, 1:, :]
if mask is not None:
pool = nn.MaxPool2d(kernel_size=(14, 14), stride=(14, 14))
mask = self.mask_transform(mask)
mask = mask.to(image.device, dtype=image.dtype)
downsampled_mask = pool(mask)
flattened_mask = downsampled_mask.view(downsampled_mask.shape[0], -1)
flattened_mask = flattened_mask.unsqueeze(-1)
if self.use_cls_token:
flattened_mask = torch.cat(
[torch.ones(flattened_mask.shape[0], 1, 1, device=flattened_mask.device, dtype=flattened_mask.dtype),
flattened_mask], dim=1)
valid_mask = (flattened_mask != -1).float()
masked_hidden_state = last_hidden_state * valid_mask
valid_mask_bool = valid_mask.squeeze(-1) > 0
valid_counts = valid_mask_bool.sum(dim=1)
max_valid_tokens = valid_counts.max().item()
batch_indices = torch.arange(valid_mask_bool.shape[0], device=valid_mask_bool.device)
batch_indices = batch_indices.unsqueeze(1).expand(-1, valid_mask_bool.shape[1])
flat_batch_indices = batch_indices[valid_mask_bool]
flat_token_indices = torch.arange(valid_mask_bool.shape[1], device=valid_mask_bool.device)
flat_token_indices = flat_token_indices.unsqueeze(0).expand(valid_mask_bool.shape[0], -1)
flat_token_indices = flat_token_indices[valid_mask_bool]
valid_tokens = masked_hidden_state[flat_batch_indices, flat_token_indices]
# Create output tensor with special padding value (-1) instead of zeros
final_output = torch.full(
(valid_mask_bool.shape[0], max_valid_tokens, last_hidden_state.shape[-1]),
-1.0, # Use -1 as padding value to clearly distinguish from valid tokens
device=last_hidden_state.device, dtype=last_hidden_state.dtype
)
cum_counts = torch.cumsum(valid_counts, dim=0) - valid_counts
for i in range(valid_mask_bool.shape[0]):
if valid_counts[i] > 0:
start_idx = cum_counts[i]
end_idx = start_idx + valid_counts[i]
final_output[i, :valid_counts[i]] = valid_tokens[start_idx:end_idx]
return final_output
return last_hidden_state
def unconditional_embedding(self, batch_size, **kwargs):
device = next(self.model.parameters()).device
dtype = next(self.model.parameters()).dtype
num_tokens = kwargs.get('num_tokens', self.num_patches)
zero = torch.zeros(
batch_size,
num_tokens,
self.model.config.hidden_size,
device=device,
dtype=dtype,
)
return zero
class CLIPImageEncoder(ImageEncoder):
MODEL_CLASS = CLIPVisionModelWithProjection
MODEL_CONFIG_CLASS = CLIPVisionConfig
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
class DinoImageEncoder(ImageEncoder):
MODEL_CLASS = Dinov2Model
MODEL_CONFIG_CLASS = Dinov2Config
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
class DinoImageEncoderMV(DinoImageEncoder):
def __init__(
self,
version=None,
config=None,
use_cls_token=True,
image_size=224,
view_num=4,
**kwargs,
):
super().__init__(version, config, use_cls_token, image_size, **kwargs)
self.view_num = view_num
self.num_patches = self.num_patches
pos = np.arange(self.view_num, dtype=np.float32)
view_embedding = torch.from_numpy(
get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)).float()
view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1)
self.view_embed = view_embedding.unsqueeze(0)
def forward(self, image, mask=None, value_range=(-1, 1), view_idxs=None):
if value_range is not None:
low, high = value_range
image = (image - low) / (high - low)
image = image.to(self.model.device, dtype=self.model.dtype)
bs, num_views, c, h, w = image.shape
image = image.view(bs * num_views, c, h, w)
inputs = self.transform(image)
outputs = self.model(inputs)
last_hidden_state = outputs.last_hidden_state
last_hidden_state = last_hidden_state.view(
bs, num_views, last_hidden_state.shape[-2],
last_hidden_state.shape[-1]
)
view_embedding = self.view_embed.to(last_hidden_state.dtype).to(last_hidden_state.device)
if view_idxs is not None:
assert len(view_idxs) == bs
view_embeddings = []
for i in range(bs):
view_idx = view_idxs[i]
assert num_views == len(view_idx)
view_embeddings.append(self.view_embed[:, view_idx, ...])
view_embedding = torch.cat(view_embeddings, 0).to(last_hidden_state.dtype).to(last_hidden_state.device)
if num_views != self.view_num:
view_embedding = view_embedding[:, :num_views, ...]
last_hidden_state = last_hidden_state + view_embedding
last_hidden_state = last_hidden_state.view(bs, num_views * last_hidden_state.shape[-2],
last_hidden_state.shape[-1])
return last_hidden_state
def unconditional_embedding(self, batch_size, view_idxs=None, **kwargs):
device = next(self.model.parameters()).device
dtype = next(self.model.parameters()).dtype
zero = torch.zeros(
batch_size,
self.num_patches * len(view_idxs[0]),
self.model.config.hidden_size,
device=device,
dtype=dtype,
)
return zero
def build_image_encoder(config):
if config['type'] == 'CLIPImageEncoder':
return CLIPImageEncoder(**config['kwargs'])
elif config['type'] == 'DinoImageEncoder':
return DinoImageEncoder(**config['kwargs'])
elif config['type'] == 'DinoImageEncoderMV':
return DinoImageEncoderMV(**config['kwargs'])
else:
raise ValueError(f'Unknown image encoder type: {config["type"]}')
class DualImageEncoder(nn.Module):
def __init__(
self,
main_image_encoder,
additional_image_encoder,
):
super().__init__()
self.main_image_encoder = build_image_encoder(main_image_encoder)
self.additional_image_encoder = build_image_encoder(additional_image_encoder)
def forward(self, image, mask=None, **kwargs):
outputs = {
'main': self.main_image_encoder(image, mask=mask, **kwargs),
'additional': self.additional_image_encoder(image, mask=mask, **kwargs),
}
return outputs
def unconditional_embedding(self, batch_size, **kwargs):
outputs = {
'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
'additional': self.additional_image_encoder.unconditional_embedding(batch_size, **kwargs),
}
return outputs
class SingleImageEncoder(nn.Module):
def __init__(
self,
main_image_encoder,
drop_ratio=0.1,
):
super().__init__()
self.main_image_encoder = build_image_encoder(main_image_encoder)
self.drop_ratio = drop_ratio
# self.disable_drop = disable_drop
def forward(self, image, disable_drop=True, mask=None, **kwargs):
outputs = {
'main': self.main_image_encoder(image, mask=mask, **kwargs),
}
if disable_drop:
return outputs
else:
random_p = torch.rand(len(image), device='cuda')
remain_bool_tensor = random_p > self.drop_ratio
outputs['main'] *= remain_bool_tensor.view(-1,1,1)
return outputs
outputs = {
'main': self.main_image_encoder(image, mask=mask, **kwargs),
}
return outputs
def unconditional_embedding(self, batch_size, **kwargs):
outputs = {
'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
}
return outputs
================================================
FILE: ultrashape/models/denoisers/__init__.py
================================================
# ==============================================================================
# Original work Copyright (c) 2025 Tencent.
# Modified work Copyright (c) 2025 UltraShape Team.
#
# Modified by UltraShape on 2025.12.25
# ==============================================================================
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
from .dit_mask import RefineDiT
================================================
FILE: ultrashape/models/denoisers/dit_mask.py
================================================
# ==============================================================================
# Original work Copyright (c) 2025 Tencent.
# Modified work Copyright (c) 2025 UltraShape Team.
#
# Modified by UltraShape on 2025.12.25
# ==============================================================================
# Open Source Model Licensed under the Apache License Version 2.0
# and Other Licenses of the Third-Party Components therein:
# The below Model in this distribution may have been modified by THL A29 Limited
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
# The below software and/or models in this distribution may have been
# modified by THL A29 Limited ("Tencent Modifications").
# All Tencent Modifications are Copyright (C) THL A29 Limited.
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import os
import yaml
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .moe_layers import MoEBlock
from ...utils import logger, synchronize_timer, smart_load_model
from flash_attn import flash_attn_varlen_func
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class Timesteps(nn.Module):
def __init__(self,
num_channels: int,
downscale_freq_shift: float = 0.0,
scale: int = 1,
max_period: int = 10000
):
super().__init__()
self.num_channels = num_channels
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
self.max_period = max_period
def forward(self, timesteps):
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
embedding_dim = self.num_channels
half_dim = embedding_dim // 2
exponent = -math.log(self.max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - self.downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
emb = self.scale * emb
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, cond_proj_dim=None, out_size=None):
super().__init__()
if out_size is None:
out_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(hidden_size, frequency_embedding_size, bias=True),
nn.GELU(),
nn.Linear(frequency_embedding_size, out_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, frequency_embedding_size, bias=False)
self.time_embed = Timesteps(hidden_size)
def forward(self, t, condition):
t_freq = self.time_embed(t).type(self.mlp[0].weight.dtype)
# t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
if condition is not None:
t_freq = t_freq + self.cond_proj(condition)
t = self.mlp(t_freq)
t = t.unsqueeze(dim=1)
return t
class MLP(nn.Module):
def __init__(self, *, width: int):
super().__init__()
self.width = width
self.fc1 = nn.Linear(width, width * 4)
self.fc2 = nn.Linear(width * 4, width)
self.gelu = nn.GELU()
def forward(self, x):
return self.fc2(self.gelu(self.fc1(x)))
class CrossAttention(nn.Module):
def __init__(
self,
qdim,
kdim,
num_heads,
qkv_bias=True,
qk_norm=False,
norm_layer=nn.LayerNorm,
**kwargs,
):
super().__init__()
self.qdim = qdim
self.kdim = kdim
self.num_heads = num_heads
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
self.head_dim = self.qdim // num_heads
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
self.scale = self.head_dim ** -0.5
self.to_q = nn.Linear(qdim, qdim, bias=qkv_bias)
self.to_k = nn.Linear(kdim, qdim, bias=qkv_bias)
self.to_v = nn.Linear(kdim, qdim, bias=qkv_bias)
# TODO: eps should be 1 / 65530 if using fp16
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.out_proj = nn.Linear(qdim, qdim, bias=True)
def forward(self, x, y):
"""
Parameters
----------
x: torch.Tensor
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
y: torch.Tensor
(batch, seqlen2, hidden_dim2) - may contain padding (marked with -1)
freqs_cis_img: torch.Tensor
(batch, hidden_dim // 2), RoPE for image
"""
b, s1, c = x.shape # [b, s1, D]
# Detect padding tokens: check if all values in the feature dimension are -1
# y_mask: [b, s2], True for valid tokens, False for padding
y_mask = (y != -1).any(dim=-1) # [b, s2]
has_padding = not y_mask.all()
_, s2, c = y.shape # [b, s2, 1024]
q = self.to_q(x)
k = self.to_k(y)
v = self.to_v(y)
kv = torch.cat((k, v), dim=-1)
split_size = kv.shape[-1] // self.num_heads // 2
kv = kv.view(1, -1, self.num_heads, split_size * 2)
k, v = torch.split(kv, split_size, dim=-1)
q = q.view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
k = k.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
v = v.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
q = self.q_norm(q)
k = self.k_norm(k)
if has_padding:
seqlens_k = y_mask.sum(dim=1).int()
q_flat = q.reshape(-1, self.num_heads, self.head_dim)
# For k, v: only keep valid tokens (remove padding)
# Create indices for valid tokens
valid_indices = []
cu_seqlens_k = [0]
for i in range(b):
valid_len = seqlens_k[i].item()
batch_indices = torch.arange(valid_len, device=y.device) + i * s2
valid_indices.append(batch_indices)
cu_seqlens_k.append(cu_seqlens_k[-1] + valid_len)
valid_indices = torch.cat(valid_indices)
k_flat = k.reshape(b * s2, self.num_heads, self.head_dim)[valid_indices] # [total_k, h, d]
v_flat = v.reshape(b * s2, self.num_heads, self.head_dim)[valid_indices] # [total_k, h, d]
# Create cumulative sequence lengths
cu_seqlens_q = torch.arange(0, (b + 1) * s1, s1, dtype=torch.int32, device=x.device)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, device=x.device)
# Call flash attention varlen
q_flat = q_flat.to(torch.bfloat16)
k_flat = k_flat.to(torch.bfloat16)
v_flat = v_flat.to(torch.bfloat16)
context = flash_attn_varlen_func(
q_flat, k_flat, v_flat,
cu_seqlens_q, cu_seqlens_k,
s1, seqlens_k.max().item(),
dropout_p=0.0,
softmax_scale=None,
causal=False
)
context = context.reshape(b, s1, -1)
else:
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=True
):
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.num_heads), (q, k, v))
attn_mask = None
context = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask
).transpose(1, 2).reshape(b, s1, -1)
out = self.out_proj(context)
return out
class Attention(nn.Module):
"""
We rename some layer names to align with flash attention
"""
def __init__(
self,
dim,
num_heads,
qkv_bias=True,
qk_norm=False,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
self.head_dim = self.dim // num_heads
# This assertion is aligned with flash attention
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
self.scale = self.head_dim ** -0.5
self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.out_proj = nn.Linear(dim, dim)
# def forward(self, x):
def forward(self, x, rotary_cos=None, rotary_sin=None):
B, N, C = x.shape
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
qkv = torch.cat((q, k, v), dim=-1)
split_size = qkv.shape[-1] // self.num_heads // 3
qkv = qkv.view(1, -1, self.num_heads, split_size * 3)
q, k, v = torch.split(qkv, split_size, dim=-1)
q = q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, s, d]
k = k.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, s, d]
v = v.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
q = self.q_norm(q) # [b, h, s, d]
k = self.k_norm(k) # [b, h, s, d]
# ========================= Apply RoPE =========================
if rotary_cos is not None:
q = apply_rotary_emb(q, rotary_cos, rotary_sin)
k = apply_rotary_emb(k, rotary_cos, rotary_sin)
# ==============================================================
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=True
):
x = F.scaled_dot_product_attention(q, k, v)
x = x.transpose(1, 2).reshape(B, N, -1)
x = self.out_proj(x)
return x
class DiTBlock(nn.Module):
def __init__(
self,
hidden_size,
c_emb_size,
num_heads,
text_states_dim=1024,
use_flash_attn=False,
qk_norm=False,
norm_layer=nn.LayerNorm,
qk_norm_layer=nn.RMSNorm,
init_scale=1.0,
qkv_bias=True,
skip_connection=True,
timested_modulate=False,
use_moe: bool = False,
num_experts: int = 8,
moe_top_k: int = 2,
**kwargs,
):
super().__init__()
self.use_flash_attn = use_flash_attn
use_ele_affine = True
# ========================= Self-Attention =========================
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
norm_layer=qk_norm_layer)
# ========================= FFN =========================
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
# ========================= Add =========================
# Simply use add like SDXL.
self.timested_modulate = timested_modulate
if self.timested_modulate:
self.default_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(c_emb_size, hidden_size, bias=True)
)
# ========================= Cross-Attention =========================
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
qk_norm=qk_norm, norm_layer=qk_norm_layer, init_scale=init_scale)
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
if skip_connection:
self.skip_norm = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
else:
self.skip_linear = None
self.use_moe = use_moe
if self.use_moe:
self.moe = MoEBlock(
hidden_size,
num_experts=num_experts,
moe_top_k=moe_top_k,
dropout=0.0,
activation_fn="gelu",
final_dropout=False,
ff_inner_dim=int(hidden_size * 4.0),
ff_bias=True,
)
else:
self.mlp = MLP(width=hidden_size)
def forward(self, x, c=None, text_states=None, skip_value=None, rotary_cos=None, rotary_sin=None):
if self.skip_linear is not None:
cat = torch.cat([skip_value, x], dim=-1)
x = self.skip_linear(cat)
x = self.skip_norm(x)
# Self-Attention
if self.timested_modulate:
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
x = x + shift_msa
attn_out = self.attn1(self.norm1(x), rotary_cos=rotary_cos, rotary_sin=rotary_sin)
x = x + attn_out
# Cross-Attention
x = x + self.attn2(self.norm2(x), text_states)
# FFN Layer
mlp_inputs = self.norm3(x)
if self.use_moe:
x = x + self.moe(mlp_inputs)
else:
x = x + self.mlp(mlp_inputs)
return x
class AttentionPool(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x, attention_mask=None):
x = x.permute(1, 0, 2) # NLC -> LNC
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(-1).permute(1, 0, 2)
global_emb = (x * attention_mask).sum(dim=0) / attention_mask.sum(dim=0)
x = torch.cat([global_emb[None,], x], dim=0)
else:
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, final_hidden_size, out_channels):
super().__init__()
self.final_hidden_size = final_hidden_size
self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=True, eps=1e-6)
self.linear = nn.Linear(final_hidden_size, out_channels, bias=True)
def forward(self, x):
x = self.norm_final(x)
x = x[:, 1:]
x = self.linear(x)
return x
class RefineDiT(nn.Module):
@classmethod
@synchronize_timer('Refine Model Loading')
def from_single_file(
cls,
ckpt_path,
config_path,
device='cuda',
dtype=torch.float16,
use_safetensors=None,
**kwargs,
):
# load config
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# load ckpt
if use_safetensors:
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Model file {ckpt_path} not found")
logger.info(f"Loading model from {ckpt_path}")
if use_safetensors:
import safetensors.torch
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
else:
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
if 'model' in ckpt:
ckpt = ckpt['model']
if 'model' in config:
config = config['model']
model_kwargs = config['params']
model_kwargs.update(kwargs)
model = cls(**model_kwargs)
model.load_state_dict(ckpt)
model.to(device=device, dtype=dtype)
return model
@classmethod
def from_pretrained(
cls,
model_path,
device='cuda',
dtype=torch.float16,
use_safetensors=False,
variant='fp16',
subfolder='hunyuan3d-dit-v2-1',
**kwargs,
):
config_path, ckpt_path = smart_load_model(
model_path,
subfolder=subfolder,
use_safetensors=use_safetensors,
variant=variant
)
return cls.from_single_file(
ckpt_path,
config_path,
device=device,
dtype=dtype,
use_safetensors=use_safetensors,
**kwargs
)
def __init__(
self,
input_size=1024,
in_channels=4,
hidden_size=1024,
context_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4.0,
norm_type='layer',
qk_norm_type='rms',
qk_norm=False,
text_len=257,
guidance_cond_proj_dim=None,
qkv_bias=True,
num_moe_layers: int = 6,
num_experts: int = 8,
moe_top_k: int = 2,
voxel_query_res: int = 128,
**kwargs
):
super().__init__()
self.input_size = input_size
self.depth = depth
self.in_channels = in_channels
self.out_channels = in_channels
self.num_heads = num_heads
self.hidden_size = hidden_size
self.norm = nn.LayerNorm if norm_type == 'layer' else nn.RMSNorm
self.qk_norm = nn.RMSNorm if qk_norm_type == 'rms' else nn.LayerNorm
self.context_dim = context_dim
self.voxel_query_res = voxel_query_res
self.guidance_cond_proj_dim = guidance_cond_proj_dim
self.text_len = text_len
self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim=guidance_cond_proj_dim)
self.blocks = nn.ModuleList([
DiTBlock(hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
text_states_dim=context_dim,
qk_norm=qk_norm,
norm_layer=self.norm,
qk_norm_layer=self.qk_norm,
skip_connection=layer > depth // 2,
qkv_bias=qkv_bias,
use_moe=True if depth - layer <= num_moe_layers else False,
num_experts=num_experts,
moe_top_k=moe_top_k
)
for layer in range(depth)
])
self.depth = depth
self.final_layer = FinalLayer(hidden_size, self.out_channels)
def forward(self, x, t, contexts, **kwargs):
cond = contexts['main']
t = self.t_embedder(t, condition=kwargs.get('guidance_cond'))
x = self.x_embedder(x)
c = t
##########################################
head_dim = self.blocks[0].attn1.head_dim
num_cond_tokens = c.shape[1] if c.dim() == 3 else 1
device = x.device
cond_cos = torch.ones(x.shape[0], num_cond_tokens, head_dim, device=device)
cond_sin = torch.zeros(x.shape[0], num_cond_tokens, head_dim, device=device)
voxel_cond = kwargs.get('voxel_cond')
# rotary_cos_vox, rotary_sin_vox = precompute_freqs_cis_3d(head_dim, voxel_cond)
rotary_cos_vox, rotary_sin_vox = precompute_freqs_cis_3d_interpolated(
head_dim, voxel_cond, current_res=self.voxel_query_res)
rotary_cos = torch.cat([cond_cos, rotary_cos_vox], dim=1)
rotary_sin = torch.cat([cond_sin, rotary_sin_vox], dim=1)
##########################################
x = torch.cat([c, x], dim=1)
skip_value_list = []
for layer, block in enumerate(self.blocks):
skip_value = None if layer <= self.depth // 2 else skip_value_list.pop()
x = block(x, c, cond, rotary_cos=rotary_cos, rotary_sin=rotary_sin, skip_value=skip_value)
if layer < self.depth // 2:
skip_value_list.append(x)
x = self.final_layer(x)
return x
def apply_rotary_emb(x, cos, sin):
"""
x: [B, H, N, D]
cos, sin: [B, N, D]
"""
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
return (x * cos) + (rotate_half(x) * sin)
def precompute_freqs_cis_3d(dim: int, grid_indices: torch.Tensor, theta: float = 10000.0):
"""
grid_indices: [B, N, 3] voxel idx
"""
dim_x = dim // 3
dim_y = dim // 3
dim_z = dim - dim_x - dim_y
device = grid_indices.device
freqs_x = 1.0 / (theta ** (torch.arange(0, dim_x, 2, device=device).float() / dim_x))
freqs_y = 1.0 / (theta ** (torch.arange(0, dim_y, 2, device=device).float() / dim_y))
freqs_z = 1.0 / (theta ** (torch.arange(0, dim_z, 2, device=device).float() / dim_z))
x_idx = grid_indices[..., 0].float()
y_idx = grid_indices[..., 1].float()
z_idx = grid_indices[..., 2].float()
args_x = x_idx.unsqueeze(-1) * freqs_x.unsqueeze(0).unsqueeze(0)
args_y = y_idx.unsqueeze(-1) * freqs_y.unsqueeze(0).unsqueeze(0)
args_z = z_idx.unsqueeze(-1) * freqs_z.unsqueeze(0).unsqueeze(0)
args = torch.cat([args_x, args_y, args_z], dim=-1)
args = torch.cat([args, args], dim=-1)
return torch.cos(args), torch.sin(args)
def precompute_freqs_cis_3d_interpolated(
dim: int,
grid_indices: torch.Tensor,
theta: float = 10000.0,
trained_res: float = 128.0, # training resolution
current_res: float = 256.0, # inference resolution
):
scale_factor = current_res / trained_res
dim_x = dim // 3
dim_y = dim // 3
dim_z = dim - dim_x - dim_y
device = grid_indices.device
freqs_x = 1.0 / (theta ** (torch.arange(0, dim_x, 2, device=device).float() / dim_x))
freqs_y = 1.0 / (theta ** (torch.arange(0, dim_y, 2, device=device).float() / dim_y))
freqs_z = 1.0 / (theta ** (torch.arange(0, dim_z, 2, device=device).float() / dim_z))
num_freqs_x = dim_x // 2 + (dim_x % 2)
num_freqs_y = dim_y // 2 + (dim_y % 2)
target_len = dim // 2
freqs_x = freqs_x[:num_freqs_x]
freqs_y = freqs_y[:num_freqs_y]
freqs_z = freqs_z[:(target_len - len(freqs_x) - len(freqs_y))]
input_x = grid_indices[..., 0].float()
input_y = grid_indices[..., 1].float()
input_z = grid_indices[..., 2].float()
# Apply Scaling
pos_x = input_x / scale_factor
pos_y = input_y / scale_factor
pos_z = input_z / scale_factor
# pos * freq
args_x = pos_x.unsqueeze(-1) * freqs_x.unsqueeze(0).unsqueeze(0)
args_y = pos_y.unsqueeze(-1) * freqs_y.unsqueeze(0).unsqueeze(0)
args_z = pos_z.unsqueeze(-1) * freqs_z.unsqueeze(0).unsqueeze(0)
args = torch.cat([args_x, args_y, args_z], dim=-1)
args = torch.cat([args, args], dim=-1)
return torch.cos(args), torch.sin(args)
================================================
FILE: ultrashape/models/denoisers/moe_layers.py
================================================
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import torch
import torch.nn as nn
import numpy as np
import math
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
import torch.nn.functional as F
from diffusers.models.attention import FeedForward
class AddAuxiliaryLoss(torch.autograd.Function):
"""
The trick function of adding auxiliary (aux) loss,
which includes the gradient of the aux loss during backpropagation.
"""
@staticmethod
def forward(ctx, x, loss):
assert loss.numel() == 1
ctx.dtype = loss.dtype
ctx.required_aux_loss = loss.requires_grad
return x
@staticmethod
def backward(ctx, grad_output):
grad_loss = None
if ctx.required_aux_loss:
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
return grad_output, grad_loss
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01):
super().__init__()
self.top_k = num_experts_per_tok
self.n_routed_experts = num_experts
self.scoring_func = 'softmax'
self.alpha = aux_loss_alpha
self.seq_aux = False
# topk selection algorithm
self.norm_topk_prob = False
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
# print(bsz, seq_len, h)
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
### select top-k experts
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
### expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(
1,
topk_idx_for_aux_loss,
torch.ones(
bsz, seq_len * aux_topk,
device=hidden_states.device
)
).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean()
aux_loss = aux_loss * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1),
num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = None
return topk_idx, topk_weight, aux_loss
class MoEBlock(nn.Module):
def __init__(self, dim, num_experts=8, moe_top_k=2,
activation_fn = "gelu", dropout=0.0, final_dropout = False,
ff_inner_dim = None, ff_bias = True):
super().__init__()
self.moe_top_k = moe_top_k
self.experts = nn.ModuleList([
FeedForward(dim,dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias)
for i in range(num_experts)])
self.gate = MoEGate(embed_dim=dim, num_experts=num_experts, num_experts_per_tok=moe_top_k)
self.shared_experts = FeedForward(dim,dropout=dropout, activation_fn=activation_fn,
final_dropout=final_dropout, inner_dim=ff_inner_dim,
bias=ff_bias)
def initialize_weight(self):
pass
def forward(self, hidden_states):
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim=0)
y = torch.empty_like(hidden_states, dtype=hidden_states.dtype)
for i, expert in enumerate(self.experts):
tmp = expert(hidden_states[flat_topk_idx == i])
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
y = AddAuxiliaryLoss.apply(y, aux_loss)
else:
y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.moe_top_k
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# for fp16 and other dtype
expert_cache = expert_cache.to(expert_out.dtype)
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]),
expert_out,
reduce='sum')
return expert_cache
================================================
FILE: ultrashape/models/diffusion/flow_matching_dit_trainer.py
================================================
import os
from contextlib import contextmanager
from typing import List, Tuple, Optional, Union
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities import rank_zero_only
from ultrashape.pipelines import export_to_trimesh
from ...utils.ema import LitEma
from ...utils.misc import instantiate_from_config, instantiate_non_trainable_model, instantiate_vae_model, instantiate_vae_model_local
class Diffuser(pl.LightningModule):
def __init__(
self,
*,
vae_config,
cond_config,
dit_cfg,
scheduler_cfg,
optimizer_cfg,
pipeline_cfg=None,
image_processor_cfg=None,
lora_config=None,
ema_config=None,
scale_by_std: bool = False,
z_scale_factor: float = 1.0,
ckpt_path: Optional[str] = None,
ignore_keys: Union[Tuple[str], List[str]] = (),
torch_compile: bool = False,
):
super().__init__()
# ========= init optimizer config ========= #
self.optimizer_cfg = optimizer_cfg
# ========= init diffusion scheduler ========= #
self.scheduler_cfg = scheduler_cfg
self.sampler = None
if 'transport' in scheduler_cfg:
self.transport = instantiate_from_config(scheduler_cfg.transport)
self.sampler = instantiate_from_config(scheduler_cfg.sampler, transport=self.transport)
self.sample_fn = self.sampler.sample_ode(**scheduler_cfg.sampler.ode_params)
# ========= init the model ========= #
self.dit_cfg = dit_cfg
self.model = instantiate_from_config(dit_cfg, device=None, dtype=None)
self.cond_stage_model = instantiate_from_config(cond_config)
self.ckpt_path = ckpt_path
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
# ========= config lora model ========= #
if lora_config is not None:
from peft import LoraConfig, get_peft_model
loraconfig = LoraConfig(
r=lora_config.rank,
lora_alpha=lora_config.rank,
target_modules=lora_config.get('target_modules')
)
self.model = get_peft_model(self.model, loraconfig)
# ========= config ema model ========= #
self.ema_config = ema_config
if self.ema_config is not None:
if self.ema_config.ema_model == 'DSEma':
# from michelangelo.models.modules.ema_deepspeed import DSEma
from ..utils.ema_deepspeed import DSEma
self.model_ema = DSEma(self.model, decay=self.ema_config.ema_decay)
else:
self.model_ema = LitEma(self.model, decay=self.ema_config.ema_decay)
#do not initilize EMA weight from ckpt path, since I need to change moe layers
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
# ========= init vae at last to prevent it is overridden by loaded ckpt ========= #
self.first_stage_model = instantiate_vae_model_local(vae_config)
self.first_stage_model.enable_flashvdm_decoder()
self.scale_by_std = scale_by_std
if scale_by_std:
self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
else:
self.z_scale_factor = z_scale_factor
# ========= init pipeline for inference ========= #
self.image_processor_cfg = image_processor_cfg
self.image_processor = None
if self.image_processor_cfg is not None:
self.image_processor = instantiate_from_config(self.image_processor_cfg)
self.pipeline_cfg = pipeline_cfg
from ...schedulers import FlowMatchEulerDiscreteScheduler
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
self.pipeline = instantiate_from_config(
pipeline_cfg,
vae=self.first_stage_model,
model=self.model,
scheduler=scheduler,
conditioner=self.cond_stage_model,
image_processor=self.image_processor,
)
# ========= torch compile to accelerate ========= #
self.torch_compile = torch_compile
if self.torch_compile:
torch.nn.Module.compile(self.model)
torch.nn.Module.compile(self.first_stage_model)
torch.nn.Module.compile(self.cond_stage_model)
print(f'*' * 100)
print(f'Compile model for acceleration')
print(f'*' * 100)
@contextmanager
def ema_scope(self, context=None):
if self.ema_config is not None and self.ema_config.get('ema_inference', False):
self.model_ema.store(self.model)
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.ema_config is not None and self.ema_config.get('ema_inference', False):
self.model_ema.restore(self.model)
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=()):
ckpt = torch.load(path, map_location="cpu")
if 'state_dict' not in ckpt:
# deepspeed ckpt
state_dict = {}
for k in ckpt.keys():
new_k = k.replace('_forward_module.', '')
state_dict[new_k] = ckpt[k]
else:
state_dict = ckpt["state_dict"]
keys = list(state_dict.keys())
for k in keys:
for ik in ignore_keys:
if ik in k:
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def on_load_checkpoint(self, checkpoint):
"""
The pt_model is trained separately, so we already have access to its
checkpoint and load it separately with `self.set_pt_model`.
However, the PL Trainer is strict about
checkpoint loading (not configurable), so it expects the loaded state_dict
to match exactly the keys in the model state_dict.
So, when loading the checkpoint, before matching keys, we add all pt_model keys
from self.state_dict() to the checkpoint state dict, so that they match
"""
for key in self.state_dict().keys():
if key.startswith("model_ema") and key not in checkpoint["state_dict"]:
checkpoint["state_dict"][key] = self.state_dict()[key]
def configure_optimizers(self) -> Tuple[List, List]:
lr = self.learning_rate
params_list = []
trainable_parameters = list(self.model.parameters())
params_list.append({'params': trainable_parameters, 'lr': lr})
no_decay = ['bias', 'norm.weight', 'norm.bias', 'norm1.weight', 'norm1.bias', 'norm2.weight', 'norm2.bias']
if self.optimizer_cfg.get('train_image_encoder', False):
image_encoder_parameters = list(self.cond_stage_model.named_parameters())
image_encoder_parameters_decay = [param for name, param in image_encoder_parameters if
not any((no_decay_name in name) for no_decay_name in no_decay)]
image_encoder_parameters_nodecay = [param for name, param in image_encoder_parameters if
any((no_decay_name in name) for no_decay_name in no_decay)]
# filter trainable params
image_encoder_parameters_decay = [param for param in image_encoder_parameters_decay if
param.requires_grad]
image_encoder_parameters_nodecay = [param for param in image_encoder_parameters_nodecay if
param.requires_grad]
print(f"Image Encoder Params: {len(image_encoder_parameters_decay)} decay, ")
print(f"Image Encoder Params: {len(image_encoder_parameters_nodecay)} nodecay, ")
image_encoder_lr = self.optimizer_cfg['image_encoder_lr']
image_encoder_lr_multiply = self.optimizer_cfg.get('image_encoder_lr_multiply', 1.0)
image_encoder_lr = image_encoder_lr if image_encoder_lr is not None else lr * image_encoder_lr_multiply
params_list.append(
{'params': image_encoder_parameters_decay, 'lr': image_encoder_lr,
'weight_decay': 0.05})
params_list.append(
{'params': image_encoder_parameters_nodecay, 'lr': image_encoder_lr,
'weight_decay': 0.})
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=params_list, lr=lr)
if hasattr(self.optimizer_cfg, 'scheduler'):
scheduler_func = instantiate_from_config(
self.optimizer_cfg.scheduler,
max_decay_steps=self.trainer.max_steps,
lr_max=lr
)
scheduler = {
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
"interval": "step",
"frequency": 1
}
schedulers = [scheduler]
else:
schedulers = []
optimizers = [optimizer]
return optimizers, schedulers
def on_train_batch_end(self, *args, **kwargs):
if self.ema_config is not None:
self.model_ema(self.model)
def on_train_epoch_start(self) -> None:
pl.seed_everything(self.trainer.global_rank)
def forward(self, batch, disable_drop):
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): #float32 for text
contexts = self.cond_stage_model(image=batch.get('image'), text=batch.get('text'), mask=batch.get('mask'), disable_drop=disable_drop)
with torch.autocast(device_type="cuda", dtype=torch.float16):
with torch.no_grad():
latents, voxel_idx = self.first_stage_model.encode(batch["surface"], sample_posterior=True, need_voxel=True)
latents = self.z_scale_factor * latents
# print(latents.shape)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = self.transport.training_losses(self.model, latents,
dict(contexts=contexts, voxel_cond=voxel_idx))["loss"].mean()
return loss
def training_step(self, batch, batch_idx, optimizer_idx=0):
loss = self.forward(batch, disable_drop=False)
split = 'train'
loss_dict = {
f"{split}/total_loss": loss.detach(),
f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
}
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
return loss
def validation_step(self, batch, batch_idx, optimizer_idx=0):
loss = self.forward(batch, disable_drop=True)
split = 'val'
loss_dict = {
f"{split}/total_loss": loss.detach(),
f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
}
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
return loss
@torch.no_grad()
def sample(self, batch, output_type='trimesh', **kwargs):
self.cond_stage_model.disable_drop = True
generator = torch.Generator().manual_seed(0)
with self.ema_scope("Sample"):
with torch.amp.autocast(device_type='cuda'):
try:
self.pipeline.device = self.device
self.pipeline.dtype = self.dtype
print("### USING PIPELINE ###")
print(f'device: {self.device} dtype : {self.dtype}')
additional_params = {'output_type':output_type}
image = batch.get("image", None)
mask = batch.get('mask', None)
outputs = self.pipeline(image=image,
mask=mask,
generator=generator,
box_v=1.0,
mc_level=0.0,
octree_resolution=1024,
**additional_params)
except Exception as e:
import traceback
traceback.print_exc()
print(f"Unexpected {e=}, {type(e)=}")
with open("error.txt", "a") as f:
f.write(str(e))
f.write(traceback.format_exc())
f.write("\n")
outputs = [None]
self.cond_stage_model.disable_drop = False
return [outputs]
================================================
FILE: ultrashape/models/diffusion/transport/__init__.py
================================================
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
# which is licensed under the MIT License.
#
# MIT License
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from .transport import Transport, ModelType, WeightType, PathType, Sampler
def create_transport(
path_type='Linear',
prediction="velocity",
loss_weight=None,
train_eps=None,
sample_eps=None,
train_sample_type="uniform",
mean = 0.0,
std = 1.0,
shift_scale = 1.0,
):
"""function for creating Transport object
**Note**: model prediction defaults to velocity
Args:
- path_type: type of path to use; default to linear
- learn_score: set model prediction to score
- learn_noise: set model prediction to noise
- velocity_weighted: weight loss by velocity weight
- likelihood_weighted: weight loss by likelihood weight
- train_eps: small epsilon for avoiding instability during training
- sample_eps: small epsilon for avoiding instability during sampling
"""
if prediction == "noise":
model_type = ModelType.NOISE
elif prediction == "score":
model_type = ModelType.SCORE
else:
model_type = ModelType.VELOCITY
if loss_weight == "velocity":
loss_type = WeightType.VELOCITY
elif loss_weight == "likelihood":
loss_type = WeightType.LIKELIHOOD
else:
loss_type = WeightType.NONE
path_choice = {
"Linear": PathType.LINEAR,
"GVP": PathType.GVP,
"VP": PathType.VP,
}
path_type = path_choice[path_type]
if (path_type in [PathType.VP]):
train_eps = 1e-5 if train_eps is None else train_eps
sample_eps = 1e-3 if train_eps is None else sample_eps
elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
train_eps = 1e-3 if train_eps is None else train_eps
sample_eps = 1e-3 if train_eps is None else sample_eps
else: # velocity & [GVP, LINEAR] is stable everywhere
train_eps = 0
sample_eps = 0
# create flow state
state = Transport(
model_type=model_type,
path_type=path_type,
loss_type=loss_type,
train_eps=train_eps,
sample_eps=sample_eps,
train_sample_type=train_sample_type,
mean=mean,
std=std,
shift_scale =shift_scale,
)
return state
================================================
FILE: ultrashape/models/diffusion/transport/integrators.py
================================================
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
# which is licensed under the MIT License.
#
# MIT License
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import numpy as np
import torch as th
import torch.nn as nn
from torchdiffeq import odeint
from functools import partial
from tqdm import tqdm
class sde:
"""SDE solver class"""
def __init__(
self,
drift,
diffusion,
*,
t0,
t1,
num_steps,
sampler_type,
):
assert t0 < t1, "SDE sampler has to be in forward time"
self.num_timesteps = num_steps
self.t = th.linspace(t0, t1, num_steps)
self.dt = self.t[1] - self.t[0]
self.drift = drift
self.diffusion = diffusion
self.sampler_type = sampler_type
def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
w_cur = th.randn(x.size()).to(x)
t = th.ones(x.size(0)).to(x) * t
dw = w_cur * th.sqrt(self.dt)
drift = self.drift(x, t, model, **model_kwargs)
diffusion = self.diffusion(x, t)
mean_x = x + drift * self.dt
x = mean_x + th.sqrt(2 * diffusion) * dw
return x, mean_x
def __Heun_step(self, x, _, t, model, **model_kwargs):
w_cur = th.randn(x.size()).to(x)
dw = w_cur * th.sqrt(self.dt)
t_cur = th.ones(x.size(0)).to(x) * t
diffusion = self.diffusion(x, t_cur)
xhat = x + th.sqrt(2 * diffusion) * dw
K1 = self.drift(xhat, t_cur, model, **model_kwargs)
xp = xhat + self.dt * K1
K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
def __forward_fn(self):
"""TODO: generalize here by adding all private functions ending with steps to it"""
sampler_dict = {
"Euler": self.__Euler_Maruyama_step,
"Heun": self.__Heun_step,
}
try:
sampler = sampler_dict[self.sampler_type]
except:
raise NotImplementedError("Smapler type not implemented.")
return sampler
def sample(self, init, model, **model_kwargs):
"""forward loop of sde"""
x = init
mean_x = init
samples = []
sampler = self.__forward_fn()
for ti in self.t[:-1]:
with th.no_grad():
x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
samples.append(x)
return samples
class ode:
"""ODE solver class"""
def __init__(
self,
drift,
*,
t0,
t1,
sampler_type,
num_steps,
atol,
rtol,
):
assert t0 < t1, "ODE sampler has to be in forward time"
self.drift = drift
self.t = th.linspace(t0, t1, num_steps)
self.atol = atol
self.rtol = rtol
self.sampler_type = sampler_type
def sample(self, x, model, **model_kwargs):
device = x[0].device if isinstance(x, tuple) else x.device
def _fn(t, x):
t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
model_output = self.drift(x, t, model, **model_kwargs)
return model_output
t = self.t.to(device)
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
samples = odeint(
_fn,
x,
t,
method=self.sampler_type,
atol=atol,
rtol=rtol
)
return samples
================================================
FILE: ultrashape/models/diffusion/transport/path.py
================================================
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
# which is licensed under the MIT License.
#
# MIT License
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch as th
import numpy as np
from functools import partial
def expand_t_like_x(t, x):
"""Function to reshape time t to broadcastable dimension of x
Args:
t: [batch_dim,], time vector
x: [batch_dim,...], data point
"""
dims = [1] * (len(x.size()) - 1)
t = t.view(t.size(0), *dims)
return t
#################### Coupling Plans ####################
class ICPlan:
"""Linear Coupling Plan"""
def __init__(self, sigma=0.0):
self.sigma = sigma
def compute_alpha_t(self, t):
"""Compute the data coefficient along the path"""
return t, 1
def compute_sigma_t(self, t):
"""Compute the noise coefficient along the path"""
return 1 - t, -1
def compute_d_alpha_alpha_ratio_t(self, t):
"""Compute the ratio between d_alpha and alpha"""
return 1 / t
def compute_drift(self, x, t):
"""We always output sde according to score parametrization; """
t = expand_t_like_x(t, x)
alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
sigma_t, d_sigma_t = self.compute_sigma_t(t)
drift = alpha_ratio * x
diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
return -drift, diffusion
def compute_diffusion(self, x, t, form="constant", norm=1.0):
"""Compute the diffusion term of the SDE
Args:
x: [batch_dim, ...], data point
t: [batch_dim,], time vector
form: str, form of the diffusion term
norm: float, norm of the diffusion term
"""
t = expand_t_like_x(t, x)
choices = {
"constant": norm,
"SBDM": norm * self.compute_drift(x, t)[1],
"sigma": norm * self.compute_sigma_t(t)[0],
"linear": norm * (1 - t),
"decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
"inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
}
try:
diffusion = choices[form]
except KeyError:
raise NotImplementedError(f"Diffusion form {form} not implemented")
return diffusion
def get_score_from_velocity(self, velocity, x, t):
"""Wrapper function: transfrom velocity prediction model to score
Args:
velocity: [batch_dim, ...] shaped tensor; velocity model output
x: [batch_dim, ...] shaped tensor; x_t data point
t: [batch_dim,] time tensor
"""
t = expand_t_like_x(t, x)
alpha_t, d_alpha_t = self.compute_alpha_t(t)
sigma_t, d_sigma_t = self.compute_sigma_t(t)
mean = x
reverse_alpha_ratio = alpha_t / d_alpha_t
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
score = (reverse_alpha_ratio * velocity - mean) / var
return score
def get_noise_from_velocity(self, velocity, x, t):
"""Wrapper function: transfrom velocity prediction model to denoiser
Args:
velocity: [batch_dim, ...] shaped tensor; velocity model output
x: [batch_dim, ...] shaped tensor; x_t data point
t: [batch_dim,] time tensor
"""
t = expand_t_like_x(t, x)
alpha_t, d_alpha_t = self.compute_alpha_t(t)
sigma_t, d_sigma_t = self.compute_sigma_t(t)
mean = x
reverse_alpha_ratio = alpha_t / d_alpha_t
var = reverse_alpha_ratio * d_sigma_t - sigma_t
noise = (reverse_alpha_ratio * velocity - mean) / var
return noise
def get_velocity_from_score(self, score, x, t):
"""Wrapper function: transfrom score prediction model to velocity
Args:
score: [batch_dim, ...] shaped tensor; score model output
x: [batch_dim, ...] shaped tensor; x_t data point
t: [batch_dim,] time tensor
"""
t = expand_t_like_x(t, x)
drift, var = self.compute_drift(x, t)
velocity = var * score - drift
return velocity
def compute_mu_t(self, t, x0, x1):
"""Compute the mean of time-dependent density p_t"""
t = expand_t_like_x(t, x1)
alpha_t, _ = self.compute_alpha_t(t)
sigma_t, _ = self.compute_sigma_t(t)
# t*x1 + (1-t)*x0 ; t=0 x0; t=1 x1
return alpha_t * x1 + sigma_t * x0
def compute_xt(self, t, x0, x1):
"""Sample xt from time-dependent density p_t; rng is required"""
xt = self.compute_mu_t(t, x0, x1)
return xt
def compute_ut(self, t, x0, x1, xt):
"""Compute the vector field corresponding to p_t"""
t = expand_t_like_x(t, x1)
_, d_alpha_t = self.compute_alpha_t(t)
_, d_sigma_t = self.compute_sigma_t(t)
return d_alpha_t * x1 + d_sigma_t * x0
def plan(self, t, x0, x1):
xt = self.compute_xt(t, x0, x1)
ut = self.compute_ut(t, x0, x1, xt)
return t, xt, ut
class VPCPlan(ICPlan):
"""class for VP path flow matching"""
def __init__(self, sigma_min=0.1, sigma_max=20.0):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * \
(self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * \
(self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
def compute_alpha_t(self, t):
"""Compute coefficient of x1"""
alpha_t = self.log_mean_coeff(t)
alpha_t = th.exp(alpha_t)
d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
return alpha_t, d_alpha_t
def compute_sigma_t(self, t):
"""Compute coefficient of x0"""
p_sigma_t = 2 * self.log_mean_coeff(t)
sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
return sigma_t, d_sigma_t
def compute_d_alpha_alpha_ratio_t(self, t):
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
return self.d_log_mean_coeff(t)
def compute_drift(self, x, t):
"""Compute the drift term of the SDE"""
t = expand_t_like_x(t, x)
beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
return -0.5 * beta_t * x, beta_t / 2
class GVPCPlan(ICPlan):
def __init__(self, sigma=0.0):
super().__init__(sigma)
def compute_alpha_t(self, t):
"""Compute coefficient of x1"""
alpha_t = th.sin(t * np.pi / 2)
d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
return alpha_t, d_alpha_t
def compute_sigma_t(self, t):
"""Compute coefficient of x0"""
sigma_t = th.cos(t * np.pi / 2)
d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
return sigma_t, d_sigma_t
def compute_d_alpha_alpha_ratio_t(self, t):
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
return np.pi / (2 * th.tan(t * np.pi / 2))
================================================
FILE: ultrashape/models/diffusion/transport/transport.py
================================================
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
# which is licensed under the MIT License.
#
# MIT License
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch as th
import numpy as np
import logging
import enum
from . import path
from .utils import EasyDict, log_state, mean_flat
from .integrators import ode, sde
class ModelType(enum.Enum):
"""
Which type of output the model predicts.
"""
NOISE = enum.auto() # the model predicts epsilon
SCORE = enum.auto() # the model predicts \nabla \log p(x)
VELOCITY = enum.auto() # the model predicts v(x)
class PathType(enum.Enum):
"""
Which type of path to use.
"""
LINEAR = enum.auto()
GVP = enum.auto()
VP = enum.auto()
class WeightType(enum.Enum):
"""
Which type of weighting to use.
"""
NONE = enum.auto()
VELOCITY = enum.auto()
LIKELIHOOD = enum.auto()
class Transport:
def __init__(
self,
*,
model_type,
path_type,
loss_type,
train_eps,
sample_eps,
train_sample_type = "uniform",
**kwargs,
):
path_options = {
PathType.LINEAR: path.ICPlan,
PathType.GVP: path.GVPCPlan,
PathType.VP: path.VPCPlan,
}
self.loss_type = loss_type
self.model_type = model_type
self.path_sampler = path_options[path_type]()
self.train_eps = train_eps
self.sample_eps = sample_eps
self.train_sample_type = train_sample_type
if self.train_sample_type == "logit_normal":
self.mean = kwargs['mean']
self.std = kwargs['std']
self.shift_scale = kwargs['shift_scale']
print(f"using logit normal sample, shift scale is {self.shift_scale}")
def prior_logp(self, z):
'''
Standard multivariate normal prior
Assume z is batched
'''
shape = th.tensor(z.size())
N = th.prod(shape[1:])
_fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
return th.vmap(_fn)(z)
def check_interval(
self,
train_eps,
sample_eps,
*,
diffusion_form="SBDM",
sde=False,
reverse=False,
eval=False,
last_step_size=0.0,
):
t0 = 0
t1 = 1
eps = train_eps if not eval else sample_eps
if (type(self.path_sampler) in [path.VPCPlan]):
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
and (
self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
if reverse:
t0, t1 = 1 - t0, 1 - t1
return t0, t1
def sample(self, x1):
"""Sampling x0 & t based on shape of x1 (if needed)
Args:
x1 - data point; [batch, *dim]
"""
x0 = th.randn_like(x1)
if self.train_sample_type=="uniform":
t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
t = t.to(x1)
elif self.train_sample_type=="logit_normal":
t = th.randn((x1.shape[0],)) * self.std + self.mean
t = t.to(x1)
t = 1/(1+th.exp(-t))
t = np.sqrt(self.shift_scale)*t/(1+(np.sqrt(self.shift_scale)-1)*t)
return t, x0, x1
def training_losses(
self,
model,
x1,
model_kwargs=None
):
"""Loss for training the score model
Args:
- model: backbone model; could be score, noise, or velocity
- x1: datapoint
- model_kwargs: additional arguments for the model
"""
if model_kwargs == None:
model_kwargs = {}
t, x0, x1 = self.sample(x1)
t, xt, ut = self.path_sampler.plan(t, x0, x1)
model_output = model(xt, t, **model_kwargs)
B, *_, C = xt.shape
assert model_output.size() == (B, *xt.size()[1:-1], C)
terms = {}
terms['pred'] = model_output
if self.model_type == ModelType.VELOCITY:
terms['loss'] = mean_flat(((model_output - ut) ** 2))
else:
_, drift_var = self.path_sampler.compute_drift(xt, t)
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
if self.loss_type in [WeightType.VELOCITY]:
weight = (drift_var / sigma_t) ** 2
elif self.loss_type in [WeightType.LIKELIHOOD]:
weight = drift_var / (sigma_t ** 2)
elif self.loss_type in [WeightType.NONE]:
weight = 1
else:
raise NotImplementedError()
if self.model_type == ModelType.NOISE:
terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
else:
terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
return terms
def get_drift(
self
):
"""member function for obtaining the drift of the probability flow ODE"""
def score_ode(x, t, model, **model_kwargs):
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
model_output = model(x, t, **model_kwargs)
return (-drift_mean + drift_var * model_output) # by change of variable
def noise_ode(x, t, model, **model_kwargs):
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
model_output = model(x, t, **model_kwargs)
score = model_output / -sigma_t
return (-drift_mean + drift_var * score)
def velocity_ode(x, t, model, **model_kwargs):
model_output = model(x, t, **model_kwargs)
return model_output
if self.model_type == ModelType.NOISE:
drift_fn = noise_ode
elif self.model_type == ModelType.SCORE:
drift_fn = score_ode
else:
drift_fn = velocity_ode
def body_fn(x, t, model, **model_kwargs):
model_output = drift_fn(x, t, model, **model_kwargs)
assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
return model_output
return body_fn
def get_score(
self,
):
"""member function for obtaining score of
x_t = alpha_t * x + sigma_t * eps"""
if self.model_type == ModelType.NOISE:
score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / - \
self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
elif self.model_type == ModelType.SCORE:
score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
elif self.model_type == ModelType.VELOCITY:
score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x,
t)
else:
raise NotImplementedError()
return score_fn
class Sampler:
"""Sampler class for the transport model"""
def __init__(
self,
transport,
):
"""Constructor for a general sampler; supporting different sampling methods
Args:
- transport: an tranport object specify model prediction & interpolant type
"""
self.transport = transport
self.drift = self.transport.get_drift()
self.score = self.transport.get_score()
def __get_sde_diffusion_and_drift(
self,
*,
diffusion_form="SBDM",
diffusion_norm=1.0,
):
def diffusion_fn(x, t):
diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
return diffusion
sde_drift = \
lambda x, t, model, **kwargs: \
self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
sde_diffusion = diffusion_fn
return sde_drift, sde_diffusion
def __get_last_step(
self,
sde_drift,
*,
last_step,
last_step_size,
):
"""Get the last step function of the SDE solver"""
if last_step is None:
last_step_fn = \
lambda x, t, model, **model_kwargs: \
x
elif last_step == "Mean":
last_step_fn = \
lambda x, t, model, **model_kwargs: \
x + sde_drift(x, t, model, **model_kwargs) * last_step_size
elif last_step == "Tweedie":
alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
sigma = self.transport.path_sampler.compute_sigma_t
last_step_fn = \
lambda x, t, model, **model_kwargs: \
x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model,
**model_kwargs)
elif last_step == "Euler":
last_step_fn = \
lambda x, t, model, **model_kwargs: \
x + self.drift(x, t, model, **model_kwargs) * last_step_size
else:
raise NotImplementedError()
return last_step_fn
def sample_sde(
self,
*,
sampling_method="Euler",
diffusion_form="SBDM",
diffusion_norm=1.0,
last_step="Mean",
last_step_size=0.04,
num_steps=250,
):
"""returns a sampling function with given SDE settings
Args:
- sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
- diffusion_form: function form of diffusion coefficient; default to be matching SBDM
- diffusion_norm: function magnitude of diffusion coefficient; default to 1
- last_step: type of the last step; default to identity
- last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
- num_steps: total integration step of SDE
"""
if last_step is None:
last_step_size = 0.0
sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
diffusion_form=diffusion_form,
diffusion_norm=diffusion_norm,
)
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
diffusion_form=diffusion_form,
sde=True,
eval=True,
reverse=False,
last_step_size=last_step_size,
)
_sde = sde(
sde_drift,
sde_diffusion,
t0=t0,
t1=t1,
num_steps=num_steps,
sampler_type=sampling_method
)
last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
def _sample(init, model, **model_kwargs):
xs = _sde.sample(init, model, **model_kwargs)
ts = th.ones(init.size(0), device=init.device) * t1
x = last_step_fn(xs[-1], ts, model, **model_kwargs)
xs.append(x)
assert len(xs) == num_steps, "Samples does not match the number of steps"
return xs
return _sample
def sample_ode(
self,
*,
sampling_method="dopri5",
num_steps=50,
atol=1e-6,
rtol=1e-3,
reverse=False,
):
"""returns a sampling function with given ODE settings
Args:
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
- num_steps:
- fixed solver (Euler, Heun): the actual number of integration steps performed
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
- atol: absolute error tolerance for the solver
- rtol: relative error tolerance for the solver
- reverse: whether solving the ODE in reverse (data to noise); default to False
"""
if reverse:
drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
else:
drift = self.drift
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
sde=False,
eval=True,
reverse=reverse,
last_step_size=0.0,
)
_ode = ode(
drift=drift,
t0=t0,
t1=t1,
sampler_type=sampling_method,
num_steps=num_steps,
atol=atol,
rtol=rtol,
)
return _ode.sample
def sample_ode_intermediate(
self,
*,
sampling_method="dopri5",
num_steps=50,
atol=1e-6,
rtol=1e-3,
t=0.5,
reverse=False,
):
"""returns a sampling function with given ODE settings
Args:
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
- num_steps:
- fixed solver (Euler, Heun): the actual number of integration steps performed
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
- atol: absolute error tolerance for the solver
- rtol: relative error tolerance for the solver
- reverse: whether solving the ODE in reverse (data to noise); default to False
"""
if reverse:
drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
else:
drift = self.drift
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
sde=False,
eval=True,
reverse=reverse,
last_step_size=0.0,
)
_ode = ode(
drift=drift,
t0=t,
t1=t1,
sampler_type=sampling_method,
num_steps=num_steps,
atol=atol,
rtol=rtol,
)
return _ode.sample
def sample_ode_likelihood(
self,
*,
sampling_method="dopri5",
num_steps=50,
atol=1e-6,
rtol=1e-3,
):
"""returns a sampling function for calculating likelihood with given ODE settings
Args:
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
- num_steps:
- fixed solver (Euler, Heun): the actual number of integration steps performed
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
- atol: absolute error tolerance for the solver
- rtol: relative error tolerance for the solver
"""
def _likelihood_drift(x, t, model, **model_kwargs):
x, _ = x
eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
t = th.ones_like(t) * (1 - t)
with th.enable_grad():
x.requires_grad = True
grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
drift = self.drift(x, t, model, **model_kwargs)
return (-drift, logp_grad)
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
sde=False,
eval=True,
reverse=False,
last_step_size=0.0,
)
_ode = ode(
drift=_likelihood_drift,
t0=t0,
t1=t1,
sampler_type=sampling_method,
num_steps=num_steps,
atol=atol,
rtol=rtol,
)
def _sample_fn(x, model, **model_kwargs):
init_logp = th.zeros(x.size(0)).to(x)
input = (x, init_logp)
drift, delta_logp = _ode.sample(input, model, **model_kwargs)
drift, delta_logp = drift[-1], delta_logp[-1]
prior_logp = self.transport.prior_logp(drift)
logp = prior_logp - delta_logp
return logp, drift
return _sample_fn
================================================
FILE: ultrashape/models/diffusion/transport/utils.py
================================================
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
# which is licensed under the MIT License.
#
# MIT License
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch as th
class EasyDict:
def __init__(self, sub_dict):
for k, v in sub_dict.items():
setattr(self, k, v)
def __getitem__(self, key):
return getattr(self, key)
def mean_flat(x):
"""
Take the mean over all non-batch dimensions.
"""
return th.mean(x, dim=list(range(1, len(x.size()))))
def log_state(state):
result = []
sorted_state = dict(sorted(state.items()))
for key, value in sorted_state.items():
# Check if the value is an instance of a class
if "