Repository: crowsonkb/v-diffusion-jax
Branch: master
Commit: f082e75aa036
Files: 18
Total size: 51.7 KB
Directory structure:
gitextract_k6dpc7k6/
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── clip_sample.py
├── diffusion/
│ ├── __init__.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── danbooru_128.py
│ │ ├── imagenet_128.py
│ │ ├── models.py
│ │ ├── wikiart_128.py
│ │ └── wikiart_256.py
│ ├── sampling.py
│ └── utils.py
├── interpolate.py
├── make_grid.py
├── requirements.txt
└── sample.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
venv*
__pycache__
.ipynb_checkpoints
*.pkl
out*
*.egg-info
*.ini
================================================
FILE: .gitmodules
================================================
[submodule "CLIP_JAX"]
path = CLIP_JAX
url = https://github.com/kingoflolz/CLIP_JAX
================================================
FILE: LICENSE
================================================
Copyright (c) 2021 Katherine Crowson and John David Pressman
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
================================================
FILE: README.md
================================================
# v-diffusion-jax
v objective diffusion inference code for JAX, by Katherine Crowson ([@RiversHaveWings](https://twitter.com/RiversHaveWings)) and Chainbreakers AI ([@jd_pressman](https://twitter.com/jd_pressman)).
The models are denoising diffusion probabilistic models (https://arxiv.org/abs/2006.11239), which are trained to reverse a gradual noising process, allowing the models to generate samples from the learned data distributions starting from random noise. DDIM-style deterministic sampling (https://arxiv.org/abs/2010.02502) is also supported. The models are also trained on continuous timesteps. They use the 'v' objective from Progressive Distillation for Fast Sampling of Diffusion Models (https://openreview.net/forum?id=TIdIXIpzhoI).
Thank you to Google's [TPU Research Cloud](https://sites.research.google/trc/about/) and [stability.ai](https://www.stability.ai) for compute to train these models!
## Dependencies
- JAX ([installation instructions](https://github.com/google/jax#installation))
- dm-haiku, einops, numpy, optax, Pillow, tqdm (install with `pip install`)
- CLIP_JAX (https://github.com/kingoflolz/CLIP_JAX), and its additional pip-installable dependencies: ftfy, regex, torch, torchvision (it does not need GPU PyTorch). **If you `git clone --recursive` this repo, it should fetch CLIP_JAX automatically.**
## Model checkpoints:
- [Danbooru SFW 128x128](https://the-eye.eu/public/AI/models/v-diffusion/danbooru_128.pkl), SHA-256 `8551fe663dae988e619444efd99995775c7618af2f15ab5d8caf6b123513c334`
- [ImageNet 128x128](https://the-eye.eu/public/AI/models/v-diffusion/imagenet_128.pkl), SHA-256 `4fc7c817b9aaa9018c6dbcbf5cd444a42f4a01856b34c49039f57fe48e090530`
- [WikiArt 128x128](https://the-eye.eu/public/AI/models/v-diffusion/wikiart_128.pkl), SHA-256 `8fbe4e0206262996ff76d3f82a18dc67d3edd28631d4725e0154b51d00b9f91a`
- [WikiArt 256x256](https://the-eye.eu/public/AI/models/v-diffusion/wikiart_256.pkl), SHA-256 `ebc6e77865bbb2d91dad1a0bfb670079c4992684a0e97caa28f784924c3afd81`
## Sampling
### Example
If the model checkpoints are stored in `checkpoints/`, the following will generate an image:
```
./clip_sample.py "a friendly robot, watercolor by James Gurney" --model wikiart_256 --seed 0
```
If they are somewhere else, you need to specify the path to the checkpoint with `--checkpoint`.
### Unconditional sampling
```
usage: sample.py [-h] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT] [--eta ETA] --model
{danbooru_128,imagenet_128,wikiart_128,wikiart_256} [-n N] [--seed SEED]
[--steps STEPS]
```
`--batch-size`: sample this many images at a time (default 1)
`--checkpoint`: manually specify the model checkpoint file
`--eta`: set to 0 for deterministic (DDIM) sampling, 1 (the default) for stochastic (DDPM) sampling, and in between to interpolate between the two. DDIM is preferred for low numbers of timesteps.
`--init`: specify the init image (optional)
`--model`: specify the model to use
`-n`: sample until this many images are sampled (default 1)
`--seed`: specify the random seed (default 0)
`--starting-timestep`: specify the starting timestep if an init image is used (range 0-1, default 0.9)
`--steps`: specify the number of diffusion timesteps (default is 1000, can lower for faster but lower quality sampling)
### CLIP guided sampling
CLIP guided sampling lets you generate images with diffusion models conditional on the output matching a text prompt.
```
usage: clip_sample.py [-h] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT]
[--clip-guidance-scale CLIP_GUIDANCE_SCALE] [--eta ETA] --model
{danbooru_128,imagenet_128,wikiart_128,wikiart_256} [-n N] [--seed SEED]
[--steps STEPS]
prompt
```
`clip_sample.py` has the same options as `sample.py` and these additional ones:
`prompt`: the text prompt to use
`--clip-guidance-scale`: how strongly the result should match the text prompt (default 1000)
================================================
FILE: clip_sample.py
================================================
#!/usr/bin/env python3
"""CLIP guided sampling from a diffusion model."""
import argparse
from functools import partial
from pathlib import Path
import sys
from einops import repeat
import jax
import jax.numpy as jnp
from PIL import Image
from tqdm import tqdm, trange
from diffusion import get_model, get_models, load_params, sampling, utils
MODULE_DIR = Path(__file__).resolve().parent
sys.path.append(str(MODULE_DIR / 'CLIP_JAX'))
import clip_jax
def make_normalize(mean, std):
mean = jnp.array(mean).reshape([3, 1, 1])
std = jnp.array(std).reshape([3, 1, 1])
def inner(image):
return (image - mean) / std
return inner
def norm2(x):
"""Normalizes a batch of vectors to the unit sphere."""
return x / jnp.sqrt(jnp.sum(jnp.square(x), axis=-1, keepdims=True))
def spherical_dist_loss(x, y):
"""Computes 1/2 the squared spherical distance between the two arguments."""
return jnp.square(jnp.arccos(jnp.sum(norm2(x) * norm2(y), axis=-1))) / 2
def main():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('prompt', type=str,
help='the text prompt')
p.add_argument('--batch-size', '-bs', type=int, default=1,
help='the number of images per batch')
p.add_argument('--checkpoint', type=str,
help='the checkpoint to use')
p.add_argument('--clip-guidance-scale', '-cs', type=float, default=1000.,
help='the CLIP guidance scale')
p.add_argument('--eta', type=float, default=1.,
help='the amount of noise to add during sampling (0-1)')
p.add_argument('--init', type=str,
help='the init image')
p.add_argument('--model', type=str, choices=get_models(), required=True,
help='the model to use')
p.add_argument('-n', type=int, default=1,
help='the number of images to sample')
p.add_argument('--seed', type=int, default=0,
help='the random seed')
p.add_argument('--starting-timestep', '-st', type=float, default=0.9,
help='the timestep to start at (used with init images)')
p.add_argument('--steps', type=int, default=1000,
help='the number of timesteps')
args = p.parse_args()
model = get_model(args.model)
checkpoint = args.checkpoint
if not checkpoint:
checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pkl'
params = load_params(checkpoint)
image_fn, text_fn, clip_params, _ = clip_jax.load('ViT-B/16')
clip_patch_size = 16
clip_size = 224
normalize = make_normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
target_embed = text_fn(clip_params, clip_jax.tokenize([args.prompt]))
if args.init:
_, y, x = model.shape
init = Image.open(args.init).convert('RGB').resize((x, y), Image.LANCZOS)
init = utils.from_pil_image(init)[None]
key = jax.random.PRNGKey(args.seed)
def clip_cond_fn_loss(x, key, params, clip_params, t, extra_args):
dummy_key = jax.random.PRNGKey(0)
v = model.apply(params, dummy_key, x, repeat(t, '-> n', n=x.shape[0]), extra_args)
alpha, sigma = utils.t_to_alpha_sigma(t)
pred = x * alpha - v * sigma
clip_in = jax.image.resize(pred, (*pred.shape[:2], clip_size, clip_size), 'cubic')
extent = clip_patch_size // 2
clip_in = jnp.pad(clip_in, [(0, 0), (0, 0), (extent, extent), (extent, extent)], 'edge')
sat_vmap = jax.vmap(partial(jax.image.scale_and_translate, method='cubic'),
in_axes=(0, None, None, 0, 0))
scales = jnp.ones([pred.shape[0], 2])
translates = jax.random.uniform(key, [pred.shape[0], 2], minval=-extent, maxval=extent)
clip_in = sat_vmap(clip_in, (3, clip_size, clip_size), (1, 2), scales, translates)
image_embeds = image_fn(clip_params, normalize((clip_in + 1) / 2))
return jnp.sum(spherical_dist_loss(image_embeds, target_embed))
def clip_cond_fn(x, key, t, extra_args, params, clip_params):
grad_fn = jax.grad(clip_cond_fn_loss)
grad = grad_fn(x, key, params, clip_params, t, extra_args)
return grad * -args.clip_guidance_scale
def run(key, n):
tqdm.write('Sampling...')
key, subkey = jax.random.split(key)
noise = jax.random.normal(subkey, [n, *model.shape])
key, subkey = jax.random.split(key)
cond_params = {'params': params, 'clip_params': clip_params}
sample_step = partial(sampling.jit_cond_sample_step,
extra_args={},
cond_fn=clip_cond_fn,
cond_params=cond_params)
steps = utils.get_ddpm_schedule(jnp.linspace(1, 0, args.steps + 1)[:-1])
if args.init:
steps = steps[steps < args.starting_timestep]
alpha, sigma = utils.t_to_alpha_sigma(steps[0])
noise = init * alpha + noise * sigma
return sampling.sample_loop(model, params, subkey, noise, steps, args.eta, sample_step)
def run_all(key, n, batch_size):
for i in trange(0, n, batch_size):
key, subkey = jax.random.split(key)
cur_batch_size = min(n - i, batch_size)
outs = run(key, cur_batch_size)
for j, out in enumerate(outs):
utils.to_pil_image(out).save(f'out_{i + j:05}.png')
try:
run_all(key, args.n, args.batch_size)
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()
================================================
FILE: diffusion/__init__.py
================================================
from . import sampling, utils
from .models import get_model, get_models, load_params
================================================
FILE: diffusion/models/__init__.py
================================================
from .models import get_model, get_models, load_params
================================================
FILE: diffusion/models/danbooru_128.py
================================================
import haiku as hk
import jax
import jax.numpy as jnp
from .. import utils
class FourierFeatures(hk.Module):
def __init__(self, output_size, std=1., name=None):
super().__init__(name=name)
assert output_size % 2 == 0
self.output_size = output_size
self.std = std
def __call__(self, x):
w = hk.get_parameter('w', [self.output_size // 2, x.shape[1]],
init=hk.initializers.RandomNormal(self.std, 0))
f = 2 * jnp.pi * x @ w.T
return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1)
class Dropout2d(hk.Module):
def __init__(self, rate=0.5, name=None):
super().__init__(name=name)
self.rate = rate
def __call__(self, x, enabled):
rate = self.rate * enabled
key = hk.next_rng_key()
p = jax.random.bernoulli(key, 1.0 - rate, shape=x.shape[:2])[..., None, None]
return x * p / (1.0 - rate)
class SelfAttention2d(hk.Module):
def __init__(self, n_head=1, dropout_rate=0.1, name=None):
super().__init__(name=name)
self.n_head = n_head
self.dropout_rate = dropout_rate
def __call__(self, x, dropout_enabled):
n, c, h, w = x.shape
assert c % self.n_head == 0
qkv_proj = hk.Conv2D(c * 3, 1, data_format='NCHW', name='qkv_proj')
out_proj = hk.Conv2D(c, 1, data_format='NCHW', name='out_proj')
dropout = Dropout2d(self.dropout_rate)
qkv = qkv_proj(x)
qkv = jnp.swapaxes(qkv.reshape([n, self.n_head * 3, c // self.n_head, h * w]), 2, 3)
q, k, v = jnp.split(qkv, 3, axis=1)
scale = k.shape[3]**-0.25
att = jax.nn.softmax((q * scale) @ (jnp.swapaxes(k, 2, 3) * scale), axis=3)
y = jnp.swapaxes(att @ v, 2, 3).reshape([n, c, h, w])
return x + dropout(out_proj(y), dropout_enabled)
def res_conv_block(c_mid, c_out, dropout_last=True):
def inner(x, is_training):
x_skip_layer = hk.Conv2D(c_out, 1, with_bias=False, data_format='NCHW')
x_skip = x if x.shape[1] == c_out else x_skip_layer(x)
x = hk.Conv2D(c_mid, 3, data_format='NCHW')(x)
x = jax.nn.relu(x)
x = Dropout2d(0.1)(x, is_training)
x = hk.Conv2D(c_out, 3, data_format='NCHW')(x)
x = jax.nn.relu(x)
if dropout_last:
x = Dropout2d(0.1)(x, is_training)
return x + x_skip
return inner
def diffusion_model(x, t, extra_args):
c = 256
is_training = jnp.array(0.)
log_snr = utils.alpha_sigma_to_log_snr(*utils.t_to_alpha_sigma(t))
timestep_embed = FourierFeatures(16, 0.2)(log_snr[:, None])
te_planes = jnp.tile(timestep_embed[..., None, None], [1, 1, x.shape[2], x.shape[3]])
x = jnp.concatenate([x, te_planes], axis=1) # 128x128
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, c)(x, is_training)
x_2 = hk.AvgPool(2, 2, 'SAME', 1)(x) # 64x64
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_3 = hk.AvgPool(2, 2, 'SAME', 1)(x_2) # 32x32
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_4 = hk.AvgPool(2, 2, 'SAME', 1)(x_3) # 16x16
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = SelfAttention2d(c * 4 // 128)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = SelfAttention2d(c * 4 // 128)(x_4, is_training)
x_5 = hk.AvgPool(2, 2, 'SAME', 1)(x_4) # 8x8
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = SelfAttention2d(c * 4 // 128)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = SelfAttention2d(c * 4 // 128)(x_5, is_training)
x_6 = hk.AvgPool(2, 2, 'SAME', 1)(x_5) # 4x4
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = SelfAttention2d(c * 8 // 128)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = SelfAttention2d(c * 8 // 128)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = SelfAttention2d(c * 8 // 128)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 4)(x_6, is_training)
x_6 = SelfAttention2d(c * 4 // 128)(x_6, is_training)
x_6 = jax.image.resize(x_6, [*x_6.shape[:2], *x_5.shape[2:]], 'nearest')
x_5 = jnp.concatenate([x_5, x_6], axis=1)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = SelfAttention2d(c * 4 // 128)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = SelfAttention2d(c * 4 // 128)(x_5, is_training)
x_5 = jax.image.resize(x_5, [*x_5.shape[:2], *x_4.shape[2:]], 'nearest')
x_4 = jnp.concatenate([x_4, x_5], axis=1)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = SelfAttention2d(c * 4 // 128)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 2)(x_4, is_training)
x_4 = SelfAttention2d(c * 2 // 128)(x_4, is_training)
x_4 = jax.image.resize(x_4, [*x_4.shape[:2], *x_3.shape[2:]], 'nearest')
x_3 = jnp.concatenate([x_3, x_4], axis=1)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = jax.image.resize(x_3, [*x_3.shape[:2], *x_2.shape[2:]], 'nearest')
x_2 = jnp.concatenate([x_2, x_3], axis=1)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c)(x_2, is_training)
x_2 = jax.image.resize(x_2, [*x_2.shape[:2], *x.shape[2:]], 'nearest')
x = jnp.concatenate([x, x_2], axis=1)
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, 3, dropout_last=False)(x, is_training)
return x
class Danbooru128Model:
init, apply = hk.transform(diffusion_model)
shape = (3, 128, 128)
min_t = float(utils.get_ddpm_schedule(jnp.array(0.)))
max_t = float(utils.get_ddpm_schedule(jnp.array(1.)))
================================================
FILE: diffusion/models/imagenet_128.py
================================================
import haiku as hk
import jax
import jax.numpy as jnp
from .. import utils
class FourierFeatures(hk.Module):
def __init__(self, output_size, std=1., name=None):
super().__init__(name=name)
assert output_size % 2 == 0
self.output_size = output_size
self.std = std
def __call__(self, x):
w = hk.get_parameter('w', [self.output_size // 2, x.shape[1]],
init=hk.initializers.RandomNormal(self.std, 0))
f = 2 * jnp.pi * x @ w.T
return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1)
class Dropout2d(hk.Module):
def __init__(self, rate=0.5, name=None):
super().__init__(name=name)
self.rate = rate
def __call__(self, x, enabled):
rate = self.rate * enabled
key = hk.next_rng_key()
p = jax.random.bernoulli(key, 1.0 - rate, shape=x.shape[:2])[..., None, None]
return x * p / (1.0 - rate)
class SelfAttention2d(hk.Module):
def __init__(self, c_in, n_head=1, dropout_rate=0.1, name=None):
super().__init__(name=name)
assert c_in % n_head == 0
self.c_in = c_in
self.n_head = n_head
self.dropout_rate = dropout_rate
def __call__(self, x, dropout_enabled):
n, c, h, w = x.shape
qkv_proj = hk.Conv2D(self.c_in * 3, 1, data_format='NCHW', name='qkv_proj')
out_proj = hk.Conv2D(self.c_in, 1, data_format='NCHW', name='out_proj')
dropout = Dropout2d(self.dropout_rate)
qkv = qkv_proj(x)
qkv = jnp.swapaxes(qkv.reshape([n, self.n_head * 3, c // self.n_head, h * w]), 2, 3)
q, k, v = jnp.split(qkv, 3, axis=1)
scale = k.shape[3]**-0.25
att = jax.nn.softmax((q * scale) @ (jnp.swapaxes(k, 2, 3) * scale), axis=3)
y = jnp.swapaxes(att @ v, 2, 3).reshape([n, c, h, w])
return x + dropout(out_proj(y), dropout_enabled)
def res_conv_block(c_mid, c_out, dropout_last=True):
@hk.remat
def inner(x, is_training):
x_skip_layer = hk.Conv2D(c_out, 1, with_bias=False, data_format='NCHW')
x_skip = x if x.shape[1] == c_out else x_skip_layer(x)
x = hk.Conv2D(c_mid, 3, data_format='NCHW')(x)
x = jax.nn.relu(x)
x = Dropout2d(0.1)(x, is_training)
x = hk.Conv2D(c_out, 3, data_format='NCHW')(x)
if dropout_last:
x = jax.nn.relu(x)
x = Dropout2d(0.1)(x, is_training)
return x + x_skip
return inner
def diffusion_model(x, t, extra_args):
c = 128
is_training = jnp.array(0.)
log_snr = utils.alpha_sigma_to_log_snr(*utils.t_to_alpha_sigma(t))
timestep_embed = FourierFeatures(16, 0.2)(log_snr[:, None])
te_planes = jnp.tile(timestep_embed[..., None, None], [1, 1, x.shape[2], x.shape[3]])
x = jnp.concatenate([x, te_planes], axis=1) # 128x128
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, c)(x, is_training)
x_2 = hk.AvgPool(2, 2, 'SAME', 1)(x) # 64x64
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_3 = hk.AvgPool(2, 2, 'SAME', 1)(x_2) # 32x32
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_4 = hk.AvgPool(2, 2, 'SAME', 1)(x_3) # 16x16
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)
x_5 = hk.AvgPool(2, 2, 'SAME', 1)(x_4) # 8x8
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)
x_6 = hk.AvgPool(2, 2, 'SAME', 1)(x_5) # 4x4
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = SelfAttention2d(c * 8, c * 8 // 128)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 4)(x_6, is_training)
x_6 = SelfAttention2d(c * 4, c * 4 // 128)(x_6, is_training)
x_6 = jax.image.resize(x_6, [*x_6.shape[:2], *x_5.shape[2:]], 'nearest')
x_5 = jnp.concatenate([x_5, x_6], axis=1)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = SelfAttention2d(c * 4, c * 4 // 128)(x_5, is_training)
x_5 = jax.image.resize(x_5, [*x_5.shape[:2], *x_4.shape[2:]], 'nearest')
x_4 = jnp.concatenate([x_4, x_5], axis=1)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = SelfAttention2d(c * 4, c * 4 // 128)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 2)(x_4, is_training)
x_4 = SelfAttention2d(c * 2, c * 2 // 128)(x_4, is_training)
x_4 = jax.image.resize(x_4, [*x_4.shape[:2], *x_3.shape[2:]], 'nearest')
x_3 = jnp.concatenate([x_3, x_4], axis=1)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = jax.image.resize(x_3, [*x_3.shape[:2], *x_2.shape[2:]], 'nearest')
x_2 = jnp.concatenate([x_2, x_3], axis=1)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c)(x_2, is_training)
x_2 = jax.image.resize(x_2, [*x_2.shape[:2], *x.shape[2:]], 'nearest')
x = jnp.concatenate([x, x_2], axis=1)
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, 3, dropout_last=False)(x, is_training)
return x
class ImageNet128Model:
init, apply = hk.transform(diffusion_model)
shape = (3, 128, 128)
min_t = float(utils.get_ddpm_schedule(jnp.array(0.)))
max_t = float(utils.get_ddpm_schedule(jnp.array(1.)))
================================================
FILE: diffusion/models/models.py
================================================
import pickle
import jax
import jax.numpy as jnp
from . import danbooru_128, imagenet_128, wikiart_128, wikiart_256
models = {
'danbooru_128': danbooru_128.Danbooru128Model,
'imagenet_128': imagenet_128.ImageNet128Model,
'wikiart_128': wikiart_128.WikiArt128Model,
'wikiart_256': wikiart_256.WikiArt256Model,
}
def get_model(model):
return models[model]
def get_models():
return list(models.keys())
def load_params(checkpoint):
with open(checkpoint, 'rb') as fp:
return jax.tree_map(jnp.array, pickle.load(fp)['params_ema'])
================================================
FILE: diffusion/models/wikiart_128.py
================================================
import haiku as hk
import jax
import jax.numpy as jnp
from .. import utils
class FourierFeatures(hk.Module):
def __init__(self, output_size, std=1., name=None):
super().__init__(name=name)
assert output_size % 2 == 0
self.output_size = output_size
self.std = std
def __call__(self, x):
w = hk.get_parameter('w', [self.output_size // 2, x.shape[1]],
init=hk.initializers.RandomNormal(self.std, 0))
f = 2 * jnp.pi * x @ w.T
return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1)
class Dropout2d(hk.Module):
def __init__(self, rate=0.5, name=None):
super().__init__(name=name)
self.rate = rate
def __call__(self, x, enabled):
rate = self.rate * enabled
key = hk.next_rng_key()
p = jax.random.bernoulli(key, 1.0 - rate, shape=x.shape[:2])[..., None, None]
return x * p / (1.0 - rate)
def res_conv_block(c_mid, c_out, dropout_last=True):
@hk.remat
def inner(x, is_training):
x_skip_layer = hk.Conv2D(c_out, 1, with_bias=False, data_format='NCHW')
x_skip = x if x.shape[1] == c_out else x_skip_layer(x)
x = hk.Conv2D(c_mid, 3, data_format='NCHW')(x)
x = jax.nn.relu(x)
x = Dropout2d(0.1)(x, is_training)
x = hk.Conv2D(c_out, 3, data_format='NCHW')(x)
x = jax.nn.relu(x)
if dropout_last:
x = Dropout2d(0.1)(x, is_training)
return x + x_skip
return inner
def diffusion_model(x, t, extra_args):
c = 128
is_training = jnp.array(0.)
log_snr = utils.alpha_sigma_to_log_snr(*utils.t_to_alpha_sigma(t))
timestep_embed = FourierFeatures(16, 0.2)(log_snr[:, None])
te_planes = jnp.tile(timestep_embed[..., None, None], [1, 1, x.shape[2], x.shape[3]])
x = jnp.concatenate([x, te_planes], axis=1) # 128x128
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, c)(x, is_training)
x_2 = hk.AvgPool(2, 2, 'SAME', 1)(x) # 64x64
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_3 = hk.AvgPool(2, 2, 'SAME', 1)(x_2) # 32x32
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_4 = hk.AvgPool(2, 2, 'SAME', 1)(x_3) # 16x16
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_5 = hk.AvgPool(2, 2, 'SAME', 1)(x_4) # 8x8
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_6 = hk.AvgPool(2, 2, 'SAME', 1)(x_5) # 4x4
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 8)(x_6, is_training)
x_6 = res_conv_block(c * 8, c * 4)(x_6, is_training)
x_6 = jax.image.resize(x_6, [*x_6.shape[:2], *x_5.shape[2:]], 'nearest')
x_5 = jnp.concatenate([x_5, x_6], axis=1)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = jax.image.resize(x_5, [*x_5.shape[:2], *x_4.shape[2:]], 'nearest')
x_4 = jnp.concatenate([x_4, x_5], axis=1)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 4)(x_4, is_training)
x_4 = res_conv_block(c * 4, c * 2)(x_4, is_training)
x_4 = jax.image.resize(x_4, [*x_4.shape[:2], *x_3.shape[2:]], 'nearest')
x_3 = jnp.concatenate([x_3, x_4], axis=1)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = jax.image.resize(x_3, [*x_3.shape[:2], *x_2.shape[2:]], 'nearest')
x_2 = jnp.concatenate([x_2, x_3], axis=1)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c * 2)(x_2, is_training)
x_2 = res_conv_block(c * 2, c)(x_2, is_training)
x_2 = jax.image.resize(x_2, [*x_2.shape[:2], *x.shape[2:]], 'nearest')
x = jnp.concatenate([x, x_2], axis=1)
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, c)(x, is_training)
x = res_conv_block(c, 3, dropout_last=False)(x, is_training)
return x
class WikiArt128Model:
init, apply = hk.transform(diffusion_model)
shape = (3, 128, 128)
min_t = float(utils.get_ddpm_schedule(jnp.array(0.)))
max_t = float(utils.get_ddpm_schedule(jnp.array(1.)))
================================================
FILE: diffusion/models/wikiart_256.py
================================================
import haiku as hk
import jax
import jax.numpy as jnp
from .. import utils
class FourierFeatures(hk.Module):
def __init__(self, output_size, std=1., name=None):
super().__init__(name=name)
assert output_size % 2 == 0
self.output_size = output_size
self.std = std
def __call__(self, x):
w = hk.get_parameter('w', [self.output_size // 2, x.shape[1]],
init=hk.initializers.RandomNormal(self.std, 0))
f = 2 * jnp.pi * x @ w.T
return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1)
class Dropout2d(hk.Module):
def __init__(self, rate=0.5, name=None):
super().__init__(name=name)
self.rate = rate
def __call__(self, x, enabled):
rate = self.rate * enabled
key = hk.next_rng_key()
p = jax.random.bernoulli(key, 1.0 - rate, shape=x.shape[:2])[..., None, None]
return x * p / (1.0 - rate)
class SelfAttention2d(hk.Module):
def __init__(self, n_head=1, dropout_rate=0.1, name=None):
super().__init__(name=name)
self.n_head = n_head
self.dropout_rate = dropout_rate
def __call__(self, x, dropout_enabled):
n, c, h, w = x.shape
assert c % self.n_head == 0
qkv_proj = hk.Conv2D(c * 3, 1, data_format='NCHW', name='qkv_proj')
out_proj = hk.Conv2D(c, 1, data_format='NCHW', name='out_proj')
dropout = Dropout2d(self.dropout_rate)
qkv = qkv_proj(x)
qkv = jnp.swapaxes(qkv.reshape([n, self.n_head * 3, c // self.n_head, h * w]), 2, 3)
q, k, v = jnp.split(qkv, 3, axis=1)
scale = k.shape[3]**-0.25
att = jax.nn.softmax((q * scale) @ (jnp.swapaxes(k, 2, 3) * scale), axis=3)
y = jnp.swapaxes(att @ v, 2, 3).reshape([n, c, h, w])
return x + dropout(out_proj(y), dropout_enabled)
def res_conv_block(c_mid, c_out, dropout_last=True):
@hk.remat
def inner(x, is_training):
x_skip_layer = hk.Conv2D(c_out, 1, with_bias=False, data_format='NCHW')
x_skip = x if x.shape[1] == c_out else x_skip_layer(x)
x = hk.Conv2D(c_mid, 3, data_format='NCHW')(x)
x = jax.nn.relu(x)
x = Dropout2d(0.1)(x, is_training)
x = hk.Conv2D(c_out, 3, data_format='NCHW')(x)
if dropout_last:
x = jax.nn.relu(x)
x = Dropout2d(0.1)(x, is_training)
return x + x_skip
return inner
def diffusion_model(x, t, extra_args):
c = 128
is_training = jnp.array(0.)
log_snr = utils.alpha_sigma_to_log_snr(*utils.t_to_alpha_sigma(t))
timestep_embed = FourierFeatures(16, 0.2)(log_snr[:, None])
te_planes = jnp.tile(timestep_embed[..., None, None], [1, 1, x.shape[2], x.shape[3]])
x = jnp.concatenate([x, te_planes], axis=1) # 256x256
x = res_conv_block(c // 2, c // 2)(x, is_training)
x = res_conv_block(c // 2, c // 2)(x, is_training)
x = res_conv_block(c // 2, c // 2)(x, is_training)
x = res_conv_block(c // 2, c // 2)(x, is_training)
x_2 = hk.AvgPool(2, 2, 'SAME', 1)(x) # 128x128
x_2 = res_conv_block(c, c)(x_2, is_training)
x_2 = res_conv_block(c, c)(x_2, is_training)
x_2 = res_conv_block(c, c)(x_2, is_training)
x_2 = res_conv_block(c, c)(x_2, is_training)
x_3 = hk.AvgPool(2, 2, 'SAME', 1)(x_2) # 64x64
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_4 = hk.AvgPool(2, 2, 'SAME', 1)(x_3) # 32x32
x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)
x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)
x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)
x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)
x_5 = hk.AvgPool(2, 2, 'SAME', 1)(x_4) # 16x16
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)
x_6 = hk.AvgPool(2, 2, 'SAME', 1)(x_5) # 8x8
x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)
x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)
x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)
x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)
x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)
x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)
x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)
x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)
x_7 = hk.AvgPool(2, 2, 'SAME', 1)(x_6) # 4x4
x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)
x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)
x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)
x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)
x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)
x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)
x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)
x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)
x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)
x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)
x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)
x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)
x_7 = res_conv_block(c * 8, c * 8)(x_7, is_training)
x_7 = hk.remat(SelfAttention2d(c * 8 // 128))(x_7, is_training)
x_7 = res_conv_block(c * 8, c * 4)(x_7, is_training)
x_7 = hk.remat(SelfAttention2d(c * 4 // 128))(x_7, is_training)
x_7 = jax.image.resize(x_7, [*x_7.shape[:2], *x_6.shape[2:]], 'nearest')
x_6 = jnp.concatenate([x_6, x_7], axis=1)
x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)
x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)
x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)
x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)
x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)
x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)
x_6 = res_conv_block(c * 4, c * 4)(x_6, is_training)
x_6 = hk.remat(SelfAttention2d(c * 4 // 128))(x_6, is_training)
x_6 = jax.image.resize(x_6, [*x_6.shape[:2], *x_5.shape[2:]], 'nearest')
x_5 = jnp.concatenate([x_5, x_6], axis=1)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 4)(x_5, is_training)
x_5 = hk.remat(SelfAttention2d(c * 4 // 128))(x_5, is_training)
x_5 = res_conv_block(c * 4, c * 2)(x_5, is_training)
x_5 = hk.remat(SelfAttention2d(c * 2 // 128))(x_5, is_training)
x_5 = jax.image.resize(x_5, [*x_5.shape[:2], *x_4.shape[2:]], 'nearest')
x_4 = jnp.concatenate([x_4, x_5], axis=1)
x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)
x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)
x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)
x_4 = res_conv_block(c * 2, c * 2)(x_4, is_training)
x_4 = jax.image.resize(x_4, [*x_4.shape[:2], *x_3.shape[2:]], 'nearest')
x_3 = jnp.concatenate([x_3, x_4], axis=1)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c * 2)(x_3, is_training)
x_3 = res_conv_block(c * 2, c)(x_3, is_training)
x_3 = jax.image.resize(x_3, [*x_3.shape[:2], *x_2.shape[2:]], 'nearest')
x_2 = jnp.concatenate([x_2, x_3], axis=1)
x_2 = res_conv_block(c, c)(x_2, is_training)
x_2 = res_conv_block(c, c)(x_2, is_training)
x_2 = res_conv_block(c, c)(x_2, is_training)
x_2 = res_conv_block(c, c // 2)(x_2, is_training)
x_2 = jax.image.resize(x_2, [*x_2.shape[:2], *x.shape[2:]], 'nearest')
x = jnp.concatenate([x, x_2], axis=1)
x = res_conv_block(c // 2, c // 2)(x, is_training)
x = res_conv_block(c // 2, c // 2)(x, is_training)
x = res_conv_block(c // 2, c // 2)(x, is_training)
x = res_conv_block(c // 2, 3, dropout_last=False)(x, is_training)
return x
class WikiArt256Model:
init, apply = hk.transform(diffusion_model)
shape = (3, 256, 256)
min_t = float(utils.get_ddpm_schedule(jnp.array(0.)))
max_t = float(utils.get_ddpm_schedule(jnp.array(1.)))
================================================
FILE: diffusion/sampling.py
================================================
from einops import repeat
import jax
import jax.numpy as jnp
from tqdm import trange
from . import utils
def sample_step(model, params, key, x, t, t_next, eta, extra_args):
dummy_key = jax.random.PRNGKey(0)
v = model.apply(params, dummy_key, x, repeat(t, '-> n', n=x.shape[0]), extra_args)
alpha, sigma = utils.t_to_alpha_sigma(t)
key, subkey = jax.random.split(key)
pred = x * alpha - v * sigma
eps = x * sigma + v * alpha
alpha_next, sigma_next = utils.t_to_alpha_sigma(t_next)
ddim_sigma = eta * jnp.sqrt(sigma_next**2 / sigma**2) * \
jnp.sqrt(1 - alpha**2 / alpha_next**2)
adjusted_sigma = jnp.sqrt(sigma_next**2 - ddim_sigma**2)
x = pred * alpha_next + eps * adjusted_sigma
x = x + jax.random.normal(key, x.shape) * ddim_sigma
return x, pred
jit_sample_step = jax.jit(sample_step, static_argnums=0)
def cond_sample_step(model, params, key, x, t, t_next, eta, extra_args, cond_fn, cond_params):
dummy_key = jax.random.PRNGKey(0)
v = model.apply(params, dummy_key, x, repeat(t, '-> n', n=x.shape[0]), extra_args)
alpha, sigma = utils.t_to_alpha_sigma(t)
key, subkey = jax.random.split(key)
cond_grad = cond_fn(x, subkey, t, extra_args, **cond_params)
v = v - cond_grad * (sigma / alpha)
pred = x * alpha - v * sigma
eps = x * sigma + v * alpha
alpha_next, sigma_next = utils.t_to_alpha_sigma(t_next)
ddim_sigma = eta * jnp.sqrt(sigma_next**2 / sigma**2) * \
jnp.sqrt(1 - alpha**2 / alpha_next**2)
adjusted_sigma = jnp.sqrt(sigma_next**2 - ddim_sigma**2)
x = pred * alpha_next + eps * adjusted_sigma
x = x + jax.random.normal(key, x.shape) * ddim_sigma
return x, pred
jit_cond_sample_step = jax.jit(cond_sample_step, static_argnums=(0, 8))
def sample_loop(model, params, key, x, steps, eta, sample_step):
for i in trange(len(steps)):
key, subkey = jax.random.split(key)
if i < len(steps) - 1:
x, _ = sample_step(model, params, subkey, x, steps[i], steps[i + 1], eta)
else:
_, pred = sample_step(model, params, subkey, x, steps[i], steps[i], eta)
return pred
def reverse_sample_step(model, params, key, x, t, t_next, extra_args):
dummy_key = jax.random.PRNGKey(0)
v = model.apply(params, dummy_key, x, repeat(t, '-> n', n=x.shape[0]), extra_args)
alpha, sigma = utils.t_to_alpha_sigma(t)
pred = x * alpha - v * sigma
eps = x * sigma + v * alpha
alpha_next, sigma_next = utils.t_to_alpha_sigma(t_next)
x = pred * alpha_next + eps * sigma_next
return x, pred
jit_reverse_sample_step = jax.jit(reverse_sample_step, static_argnums=0)
def reverse_sample_loop(model, params, key, x, steps, sample_step):
for i in trange(len(steps) - 1):
key, subkey = jax.random.split(key)
x, _ = sample_step(model, params, subkey, x, steps[i], steps[i + 1])
return x
================================================
FILE: diffusion/utils.py
================================================
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
def from_pil_image(x):
"""Converts from a PIL image to a JAX array."""
x = jnp.array(x)
if x.ndim == 2:
x = x[..., None]
return x.transpose((2, 0, 1)) / 127.5 - 1
def to_pil_image(x):
"""Converts from a JAX array to a PIL image."""
if x.ndim == 4:
assert x.shape[0] == 1
x = x[0]
if x.shape[0] == 1:
x = x[0]
else:
x = x.transpose((1, 2, 0))
arr = np.array(jnp.round(jnp.clip((x + 1) * 127.5, 0, 255)).astype(jnp.uint8))
return Image.fromarray(arr)
def log_snr_to_alpha_sigma(log_snr):
"""Returns the scaling factors for the clean image and for the noise, given
the log SNR for a timestep."""
return jnp.sqrt(jax.nn.sigmoid(log_snr)), jnp.sqrt(jax.nn.sigmoid(-log_snr))
def alpha_sigma_to_log_snr(alpha, sigma):
"""Returns a log snr, given the scaling factors for the clean image and for
the noise."""
return jnp.log(alpha**2 / sigma**2)
def t_to_alpha_sigma(t):
"""Returns the scaling factors for the clean image and for the noise, given
a timestep."""
return jnp.cos(t * jnp.pi / 2), jnp.sin(t * jnp.pi / 2)
def alpha_sigma_to_t(alpha, sigma):
"""Returns a timestep, given the scaling factors for the clean image and for
the noise."""
return jnp.arctan2(sigma, alpha) / jnp.pi * 2
def get_ddpm_schedule(ddpm_t):
"""Returns timesteps for the noise schedule from the DDPM paper."""
log_snr = -jnp.log(jnp.expm1(1e-4 + 10 * ddpm_t**2))
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
return alpha_sigma_to_t(alpha, sigma)
================================================
FILE: interpolate.py
================================================
#!/usr/bin/env python3
"""Interpolation in a diffusion model's latent space."""
import argparse
from functools import partial
from pathlib import Path
import jax
import jax.numpy as jnp
from PIL import Image
from tqdm import trange
from diffusion import get_model, get_models, load_params, sampling, utils
MODULE_DIR = Path(__file__).resolve().parent
def main():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('--batch-size', '-bs', type=int, default=4,
help='the number of images per batch')
p.add_argument('--checkpoint', type=str,
help='the checkpoint to use')
p.add_argument('--init-1', type=str,
help='the init image for the starting point')
p.add_argument('--init-2', type=str,
help='the init image for the ending point')
p.add_argument('--model', type=str, choices=get_models(), required=True,
help='the model to use')
p.add_argument('-n', type=int, default=16,
help='the number of images to sample')
p.add_argument('--seed-1', type=int, default=0,
help='the random seed for the starting point')
p.add_argument('--seed-2', type=int, default=1,
help='the random seed for the ending point')
p.add_argument('--steps', type=int, default=1000,
help='the number of timesteps')
args = p.parse_args()
model = get_model(args.model)
checkpoint = args.checkpoint
if not checkpoint:
checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pkl'
params = load_params(checkpoint)
key_1 = jax.random.PRNGKey(args.seed_1)
key_2 = jax.random.PRNGKey(args.seed_2)
latent_1 = jax.random.normal(key_1, [1, *model.shape])
latent_2 = jax.random.normal(key_2, [1, *model.shape])
_, y, x = model.shape
reverse_sample_step = partial(sampling.jit_reverse_sample_step, extra_args={})
reverse_steps = utils.get_ddpm_schedule(jnp.linspace(0, 1, args.steps + 1))
if args.init_1:
init_1 = Image.open(args.init_1).convert('RGB').resize((x, y), Image.LANCZOS)
init_1 = utils.from_pil_image(init_1)[None]
print('Inverting the starting init image...')
latent_1 = sampling.reverse_sample_loop(model, params, key_1, init_1, reverse_steps,
reverse_sample_step)
if args.init_2:
init_2 = Image.open(args.init_2).convert('RGB').resize((x, y), Image.LANCZOS)
init_2 = utils.from_pil_image(init_2)[None]
print('Inverting the ending init image...')
latent_2 = sampling.reverse_sample_loop(model, params, key_2, init_2, reverse_steps,
reverse_sample_step)
def run(weights):
alphas, sigmas = utils.t_to_alpha_sigma(weights)
latents = latent_1 * alphas[:, None, None, None] + latent_2 * sigmas[:, None, None, None]
sample_step = partial(sampling.jit_sample_step, extra_args={})
steps = utils.get_ddpm_schedule(jnp.linspace(1, 0, args.steps + 1)[:-1])
dummy_key = jax.random.PRNGKey(0)
return sampling.sample_loop(model, params, dummy_key, latents, steps, 0., sample_step)
def run_all(weights):
for i in trange(0, len(weights), args.batch_size):
outs = run(weights[i:i+args.batch_size])
for j, out in enumerate(outs):
utils.to_pil_image(out).save(f'out_{i + j:05}.png')
try:
print('Sampling...')
run_all(jnp.linspace(0, 1, args.n))
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()
================================================
FILE: make_grid.py
================================================
#!/usr/bin/env python3
"""Assembles images into a grid."""
import argparse
import math
import sys
from PIL import Image
def main():
p = argparse.ArgumentParser(description=__doc__)
p.add_argument('images', type=str, nargs='+', metavar='image',
help='the input images')
p.add_argument('--output', '-o', type=str, default='out.png',
help='the output image')
p.add_argument('--nrow', type=int,
help='the number of images per row')
args = p.parse_args()
images = [Image.open(image) for image in args.images]
mode = images[0].mode
size = images[0].size
for image, name in zip(images, args.images):
if image.mode != mode:
print(f'Error: Image {name} had mode {image.mode}, expected {mode}', file=sys.stderr)
sys.exit(1)
if image.size != size:
print(f'Error: Image {name} had size {image.size}, expected {size}', file=sys.stderr)
sys.exit(1)
n = len(images)
x = args.nrow if args.nrow else math.ceil(n**0.5)
y = math.ceil(n / x)
output = Image.new(mode, (size[0] * x, size[1] * y))
for i, image in enumerate(images):
cur_x, cur_y = i % x, i // x
output.paste(image, (size[0] * cur_x, size[1] * cur_y))
output.save(args.output)
if __name__ == '__main__':
main()
================================================
FILE: requirements.txt
================================================
dm-haiku
einops
ftfy
jax
jaxlib
numpy
optax
regex
Pillow
torch
torchvision
tqdm
================================================
FILE: sample.py
================================================
#!/usr/bin/env python3
"""Unconditional sampling from a diffusion model."""
import argparse
from functools import partial
from pathlib import Path
import jax
import jax.numpy as jnp
from PIL import Image
from tqdm import tqdm, trange
from diffusion import get_model, get_models, load_params, sampling, utils
MODULE_DIR = Path(__file__).resolve().parent
def main():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('--batch-size', '-bs', type=int, default=1,
help='the number of images per batch')
p.add_argument('--checkpoint', type=str,
help='the checkpoint to use')
p.add_argument('--eta', type=float, default=1.,
help='the amount of noise to add during sampling (0-1)')
p.add_argument('--init', type=str,
help='the init image')
p.add_argument('--model', type=str, choices=get_models(), required=True,
help='the model to use')
p.add_argument('-n', type=int, default=1,
help='the number of images to sample')
p.add_argument('--seed', type=int, default=0,
help='the random seed')
p.add_argument('--starting-timestep', '-st', type=float, default=0.9,
help='the timestep to start at (used with init images)')
p.add_argument('--steps', type=int, default=1000,
help='the number of timesteps')
args = p.parse_args()
model = get_model(args.model)
checkpoint = args.checkpoint
if not checkpoint:
checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pkl'
params = load_params(checkpoint)
if args.init:
_, y, x = model.shape
init = Image.open(args.init).convert('RGB').resize((x, y), Image.LANCZOS)
init = utils.from_pil_image(init)[None]
key = jax.random.PRNGKey(args.seed)
def run(key, n):
tqdm.write('Sampling...')
key, subkey = jax.random.split(key)
noise = jax.random.normal(subkey, [n, *model.shape])
key, subkey = jax.random.split(key)
sample_step = partial(sampling.jit_sample_step, extra_args={})
steps = utils.get_ddpm_schedule(jnp.linspace(1, 0, args.steps + 1)[:-1])
if args.init:
steps = steps[steps < args.starting_timestep]
alpha, sigma = utils.t_to_alpha_sigma(steps[0])
noise = init * alpha + noise * sigma
return sampling.sample_loop(model, params, subkey, noise, steps, args.eta, sample_step)
def run_all(key, n, batch_size):
for i in trange(0, n, batch_size):
key, subkey = jax.random.split(key)
cur_batch_size = min(n - i, batch_size)
outs = run(key, cur_batch_size)
for j, out in enumerate(outs):
utils.to_pil_image(out).save(f'out_{i + j:05}.png')
try:
run_all(key, args.n, args.batch_size)
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()
gitextract_k6dpc7k6/ ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── clip_sample.py ├── diffusion/ │ ├── __init__.py │ ├── models/ │ │ ├── __init__.py │ │ ├── danbooru_128.py │ │ ├── imagenet_128.py │ │ ├── models.py │ │ ├── wikiart_128.py │ │ └── wikiart_256.py │ ├── sampling.py │ └── utils.py ├── interpolate.py ├── make_grid.py ├── requirements.txt └── sample.py
SYMBOL INDEX (67 symbols across 11 files)
FILE: clip_sample.py
function make_normalize (line 24) | def make_normalize(mean, std):
function norm2 (line 33) | def norm2(x):
function spherical_dist_loss (line 38) | def spherical_dist_loss(x, y):
function main (line 43) | def main():
FILE: diffusion/models/danbooru_128.py
class FourierFeatures (line 8) | class FourierFeatures(hk.Module):
method __init__ (line 9) | def __init__(self, output_size, std=1., name=None):
method __call__ (line 15) | def __call__(self, x):
class Dropout2d (line 22) | class Dropout2d(hk.Module):
method __init__ (line 23) | def __init__(self, rate=0.5, name=None):
method __call__ (line 27) | def __call__(self, x, enabled):
class SelfAttention2d (line 34) | class SelfAttention2d(hk.Module):
method __init__ (line 35) | def __init__(self, n_head=1, dropout_rate=0.1, name=None):
method __call__ (line 40) | def __call__(self, x, dropout_enabled):
function res_conv_block (line 55) | def res_conv_block(c_mid, c_out, dropout_last=True):
function diffusion_model (line 70) | def diffusion_model(x, t, extra_args):
class Danbooru128Model (line 131) | class Danbooru128Model:
FILE: diffusion/models/imagenet_128.py
class FourierFeatures (line 8) | class FourierFeatures(hk.Module):
method __init__ (line 9) | def __init__(self, output_size, std=1., name=None):
method __call__ (line 15) | def __call__(self, x):
class Dropout2d (line 22) | class Dropout2d(hk.Module):
method __init__ (line 23) | def __init__(self, rate=0.5, name=None):
method __call__ (line 27) | def __call__(self, x, enabled):
class SelfAttention2d (line 34) | class SelfAttention2d(hk.Module):
method __init__ (line 35) | def __init__(self, c_in, n_head=1, dropout_rate=0.1, name=None):
method __call__ (line 42) | def __call__(self, x, dropout_enabled):
function res_conv_block (line 56) | def res_conv_block(c_mid, c_out, dropout_last=True):
function diffusion_model (line 72) | def diffusion_model(x, t, extra_args):
class ImageNet128Model (line 169) | class ImageNet128Model:
FILE: diffusion/models/models.py
function get_model (line 17) | def get_model(model):
function get_models (line 21) | def get_models():
function load_params (line 25) | def load_params(checkpoint):
FILE: diffusion/models/wikiart_128.py
class FourierFeatures (line 8) | class FourierFeatures(hk.Module):
method __init__ (line 9) | def __init__(self, output_size, std=1., name=None):
method __call__ (line 15) | def __call__(self, x):
class Dropout2d (line 22) | class Dropout2d(hk.Module):
method __init__ (line 23) | def __init__(self, rate=0.5, name=None):
method __call__ (line 27) | def __call__(self, x, enabled):
function res_conv_block (line 34) | def res_conv_block(c_mid, c_out, dropout_last=True):
function diffusion_model (line 50) | def diffusion_model(x, t, extra_args):
class WikiArt128Model (line 123) | class WikiArt128Model:
FILE: diffusion/models/wikiart_256.py
class FourierFeatures (line 8) | class FourierFeatures(hk.Module):
method __init__ (line 9) | def __init__(self, output_size, std=1., name=None):
method __call__ (line 15) | def __call__(self, x):
class Dropout2d (line 22) | class Dropout2d(hk.Module):
method __init__ (line 23) | def __init__(self, rate=0.5, name=None):
method __call__ (line 27) | def __call__(self, x, enabled):
class SelfAttention2d (line 34) | class SelfAttention2d(hk.Module):
method __init__ (line 35) | def __init__(self, n_head=1, dropout_rate=0.1, name=None):
method __call__ (line 40) | def __call__(self, x, dropout_enabled):
function res_conv_block (line 55) | def res_conv_block(c_mid, c_out, dropout_last=True):
function diffusion_model (line 71) | def diffusion_model(x, t, extra_args):
class WikiArt256Model (line 179) | class WikiArt256Model:
FILE: diffusion/sampling.py
function sample_step (line 9) | def sample_step(model, params, key, x, t, t_next, eta, extra_args):
function cond_sample_step (line 28) | def cond_sample_step(model, params, key, x, t, t_next, eta, extra_args, ...
function sample_loop (line 49) | def sample_loop(model, params, key, x, steps, eta, sample_step):
function reverse_sample_step (line 59) | def reverse_sample_step(model, params, key, x, t, t_next, extra_args):
function reverse_sample_loop (line 73) | def reverse_sample_loop(model, params, key, x, steps, sample_step):
FILE: diffusion/utils.py
function from_pil_image (line 7) | def from_pil_image(x):
function to_pil_image (line 15) | def to_pil_image(x):
function log_snr_to_alpha_sigma (line 28) | def log_snr_to_alpha_sigma(log_snr):
function alpha_sigma_to_log_snr (line 34) | def alpha_sigma_to_log_snr(alpha, sigma):
function t_to_alpha_sigma (line 40) | def t_to_alpha_sigma(t):
function alpha_sigma_to_t (line 46) | def alpha_sigma_to_t(alpha, sigma):
function get_ddpm_schedule (line 52) | def get_ddpm_schedule(ddpm_t):
FILE: interpolate.py
function main (line 19) | def main():
FILE: make_grid.py
function main (line 12) | def main():
FILE: sample.py
function main (line 19) | def main():
Condensed preview — 18 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (55K chars).
[
{
"path": ".gitignore",
"chars": 65,
"preview": "venv*\n__pycache__\n.ipynb_checkpoints\n*.pkl\nout*\n*.egg-info\n*.ini\n"
},
{
"path": ".gitmodules",
"chars": 86,
"preview": "[submodule \"CLIP_JAX\"]\n\tpath = CLIP_JAX\n\turl = https://github.com/kingoflolz/CLIP_JAX\n"
},
{
"path": "LICENSE",
"chars": 1085,
"preview": "Copyright (c) 2021 Katherine Crowson and John David Pressman\n\nPermission is hereby granted, free of charge, to any perso"
},
{
"path": "README.md",
"chars": 4023,
"preview": "# v-diffusion-jax\n\nv objective diffusion inference code for JAX, by Katherine Crowson ([@RiversHaveWings](https://twitte"
},
{
"path": "clip_sample.py",
"chars": 5726,
"preview": "#!/usr/bin/env python3\n\n\"\"\"CLIP guided sampling from a diffusion model.\"\"\"\n\nimport argparse\nfrom functools import partia"
},
{
"path": "diffusion/__init__.py",
"chars": 85,
"preview": "from . import sampling, utils\nfrom .models import get_model, get_models, load_params\n"
},
{
"path": "diffusion/models/__init__.py",
"chars": 55,
"preview": "from .models import get_model, get_models, load_params\n"
},
{
"path": "diffusion/models/danbooru_128.py",
"chars": 5893,
"preview": "import haiku as hk\nimport jax\nimport jax.numpy as jnp\n\nfrom .. import utils\n\n\nclass FourierFeatures(hk.Module):\n def "
},
{
"path": "diffusion/models/imagenet_128.py",
"chars": 8140,
"preview": "import haiku as hk\nimport jax\nimport jax.numpy as jnp\n\nfrom .. import utils\n\n\nclass FourierFeatures(hk.Module):\n def "
},
{
"path": "diffusion/models/models.py",
"chars": 571,
"preview": "import pickle\n\nimport jax\nimport jax.numpy as jnp\n\nfrom . import danbooru_128, imagenet_128, wikiart_128, wikiart_256\n\n\n"
},
{
"path": "diffusion/models/wikiart_128.py",
"chars": 5620,
"preview": "import haiku as hk\nimport jax\nimport jax.numpy as jnp\n\nfrom .. import utils\n\n\nclass FourierFeatures(hk.Module):\n def "
},
{
"path": "diffusion/models/wikiart_256.py",
"chars": 8815,
"preview": "import haiku as hk\nimport jax\nimport jax.numpy as jnp\n\nfrom .. import utils\n\n\nclass FourierFeatures(hk.Module):\n def "
},
{
"path": "diffusion/sampling.py",
"chars": 2895,
"preview": "from einops import repeat\nimport jax\nimport jax.numpy as jnp\nfrom tqdm import trange\n\nfrom . import utils\n\n\ndef sample_s"
},
{
"path": "diffusion/utils.py",
"chars": 1654,
"preview": "import jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom PIL import Image\n\n\ndef from_pil_image(x):\n \"\"\"Converts fro"
},
{
"path": "interpolate.py",
"chars": 3750,
"preview": "#!/usr/bin/env python3\n\n\"\"\"Interpolation in a diffusion model's latent space.\"\"\"\n\nimport argparse\nfrom functools import "
},
{
"path": "make_grid.py",
"chars": 1366,
"preview": "#!/usr/bin/env python3\n\n\"\"\"Assembles images into a grid.\"\"\"\n\nimport argparse\nimport math\nimport sys\n\nfrom PIL import Ima"
},
{
"path": "requirements.txt",
"chars": 80,
"preview": "dm-haiku\neinops\nftfy\njax\njaxlib\nnumpy\noptax\nregex\nPillow\ntorch\ntorchvision\ntqdm\n"
},
{
"path": "sample.py",
"chars": 3067,
"preview": "#!/usr/bin/env python3\n\n\"\"\"Unconditional sampling from a diffusion model.\"\"\"\n\nimport argparse\nfrom functools import part"
}
]
About this extraction
This page contains the full source code of the crowsonkb/v-diffusion-jax GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 18 files (51.7 KB), approximately 18.0k tokens, and a symbol index with 67 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.