Repository: tchambon/IADB Branch: main Commit: c398457663db Files: 3 Total size: 4.3 KB Directory structure: gitextract_a679gf1m/ ├── README.md ├── environment.yml └── iadb.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model ### [Paper](https://arxiv.org/abs/2305.03486) | [Blog post](https://ggx-research.github.io/publication/2023/05/10/publication-iadb.html) | [2D tutorial](https://tchambon.github.io/posts/iadb-2D/)
This repository is the official implementation of IADB ([Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486)), published at Siggraph 2023. ![image](imgs/teaser.png) For a simple and intuitive explanation of our method, you can read our [blog post](https://ggx-research.github.io/publication/2023/05/10/publication-iadb.html) and check our [2D tutorial](https://tchambon.github.io/posts/iadb-2D/). # Quick start If you want to setup a new conda environment, download a dataset (celeba) and launch a training, you can follow this: ``` conda env create -f environment.yml conda activate iadb python iadb.py ``` # Setup Python 3 dependencies: - [Pytorch](https://pytorch.org/) - [torchvision](https://pytorch.org/) - [Diffusers](https://github.com/huggingface/diffusers) This code has been tested with Python 3.8 on Ubuntu 22.04. We recommend setting up a dedicated Conda environment using Python 3.8 and Pytorch 2.0.1. # Code description The iadb.py contains a simple training loop. It demonstrates how to train a new IADB model and how to generate results (using the provided sample_iadb function). ================================================ FILE: environment.yml ================================================ name: iadb channels: - pytorch - nvidia - defaults dependencies: - python=3.8.16=h7a1cb2a_3 - pytorch=2.0.1=py3.8_cuda11.8_cudnn8.7.0_0 - pytorch-cuda=11.8=h7e8668a_5 - torchvision=0.15.2=py38_cu118 - pip: - diffusers==0.16.1 ================================================ FILE: iadb.py ================================================ import torch import torchvision from torchvision import transforms from diffusers import UNet2DModel from torch.optim import Adam def get_model(): block_out_channels=(128, 128, 256, 256, 512, 512) down_block_types=( "DownBlock2D", # a regular ResNet downsampling block "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention "DownBlock2D", ) up_block_types=( "UpBlock2D", # a regular ResNet upsampling block "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D" ) return UNet2DModel(block_out_channels=block_out_channels,out_channels=3, in_channels=3, up_block_types=up_block_types, down_block_types=down_block_types, add_attention=True) @torch.no_grad() def sample_iadb(model, x0, nb_step): x_alpha = x0 for t in range(nb_step): alpha_start = (t/nb_step) alpha_end =((t+1)/nb_step) d = model(x_alpha, torch.tensor(alpha_start, device=x_alpha.device))['sample'] x_alpha = x_alpha + (alpha_end-alpha_start)*d return x_alpha device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") CELEBA_FOLDER = './datasets/celeba/' transform = transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64), transforms.RandomHorizontalFlip(0.5),transforms.ToTensor()]) train_dataset = torchvision.datasets.CelebA(root=CELEBA_FOLDER, split='train', download=True, transform=transform) dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=True) model = get_model() model = model.to(device) optimizer = Adam(model.parameters(), lr=1e-4) nb_iter = 0 print('Start training') for current_epoch in range(100): for i, data in enumerate(dataloader): x1 = (data[0].to(device)*2)-1 x0 = torch.randn_like(x1) bs = x0.shape[0] alpha = torch.rand(bs, device=device) x_alpha = alpha.view(-1,1,1,1) * x1 + (1-alpha).view(-1,1,1,1) * x0 d = model(x_alpha, alpha)['sample'] loss = torch.sum((d - (x1-x0))**2) optimizer.zero_grad() loss.backward() optimizer.step() nb_iter += 1 if nb_iter % 200 == 0: with torch.no_grad(): print(f'Save export {nb_iter}') sample = (sample_iadb(model, x0, nb_step=128) * 0.5) + 0.5 torchvision.utils.save_image(sample, f'export_{str(nb_iter).zfill(8)}.png') torch.save(model.state_dict(), f'celeba.ckpt')