Repository: crowsonkb/v-diffusion-jax Branch: master Commit: f082e75aa036 Files: 18 Total size: 51.7 KB Directory structure: gitextract_k6dpc7k6/ ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── clip_sample.py ├── diffusion/ │ ├── __init__.py │ ├── models/ │ │ ├── __init__.py │ │ ├── danbooru_128.py │ │ ├── imagenet_128.py │ │ ├── models.py │ │ ├── wikiart_128.py │ │ └── wikiart_256.py │ ├── sampling.py │ └── utils.py ├── interpolate.py ├── make_grid.py ├── requirements.txt └── sample.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ venv* __pycache__ .ipynb_checkpoints *.pkl out* *.egg-info *.ini ================================================ FILE: .gitmodules ================================================ [submodule "CLIP_JAX"] path = CLIP_JAX url = https://github.com/kingoflolz/CLIP_JAX ================================================ FILE: LICENSE ================================================ Copyright (c) 2021 Katherine Crowson and John David Pressman Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # v-diffusion-jax v objective diffusion inference code for JAX, by Katherine Crowson ([@RiversHaveWings](https://twitter.com/RiversHaveWings)) and Chainbreakers AI ([@jd_pressman](https://twitter.com/jd_pressman)). The models are denoising diffusion probabilistic models (https://arxiv.org/abs/2006.11239), which are trained to reverse a gradual noising process, allowing the models to generate samples from the learned data distributions starting from random noise. DDIM-style deterministic sampling (https://arxiv.org/abs/2010.02502) is also supported. The models are also trained on continuous timesteps. They use the 'v' objective from Progressive Distillation for Fast Sampling of Diffusion Models (https://openreview.net/forum?id=TIdIXIpzhoI). Thank you to Google's [TPU Research Cloud](https://sites.research.google/trc/about/) and [stability.ai](https://www.stability.ai) for compute to train these models! ## Dependencies - JAX ([installation instructions](https://github.com/google/jax#installation)) - dm-haiku, einops, numpy, optax, Pillow, tqdm (install with `pip install`) - CLIP_JAX (https://github.com/kingoflolz/CLIP_JAX), and its additional pip-installable dependencies: ftfy, regex, torch, torchvision (it does not need GPU PyTorch). **If you `git clone --recursive` this repo, it should fetch CLIP_JAX automatically.** ## Model checkpoints: - [Danbooru SFW 128x128](https://the-eye.eu/public/AI/models/v-diffusion/danbooru_128.pkl), SHA-256 `8551fe663dae988e619444efd99995775c7618af2f15ab5d8caf6b123513c334` - [ImageNet 128x128](https://the-eye.eu/public/AI/models/v-diffusion/imagenet_128.pkl), SHA-256 `4fc7c817b9aaa9018c6dbcbf5cd444a42f4a01856b34c49039f57fe48e090530` - [WikiArt 128x128](https://the-eye.eu/public/AI/models/v-diffusion/wikiart_128.pkl), SHA-256 `8fbe4e0206262996ff76d3f82a18dc67d3edd28631d4725e0154b51d00b9f91a` - [WikiArt 256x256](https://the-eye.eu/public/AI/models/v-diffusion/wikiart_256.pkl), SHA-256 `ebc6e77865bbb2d91dad1a0bfb670079c4992684a0e97caa28f784924c3afd81` ## Sampling ### Example If the model checkpoints are stored in `checkpoints/`, the following will generate an image: ``` ./clip_sample.py "a friendly robot, watercolor by James Gurney" --model wikiart_256 --seed 0 ``` If they are somewhere else, you need to specify the path to the checkpoint with `--checkpoint`. ### Unconditional sampling ``` usage: sample.py [-h] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT] [--eta ETA] --model {danbooru_128,imagenet_128,wikiart_128,wikiart_256} [-n N] [--seed SEED] [--steps STEPS] ``` `--batch-size`: sample this many images at a time (default 1) `--checkpoint`: manually specify the model checkpoint file `--eta`: set to 0 for deterministic (DDIM) sampling, 1 (the default) for stochastic (DDPM) sampling, and in between to interpolate between the two. DDIM is preferred for low numbers of timesteps. `--init`: specify the init image (optional) `--model`: specify the model to use `-n`: sample until this many images are sampled (default 1) `--seed`: specify the random seed (default 0) `--starting-timestep`: specify the starting timestep if an init image is used (range 0-1, default 0.9) `--steps`: specify the number of diffusion timesteps (default is 1000, can lower for faster but lower quality sampling) ### CLIP guided sampling CLIP guided sampling lets you generate images with diffusion models conditional on the output matching a text prompt. ``` usage: clip_sample.py [-h] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT] [--clip-guidance-scale CLIP_GUIDANCE_SCALE] [--eta ETA] --model {danbooru_128,imagenet_128,wikiart_128,wikiart_256} [-n N] [--seed SEED] [--steps STEPS] prompt ``` `clip_sample.py` has the same options as `sample.py` and these additional ones: `prompt`: the text prompt to use `--clip-guidance-scale`: how strongly the result should match the text prompt (default 1000) ================================================ FILE: clip_sample.py ================================================ #!/usr/bin/env python3 """CLIP guided sampling from a diffusion model.""" import argparse from functools import partial from pathlib import Path import sys from einops import repeat import jax import jax.numpy as jnp from PIL import Image from tqdm import tqdm, trange from diffusion import get_model, get_models, load_params, sampling, utils MODULE_DIR = Path(__file__).resolve().parent sys.path.append(str(MODULE_DIR / 'CLIP_JAX')) import clip_jax def make_normalize(mean, std): mean = jnp.array(mean).reshape([3, 1, 1]) std = jnp.array(std).reshape([3, 1, 1]) def inner(image): return (image - mean) / std return inner def norm2(x): """Normalizes a batch of vectors to the unit sphere.""" return x / jnp.sqrt(jnp.sum(jnp.square(x), axis=-1, keepdims=True)) def spherical_dist_loss(x, y): """Computes 1/2 the squared spherical distance between the two arguments.""" return jnp.square(jnp.arccos(jnp.sum(norm2(x) * norm2(y), axis=-1))) / 2 def main(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) p.add_argument('prompt', type=str, help='the text prompt') p.add_argument('--batch-size', '-bs', type=int, default=1, help='the number of images per batch') p.add_argument('--checkpoint', type=str, help='the checkpoint to use') p.add_argument('--clip-guidance-scale', '-cs', type=float, default=1000., help='the CLIP guidance scale') p.add_argument('--eta', type=float, default=1., help='the amount of noise to add during sampling (0-1)') p.add_argument('--init', type=str, help='the init image') p.add_argument('--model', type=str, choices=get_models(), required=True, help='the model to use') p.add_argument('-n', type=int, default=1, help='the number of images to sample') p.add_argument('--seed', type=int, default=0, help='the random seed') p.add_argument('--starting-timestep', '-st', type=float, default=0.9, help='the timestep to start at (used with init images)') p.add_argument('--steps', type=int, default=1000, help='the number of timesteps') args = p.parse_args() model = get_model(args.model) checkpoint = args.checkpoint if not checkpoint: checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pkl' params = load_params(checkpoint) image_fn, text_fn, clip_params, _ = clip_jax.load('ViT-B/16') clip_patch_size = 16 clip_size = 224 normalize = make_normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) target_embed = text_fn(clip_params, clip_jax.tokenize([args.prompt])) if args.init: _, y, x = model.shape init = Image.open(args.init).convert('RGB').resize((x, y), Image.LANCZOS) init = utils.from_pil_image(init)[None] key = jax.random.PRNGKey(args.seed) def clip_cond_fn_loss(x, key, params, clip_params, t, extra_args): dummy_key = jax.random.PRNGKey(0) v = model.apply(params, dummy_key, x, repeat(t, '-> n', n=x.shape[0]), extra_args) alpha, sigma = utils.t_to_alpha_sigma(t) pred = x * alpha - v * sigma clip_in = jax.image.resize(pred, (*pred.shape[:2], clip_size, clip_size), 'cubic') extent = clip_patch_size // 2 clip_in = jnp.pad(clip_in, [(0, 0), (0, 0), (extent, extent), (extent, extent)], 'edge') sat_vmap = jax.vmap(partial(jax.image.scale_and_translate, method='cubic'), in_axes=(0, None, None, 0, 0)) scales = jnp.ones([pred.shape[0], 2]) translates = jax.random.uniform(key, [pred.shape[0], 2], minval=-extent, maxval=extent) clip_in = sat_vmap(clip_in, (3, clip_size, clip_size), (1, 2), scales, translates) image_embeds = image_fn(clip_params, normalize((clip_in + 1) / 2)) return jnp.sum(spherical_dist_loss(image_embeds, target_embed)) def clip_cond_fn(x, key, t, extra_args, params, clip_params): grad_fn = jax.grad(clip_cond_fn_loss) grad = grad_fn(x, key, params, clip_params, t, extra_args) return grad * -args.clip_guidance_scale def run(key, n): tqdm.write('Sampling...') key, subkey = jax.random.split(key) noise = jax.random.normal(subkey, [n, *model.shape]) key, subkey = jax.random.split(key) cond_params = {'params': params, 'clip_params': clip_params} sample_step = partial(sampling.jit_cond_sample_step, extra_args={}, cond_fn=clip_cond_fn, cond_params=cond_params) steps = utils.get_ddpm_schedule(jnp.linspace(1, 0, args.steps + 1)[:-1]) if args.init: steps = steps[steps < args.starting_timestep] alpha, sigma = utils.t_to_alpha_sigma(steps[0]) noise = init * alpha + noise * sigma return sampling.sample_loop(model, params, subkey, noise, steps, args.eta, sample_step) def run_all(key, n, batch_size): for i in trange(0, n, batch_size): key, subkey = jax.random.split(key) cur_batch_size = min(n - i, batch_size) outs = run(key, cur_batch_size) for j, out in enumerate(outs): utils.to_pil_image(out).save(f'out_{i + j:05}.png') try: run_all(key, args.n, args.batch_size) except KeyboardInterrupt: pass if __name__ == '__main__': main() ================================================ FILE: diffusion/__init__.py ================================================ from . import sampling, utils from .models import get_model, get_models, load_params ================================================ FILE: diffusion/models/__init__.py ================================================ from .models import get_model, get_models, load_params ================================================ FILE: diffusion/models/danbooru_128.py ================================================ import haiku as hk import jax import jax.numpy as jnp from .. import utils class FourierFeatures(hk.Module): def __init__(self, output_size, std=1., name=None): super().__init__(name=name) assert output_size % 2 == 0 self.output_size = output_size self.std = std def __call__(self, x): w = hk.get_parameter('w', [self.output_size // 2, x.shape[1]], init=hk.initializers.RandomNormal(self.std, 0)) f = 2 * jnp.pi * x @ w.T return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1) class Dropout2d(hk.Module): def __init__(self, rate=0.5, name=None): super().__init__(name=name) self.rate = rate def __call__(self, x, enabled): rate = self.rate * enabled key = hk.next_rng_key() p = jax.random.bernoulli(key, 1.0 - rate, shape=x.shape[:2])[..., None, None] return x * p / (1.0 - rate) class SelfAttention2d(hk.Module): def __init__(self, n_head=1, dropout_rate=0.1, name=None): super().__init__(name=name) self.n_head = n_head self.dropout_rate = dropout_rate def __call__(self, x, dropout_enabled): n, c, h, w = x.shape assert c % self.n_head == 0 qkv_proj = hk.Conv2D(c * 3, 1, data_format='NCHW', name='qkv_proj') out_proj = hk.Conv2D(c, 1, data_format='NCHW', name='out_proj') dropout = Dropout2d(self.dropout_rate) qkv = qkv_proj(x) qkv = jnp.swapaxes(qkv.reshape([n, self.n_head * 3, c // self.n_head, h * w]), 2, 3) q, k, v = jnp.split(qkv, 3, axis=1) scale = k.shape[3]**-0.25 att = jax.nn.softmax((q * scale) @ (jnp.swapaxes(k, 2, 3) * scale), axis=3) y = jnp.swapaxes(att @ v, 2, 3).reshape([n, c, h, w]) return x + dropout(out_proj(y), dropout_enabled) def res_conv_block(c_mid, c_out, dropout_last=True): def inner(x, is_training): x_skip_layer = hk.Conv2D(c_out, 1, with_bias=False, data_format='NCHW') x_skip = x if x.shape[1] == c_out else x_skip_layer(x) x = hk.Conv2D(c_mid, 3, data_format='NCHW')(x) x = jax.nn.relu(x) x = Dropout2d(0.1)(x, is_training) x = hk.Conv2D(c_out, 3, data_format='NCHW')(x) x = jax.nn.relu(x) if dropout_last: x = Dropout2d(0.1)(x, is_training) return x + x_skip return inner def diffusion_model(x, t, extra_args): c = 256 is_training = jnp.array(0.) log_snr = utils.alpha_sigma_to_log_snr(*utils.t_to_alpha_sigma(t)) timestep_embed = FourierFeatures(16, 0.2)(log_snr[:, None]) te_planes = jnp.tile(timestep_embed[..., None, None], [1, 1, x.shape[2], x.shape[3]]) x = jnp.concatenate([x, te_planes], axis=1) # 128x128 x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, c)(x, is_training) x_2 = hk.AvgPool(2, 2, 'SAME', 1)(x) # 64x64 x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_3 = hk.AvgPool(2, 2, 'SAME', 1)(x_2) # 32x32 x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_4 = hk.AvgPool(2, 2, 'SAME', 1)(x_3) # 16x16 x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = SelfAttention2d(c * 4 // 128)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = SelfAttention2d(c * 4 // 128)(x_4, is_training) x_5 = hk.AvgPool(2, 2, 'SAME', 1)(x_4) # 8x8 x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = SelfAttention2d(c * 4 // 128)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = SelfAttention2d(c * 4 // 128)(x_5, is_training) x_6 = hk.AvgPool(2, 2, 'SAME', 1)(x_5) # 4x4 x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = SelfAttention2d(c * 8 // 128)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = SelfAttention2d(c * 8 // 128)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = SelfAttention2d(c * 8 // 128)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 4)(x_6, is_training) x_6 = SelfAttention2d(c * 4 // 128)(x_6, is_training) x_6 = jax.image.resize(x_6, [*x_6.shape[:2], *x_5.shape[2:]], 'nearest') x_5 = jnp.concatenate([x_5, x_6], axis=1) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = SelfAttention2d(c * 4 // 128)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = SelfAttention2d(c * 4 // 128)(x_5, is_training) x_5 = jax.image.resize(x_5, [*x_5.shape[:2], *x_4.shape[2:]], 'nearest') x_4 = jnp.concatenate([x_4, x_5], axis=1) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = SelfAttention2d(c * 4 // 128)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 2)(x_4, is_training) x_4 = SelfAttention2d(c * 2 // 128)(x_4, is_training) x_4 = jax.image.resize(x_4, [*x_4.shape[:2], *x_3.shape[2:]], 'nearest') x_3 = jnp.concatenate([x_3, x_4], axis=1) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = jax.image.resize(x_3, [*x_3.shape[:2], *x_2.shape[2:]], 'nearest') x_2 = jnp.concatenate([x_2, x_3], axis=1) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c)(x_2, is_training) x_2 = jax.image.resize(x_2, [*x_2.shape[:2], *x.shape[2:]], 'nearest') x = jnp.concatenate([x, x_2], axis=1) x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, 3, dropout_last=False)(x, is_training) return x class Danbooru128Model: init, apply = hk.transform(diffusion_model) shape = (3, 128, 128) min_t = float(utils.get_ddpm_schedule(jnp.array(0.))) max_t = float(utils.get_ddpm_schedule(jnp.array(1.))) ================================================ FILE: diffusion/models/imagenet_128.py ================================================ import haiku as hk import jax import jax.numpy as jnp from .. import utils class FourierFeatures(hk.Module): def __init__(self, output_size, std=1., name=None): super().__init__(name=name) assert output_size % 2 == 0 self.output_size = output_size self.std = std def __call__(self, x): w = hk.get_parameter('w', [self.output_size // 2, x.shape[1]], init=hk.initializers.RandomNormal(self.std, 0)) f = 2 * jnp.pi * x @ w.T return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1) class Dropout2d(hk.Module): def __init__(self, rate=0.5, name=None): super().__init__(name=name) self.rate = rate def __call__(self, x, enabled): rate = self.rate * enabled key = hk.next_rng_key() p = jax.random.bernoulli(key, 1.0 - rate, shape=x.shape[:2])[..., None, None] return x * p / (1.0 - rate) class SelfAttention2d(hk.Module): def __init__(self, c_in, n_head=1, dropout_rate=0.1, name=None): super().__init__(name=name) assert c_in % n_head == 0 self.c_in = c_in self.n_head = n_head self.dropout_rate = dropout_rate def __call__(self, x, dropout_enabled): n, c, h, w = x.shape qkv_proj = hk.Conv2D(self.c_in * 3, 1, data_format='NCHW', name='qkv_proj') out_proj = hk.Conv2D(self.c_in, 1, data_format='NCHW', name='out_proj') dropout = Dropout2d(self.dropout_rate) qkv = qkv_proj(x) qkv = jnp.swapaxes(qkv.reshape([n, self.n_head * 3, c // self.n_head, h * w]), 2, 3) q, k, v = jnp.split(qkv, 3, axis=1) scale = k.shape[3]**-0.25 att = jax.nn.softmax((q * scale) @ (jnp.swapaxes(k, 2, 3) * scale), axis=3) y = jnp.swapaxes(att @ v, 2, 3).reshape([n, c, h, w]) return x + dropout(out_proj(y), dropout_enabled) def res_conv_block(c_mid, c_out, dropout_last=True): @hk.remat def inner(x, is_training): x_skip_layer = hk.Conv2D(c_out, 1, with_bias=False, data_format='NCHW') x_skip = x if x.shape[1] == c_out else x_skip_layer(x) x = hk.Conv2D(c_mid, 3, data_format='NCHW')(x) x = jax.nn.relu(x) x = Dropout2d(0.1)(x, is_training) x = hk.Conv2D(c_out, 3, data_format='NCHW')(x) if dropout_last: x = jax.nn.relu(x) x = Dropout2d(0.1)(x, is_training) return x + x_skip return inner def diffusion_model(x, t, extra_args): c = 128 is_training = jnp.array(0.) log_snr = utils.alpha_sigma_to_log_snr(*utils.t_to_alpha_sigma(t)) timestep_embed = FourierFeatures(16, 0.2)(log_snr[:, None]) te_planes = jnp.tile(timestep_embed[..., None, None], [1, 1, x.shape[2], x.shape[3]]) x = jnp.concatenate([x, te_planes], axis=1) # 128x128 x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, c)(x, is_training) x_2 = hk.AvgPool(2, 2, 'SAME', 1)(x) # 64x64 x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_3 = hk.AvgPool(2, 2, 'SAME', 1)(x_2) # 32x32 x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_4 = hk.AvgPool(2, 2, 'SAME', 1)(x_3) # 16x16 x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training) x_5 = hk.AvgPool(2, 2, 'SAME', 1)(x_4) # 8x8 x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training) x_6 = hk.AvgPool(2, 2, 'SAME', 1)(x_5) # 4x4 x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 4)(x_6, is_training) x_6 = SelfAttention2d(c * 4, c * 4 // 128)(x_6, is_training) x_6 = jax.image.resize(x_6, [*x_6.shape[:2], *x_5.shape[2:]], 'nearest') x_5 = jnp.concatenate([x_5, x_6], axis=1) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training) x_5 = jax.image.resize(x_5, [*x_5.shape[:2], *x_4.shape[2:]], 'nearest') x_4 = jnp.concatenate([x_4, x_5], axis=1) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 2)(x_4, is_training) x_4 = SelfAttention2d(c * 2, c * 2 // 128)(x_4, is_training) x_4 = jax.image.resize(x_4, [*x_4.shape[:2], *x_3.shape[2:]], 'nearest') x_3 = jnp.concatenate([x_3, x_4], axis=1) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = jax.image.resize(x_3, [*x_3.shape[:2], *x_2.shape[2:]], 'nearest') x_2 = jnp.concatenate([x_2, x_3], axis=1) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c)(x_2, is_training) x_2 = jax.image.resize(x_2, [*x_2.shape[:2], *x.shape[2:]], 'nearest') x = jnp.concatenate([x, x_2], axis=1) x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, 3, dropout_last=False)(x, is_training) return x class ImageNet128Model: init, apply = hk.transform(diffusion_model) shape = (3, 128, 128) min_t = float(utils.get_ddpm_schedule(jnp.array(0.))) max_t = float(utils.get_ddpm_schedule(jnp.array(1.))) ================================================ FILE: diffusion/models/models.py ================================================ import pickle import jax import jax.numpy as jnp from . import danbooru_128, imagenet_128, wikiart_128, wikiart_256 models = { 'danbooru_128': danbooru_128.Danbooru128Model, 'imagenet_128': imagenet_128.ImageNet128Model, 'wikiart_128': wikiart_128.WikiArt128Model, 'wikiart_256': wikiart_256.WikiArt256Model, } def get_model(model): return models[model] def get_models(): return list(models.keys()) def load_params(checkpoint): with open(checkpoint, 'rb') as fp: return jax.tree_map(jnp.array, pickle.load(fp)['params_ema']) ================================================ FILE: diffusion/models/wikiart_128.py ================================================ import haiku as hk import jax import jax.numpy as jnp from .. import utils class FourierFeatures(hk.Module): def __init__(self, output_size, std=1., name=None): super().__init__(name=name) assert output_size % 2 == 0 self.output_size = output_size self.std = std def __call__(self, x): w = hk.get_parameter('w', [self.output_size // 2, x.shape[1]], init=hk.initializers.RandomNormal(self.std, 0)) f = 2 * jnp.pi * x @ w.T return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1) class Dropout2d(hk.Module): def __init__(self, rate=0.5, name=None): super().__init__(name=name) self.rate = rate def __call__(self, x, enabled): rate = self.rate * enabled key = hk.next_rng_key() p = jax.random.bernoulli(key, 1.0 - rate, shape=x.shape[:2])[..., None, None] return x * p / (1.0 - rate) def res_conv_block(c_mid, c_out, dropout_last=True): @hk.remat def inner(x, is_training): x_skip_layer = hk.Conv2D(c_out, 1, with_bias=False, data_format='NCHW') x_skip = x if x.shape[1] == c_out else x_skip_layer(x) x = hk.Conv2D(c_mid, 3, data_format='NCHW')(x) x = jax.nn.relu(x) x = Dropout2d(0.1)(x, is_training) x = hk.Conv2D(c_out, 3, data_format='NCHW')(x) x = jax.nn.relu(x) if dropout_last: x = Dropout2d(0.1)(x, is_training) return x + x_skip return inner def diffusion_model(x, t, extra_args): c = 128 is_training = jnp.array(0.) log_snr = utils.alpha_sigma_to_log_snr(*utils.t_to_alpha_sigma(t)) timestep_embed = FourierFeatures(16, 0.2)(log_snr[:, None]) te_planes = jnp.tile(timestep_embed[..., None, None], [1, 1, x.shape[2], x.shape[3]]) x = jnp.concatenate([x, te_planes], axis=1) # 128x128 x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, c)(x, is_training) x_2 = hk.AvgPool(2, 2, 'SAME', 1)(x) # 64x64 x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_3 = hk.AvgPool(2, 2, 'SAME', 1)(x_2) # 32x32 x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_4 = hk.AvgPool(2, 2, 'SAME', 1)(x_3) # 16x16 x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_5 = hk.AvgPool(2, 2, 'SAME', 1)(x_4) # 8x8 x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_6 = hk.AvgPool(2, 2, 'SAME', 1)(x_5) # 4x4 x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training) x_6 = res_conv_block(c * 8, c * 4)(x_6, is_training) x_6 = jax.image.resize(x_6, [*x_6.shape[:2], *x_5.shape[2:]], 'nearest') x_5 = jnp.concatenate([x_5, x_6], axis=1) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = jax.image.resize(x_5, [*x_5.shape[:2], *x_4.shape[2:]], 'nearest') x_4 = jnp.concatenate([x_4, x_5], axis=1) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training) x_4 = res_conv_block(c * 4, c * 2)(x_4, is_training) x_4 = jax.image.resize(x_4, [*x_4.shape[:2], *x_3.shape[2:]], 'nearest') x_3 = jnp.concatenate([x_3, x_4], axis=1) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = jax.image.resize(x_3, [*x_3.shape[:2], *x_2.shape[2:]], 'nearest') x_2 = jnp.concatenate([x_2, x_3], axis=1) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training) x_2 = res_conv_block(c * 2, c)(x_2, is_training) x_2 = jax.image.resize(x_2, [*x_2.shape[:2], *x.shape[2:]], 'nearest') x = jnp.concatenate([x, x_2], axis=1) x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, c)(x, is_training) x = res_conv_block(c, 3, dropout_last=False)(x, is_training) return x class WikiArt128Model: init, apply = hk.transform(diffusion_model) shape = (3, 128, 128) min_t = float(utils.get_ddpm_schedule(jnp.array(0.))) max_t = float(utils.get_ddpm_schedule(jnp.array(1.))) ================================================ FILE: diffusion/models/wikiart_256.py ================================================ import haiku as hk import jax import jax.numpy as jnp from .. import utils class FourierFeatures(hk.Module): def __init__(self, output_size, std=1., name=None): super().__init__(name=name) assert output_size % 2 == 0 self.output_size = output_size self.std = std def __call__(self, x): w = hk.get_parameter('w', [self.output_size // 2, x.shape[1]], init=hk.initializers.RandomNormal(self.std, 0)) f = 2 * jnp.pi * x @ w.T return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1) class Dropout2d(hk.Module): def __init__(self, rate=0.5, name=None): super().__init__(name=name) self.rate = rate def __call__(self, x, enabled): rate = self.rate * enabled key = hk.next_rng_key() p = jax.random.bernoulli(key, 1.0 - rate, shape=x.shape[:2])[..., None, None] return x * p / (1.0 - rate) class SelfAttention2d(hk.Module): def __init__(self, n_head=1, dropout_rate=0.1, name=None): super().__init__(name=name) self.n_head = n_head self.dropout_rate = dropout_rate def __call__(self, x, dropout_enabled): n, c, h, w = x.shape assert c % self.n_head == 0 qkv_proj = hk.Conv2D(c * 3, 1, data_format='NCHW', name='qkv_proj') out_proj = hk.Conv2D(c, 1, data_format='NCHW', name='out_proj') dropout = Dropout2d(self.dropout_rate) qkv = qkv_proj(x) qkv = jnp.swapaxes(qkv.reshape([n, self.n_head * 3, c // self.n_head, h * w]), 2, 3) q, k, v = jnp.split(qkv, 3, axis=1) scale = k.shape[3]**-0.25 att = jax.nn.softmax((q * scale) @ (jnp.swapaxes(k, 2, 3) * scale), axis=3) y = jnp.swapaxes(att @ v, 2, 3).reshape([n, c, h, w]) return x + dropout(out_proj(y), dropout_enabled) def res_conv_block(c_mid, c_out, dropout_last=True): @hk.remat def inner(x, is_training): x_skip_layer = hk.Conv2D(c_out, 1, with_bias=False, data_format='NCHW') x_skip = x if x.shape[1] == c_out else x_skip_layer(x) x = hk.Conv2D(c_mid, 3, data_format='NCHW')(x) x = jax.nn.relu(x) x = Dropout2d(0.1)(x, is_training) x = hk.Conv2D(c_out, 3, data_format='NCHW')(x) if dropout_last: x = jax.nn.relu(x) x = Dropout2d(0.1)(x, is_training) return x + x_skip return inner def diffusion_model(x, t, extra_args): c = 128 is_training = jnp.array(0.) log_snr = utils.alpha_sigma_to_log_snr(*utils.t_to_alpha_sigma(t)) timestep_embed = FourierFeatures(16, 0.2)(log_snr[:, None]) te_planes = jnp.tile(timestep_embed[..., None, None], [1, 1, x.shape[2], x.shape[3]]) x = jnp.concatenate([x, te_planes], axis=1) # 256x256 x = res_conv_block(c // 2, c // 2)(x, is_training) x = res_conv_block(c // 2, c // 2)(x, is_training) x = res_conv_block(c // 2, c // 2)(x, is_training) x = res_conv_block(c // 2, c // 2)(x, is_training) x_2 = hk.AvgPool(2, 2, 'SAME', 1)(x) # 128x128 x_2 = res_conv_block(c, c)(x_2, is_training) x_2 = res_conv_block(c, c)(x_2, is_training) x_2 = res_conv_block(c, c)(x_2, is_training) x_2 = res_conv_block(c, c)(x_2, is_training) x_3 = hk.AvgPool(2, 2, 'SAME', 1)(x_2) # 64x64 x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_4 = hk.AvgPool(2, 2, 'SAME', 1)(x_3) # 32x32 x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training) x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training) x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training) x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training) x_5 = hk.AvgPool(2, 2, 'SAME', 1)(x_4) # 16x16 x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training) x_6 = hk.AvgPool(2, 2, 'SAME', 1)(x_5) # 8x8 x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training) x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training) x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training) x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training) x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training) x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training) x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training) x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training) x_7 = hk.AvgPool(2, 2, 'SAME', 1)(x_6) # 4x4 x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training) x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training) x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training) x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training) x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training) x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training) x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training) x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training) x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training) x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training) x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training) x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training) x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training) x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training) x_7 = res_conv_block(c * 8, c * 4)(x_7, is_training) x_7 = hk.remat(SelfAttention2d(c * 4 // 128))(x_7, is_training) x_7 = jax.image.resize(x_7, [*x_7.shape[:2], *x_6.shape[2:]], 'nearest') x_6 = jnp.concatenate([x_6, x_7], axis=1) x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training) x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training) x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training) x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training) x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training) x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training) x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training) x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training) x_6 = jax.image.resize(x_6, [*x_6.shape[:2], *x_5.shape[2:]], 'nearest') x_5 = jnp.concatenate([x_5, x_6], axis=1) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training) x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training) x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training) x_5 = res_conv_block(c * 4, c * 2)(x_5, is_training) x_5 = hk.remat(SelfAttention2d(c * 2 // 128))(x_5, is_training) x_5 = jax.image.resize(x_5, [*x_5.shape[:2], *x_4.shape[2:]], 'nearest') x_4 = jnp.concatenate([x_4, x_5], axis=1) x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training) x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training) x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training) x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training) x_4 = jax.image.resize(x_4, [*x_4.shape[:2], *x_3.shape[2:]], 'nearest') x_3 = jnp.concatenate([x_3, x_4], axis=1) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training) x_3 = res_conv_block(c * 2, c)(x_3, is_training) x_3 = jax.image.resize(x_3, [*x_3.shape[:2], *x_2.shape[2:]], 'nearest') x_2 = jnp.concatenate([x_2, x_3], axis=1) x_2 = res_conv_block(c, c)(x_2, is_training) x_2 = res_conv_block(c, c)(x_2, is_training) x_2 = res_conv_block(c, c)(x_2, is_training) x_2 = res_conv_block(c, c // 2)(x_2, is_training) x_2 = jax.image.resize(x_2, [*x_2.shape[:2], *x.shape[2:]], 'nearest') x = jnp.concatenate([x, x_2], axis=1) x = res_conv_block(c // 2, c // 2)(x, is_training) x = res_conv_block(c // 2, c // 2)(x, is_training) x = res_conv_block(c // 2, c // 2)(x, is_training) x = res_conv_block(c // 2, 3, dropout_last=False)(x, is_training) return x class WikiArt256Model: init, apply = hk.transform(diffusion_model) shape = (3, 256, 256) min_t = float(utils.get_ddpm_schedule(jnp.array(0.))) max_t = float(utils.get_ddpm_schedule(jnp.array(1.))) ================================================ FILE: diffusion/sampling.py ================================================ from einops import repeat import jax import jax.numpy as jnp from tqdm import trange from . import utils def sample_step(model, params, key, x, t, t_next, eta, extra_args): dummy_key = jax.random.PRNGKey(0) v = model.apply(params, dummy_key, x, repeat(t, '-> n', n=x.shape[0]), extra_args) alpha, sigma = utils.t_to_alpha_sigma(t) key, subkey = jax.random.split(key) pred = x * alpha - v * sigma eps = x * sigma + v * alpha alpha_next, sigma_next = utils.t_to_alpha_sigma(t_next) ddim_sigma = eta * jnp.sqrt(sigma_next**2 / sigma**2) * \ jnp.sqrt(1 - alpha**2 / alpha_next**2) adjusted_sigma = jnp.sqrt(sigma_next**2 - ddim_sigma**2) x = pred * alpha_next + eps * adjusted_sigma x = x + jax.random.normal(key, x.shape) * ddim_sigma return x, pred jit_sample_step = jax.jit(sample_step, static_argnums=0) def cond_sample_step(model, params, key, x, t, t_next, eta, extra_args, cond_fn, cond_params): dummy_key = jax.random.PRNGKey(0) v = model.apply(params, dummy_key, x, repeat(t, '-> n', n=x.shape[0]), extra_args) alpha, sigma = utils.t_to_alpha_sigma(t) key, subkey = jax.random.split(key) cond_grad = cond_fn(x, subkey, t, extra_args, **cond_params) v = v - cond_grad * (sigma / alpha) pred = x * alpha - v * sigma eps = x * sigma + v * alpha alpha_next, sigma_next = utils.t_to_alpha_sigma(t_next) ddim_sigma = eta * jnp.sqrt(sigma_next**2 / sigma**2) * \ jnp.sqrt(1 - alpha**2 / alpha_next**2) adjusted_sigma = jnp.sqrt(sigma_next**2 - ddim_sigma**2) x = pred * alpha_next + eps * adjusted_sigma x = x + jax.random.normal(key, x.shape) * ddim_sigma return x, pred jit_cond_sample_step = jax.jit(cond_sample_step, static_argnums=(0, 8)) def sample_loop(model, params, key, x, steps, eta, sample_step): for i in trange(len(steps)): key, subkey = jax.random.split(key) if i < len(steps) - 1: x, _ = sample_step(model, params, subkey, x, steps[i], steps[i + 1], eta) else: _, pred = sample_step(model, params, subkey, x, steps[i], steps[i], eta) return pred def reverse_sample_step(model, params, key, x, t, t_next, extra_args): dummy_key = jax.random.PRNGKey(0) v = model.apply(params, dummy_key, x, repeat(t, '-> n', n=x.shape[0]), extra_args) alpha, sigma = utils.t_to_alpha_sigma(t) pred = x * alpha - v * sigma eps = x * sigma + v * alpha alpha_next, sigma_next = utils.t_to_alpha_sigma(t_next) x = pred * alpha_next + eps * sigma_next return x, pred jit_reverse_sample_step = jax.jit(reverse_sample_step, static_argnums=0) def reverse_sample_loop(model, params, key, x, steps, sample_step): for i in trange(len(steps) - 1): key, subkey = jax.random.split(key) x, _ = sample_step(model, params, subkey, x, steps[i], steps[i + 1]) return x ================================================ FILE: diffusion/utils.py ================================================ import jax import jax.numpy as jnp import numpy as np from PIL import Image def from_pil_image(x): """Converts from a PIL image to a JAX array.""" x = jnp.array(x) if x.ndim == 2: x = x[..., None] return x.transpose((2, 0, 1)) / 127.5 - 1 def to_pil_image(x): """Converts from a JAX array to a PIL image.""" if x.ndim == 4: assert x.shape[0] == 1 x = x[0] if x.shape[0] == 1: x = x[0] else: x = x.transpose((1, 2, 0)) arr = np.array(jnp.round(jnp.clip((x + 1) * 127.5, 0, 255)).astype(jnp.uint8)) return Image.fromarray(arr) def log_snr_to_alpha_sigma(log_snr): """Returns the scaling factors for the clean image and for the noise, given the log SNR for a timestep.""" return jnp.sqrt(jax.nn.sigmoid(log_snr)), jnp.sqrt(jax.nn.sigmoid(-log_snr)) def alpha_sigma_to_log_snr(alpha, sigma): """Returns a log snr, given the scaling factors for the clean image and for the noise.""" return jnp.log(alpha**2 / sigma**2) def t_to_alpha_sigma(t): """Returns the scaling factors for the clean image and for the noise, given a timestep.""" return jnp.cos(t * jnp.pi / 2), jnp.sin(t * jnp.pi / 2) def alpha_sigma_to_t(alpha, sigma): """Returns a timestep, given the scaling factors for the clean image and for the noise.""" return jnp.arctan2(sigma, alpha) / jnp.pi * 2 def get_ddpm_schedule(ddpm_t): """Returns timesteps for the noise schedule from the DDPM paper.""" log_snr = -jnp.log(jnp.expm1(1e-4 + 10 * ddpm_t**2)) alpha, sigma = log_snr_to_alpha_sigma(log_snr) return alpha_sigma_to_t(alpha, sigma) ================================================ FILE: interpolate.py ================================================ #!/usr/bin/env python3 """Interpolation in a diffusion model's latent space.""" import argparse from functools import partial from pathlib import Path import jax import jax.numpy as jnp from PIL import Image from tqdm import trange from diffusion import get_model, get_models, load_params, sampling, utils MODULE_DIR = Path(__file__).resolve().parent def main(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) p.add_argument('--batch-size', '-bs', type=int, default=4, help='the number of images per batch') p.add_argument('--checkpoint', type=str, help='the checkpoint to use') p.add_argument('--init-1', type=str, help='the init image for the starting point') p.add_argument('--init-2', type=str, help='the init image for the ending point') p.add_argument('--model', type=str, choices=get_models(), required=True, help='the model to use') p.add_argument('-n', type=int, default=16, help='the number of images to sample') p.add_argument('--seed-1', type=int, default=0, help='the random seed for the starting point') p.add_argument('--seed-2', type=int, default=1, help='the random seed for the ending point') p.add_argument('--steps', type=int, default=1000, help='the number of timesteps') args = p.parse_args() model = get_model(args.model) checkpoint = args.checkpoint if not checkpoint: checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pkl' params = load_params(checkpoint) key_1 = jax.random.PRNGKey(args.seed_1) key_2 = jax.random.PRNGKey(args.seed_2) latent_1 = jax.random.normal(key_1, [1, *model.shape]) latent_2 = jax.random.normal(key_2, [1, *model.shape]) _, y, x = model.shape reverse_sample_step = partial(sampling.jit_reverse_sample_step, extra_args={}) reverse_steps = utils.get_ddpm_schedule(jnp.linspace(0, 1, args.steps + 1)) if args.init_1: init_1 = Image.open(args.init_1).convert('RGB').resize((x, y), Image.LANCZOS) init_1 = utils.from_pil_image(init_1)[None] print('Inverting the starting init image...') latent_1 = sampling.reverse_sample_loop(model, params, key_1, init_1, reverse_steps, reverse_sample_step) if args.init_2: init_2 = Image.open(args.init_2).convert('RGB').resize((x, y), Image.LANCZOS) init_2 = utils.from_pil_image(init_2)[None] print('Inverting the ending init image...') latent_2 = sampling.reverse_sample_loop(model, params, key_2, init_2, reverse_steps, reverse_sample_step) def run(weights): alphas, sigmas = utils.t_to_alpha_sigma(weights) latents = latent_1 * alphas[:, None, None, None] + latent_2 * sigmas[:, None, None, None] sample_step = partial(sampling.jit_sample_step, extra_args={}) steps = utils.get_ddpm_schedule(jnp.linspace(1, 0, args.steps + 1)[:-1]) dummy_key = jax.random.PRNGKey(0) return sampling.sample_loop(model, params, dummy_key, latents, steps, 0., sample_step) def run_all(weights): for i in trange(0, len(weights), args.batch_size): outs = run(weights[i:i+args.batch_size]) for j, out in enumerate(outs): utils.to_pil_image(out).save(f'out_{i + j:05}.png') try: print('Sampling...') run_all(jnp.linspace(0, 1, args.n)) except KeyboardInterrupt: pass if __name__ == '__main__': main() ================================================ FILE: make_grid.py ================================================ #!/usr/bin/env python3 """Assembles images into a grid.""" import argparse import math import sys from PIL import Image def main(): p = argparse.ArgumentParser(description=__doc__) p.add_argument('images', type=str, nargs='+', metavar='image', help='the input images') p.add_argument('--output', '-o', type=str, default='out.png', help='the output image') p.add_argument('--nrow', type=int, help='the number of images per row') args = p.parse_args() images = [Image.open(image) for image in args.images] mode = images[0].mode size = images[0].size for image, name in zip(images, args.images): if image.mode != mode: print(f'Error: Image {name} had mode {image.mode}, expected {mode}', file=sys.stderr) sys.exit(1) if image.size != size: print(f'Error: Image {name} had size {image.size}, expected {size}', file=sys.stderr) sys.exit(1) n = len(images) x = args.nrow if args.nrow else math.ceil(n**0.5) y = math.ceil(n / x) output = Image.new(mode, (size[0] * x, size[1] * y)) for i, image in enumerate(images): cur_x, cur_y = i % x, i // x output.paste(image, (size[0] * cur_x, size[1] * cur_y)) output.save(args.output) if __name__ == '__main__': main() ================================================ FILE: requirements.txt ================================================ dm-haiku einops ftfy jax jaxlib numpy optax regex Pillow torch torchvision tqdm ================================================ FILE: sample.py ================================================ #!/usr/bin/env python3 """Unconditional sampling from a diffusion model.""" import argparse from functools import partial from pathlib import Path import jax import jax.numpy as jnp from PIL import Image from tqdm import tqdm, trange from diffusion import get_model, get_models, load_params, sampling, utils MODULE_DIR = Path(__file__).resolve().parent def main(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) p.add_argument('--batch-size', '-bs', type=int, default=1, help='the number of images per batch') p.add_argument('--checkpoint', type=str, help='the checkpoint to use') p.add_argument('--eta', type=float, default=1., help='the amount of noise to add during sampling (0-1)') p.add_argument('--init', type=str, help='the init image') p.add_argument('--model', type=str, choices=get_models(), required=True, help='the model to use') p.add_argument('-n', type=int, default=1, help='the number of images to sample') p.add_argument('--seed', type=int, default=0, help='the random seed') p.add_argument('--starting-timestep', '-st', type=float, default=0.9, help='the timestep to start at (used with init images)') p.add_argument('--steps', type=int, default=1000, help='the number of timesteps') args = p.parse_args() model = get_model(args.model) checkpoint = args.checkpoint if not checkpoint: checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pkl' params = load_params(checkpoint) if args.init: _, y, x = model.shape init = Image.open(args.init).convert('RGB').resize((x, y), Image.LANCZOS) init = utils.from_pil_image(init)[None] key = jax.random.PRNGKey(args.seed) def run(key, n): tqdm.write('Sampling...') key, subkey = jax.random.split(key) noise = jax.random.normal(subkey, [n, *model.shape]) key, subkey = jax.random.split(key) sample_step = partial(sampling.jit_sample_step, extra_args={}) steps = utils.get_ddpm_schedule(jnp.linspace(1, 0, args.steps + 1)[:-1]) if args.init: steps = steps[steps < args.starting_timestep] alpha, sigma = utils.t_to_alpha_sigma(steps[0]) noise = init * alpha + noise * sigma return sampling.sample_loop(model, params, subkey, noise, steps, args.eta, sample_step) def run_all(key, n, batch_size): for i in trange(0, n, batch_size): key, subkey = jax.random.split(key) cur_batch_size = min(n - i, batch_size) outs = run(key, cur_batch_size) for j, out in enumerate(outs): utils.to_pil_image(out).save(f'out_{i + j:05}.png') try: run_all(key, args.n, args.batch_size) except KeyboardInterrupt: pass if __name__ == '__main__': main()