[
  {
    "path": ".gitignore",
    "content": "venv*\n__pycache__\n.ipynb_checkpoints\n*.pkl\nout*\n*.egg-info\n*.ini\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"CLIP_JAX\"]\n\tpath = CLIP_JAX\n\turl = https://github.com/kingoflolz/CLIP_JAX\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright (c) 2021 Katherine Crowson and John David Pressman\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in\nall copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\nTHE SOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# v-diffusion-jax\n\nv objective diffusion inference code for JAX, by Katherine Crowson ([@RiversHaveWings](https://twitter.com/RiversHaveWings)) and Chainbreakers AI ([@jd_pressman](https://twitter.com/jd_pressman)).\n\nThe models are denoising diffusion probabilistic models (https://arxiv.org/abs/2006.11239), which are trained to reverse a gradual noising process, allowing the models to generate samples from the learned data distributions starting from random noise. DDIM-style deterministic sampling (https://arxiv.org/abs/2010.02502) is also supported. The models are also trained on continuous timesteps. They use the 'v' objective from Progressive Distillation for Fast Sampling of Diffusion Models (https://openreview.net/forum?id=TIdIXIpzhoI).\n\nThank you to Google's [TPU Research Cloud](https://sites.research.google/trc/about/) and [stability.ai](https://www.stability.ai) for compute to train these models!\n\n## Dependencies\n\n- JAX ([installation instructions](https://github.com/google/jax#installation))\n\n- dm-haiku, einops, numpy, optax, Pillow, tqdm (install with `pip install`)\n\n- CLIP_JAX (https://github.com/kingoflolz/CLIP_JAX), and its additional pip-installable dependencies: ftfy, regex, torch, torchvision (it does not need GPU PyTorch). **If you `git clone --recursive` this repo, it should fetch CLIP_JAX automatically.**\n\n## Model checkpoints:\n\n- [Danbooru SFW 128x128](https://the-eye.eu/public/AI/models/v-diffusion/danbooru_128.pkl), SHA-256 `8551fe663dae988e619444efd99995775c7618af2f15ab5d8caf6b123513c334`\n\n- [ImageNet 128x128](https://the-eye.eu/public/AI/models/v-diffusion/imagenet_128.pkl), SHA-256 `4fc7c817b9aaa9018c6dbcbf5cd444a42f4a01856b34c49039f57fe48e090530`\n\n- [WikiArt 128x128](https://the-eye.eu/public/AI/models/v-diffusion/wikiart_128.pkl), SHA-256 `8fbe4e0206262996ff76d3f82a18dc67d3edd28631d4725e0154b51d00b9f91a`\n\n- [WikiArt 256x256](https://the-eye.eu/public/AI/models/v-diffusion/wikiart_256.pkl), SHA-256 `ebc6e77865bbb2d91dad1a0bfb670079c4992684a0e97caa28f784924c3afd81`\n\n## Sampling\n\n### Example\n\nIf the model checkpoints are stored in `checkpoints/`, the following will generate an image:\n\n```\n./clip_sample.py \"a friendly robot, watercolor by James Gurney\" --model wikiart_256 --seed 0\n```\n\nIf they are somewhere else, you need to specify the path to the checkpoint with `--checkpoint`.\n\n### Unconditional sampling\n\n```\nusage: sample.py [-h] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT] [--eta ETA] --model\n                 {danbooru_128,imagenet_128,wikiart_128,wikiart_256} [-n N] [--seed SEED]\n                 [--steps STEPS]\n```\n\n`--batch-size`: sample this many images at a time (default 1)\n\n`--checkpoint`: manually specify the model checkpoint file\n\n`--eta`: set to 0 for deterministic (DDIM) sampling, 1 (the default) for stochastic (DDPM) sampling, and in between to interpolate between the two. DDIM is preferred for low numbers of timesteps.\n\n`--init`: specify the init image (optional)\n\n`--model`: specify the model to use\n\n`-n`: sample until this many images are sampled (default 1)\n\n`--seed`: specify the random seed (default 0)\n\n`--starting-timestep`: specify the starting timestep if an init image is used (range 0-1, default 0.9)\n\n`--steps`: specify the number of diffusion timesteps (default is 1000, can lower for faster but lower quality sampling)\n\n### CLIP guided sampling\n\nCLIP guided sampling lets you generate images with diffusion models conditional on the output matching a text prompt.\n\n```\nusage: clip_sample.py [-h] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT]\n                      [--clip-guidance-scale CLIP_GUIDANCE_SCALE] [--eta ETA] --model\n                      {danbooru_128,imagenet_128,wikiart_128,wikiart_256} [-n N] [--seed SEED]\n                      [--steps STEPS]\n                      prompt\n```\n\n`clip_sample.py` has the same options as `sample.py` and these additional ones:\n\n`prompt`: the text prompt to use\n\n`--clip-guidance-scale`: how strongly the result should match the text prompt (default 1000)\n"
  },
  {
    "path": "clip_sample.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"CLIP guided sampling from a diffusion model.\"\"\"\n\nimport argparse\nfrom functools import partial\nfrom pathlib import Path\nimport sys\n\nfrom einops import repeat\nimport jax\nimport jax.numpy as jnp\nfrom PIL import Image\nfrom tqdm import tqdm, trange\n\nfrom diffusion import get_model, get_models, load_params, sampling, utils\n\nMODULE_DIR = Path(__file__).resolve().parent\nsys.path.append(str(MODULE_DIR / 'CLIP_JAX'))\n\nimport clip_jax\n\n\ndef make_normalize(mean, std):\n    mean = jnp.array(mean).reshape([3, 1, 1])\n    std = jnp.array(std).reshape([3, 1, 1])\n\n    def inner(image):\n        return (image - mean) / std\n    return inner\n\n\ndef norm2(x):\n    \"\"\"Normalizes a batch of vectors to the unit sphere.\"\"\"\n    return x / jnp.sqrt(jnp.sum(jnp.square(x), axis=-1, keepdims=True))\n\n\ndef spherical_dist_loss(x, y):\n    \"\"\"Computes 1/2 the squared spherical distance between the two arguments.\"\"\"\n    return jnp.square(jnp.arccos(jnp.sum(norm2(x) * norm2(y), axis=-1))) / 2\n\n\ndef main():\n    p = argparse.ArgumentParser(description=__doc__,\n                                formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    p.add_argument('prompt', type=str,\n                   help='the text prompt')\n    p.add_argument('--batch-size', '-bs', type=int, default=1,\n                   help='the number of images per batch')\n    p.add_argument('--checkpoint', type=str,\n                   help='the checkpoint to use')\n    p.add_argument('--clip-guidance-scale', '-cs', type=float, default=1000.,\n                   help='the CLIP guidance scale')\n    p.add_argument('--eta', type=float, default=1.,\n                   help='the amount of noise to add during sampling (0-1)')\n    p.add_argument('--init', type=str,\n                   help='the init image')\n    p.add_argument('--model', type=str, choices=get_models(), required=True,\n                   help='the model to use')\n    p.add_argument('-n', type=int, default=1,\n                   help='the number of images to sample')\n    p.add_argument('--seed', type=int, default=0,\n                   help='the random seed')\n    p.add_argument('--starting-timestep', '-st', type=float, default=0.9,\n                   help='the timestep to start at (used with init images)')\n    p.add_argument('--steps', type=int, default=1000,\n                   help='the number of timesteps')\n    args = p.parse_args()\n\n    model = get_model(args.model)\n    checkpoint = args.checkpoint\n    if not checkpoint:\n        checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pkl'\n    params = load_params(checkpoint)\n\n    image_fn, text_fn, clip_params, _ = clip_jax.load('ViT-B/16')\n    clip_patch_size = 16\n    clip_size = 224\n    normalize = make_normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n                               std=[0.26862954, 0.26130258, 0.27577711])\n\n    target_embed = text_fn(clip_params, clip_jax.tokenize([args.prompt]))\n\n    if args.init:\n        _, y, x = model.shape\n        init = Image.open(args.init).convert('RGB').resize((x, y), Image.LANCZOS)\n        init = utils.from_pil_image(init)[None]\n\n    key = jax.random.PRNGKey(args.seed)\n\n    def clip_cond_fn_loss(x, key, params, clip_params, t, extra_args):\n        dummy_key = jax.random.PRNGKey(0)\n        v = model.apply(params, dummy_key, x, repeat(t, '-> n', n=x.shape[0]), extra_args)\n        alpha, sigma = utils.t_to_alpha_sigma(t)\n        pred = x * alpha - v * sigma\n        clip_in = jax.image.resize(pred, (*pred.shape[:2], clip_size, clip_size), 'cubic')\n        extent = clip_patch_size // 2\n        clip_in = jnp.pad(clip_in, [(0, 0), (0, 0), (extent, extent), (extent, extent)], 'edge')\n        sat_vmap = jax.vmap(partial(jax.image.scale_and_translate, method='cubic'),\n                            in_axes=(0, None, None, 0, 0))\n        scales = jnp.ones([pred.shape[0], 2])\n        translates = jax.random.uniform(key, [pred.shape[0], 2], minval=-extent, maxval=extent)\n        clip_in = sat_vmap(clip_in, (3, clip_size, clip_size), (1, 2), scales, translates)\n        image_embeds = image_fn(clip_params, normalize((clip_in + 1) / 2))\n        return jnp.sum(spherical_dist_loss(image_embeds, target_embed))\n\n    def clip_cond_fn(x, key, t, extra_args, params, clip_params):\n        grad_fn = jax.grad(clip_cond_fn_loss)\n        grad = grad_fn(x, key, params, clip_params, t, extra_args)\n        return grad * -args.clip_guidance_scale\n\n    def run(key, n):\n        tqdm.write('Sampling...')\n        key, subkey = jax.random.split(key)\n        noise = jax.random.normal(subkey, [n, *model.shape])\n        key, subkey = jax.random.split(key)\n        cond_params = {'params': params, 'clip_params': clip_params}\n        sample_step = partial(sampling.jit_cond_sample_step,\n                              extra_args={},\n                              cond_fn=clip_cond_fn,\n                              cond_params=cond_params)\n        steps = utils.get_ddpm_schedule(jnp.linspace(1, 0, args.steps + 1)[:-1])\n        if args.init:\n            steps = steps[steps < args.starting_timestep]\n            alpha, sigma = utils.t_to_alpha_sigma(steps[0])\n            noise = init * alpha + noise * sigma\n        return sampling.sample_loop(model, params, subkey, noise, steps, args.eta, sample_step)\n\n    def run_all(key, n, batch_size):\n        for i in trange(0, n, batch_size):\n            key, subkey = jax.random.split(key)\n            cur_batch_size = min(n - i, batch_size)\n            outs = run(key, cur_batch_size)\n            for j, out in enumerate(outs):\n                utils.to_pil_image(out).save(f'out_{i + j:05}.png')\n\n    try:\n        run_all(key, args.n, args.batch_size)\n    except KeyboardInterrupt:\n        pass\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "diffusion/__init__.py",
    "content": "from . import sampling, utils\nfrom .models import get_model, get_models, load_params\n"
  },
  {
    "path": "diffusion/models/__init__.py",
    "content": "from .models import get_model, get_models, load_params\n"
  },
  {
    "path": "diffusion/models/danbooru_128.py",
    "content": "import haiku as hk\nimport jax\nimport jax.numpy as jnp\n\nfrom .. import utils\n\n\nclass FourierFeatures(hk.Module):\n    def __init__(self, output_size, std=1., name=None):\n        super().__init__(name=name)\n        assert output_size % 2 == 0\n        self.output_size = output_size\n        self.std = std\n\n    def __call__(self, x):\n        w = hk.get_parameter('w', [self.output_size // 2, x.shape[1]],\n                             init=hk.initializers.RandomNormal(self.std, 0))\n        f = 2 * jnp.pi * x @ w.T\n        return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1)\n\n\nclass Dropout2d(hk.Module):\n    def __init__(self, rate=0.5, name=None):\n        super().__init__(name=name)\n        self.rate = rate\n\n    def __call__(self, x, enabled):\n        rate = self.rate * enabled\n        key = hk.next_rng_key()\n        p = jax.random.bernoulli(key, 1.0 - rate, shape=x.shape[:2])[..., None, None]\n        return x * p / (1.0 - rate)\n\n\nclass SelfAttention2d(hk.Module):\n    def __init__(self, n_head=1, dropout_rate=0.1, name=None):\n        super().__init__(name=name)\n        self.n_head = n_head\n        self.dropout_rate = dropout_rate\n\n    def __call__(self, x, dropout_enabled):\n        n, c, h, w = x.shape\n        assert c % self.n_head == 0\n        qkv_proj = hk.Conv2D(c * 3, 1, data_format='NCHW', name='qkv_proj')\n        out_proj = hk.Conv2D(c, 1, data_format='NCHW', name='out_proj')\n        dropout = Dropout2d(self.dropout_rate)\n        qkv = qkv_proj(x)\n        qkv = jnp.swapaxes(qkv.reshape([n, self.n_head * 3, c // self.n_head, h * w]), 2, 3)\n        q, k, v = jnp.split(qkv, 3, axis=1)\n        scale = k.shape[3]**-0.25\n        att = jax.nn.softmax((q * scale) @ (jnp.swapaxes(k, 2, 3) * scale), axis=3)\n        y = jnp.swapaxes(att @ v, 2, 3).reshape([n, c, h, w])\n        return x + dropout(out_proj(y), dropout_enabled)\n\n\ndef res_conv_block(c_mid, c_out, dropout_last=True):\n    def inner(x, is_training):\n        x_skip_layer = hk.Conv2D(c_out, 1, with_bias=False, data_format='NCHW')\n        x_skip = x if x.shape[1] == c_out else x_skip_layer(x)\n        x = hk.Conv2D(c_mid, 3, data_format='NCHW')(x)\n        x = jax.nn.relu(x)\n        x = Dropout2d(0.1)(x, is_training)\n        x = hk.Conv2D(c_out, 3, data_format='NCHW')(x)\n        x = jax.nn.relu(x)\n        if dropout_last:\n            x = Dropout2d(0.1)(x, is_training)\n        return x + x_skip\n    return inner\n\n\ndef diffusion_model(x, t, extra_args):\n    c = 256\n    is_training = jnp.array(0.)\n    log_snr = utils.alpha_sigma_to_log_snr(*utils.t_to_alpha_sigma(t))\n    timestep_embed = FourierFeatures(16, 0.2)(log_snr[:, None])\n    te_planes = jnp.tile(timestep_embed[..., None, None], [1, 1, x.shape[2], x.shape[3]])\n    x = jnp.concatenate([x, te_planes], axis=1)  # 128x128\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, c)(x, is_training)\n    x_2 = hk.AvgPool(2, 2, 'SAME', 1)(x)  # 64x64\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_3 = hk.AvgPool(2, 2, 'SAME', 1)(x_2)  # 32x32\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_4 = hk.AvgPool(2, 2, 'SAME', 1)(x_3)  # 16x16\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = SelfAttention2d(c * 4 // 128)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = SelfAttention2d(c * 4 // 128)(x_4, is_training)\n    x_5 = hk.AvgPool(2, 2, 'SAME', 1)(x_4)  # 8x8\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = SelfAttention2d(c * 4 // 128)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = SelfAttention2d(c * 4 // 128)(x_5, is_training)\n    x_6 = hk.AvgPool(2, 2, 'SAME', 1)(x_5)  # 4x4\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = SelfAttention2d(c * 8 // 128)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = SelfAttention2d(c * 8 // 128)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = SelfAttention2d(c * 8 // 128)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 4)(x_6, is_training)\n    x_6 = SelfAttention2d(c * 4 // 128)(x_6, is_training)\n    x_6 = jax.image.resize(x_6, [*x_6.shape[:2], *x_5.shape[2:]], 'nearest')\n    x_5 = jnp.concatenate([x_5, x_6], axis=1)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = SelfAttention2d(c * 4 // 128)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = SelfAttention2d(c * 4 // 128)(x_5, is_training)\n    x_5 = jax.image.resize(x_5, [*x_5.shape[:2], *x_4.shape[2:]], 'nearest')\n    x_4 = jnp.concatenate([x_4, x_5], axis=1)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = SelfAttention2d(c * 4 // 128)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 2)(x_4, is_training)\n    x_4 = SelfAttention2d(c * 2 // 128)(x_4, is_training)\n    x_4 = jax.image.resize(x_4, [*x_4.shape[:2], *x_3.shape[2:]], 'nearest')\n    x_3 = jnp.concatenate([x_3, x_4], axis=1)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = jax.image.resize(x_3, [*x_3.shape[:2], *x_2.shape[2:]], 'nearest')\n    x_2 = jnp.concatenate([x_2, x_3], axis=1)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c)(x_2, is_training)\n    x_2 = jax.image.resize(x_2, [*x_2.shape[:2], *x.shape[2:]], 'nearest')\n    x = jnp.concatenate([x, x_2], axis=1)\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, 3, dropout_last=False)(x, is_training)\n    return x\n\n\nclass Danbooru128Model:\n    init, apply = hk.transform(diffusion_model)\n    shape = (3, 128, 128)\n    min_t = float(utils.get_ddpm_schedule(jnp.array(0.)))\n    max_t = float(utils.get_ddpm_schedule(jnp.array(1.)))\n"
  },
  {
    "path": "diffusion/models/imagenet_128.py",
    "content": "import haiku as hk\nimport jax\nimport jax.numpy as jnp\n\nfrom .. import utils\n\n\nclass FourierFeatures(hk.Module):\n    def __init__(self, output_size, std=1., name=None):\n        super().__init__(name=name)\n        assert output_size % 2 == 0\n        self.output_size = output_size\n        self.std = std\n\n    def __call__(self, x):\n        w = hk.get_parameter('w', [self.output_size // 2, x.shape[1]],\n                             init=hk.initializers.RandomNormal(self.std, 0))\n        f = 2 * jnp.pi * x @ w.T\n        return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1)\n\n\nclass Dropout2d(hk.Module):\n    def __init__(self, rate=0.5, name=None):\n        super().__init__(name=name)\n        self.rate = rate\n\n    def __call__(self, x, enabled):\n        rate = self.rate * enabled\n        key = hk.next_rng_key()\n        p = jax.random.bernoulli(key, 1.0 - rate, shape=x.shape[:2])[..., None, None]\n        return x * p / (1.0 - rate)\n\n\nclass SelfAttention2d(hk.Module):\n    def __init__(self, c_in, n_head=1, dropout_rate=0.1, name=None):\n        super().__init__(name=name)\n        assert c_in % n_head == 0\n        self.c_in = c_in\n        self.n_head = n_head\n        self.dropout_rate = dropout_rate\n\n    def __call__(self, x, dropout_enabled):\n        n, c, h, w = x.shape\n        qkv_proj = hk.Conv2D(self.c_in * 3, 1, data_format='NCHW', name='qkv_proj')\n        out_proj = hk.Conv2D(self.c_in, 1, data_format='NCHW', name='out_proj')\n        dropout = Dropout2d(self.dropout_rate)\n        qkv = qkv_proj(x)\n        qkv = jnp.swapaxes(qkv.reshape([n, self.n_head * 3, c // self.n_head, h * w]), 2, 3)\n        q, k, v = jnp.split(qkv, 3, axis=1)\n        scale = k.shape[3]**-0.25\n        att = jax.nn.softmax((q * scale) @ (jnp.swapaxes(k, 2, 3) * scale), axis=3)\n        y = jnp.swapaxes(att @ v, 2, 3).reshape([n, c, h, w])\n        return x + dropout(out_proj(y), dropout_enabled)\n\n\ndef res_conv_block(c_mid, c_out, dropout_last=True):\n    @hk.remat\n    def inner(x, is_training):\n        x_skip_layer = hk.Conv2D(c_out, 1, with_bias=False, data_format='NCHW')\n        x_skip = x if x.shape[1] == c_out else x_skip_layer(x)\n        x = hk.Conv2D(c_mid, 3, data_format='NCHW')(x)\n        x = jax.nn.relu(x)\n        x = Dropout2d(0.1)(x, is_training)\n        x = hk.Conv2D(c_out, 3, data_format='NCHW')(x)\n        if dropout_last:\n            x = jax.nn.relu(x)\n            x = Dropout2d(0.1)(x, is_training)\n        return x + x_skip\n    return inner\n\n\ndef diffusion_model(x, t, extra_args):\n    c = 128\n    is_training = jnp.array(0.)\n    log_snr = utils.alpha_sigma_to_log_snr(*utils.t_to_alpha_sigma(t))\n    timestep_embed = FourierFeatures(16, 0.2)(log_snr[:, None])\n    te_planes = jnp.tile(timestep_embed[..., None, None], [1, 1, x.shape[2], x.shape[3]])\n    x = jnp.concatenate([x, te_planes], axis=1)  # 128x128\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, c)(x, is_training)\n    x_2 = hk.AvgPool(2, 2, 'SAME', 1)(x)  # 64x64\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_3 = hk.AvgPool(2, 2, 'SAME', 1)(x_2)  # 32x32\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_4 = hk.AvgPool(2, 2, 'SAME', 1)(x_3)  # 16x16\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)\n    x_5 = hk.AvgPool(2, 2, 'SAME', 1)(x_4)  # 8x8\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)\n    x_6 = hk.AvgPool(2, 2, 'SAME', 1)(x_5)  # 4x4\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 4)(x_6, is_training)\n    x_6 = SelfAttention2d(c * 4, c * 4 // 128)(x_6, is_training)\n    x_6 = jax.image.resize(x_6, [*x_6.shape[:2], *x_5.shape[2:]], 'nearest')\n    x_5 = jnp.concatenate([x_5, x_6], axis=1)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)\n    x_5 = jax.image.resize(x_5, [*x_5.shape[:2], *x_4.shape[2:]], 'nearest')\n    x_4 = jnp.concatenate([x_4, x_5], axis=1)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 2)(x_4, is_training)\n    x_4 = SelfAttention2d(c * 2, c * 2 // 128)(x_4, is_training)\n    x_4 = jax.image.resize(x_4, [*x_4.shape[:2], *x_3.shape[2:]], 'nearest')\n    x_3 = jnp.concatenate([x_3, x_4], axis=1)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = jax.image.resize(x_3, [*x_3.shape[:2], *x_2.shape[2:]], 'nearest')\n    x_2 = jnp.concatenate([x_2, x_3], axis=1)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c)(x_2, is_training)\n    x_2 = jax.image.resize(x_2, [*x_2.shape[:2], *x.shape[2:]], 'nearest')\n    x = jnp.concatenate([x, x_2], axis=1)\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, 3, dropout_last=False)(x, is_training)\n    return x\n\n\nclass ImageNet128Model:\n    init, apply = hk.transform(diffusion_model)\n    shape = (3, 128, 128)\n    min_t = float(utils.get_ddpm_schedule(jnp.array(0.)))\n    max_t = float(utils.get_ddpm_schedule(jnp.array(1.)))\n"
  },
  {
    "path": "diffusion/models/models.py",
    "content": "import pickle\n\nimport jax\nimport jax.numpy as jnp\n\nfrom . import danbooru_128, imagenet_128, wikiart_128, wikiart_256\n\n\nmodels = {\n    'danbooru_128': danbooru_128.Danbooru128Model,\n    'imagenet_128': imagenet_128.ImageNet128Model,\n    'wikiart_128': wikiart_128.WikiArt128Model,\n    'wikiart_256': wikiart_256.WikiArt256Model,\n}\n\n\ndef get_model(model):\n    return models[model]\n\n\ndef get_models():\n    return list(models.keys())\n\n\ndef load_params(checkpoint):\n    with open(checkpoint, 'rb') as fp:\n        return jax.tree_map(jnp.array, pickle.load(fp)['params_ema'])\n"
  },
  {
    "path": "diffusion/models/wikiart_128.py",
    "content": "import haiku as hk\nimport jax\nimport jax.numpy as jnp\n\nfrom .. import utils\n\n\nclass FourierFeatures(hk.Module):\n    def __init__(self, output_size, std=1., name=None):\n        super().__init__(name=name)\n        assert output_size % 2 == 0\n        self.output_size = output_size\n        self.std = std\n\n    def __call__(self, x):\n        w = hk.get_parameter('w', [self.output_size // 2, x.shape[1]],\n                             init=hk.initializers.RandomNormal(self.std, 0))\n        f = 2 * jnp.pi * x @ w.T\n        return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1)\n\n\nclass Dropout2d(hk.Module):\n    def __init__(self, rate=0.5, name=None):\n        super().__init__(name=name)\n        self.rate = rate\n\n    def __call__(self, x, enabled):\n        rate = self.rate * enabled\n        key = hk.next_rng_key()\n        p = jax.random.bernoulli(key, 1.0 - rate, shape=x.shape[:2])[..., None, None]\n        return x * p / (1.0 - rate)\n\n\ndef res_conv_block(c_mid, c_out, dropout_last=True):\n    @hk.remat\n    def inner(x, is_training):\n        x_skip_layer = hk.Conv2D(c_out, 1, with_bias=False, data_format='NCHW')\n        x_skip = x if x.shape[1] == c_out else x_skip_layer(x)\n        x = hk.Conv2D(c_mid, 3, data_format='NCHW')(x)\n        x = jax.nn.relu(x)\n        x = Dropout2d(0.1)(x, is_training)\n        x = hk.Conv2D(c_out, 3, data_format='NCHW')(x)\n        x = jax.nn.relu(x)\n        if dropout_last:\n            x = Dropout2d(0.1)(x, is_training)\n        return x + x_skip\n    return inner\n\n\ndef diffusion_model(x, t, extra_args):\n    c = 128\n    is_training = jnp.array(0.)\n    log_snr = utils.alpha_sigma_to_log_snr(*utils.t_to_alpha_sigma(t))\n    timestep_embed = FourierFeatures(16, 0.2)(log_snr[:, None])\n    te_planes = jnp.tile(timestep_embed[..., None, None], [1, 1, x.shape[2], x.shape[3]])\n    x = jnp.concatenate([x, te_planes], axis=1)  # 128x128\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, c)(x, is_training)\n    x_2 = hk.AvgPool(2, 2, 'SAME', 1)(x)  # 64x64\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_3 = hk.AvgPool(2, 2, 'SAME', 1)(x_2)  # 32x32\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_4 = hk.AvgPool(2, 2, 'SAME', 1)(x_3)  # 16x16\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_5 = hk.AvgPool(2, 2, 'SAME', 1)(x_4)  # 8x8\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_6 = hk.AvgPool(2, 2, 'SAME', 1)(x_5)  # 4x4\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)\n    x_6 = res_conv_block(c * 8, c * 4)(x_6, is_training)\n    x_6 = jax.image.resize(x_6, [*x_6.shape[:2], *x_5.shape[2:]], 'nearest')\n    x_5 = jnp.concatenate([x_5, x_6], axis=1)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = jax.image.resize(x_5, [*x_5.shape[:2], *x_4.shape[2:]], 'nearest')\n    x_4 = jnp.concatenate([x_4, x_5], axis=1)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)\n    x_4 = res_conv_block(c * 4, c * 2)(x_4, is_training)\n    x_4 = jax.image.resize(x_4, [*x_4.shape[:2], *x_3.shape[2:]], 'nearest')\n    x_3 = jnp.concatenate([x_3, x_4], axis=1)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = jax.image.resize(x_3, [*x_3.shape[:2], *x_2.shape[2:]], 'nearest')\n    x_2 = jnp.concatenate([x_2, x_3], axis=1)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)\n    x_2 = res_conv_block(c * 2, c)(x_2, is_training)\n    x_2 = jax.image.resize(x_2, [*x_2.shape[:2], *x.shape[2:]], 'nearest')\n    x = jnp.concatenate([x, x_2], axis=1)\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, c)(x, is_training)\n    x = res_conv_block(c, 3, dropout_last=False)(x, is_training)\n    return x\n\n\nclass WikiArt128Model:\n    init, apply = hk.transform(diffusion_model)\n    shape = (3, 128, 128)\n    min_t = float(utils.get_ddpm_schedule(jnp.array(0.)))\n    max_t = float(utils.get_ddpm_schedule(jnp.array(1.)))\n"
  },
  {
    "path": "diffusion/models/wikiart_256.py",
    "content": "import haiku as hk\nimport jax\nimport jax.numpy as jnp\n\nfrom .. import utils\n\n\nclass FourierFeatures(hk.Module):\n    def __init__(self, output_size, std=1., name=None):\n        super().__init__(name=name)\n        assert output_size % 2 == 0\n        self.output_size = output_size\n        self.std = std\n\n    def __call__(self, x):\n        w = hk.get_parameter('w', [self.output_size // 2, x.shape[1]],\n                             init=hk.initializers.RandomNormal(self.std, 0))\n        f = 2 * jnp.pi * x @ w.T\n        return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1)\n\n\nclass Dropout2d(hk.Module):\n    def __init__(self, rate=0.5, name=None):\n        super().__init__(name=name)\n        self.rate = rate\n\n    def __call__(self, x, enabled):\n        rate = self.rate * enabled\n        key = hk.next_rng_key()\n        p = jax.random.bernoulli(key, 1.0 - rate, shape=x.shape[:2])[..., None, None]\n        return x * p / (1.0 - rate)\n\n\nclass SelfAttention2d(hk.Module):\n    def __init__(self, n_head=1, dropout_rate=0.1, name=None):\n        super().__init__(name=name)\n        self.n_head = n_head\n        self.dropout_rate = dropout_rate\n\n    def __call__(self, x, dropout_enabled):\n        n, c, h, w = x.shape\n        assert c % self.n_head == 0\n        qkv_proj = hk.Conv2D(c * 3, 1, data_format='NCHW', name='qkv_proj')\n        out_proj = hk.Conv2D(c, 1, data_format='NCHW', name='out_proj')\n        dropout = Dropout2d(self.dropout_rate)\n        qkv = qkv_proj(x)\n        qkv = jnp.swapaxes(qkv.reshape([n, self.n_head * 3, c // self.n_head, h * w]), 2, 3)\n        q, k, v = jnp.split(qkv, 3, axis=1)\n        scale = k.shape[3]**-0.25\n        att = jax.nn.softmax((q * scale) @ (jnp.swapaxes(k, 2, 3) * scale), axis=3)\n        y = jnp.swapaxes(att @ v, 2, 3).reshape([n, c, h, w])\n        return x + dropout(out_proj(y), dropout_enabled)\n\n\ndef res_conv_block(c_mid, c_out, dropout_last=True):\n    @hk.remat\n    def inner(x, is_training):\n        x_skip_layer = hk.Conv2D(c_out, 1, with_bias=False, data_format='NCHW')\n        x_skip = x if x.shape[1] == c_out else x_skip_layer(x)\n        x = hk.Conv2D(c_mid, 3, data_format='NCHW')(x)\n        x = jax.nn.relu(x)\n        x = Dropout2d(0.1)(x, is_training)\n        x = hk.Conv2D(c_out, 3, data_format='NCHW')(x)\n        if dropout_last:\n            x = jax.nn.relu(x)\n            x = Dropout2d(0.1)(x, is_training)\n        return x + x_skip\n    return inner\n\n\ndef diffusion_model(x, t, extra_args):\n    c = 128\n    is_training = jnp.array(0.)\n    log_snr = utils.alpha_sigma_to_log_snr(*utils.t_to_alpha_sigma(t))\n    timestep_embed = FourierFeatures(16, 0.2)(log_snr[:, None])\n    te_planes = jnp.tile(timestep_embed[..., None, None], [1, 1, x.shape[2], x.shape[3]])\n    x = jnp.concatenate([x, te_planes], axis=1)  # 256x256\n    x = res_conv_block(c // 2, c // 2)(x, is_training)\n    x = res_conv_block(c // 2, c // 2)(x, is_training)\n    x = res_conv_block(c // 2, c // 2)(x, is_training)\n    x = res_conv_block(c // 2, c // 2)(x, is_training)\n    x_2 = hk.AvgPool(2, 2, 'SAME', 1)(x)  # 128x128\n    x_2 = res_conv_block(c, c)(x_2, is_training)\n    x_2 = res_conv_block(c, c)(x_2, is_training)\n    x_2 = res_conv_block(c, c)(x_2, is_training)\n    x_2 = res_conv_block(c, c)(x_2, is_training)\n    x_3 = hk.AvgPool(2, 2, 'SAME', 1)(x_2)  # 64x64\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_4 = hk.AvgPool(2, 2, 'SAME', 1)(x_3)  # 32x32\n    x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)\n    x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)\n    x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)\n    x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)\n    x_5 = hk.AvgPool(2, 2, 'SAME', 1)(x_4)  # 16x16\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)\n    x_6 = hk.AvgPool(2, 2, 'SAME', 1)(x_5)  # 8x8\n    x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)\n    x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)\n    x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)\n    x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)\n    x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)\n    x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)\n    x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)\n    x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)\n    x_7 = hk.AvgPool(2, 2, 'SAME', 1)(x_6)  # 4x4\n    x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)\n    x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)\n    x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)\n    x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)\n    x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)\n    x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)\n    x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)\n    x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)\n    x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)\n    x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)\n    x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)\n    x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)\n    x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)\n    x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)\n    x_7 = res_conv_block(c * 8, c * 4)(x_7, is_training)\n    x_7 = hk.remat(SelfAttention2d(c * 4 // 128))(x_7, is_training)\n    x_7 = jax.image.resize(x_7, [*x_7.shape[:2], *x_6.shape[2:]], 'nearest')\n    x_6 = jnp.concatenate([x_6, x_7], axis=1)\n    x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)\n    x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)\n    x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)\n    x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)\n    x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)\n    x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)\n    x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)\n    x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)\n    x_6 = jax.image.resize(x_6, [*x_6.shape[:2], *x_5.shape[2:]], 'nearest')\n    x_5 = jnp.concatenate([x_5, x_6], axis=1)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)\n    x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)\n    x_5 = res_conv_block(c * 4, c * 2)(x_5, is_training)\n    x_5 = hk.remat(SelfAttention2d(c * 2 // 128))(x_5, is_training)\n    x_5 = jax.image.resize(x_5, [*x_5.shape[:2], *x_4.shape[2:]], 'nearest')\n    x_4 = jnp.concatenate([x_4, x_5], axis=1)\n    x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)\n    x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)\n    x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)\n    x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)\n    x_4 = jax.image.resize(x_4, [*x_4.shape[:2], *x_3.shape[2:]], 'nearest')\n    x_3 = jnp.concatenate([x_3, x_4], axis=1)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)\n    x_3 = res_conv_block(c * 2, c)(x_3, is_training)\n    x_3 = jax.image.resize(x_3, [*x_3.shape[:2], *x_2.shape[2:]], 'nearest')\n    x_2 = jnp.concatenate([x_2, x_3], axis=1)\n    x_2 = res_conv_block(c, c)(x_2, is_training)\n    x_2 = res_conv_block(c, c)(x_2, is_training)\n    x_2 = res_conv_block(c, c)(x_2, is_training)\n    x_2 = res_conv_block(c, c // 2)(x_2, is_training)\n    x_2 = jax.image.resize(x_2, [*x_2.shape[:2], *x.shape[2:]], 'nearest')\n    x = jnp.concatenate([x, x_2], axis=1)\n    x = res_conv_block(c // 2, c // 2)(x, is_training)\n    x = res_conv_block(c // 2, c // 2)(x, is_training)\n    x = res_conv_block(c // 2, c // 2)(x, is_training)\n    x = res_conv_block(c // 2, 3, dropout_last=False)(x, is_training)\n    return x\n\n\nclass WikiArt256Model:\n    init, apply = hk.transform(diffusion_model)\n    shape = (3, 256, 256)\n    min_t = float(utils.get_ddpm_schedule(jnp.array(0.)))\n    max_t = float(utils.get_ddpm_schedule(jnp.array(1.)))\n"
  },
  {
    "path": "diffusion/sampling.py",
    "content": "from einops import repeat\nimport jax\nimport jax.numpy as jnp\nfrom tqdm import trange\n\nfrom . import utils\n\n\ndef sample_step(model, params, key, x, t, t_next, eta, extra_args):\n    dummy_key = jax.random.PRNGKey(0)\n    v = model.apply(params, dummy_key, x, repeat(t, '-> n', n=x.shape[0]), extra_args)\n    alpha, sigma = utils.t_to_alpha_sigma(t)\n    key, subkey = jax.random.split(key)\n    pred = x * alpha - v * sigma\n    eps = x * sigma + v * alpha\n    alpha_next, sigma_next = utils.t_to_alpha_sigma(t_next)\n    ddim_sigma = eta * jnp.sqrt(sigma_next**2 / sigma**2) * \\\n        jnp.sqrt(1 - alpha**2 / alpha_next**2)\n    adjusted_sigma = jnp.sqrt(sigma_next**2 - ddim_sigma**2)\n    x = pred * alpha_next + eps * adjusted_sigma\n    x = x + jax.random.normal(key, x.shape) * ddim_sigma\n    return x, pred\n\n\njit_sample_step = jax.jit(sample_step, static_argnums=0)\n\n\ndef cond_sample_step(model, params, key, x, t, t_next, eta, extra_args, cond_fn, cond_params):\n    dummy_key = jax.random.PRNGKey(0)\n    v = model.apply(params, dummy_key, x, repeat(t, '-> n', n=x.shape[0]), extra_args)\n    alpha, sigma = utils.t_to_alpha_sigma(t)\n    key, subkey = jax.random.split(key)\n    cond_grad = cond_fn(x, subkey, t, extra_args, **cond_params)\n    v = v - cond_grad * (sigma / alpha)\n    pred = x * alpha - v * sigma\n    eps = x * sigma + v * alpha\n    alpha_next, sigma_next = utils.t_to_alpha_sigma(t_next)\n    ddim_sigma = eta * jnp.sqrt(sigma_next**2 / sigma**2) * \\\n        jnp.sqrt(1 - alpha**2 / alpha_next**2)\n    adjusted_sigma = jnp.sqrt(sigma_next**2 - ddim_sigma**2)\n    x = pred * alpha_next + eps * adjusted_sigma\n    x = x + jax.random.normal(key, x.shape) * ddim_sigma\n    return x, pred\n\n\njit_cond_sample_step = jax.jit(cond_sample_step, static_argnums=(0, 8))\n\n\ndef sample_loop(model, params, key, x, steps, eta, sample_step):\n    for i in trange(len(steps)):\n        key, subkey = jax.random.split(key)\n        if i < len(steps) - 1:\n            x, _ = sample_step(model, params, subkey, x, steps[i], steps[i + 1], eta)\n        else:\n            _, pred = sample_step(model, params, subkey, x, steps[i], steps[i], eta)\n    return pred\n\n\ndef reverse_sample_step(model, params, key, x, t, t_next, extra_args):\n    dummy_key = jax.random.PRNGKey(0)\n    v = model.apply(params, dummy_key, x, repeat(t, '-> n', n=x.shape[0]), extra_args)\n    alpha, sigma = utils.t_to_alpha_sigma(t)\n    pred = x * alpha - v * sigma\n    eps = x * sigma + v * alpha\n    alpha_next, sigma_next = utils.t_to_alpha_sigma(t_next)\n    x = pred * alpha_next + eps * sigma_next\n    return x, pred\n\n\njit_reverse_sample_step = jax.jit(reverse_sample_step, static_argnums=0)\n\n\ndef reverse_sample_loop(model, params, key, x, steps, sample_step):\n    for i in trange(len(steps) - 1):\n        key, subkey = jax.random.split(key)\n        x, _ = sample_step(model, params, subkey, x, steps[i], steps[i + 1])\n    return x\n"
  },
  {
    "path": "diffusion/utils.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom PIL import Image\n\n\ndef from_pil_image(x):\n    \"\"\"Converts from a PIL image to a JAX array.\"\"\"\n    x = jnp.array(x)\n    if x.ndim == 2:\n        x = x[..., None]\n    return x.transpose((2, 0, 1)) / 127.5 - 1\n\n\ndef to_pil_image(x):\n    \"\"\"Converts from a JAX array to a PIL image.\"\"\"\n    if x.ndim == 4:\n        assert x.shape[0] == 1\n        x = x[0]\n    if x.shape[0] == 1:\n        x = x[0]\n    else:\n        x = x.transpose((1, 2, 0))\n    arr = np.array(jnp.round(jnp.clip((x + 1) * 127.5, 0, 255)).astype(jnp.uint8))\n    return Image.fromarray(arr)\n\n\ndef log_snr_to_alpha_sigma(log_snr):\n    \"\"\"Returns the scaling factors for the clean image and for the noise, given\n    the log SNR for a timestep.\"\"\"\n    return jnp.sqrt(jax.nn.sigmoid(log_snr)), jnp.sqrt(jax.nn.sigmoid(-log_snr))\n\n\ndef alpha_sigma_to_log_snr(alpha, sigma):\n    \"\"\"Returns a log snr, given the scaling factors for the clean image and for\n    the noise.\"\"\"\n    return jnp.log(alpha**2 / sigma**2)\n\n\ndef t_to_alpha_sigma(t):\n    \"\"\"Returns the scaling factors for the clean image and for the noise, given\n    a timestep.\"\"\"\n    return jnp.cos(t * jnp.pi / 2), jnp.sin(t * jnp.pi / 2)\n\n\ndef alpha_sigma_to_t(alpha, sigma):\n    \"\"\"Returns a timestep, given the scaling factors for the clean image and for\n    the noise.\"\"\"\n    return jnp.arctan2(sigma, alpha) / jnp.pi * 2\n\n\ndef get_ddpm_schedule(ddpm_t):\n    \"\"\"Returns timesteps for the noise schedule from the DDPM paper.\"\"\"\n    log_snr = -jnp.log(jnp.expm1(1e-4 + 10 * ddpm_t**2))\n    alpha, sigma = log_snr_to_alpha_sigma(log_snr)\n    return alpha_sigma_to_t(alpha, sigma)\n"
  },
  {
    "path": "interpolate.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"Interpolation in a diffusion model's latent space.\"\"\"\n\nimport argparse\nfrom functools import partial\nfrom pathlib import Path\n\nimport jax\nimport jax.numpy as jnp\nfrom PIL import Image\nfrom tqdm import trange\n\nfrom diffusion import get_model, get_models, load_params, sampling, utils\n\nMODULE_DIR = Path(__file__).resolve().parent\n\n\ndef main():\n    p = argparse.ArgumentParser(description=__doc__,\n                                formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    p.add_argument('--batch-size', '-bs', type=int, default=4,\n                   help='the number of images per batch')\n    p.add_argument('--checkpoint', type=str,\n                   help='the checkpoint to use')\n    p.add_argument('--init-1', type=str,\n                   help='the init image for the starting point')\n    p.add_argument('--init-2', type=str,\n                   help='the init image for the ending point')\n    p.add_argument('--model', type=str, choices=get_models(), required=True,\n                   help='the model to use')\n    p.add_argument('-n', type=int, default=16,\n                   help='the number of images to sample')\n    p.add_argument('--seed-1', type=int, default=0,\n                   help='the random seed for the starting point')\n    p.add_argument('--seed-2', type=int, default=1,\n                   help='the random seed for the ending point')\n    p.add_argument('--steps', type=int, default=1000,\n                   help='the number of timesteps')\n    args = p.parse_args()\n\n    model = get_model(args.model)\n    checkpoint = args.checkpoint\n    if not checkpoint:\n        checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pkl'\n    params = load_params(checkpoint)\n\n    key_1 = jax.random.PRNGKey(args.seed_1)\n    key_2 = jax.random.PRNGKey(args.seed_2)\n    latent_1 = jax.random.normal(key_1, [1, *model.shape])\n    latent_2 = jax.random.normal(key_2, [1, *model.shape])\n\n    _, y, x = model.shape\n\n    reverse_sample_step = partial(sampling.jit_reverse_sample_step, extra_args={})\n    reverse_steps = utils.get_ddpm_schedule(jnp.linspace(0, 1, args.steps + 1))\n\n    if args.init_1:\n        init_1 = Image.open(args.init_1).convert('RGB').resize((x, y), Image.LANCZOS)\n        init_1 = utils.from_pil_image(init_1)[None]\n        print('Inverting the starting init image...')\n        latent_1 = sampling.reverse_sample_loop(model, params, key_1, init_1, reverse_steps,\n                                                reverse_sample_step)\n\n    if args.init_2:\n        init_2 = Image.open(args.init_2).convert('RGB').resize((x, y), Image.LANCZOS)\n        init_2 = utils.from_pil_image(init_2)[None]\n        print('Inverting the ending init image...')\n        latent_2 = sampling.reverse_sample_loop(model, params, key_2, init_2, reverse_steps,\n                                                reverse_sample_step)\n\n    def run(weights):\n        alphas, sigmas = utils.t_to_alpha_sigma(weights)\n        latents = latent_1 * alphas[:, None, None, None] + latent_2 * sigmas[:, None, None, None]\n        sample_step = partial(sampling.jit_sample_step, extra_args={})\n        steps = utils.get_ddpm_schedule(jnp.linspace(1, 0, args.steps + 1)[:-1])\n        dummy_key = jax.random.PRNGKey(0)\n        return sampling.sample_loop(model, params, dummy_key, latents, steps, 0., sample_step)\n\n    def run_all(weights):\n        for i in trange(0, len(weights), args.batch_size):\n            outs = run(weights[i:i+args.batch_size])\n            for j, out in enumerate(outs):\n                utils.to_pil_image(out).save(f'out_{i + j:05}.png')\n\n    try:\n        print('Sampling...')\n        run_all(jnp.linspace(0, 1, args.n))\n    except KeyboardInterrupt:\n        pass\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "make_grid.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"Assembles images into a grid.\"\"\"\n\nimport argparse\nimport math\nimport sys\n\nfrom PIL import Image\n\n\ndef main():\n    p = argparse.ArgumentParser(description=__doc__)\n    p.add_argument('images', type=str, nargs='+', metavar='image',\n                   help='the input images')\n    p.add_argument('--output', '-o', type=str, default='out.png',\n                   help='the output image')\n    p.add_argument('--nrow', type=int,\n                   help='the number of images per row')\n    args = p.parse_args()\n\n    images = [Image.open(image) for image in args.images]\n    mode = images[0].mode\n    size = images[0].size\n    for image, name in zip(images, args.images):\n        if image.mode != mode:\n            print(f'Error: Image {name} had mode {image.mode}, expected {mode}', file=sys.stderr)\n            sys.exit(1)\n        if image.size != size:\n            print(f'Error: Image {name} had size {image.size}, expected {size}', file=sys.stderr)\n            sys.exit(1)\n\n    n = len(images)\n    x = args.nrow if args.nrow else math.ceil(n**0.5)\n    y = math.ceil(n / x)\n\n    output = Image.new(mode, (size[0] * x, size[1] * y))\n    for i, image in enumerate(images):\n        cur_x, cur_y = i % x, i // x\n        output.paste(image, (size[0] * cur_x, size[1] * cur_y))\n\n    output.save(args.output)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "requirements.txt",
    "content": "dm-haiku\neinops\nftfy\njax\njaxlib\nnumpy\noptax\nregex\nPillow\ntorch\ntorchvision\ntqdm\n"
  },
  {
    "path": "sample.py",
    "content": "#!/usr/bin/env python3\n\n\"\"\"Unconditional sampling from a diffusion model.\"\"\"\n\nimport argparse\nfrom functools import partial\nfrom pathlib import Path\n\nimport jax\nimport jax.numpy as jnp\nfrom PIL import Image\nfrom tqdm import tqdm, trange\n\nfrom diffusion import get_model, get_models, load_params, sampling, utils\n\nMODULE_DIR = Path(__file__).resolve().parent\n\n\ndef main():\n    p = argparse.ArgumentParser(description=__doc__,\n                                formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    p.add_argument('--batch-size', '-bs', type=int, default=1,\n                   help='the number of images per batch')\n    p.add_argument('--checkpoint', type=str,\n                   help='the checkpoint to use')\n    p.add_argument('--eta', type=float, default=1.,\n                   help='the amount of noise to add during sampling (0-1)')\n    p.add_argument('--init', type=str,\n                   help='the init image')\n    p.add_argument('--model', type=str, choices=get_models(), required=True,\n                   help='the model to use')\n    p.add_argument('-n', type=int, default=1,\n                   help='the number of images to sample')\n    p.add_argument('--seed', type=int, default=0,\n                   help='the random seed')\n    p.add_argument('--starting-timestep', '-st', type=float, default=0.9,\n                   help='the timestep to start at (used with init images)')\n    p.add_argument('--steps', type=int, default=1000,\n                   help='the number of timesteps')\n    args = p.parse_args()\n\n    model = get_model(args.model)\n    checkpoint = args.checkpoint\n    if not checkpoint:\n        checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pkl'\n    params = load_params(checkpoint)\n\n    if args.init:\n        _, y, x = model.shape\n        init = Image.open(args.init).convert('RGB').resize((x, y), Image.LANCZOS)\n        init = utils.from_pil_image(init)[None]\n\n    key = jax.random.PRNGKey(args.seed)\n\n    def run(key, n):\n        tqdm.write('Sampling...')\n        key, subkey = jax.random.split(key)\n        noise = jax.random.normal(subkey, [n, *model.shape])\n        key, subkey = jax.random.split(key)\n        sample_step = partial(sampling.jit_sample_step, extra_args={})\n        steps = utils.get_ddpm_schedule(jnp.linspace(1, 0, args.steps + 1)[:-1])\n        if args.init:\n            steps = steps[steps < args.starting_timestep]\n            alpha, sigma = utils.t_to_alpha_sigma(steps[0])\n            noise = init * alpha + noise * sigma\n        return sampling.sample_loop(model, params, subkey, noise, steps, args.eta, sample_step)\n\n    def run_all(key, n, batch_size):\n        for i in trange(0, n, batch_size):\n            key, subkey = jax.random.split(key)\n            cur_batch_size = min(n - i, batch_size)\n            outs = run(key, cur_batch_size)\n            for j, out in enumerate(outs):\n                utils.to_pil_image(out).save(f'out_{i + j:05}.png')\n\n    try:\n        run_all(key, args.n, args.batch_size)\n    except KeyboardInterrupt:\n        pass\n\n\nif __name__ == '__main__':\n    main()\n"
  }
]