[
  {
    "path": "README.md",
    "content": "# Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model\n### [Paper](https://arxiv.org/abs/2305.03486) | [Blog post](https://ggx-research.github.io/publication/2023/05/10/publication-iadb.html) | [2D tutorial](https://tchambon.github.io/posts/iadb-2D/)\n<br />\n\nThis repository is the official implementation of IADB ([Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486)), published at Siggraph 2023.\n\n\n\n![image](imgs/teaser.png)\n\nFor a simple and intuitive explanation of our method, you can read our [blog post](https://ggx-research.github.io/publication/2023/05/10/publication-iadb.html) and check our [2D tutorial](https://tchambon.github.io/posts/iadb-2D/).\n\n\n# Quick start\n\nIf you want to setup a new conda environment, download a dataset (celeba) and launch a training, you can follow this:\n\n```\nconda env create -f environment.yml\nconda activate iadb\npython iadb.py\n```\n\n# Setup\n\nPython 3 dependencies:\n- [Pytorch](https://pytorch.org/) \n- [torchvision](https://pytorch.org/) \n- [Diffusers](https://github.com/huggingface/diffusers)\n\nThis code has been tested with Python 3.8 on Ubuntu 22.04. We recommend setting up a dedicated Conda environment using Python 3.8 and Pytorch 2.0.1.\n\n# Code description\n\nThe iadb.py contains a simple training loop.\n\nIt demonstrates how to train a new IADB model and how to generate results (using the provided sample_iadb function).\n"
  },
  {
    "path": "environment.yml",
    "content": "name: iadb\nchannels:\n  - pytorch\n  - nvidia\n  - defaults\ndependencies:\n  - python=3.8.16=h7a1cb2a_3\n  - pytorch=2.0.1=py3.8_cuda11.8_cudnn8.7.0_0\n  - pytorch-cuda=11.8=h7e8668a_5\n  - torchvision=0.15.2=py38_cu118\n  - pip:\n      - diffusers==0.16.1\n"
  },
  {
    "path": "iadb.py",
    "content": "import torch\nimport torchvision\nfrom torchvision import transforms\nfrom diffusers import UNet2DModel\nfrom torch.optim import Adam\n\ndef get_model():\n    block_out_channels=(128, 128, 256, 256, 512, 512)\n    down_block_types=( \n        \"DownBlock2D\",  # a regular ResNet downsampling block\n        \"DownBlock2D\", \n        \"DownBlock2D\", \n        \"DownBlock2D\", \n        \"AttnDownBlock2D\",  # a ResNet downsampling block with spatial self-attention\n        \"DownBlock2D\",\n    )\n    up_block_types=(\n        \"UpBlock2D\",  # a regular ResNet upsampling block\n        \"AttnUpBlock2D\",  # a ResNet upsampling block with spatial self-attention\n        \"UpBlock2D\", \n        \"UpBlock2D\", \n        \"UpBlock2D\", \n        \"UpBlock2D\"  \n    )\n    return UNet2DModel(block_out_channels=block_out_channels,out_channels=3, in_channels=3, up_block_types=up_block_types, down_block_types=down_block_types, add_attention=True)\n\n@torch.no_grad()\ndef sample_iadb(model, x0, nb_step):\n    x_alpha = x0\n    for t in range(nb_step):\n        alpha_start = (t/nb_step)\n        alpha_end =((t+1)/nb_step)\n\n        d = model(x_alpha, torch.tensor(alpha_start, device=x_alpha.device))['sample']\n        x_alpha = x_alpha + (alpha_end-alpha_start)*d\n\n    return x_alpha\n\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\nCELEBA_FOLDER = './datasets/celeba/'\ntransform = transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64), transforms.RandomHorizontalFlip(0.5),transforms.ToTensor()])\ntrain_dataset = torchvision.datasets.CelebA(root=CELEBA_FOLDER, split='train',\n                                        download=True, transform=transform)\n\ndataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=True)\n\nmodel = get_model()\nmodel = model.to(device)\n\noptimizer = Adam(model.parameters(), lr=1e-4)\nnb_iter = 0\nprint('Start training')\nfor current_epoch in range(100):\n    for i, data in enumerate(dataloader):\n        x1 = (data[0].to(device)*2)-1\n        x0 = torch.randn_like(x1)\n        bs = x0.shape[0]\n\n        alpha = torch.rand(bs, device=device)\n        x_alpha = alpha.view(-1,1,1,1) * x1 + (1-alpha).view(-1,1,1,1) * x0\n        \n        d = model(x_alpha, alpha)['sample']\n        loss = torch.sum((d - (x1-x0))**2)\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        nb_iter += 1\n\n        if nb_iter % 200 == 0:\n            with torch.no_grad():\n                print(f'Save export {nb_iter}')\n                sample = (sample_iadb(model, x0, nb_step=128) * 0.5) + 0.5\n                torchvision.utils.save_image(sample, f'export_{str(nb_iter).zfill(8)}.png')\n                torch.save(model.state_dict(), f'celeba.ckpt')\n"
  }
]