Showing preview only (213K chars total). Download the full file or copy to clipboard to get everything.
Repository: tomeramit/SegDiff
Branch: main
Commit: e6592c983b5b
Files: 31
Total size: 202.6 KB
Directory structure:
gitextract_40nqd4bf/
├── .gitignore
├── README.md
├── datasets/
│ ├── city.py
│ ├── monu.py
│ ├── preprocess_vaihingen.py
│ ├── transforms.py
│ └── vaih.py
├── environment.yml
├── image_sample_diff_city.py
├── image_sample_diff_medical.py
├── image_sample_diff_vaih.py
├── image_train_diff_city.py
├── image_train_diff_medical.py
├── image_train_diff_vaih.py
└── improved_diffusion/
├── RRDB.py
├── __init__.py
├── dist_util.py
├── fp16_util.py
├── gaussian_diffusion.py
├── image_datasets.py
├── logger.py
├── losses.py
├── metrics.py
├── nn.py
├── resample.py
├── respace.py
├── sampling_util.py
├── script_util.py
├── train_util.py
├── unet.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
data
.vscode
.idea
# custom
*.pkl
*.pkl.json
*.log.json
work_dirs/
work_dirs
pretrained
pretrained/
# Pytorch
*.pth
trash/
trash
================================================
FILE: README.md
================================================
This is the official repository of the paper [SegDiff: Image Segmentation with Diffusion Probabilistic Models](https://arxiv.org/abs/2112.00390)
The code is based on [Improved Denoising Diffusion Probabilistic Models.](https://github.com/openai/improved-diffusion)
## Installation
### Conda environment
To create the environment use the conda environment command
```
conda env create -f environment.yml
```
## Project structure and data preparations
our project need to be arranged in the following format
```
segdiff/ # git clone the source code here
data/ # the root of the data folders
Vaihingen/
Medical/MoNuSeg/
cityscapes_instances/
```
### Vaihingen
download the dataset from [link](https://drive.google.com/file/d/1nenpWH4BdplSiHdfXs0oYfiA5qL42plB/view)
and unzip it's content (folder named buildings), execute the preprocess
```
datasets/preprocess_vaihingen.py --path building-folder-path
```
Vaihingen dataset should have the following format
```
Vaihingen/
full_test_vaih.hdf5
full_training_vaih.hdf5
```
### MonuSeg
general [website](https://monuseg.grand-challenge.org/) of the challenge,
download the dataset
[train](https://drive.google.com/file/d/1ZgqFJomqQGNnsx7w7QBzQQMVA16lbVCA/view?usp=sharing)
and [test](https://drive.google.com/file/d/1NKkSQ5T0ZNQ8aUhh0a8Dt2YKYCQXIViw/view?usp=sharing) sets.
launch the matlab [code](https://drive.google.com/file/d/1YDtIiLZX0lQzZp_JbqneHXHvRo45ZWGX/view)
for preprocess
MonuSeg dataset should have the following format
```
MonuSeg/
Test/
img/
XX.tif
mask/
XX.png
Training/
img/
XX.tif
mask/
XX.png
```
### Cityscapes
download [cityscapes](https://www.cityscapes-dataset.com) dataset with the splits from
[PolyRNN++](https://github.com/fidler-lab/polyrnn-pp), follow the instructions [here](https://github.com/shirgur/ACDRNet) for preparations
To get cityscapes_final_v5 annotations you can sign up to get PolygonRNN++ code here http://www.cs.toronto.edu/polyrnn/code_signup/ the cityscapes_final_v5 folder is inside the data folder
Cityscapes dataset should have the following format
```
cityscapes_instances/
full/
all_classes_instances.json
train/
all_classes_instances.json
train_val/
all_classes_instances.json
val/
all_classes_instances.json
all_images.hdf5
```
## Train and Evaluate
Execute the following commands (multi gpu is supported for training, set the gpus with CUDA_VISIBLE_DEVICES and -n for the actual number)
Training options:
```
# Training
--batch-size Batch size
--lr Learning rate
# Architecture
--rrdb_blocks Number of rrdb blocks
--dropout Dropout
--diffusion_steps number of steps for the diffusion model
# Cityscapes
--class_name name of class of cityscapes, options are ["bike", "bus", "person", "train", "motorcycle", "car", "rider"]
--expansion boolean flag, for expansion setting or not
# Misc
--save_interval interval for saving model weights
```
### MonuSeg
Training script example:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 mpiexec -n 4 image_train_diff_medical.py --rrdb_blocks 12 --batch_size 2 --lr 0.0001 --diffusion_steps 100
```
Evaluation script example:
```
CUDA_VISIBLE_DEVICES=0 mpiexec -n 1 python image_sample_diff_medical.py --model_path path-for-model-weights
```
### Cityscapes
Training script example:
```
CUDA_VISIBLE_DEVICES=0,1 mpiexec -n 2 python image_train_diff_city.py --class_name "train" --expansion True --rrdb_blocks 15 --lr 0.0001 --batch_size 15 --diffusion_steps 100
```
Evaluation script example:
```
CUDA_VISIBLE_DEVICES=0 mpiexec -n 1 python image_sample_diff_city.py --model_path path-for-model-weights
```
### Vaihingen
Training script example:
```
CUDA_VISIBLE_DEVICES=0,1 mpiexec -n 2 python image_train_diff_vaih.py --lr 0.0001 --batch_size 4 --dropout 0.1 --rrdb_blocks 6 --diffusion_steps 100
```
Evaluation script example:
```
CUDA_VISIBLE_DEVICES=0 mpiexec -n 1 python image_sample_diff_vaih.py --model_path path-for-model-weights
```
## Citation
```
@article{amit2021segdiff,
title={Segdiff: Image segmentation with diffusion probabilistic models},
author={Amit, Tomer and Nachmani, Eliya and Shaharbany, Tal and Wolf, Lior},
journal={arXiv preprint arXiv:2112.00390},
year={2021}
}
```
================================================
FILE: datasets/city.py
================================================
import json
import os
import random
from pathlib import Path
import h5py
import numpy as np
import pycocotools.mask as maskUtils
import torch
from PIL import Image
from matplotlib import pyplot as plt
from mpi4py import MPI
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import resize
from tqdm import tqdm
from datasets.transforms import \
Compose, ToPILImage, RandomHorizontalFlip, ToTensor, Normalize, RandomAffine
def create_dataset(mode="train", class_name="train", expansion=False):
shard=MPI.COMM_WORLD.Get_rank()
num_shards = MPI.COMM_WORLD.Get_size()
data_inst_path = str(Path(__file__).absolute().parent.parent.parent / "data/cityscapes_instances/")
print('loading \"{}\" annotations into memory...'.format(mode))
data = json.load(open(os.path.join(data_inst_path, mode, 'all_classes_instances.json'), 'r'))
annotations = data['data'][class_name][shard::num_shards]
hdf5_obj = h5py.File(os.path.join(data_inst_path, 'all_images.hdf5'), 'r')
images = [hdf5_obj[ann['img']['file_name']] for ann in annotations]
return CityscapesInstances(
images,
annotations,
mode=mode,
expansion=expansion
)
def load_data(
*, data_dir, batch_size, image_size, class_name, class_cond=False, expansion, deterministic=False
):
"""
For a dataset, create a generator over (images, kwargs) pairs.
Each images is an NCHW float tensor, and the kwargs dict contains zero or
more keys, each of which map to a batched Tensor of their own.
The kwargs dict can be used for class labels, in which case the key is "y"
and the values are integer tensors of class labels.
:param data_dir: a dataset directory.
:param batch_size: the batch size of each returned pair.
:param image_size: the size to which images are resized.
:param class_cond: if True, include a "y" key in returned dicts for class
label. If classes are not available and this is true, an
exception will be raised.
:param deterministic: if True, yield results in a deterministic order.
"""
dataset = create_dataset(mode="train", class_name=class_name, expansion=expansion)
if deterministic:
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True
)
else:
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True
)
while True:
yield from loader
class CityscapesInstances(Dataset):
CLASSES = ('person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle')
def __init__(self,
images,
annotations,
no_aug=False,
mode='train',
loops=100,
expansion=False,
std=np.array([58.395, 57.12, 57.375]),
mean=np.array([123.675, 116.28, 103.53]),
):
super(CityscapesInstances, self).__init__()
self.loops = loops
self.mode = mode
self.mean = torch.from_numpy(mean)
self.std = torch.from_numpy(std)
self.expansion = expansion
image_size = 128
if mode == 'train' and not no_aug:
self.transformations = Compose([
ToPILImage(),
# Resize((image_size, image_size)),
RandomHorizontalFlip(),
RandomAffine(22, scale=(0.75, 1.25)),
ToTensor(),
Normalize(self.mean, self.std)
# transforms.NormalizeInstance()
])
else:
self.transformations = Compose([
ToPILImage(),
# Resize((image_size, image_size), do_mask=False),
ToTensor(),
Normalize(self.mean, self.std),
# transforms.NormalizeInstance()
])
self.instance_images = []
self.instance_masks = []
self.annotations = annotations
for item in tqdm(range(len(images))):
ann = self.annotations[item]
mask = self._poly2mask(ann['segmentation'], ann['img']['height'], ann['img']['width'])
bbox = np.maximum(0, np.array(ann['bbox']).astype(np.int32))
if self.expansion:
if self.mode == 'train':
bounding_box_expansion = random.randint(10, 20)
else:
bounding_box_expansion = 15
increase_axis_by = bbox[3] * (bounding_box_expansion / 100)
increase_each_coordinate = increase_axis_by / 2
x_1 = bbox[1] - increase_each_coordinate
x_2 = bbox[1] + bbox[3] + increase_each_coordinate
increase_axis_by = bbox[2] * (bounding_box_expansion / 100)
increase_each_coordinate = increase_axis_by / 2
y_1 = bbox[0] - increase_each_coordinate
y_2 = bbox[0] + bbox[2] + increase_each_coordinate
# check the axis order
x_2 = round(min(x_2, images[item].shape[0]))
y_2 = round(min(y_2, images[item].shape[1]))
x_1 = round(max(x_1, 0))
y_1 = round(max(y_1, 0))
instance_image = images[item][x_1:x_2, y_1:y_2]
instance_mask = mask[x_1:x_2, y_1:y_2]
else:
instance_image = images[item][bbox[1]:bbox[1] + bbox[3], bbox[0]:bbox[0] + bbox[2]]
instance_mask = mask[bbox[1]:bbox[1] + bbox[3], bbox[0]:bbox[0] + bbox[2]]
size = [image_size, image_size]
self.instance_images.append(resize(torch.from_numpy(instance_image).permute(2, 0, 1), size, Image.BILINEAR).permute(1, 2, 0).numpy())
if mode == 'train' and not no_aug:
self.instance_masks.append(resize(torch.from_numpy(instance_mask).unsqueeze(0), size, Image.NEAREST).squeeze(0).numpy())
else:
self.instance_masks.append(instance_mask)
@staticmethod
def _poly2mask(mask_ann, img_h, img_w):
if isinstance(mask_ann, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
rle = maskUtils.merge(rles)
elif isinstance(mask_ann['counts'], list):
# uncompressed RLE
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
else:
# rle
rle = mask_ann
mask = maskUtils.decode(rle)
return mask
def __len__(self):
return len(self.annotations)
def __getitem__(self, item):
ann = self.annotations[item]
instance_image, instance_mask = self.transformations(self.instance_images[item], self.instance_masks[item])
out_dict = {"conditioned_image": instance_image}
instance_mask = 2 * instance_mask - 1.0
return instance_mask.unsqueeze(0), out_dict, Path(ann["img"]['file_name']).stem
def main():
mean = np.array([0, 0, 0])
std = np.array([1, 1, 1])
dataset = create_dataset(class_name="train", mode='train')
for i in range(10):
# mask, out_dict, _ = dataset[i]
# img = out_dict["conditioned_image"]
# plt.imshow(img.permute(1, 2, 0).numpy().astype(np.uint8))
# plt.show()
#
# plt.imshow(mask.permute(1, 2, 0).numpy(), cmap='gray')
# plt.show()
masks, out_dict, _ = dataset[i]
imgs = out_dict["conditioned_image"]
for index in range(10):
plt.imshow(imgs[index * 10].permute(1, 2, 0).numpy().astype(np.uint8))
plt.show()
for index in range(10):
plt.imshow(masks[index * 10].permute(1, 2, 0).numpy(), cmap='gray')
plt.show()
pass
if __name__ == '__main__':
main()
================================================
FILE: datasets/monu.py
================================================
import os
from pathlib import Path
import imageio
import matplotlib.pyplot as plt
import numpy as np
import tifffile
import torch
from mpi4py import MPI
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets.transforms import \
Compose, ToPILImage, ColorJitter, RandomHorizontalFlip, ToTensor, Normalize, RandomVerticalFlip, RandomAffine, \
Resize, RandomCrop
def cv2_loader(path, is_mask):
if is_mask:
# img = cv2.imread(path, 0)
img = imageio.imread(path)
img[img > 0] = 1
else:
# img = cv2.cvtColor(cv2.imread(path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
# img = imageio.imread(path)
img = tifffile.imread(path)
return img
def get_monu_transform(image_size):
transform_train = Compose([
ToPILImage(),
Resize((512, 512)),
RandomCrop((image_size, image_size)),
RandomHorizontalFlip(),
RandomVerticalFlip(),
RandomAffine(int(22), scale=(float(0.75), float(1.25))),
ColorJitter(brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1),
ToTensor(),
Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78])
])
transform_test = Compose([
ToPILImage(),
Resize((512, 512)),
ToTensor(),
Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78])
])
return transform_train, transform_test
def create_dataset(mode="train", image_size=256):
datadir = str(Path(__file__).absolute().parent.parent.parent / "data/Medical/MoNuSeg")
transform_train, transform_test = get_monu_transform(image_size)
if mode == "train":
return MonuDataset(datadir, train=True, transform=transform_train, image_size=image_size)
else:
return MonuDataset(datadir, train=False, transform=transform_test)
def load_data(
*, data_dir, batch_size, image_size, class_name, class_cond=False, expansion, deterministic=False
):
"""
For a dataset, create a generator over (images, kwargs) pairs.
Each images is an NCHW float tensor, and the kwargs dict contains zero or
more keys, each of which map to a batched Tensor of their own.
The kwargs dict can be used for class labels, in which case the key is "y"
and the values are integer tensors of class labels.
:param data_dir: a dataset directory.
:param batch_size: the batch size of each returned pair.
:param image_size: the size to which images are resized.
:param class_cond: if True, include a "y" key in returned dicts for class
label. If classes are not available and this is true, an
exception will be raised.
:param deterministic: if True, yield results in a deterministic order.
"""
dataset = create_dataset(mode="train")
if deterministic:
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True
)
else:
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True
)
while True:
yield from loader
class MonuDataset(torch.utils.data.Dataset):
def __init__(self, root, transform=None, target_transform=None, train=False, loader=cv2_loader, pSize=8, image_size=256):
self.root = root
if train:
self.imgs_root = os.path.join(self.root, 'Training', 'img')
self.masks_root = os.path.join(self.root, 'Training', 'mask')
else:
self.imgs_root = os.path.join(self.root, 'Test', 'img')
self.masks_root = os.path.join(self.root, 'Test', 'mask')
self.image_size = image_size
self.paths = sorted(os.listdir(self.imgs_root))
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.train = train
self.pSize = pSize
self.masks = []
self.imgs = []
self.mean = torch.from_numpy(np.array([142.07, 98.48, 132.96]))
self.std = torch.from_numpy(np.array([65.78, 57.05, 57.78]))
shard = MPI.COMM_WORLD.Get_rank()
num_shards = MPI.COMM_WORLD.Get_size()
for file_path in tqdm(self.paths):
mask_path = file_path.split('.')[0] + '.png'
self.imgs.append(self.loader(os.path.join(self.imgs_root, file_path), is_mask=False))
self.masks.append(self.loader(os.path.join(self.masks_root, mask_path), is_mask=True))
self.imgs = self.imgs[shard::num_shards]
self.masks = self.masks[shard::num_shards]
self.paths = self.paths[shard::num_shards]
print('num of data:{}'.format(len(self.paths)))
def __getitem__(self, index):
img = self.imgs[index]
mask = self.masks[index]
img, mask = self.transform(img, mask)
out_dict = {"conditioned_image": img}
mask = 2 * mask - 1.0
return mask.unsqueeze(0), out_dict, f"{Path(self.paths[index]).stem}_{index}"
def __len__(self):
return len(self.paths)
if __name__ == "__main__":
val_dataset = create_dataset(
mode='val',
image_size=256,
)
ds = torch.utils.data.DataLoader(val_dataset,
batch_size=1,
num_workers=0,
shuffle=False,
drop_last=True)
pbar = tqdm(ds)
mean0_list = []
mean1_list = []
mean2_list = []
std0_list = []
std1_list = []
std2_list = []
for i, (mask, out_dict, _) in enumerate(pbar):
img = out_dict["conditioned_image"]
plt.imshow(img.squeeze().permute(1,2,0).numpy().astype(np.uint8))
plt.show()
plt.imshow(mask.squeeze().numpy(), cmap='gray')
plt.show()
a = img.mean(dim=(0, 2, 3))
b = img.std(dim=(0, 2, 3))
mean0_list.append(a[0].item())
mean1_list.append(a[1].item())
mean2_list.append(a[2].item())
std0_list.append(b[0].item())
std1_list.append(b[1].item())
std2_list.append(b[2].item())
print(np.mean(mean0_list))
print(np.mean(mean1_list))
print(np.mean(mean2_list))
print(np.mean(std0_list))
print(np.mean(std1_list))
print(np.mean(std2_list))
# a = img.squeeze().permute(1, 2, 0).cpu().numpy()
# b = mask.squeeze().cpu().numpy()
# a = (a - a.min()) / (a.max() - a.min())
# cv2.imwrite('kaki.jpg', 255*a)
# cv2.imwrite('kaki_mask.jpg', 255*b)
================================================
FILE: datasets/preprocess_vaihingen.py
================================================
from pathlib import Path
import h5py
import os
import cv2
import numpy as np
from cv2 import resize
def get_img(cfile):
img = cv2.cvtColor(cv2.imread(cfile, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
img = resize(img, (256,256), interpolation=cv2.INTER_NEAREST)
return img
def get_mask(cfile):
GT = cv2.imread(cfile, 0)
GT = resize(GT, (256, 256), interpolation=cv2.INTER_LINEAR)
GT[GT >= 0.5] = 1
GT[GT < 0.5] = 0
return GT
def main(args, out_path):
data_folder_path = Path(args['path'])
imgs_list = sorted(list(data_folder_path.glob("building_[0-9]*.tif")))
masks_list = sorted(list(data_folder_path.glob("building_mask_[0-9]*.tif")))
hf_tri = h5py.File(str(out_path / "full_training_vaih.hdf5"), 'w')
hf_test = h5py.File(str(out_path / "full_test_vaih.hdf5"), 'w')
imgs_tri = hf_tri.create_group('imgs')
mask_single_tri = hf_tri.create_group('mask_single')
imgs_test = hf_test.create_group('imgs')
mask_single_test = hf_test.create_group('mask_single')
for image_path in imgs_list[:100]:
print('training: ' + str(image_path))
img = get_img(str(image_path))
imgs_tri.create_dataset(image_path.stem, data=img, dtype=np.uint8)
for image_path in imgs_list[100:]:
print('validation: ' + str(image_path))
img = get_img(str(image_path))
imgs_test.create_dataset(image_path.stem, data=img, dtype=np.uint8)
for mask_path in masks_list[:100]:
print('training: ' + str(mask_path))
mask = get_mask(str(mask_path))
mask_single_tri.create_dataset(mask_path.stem, data=mask, dtype=np.uint8)
for mask_path in masks_list[100:]:
print('validation: ' + str(mask_path))
mask = get_mask(str(mask_path))
mask_single_test.create_dataset(mask_path.stem, data=mask, dtype=np.uint8)
hf_tri.close()
hf_test.close()
if __name__ == '__main__':
import argparse
folder_path = Path(__file__).absolute().parent.parent.parent / "data" / "Vaihingen"
folder_path.mkdir(parents=True, exist_ok=True)
parser = argparse.ArgumentParser(description='Description of your program')
parser.add_argument('-path',
'--path',
default='',
help='Data path, should point on "building"',
required=True)
args = vars(parser.parse_args())
main(args, out_path=folder_path)
================================================
FILE: datasets/transforms.py
================================================
from __future__ import division
import torch
import math
import sys
import random
from PIL import Image
try:
import accimage
except ImportError:
accimage = None
import numpy as np
import numbers
import types
import collections
import warnings
from torchvision.transforms import functional as F
if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "FiveCrop", "TenCrop",
"ColorJitter", "RandomRotation", "RandomAffine",
"RandomPerspective"]
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
Image.HAMMING: 'PIL.Image.HAMMING',
Image.BOX: 'PIL.Image.BOX',
}
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, mask):
for t in self.transforms:
img, mask = t(img, mask)
return img, mask
class ToTensor(object):
def __call__(self, img, mask):
# return F.to_tensor(img), F.to_tensor(mask)
img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float()
mask = torch.from_numpy(np.array(mask)).float()
return img, mask
class ToPILImage(object):
def __init__(self, mode=None):
self.mode = mode
def __call__(self, img, mask):
return F.to_pil_image(img, self.mode), F.to_pil_image(mask, self.mode)
class Normalize(object):
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, img, mask):
return F.normalize(img, self.mean, self.std, self.inplace), mask
class Resize(object):
def __init__(self, size, interpolation=Image.BILINEAR, do_mask=True):
assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
self.do_mask = do_mask
def __call__(self, img, mask):
if self.do_mask:
return F.resize(img, self.size, self.interpolation), F.resize(mask, self.size, Image.NEAREST)
else:
return F.resize(img, self.size, self.interpolation), mask
class CenterCrop(object):
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img, mask):
return F.center_crop(img, self.size), F.center_crop(mask, self.size)
class Pad(object):
def __init__(self, padding, fill=0, padding_mode='constant'):
assert isinstance(padding, (numbers.Number, tuple))
assert isinstance(fill, (numbers.Number, str, tuple))
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
self.padding = padding
self.fill = fill
self.padding_mode = padding_mode
def __call__(self, img, mask):
return F.pad(img, self.padding, self.fill, self.padding_mode), \
F.pad(mask, self.padding, self.fill, self.padding_mode)
class Lambda(object):
def __init__(self, lambd):
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
self.lambd = lambd
def __call__(self, img, mask):
return self.lambd(img), self.lambd(mask)
class Lambda_image(object):
def __init__(self, lambd):
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
self.lambd = lambd
def __call__(self, img, mask):
return self.lambd(img), mask
class RandomTransforms(object):
def __init__(self, transforms):
assert isinstance(transforms, (list, tuple))
self.transforms = transforms
def __call__(self, *args, **kwargs):
raise NotImplementedError()
class RandomApply(RandomTransforms):
def __init__(self, transforms, p=0.5):
super(RandomApply, self).__init__(transforms)
self.p = p
def __call__(self, img, mask):
if self.p < random.random():
return img, mask
for t in self.transforms:
img, mask = t(img, mask)
return img, mask
class RandomOrder(RandomTransforms):
def __call__(self, img, mask):
order = list(range(len(self.transforms)))
random.shuffle(order)
for i in order:
img, mask = self.transforms[i](img, mask)
return img, mask
class RandomChoice(RandomTransforms):
def __call__(self, img, mask):
t = random.choice(self.transforms)
return t(img, mask)
class RandomCrop(object):
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.padding_mode = padding_mode
@staticmethod
def get_params(img, output_size):
w, h = img.size
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, img, mask):
if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode)
# pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]:
img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
# pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]:
img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
i, j, h, w = self.get_params(img, self.size)
return F.crop(img, i, j, h, w), F.crop(mask, i, j, h, w)
class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, img, mask):
if random.random() < self.p:
return F.hflip(img), F.hflip(mask)
return img, mask
class RandomVerticalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, img, mask):
if random.random() < self.p:
return F.vflip(img), F.vflip(mask)
return img, mask
class RandomPerspective(object):
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC):
self.p = p
self.interpolation = interpolation
self.distortion_scale = distortion_scale
def __call__(self, img, mask):
if not F._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if random.random() < self.p:
width, height = img.size
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation), \
F.perspective(mask, startpoints, endpoints, Image.NEAREST)
return img, mask
@staticmethod
def get_params(width, height, distortion_scale):
half_height = int(height / 2)
half_width = int(width / 2)
topleft = (random.randint(0, int(distortion_scale * half_width)),
random.randint(0, int(distortion_scale * half_height)))
topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
random.randint(0, int(distortion_scale * half_height)))
botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
botleft = (random.randint(0, int(distortion_scale * half_width)),
random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
endpoints = [topleft, topright, botright, botleft]
return startpoints, endpoints
class RandomResizedCrop(object):
def __init__(self, size, mask_size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
if isinstance(size, tuple):
self.size = size
self.mask_size = mask_size
else:
self.size = (size, size)
self.mask_size = (mask_size, mask_size)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
self.interpolation = interpolation
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if (in_ratio < min(ratio)):
w = img.size[0]
h = w / min(ratio)
elif (in_ratio > max(ratio)):
h = img.size[1]
w = h * max(ratio)
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img, mask):
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \
F.resized_crop(mask, i, j, h, w, self.mask_size, Image.NEAREST)
class FiveCrop(object):
def __init__(self, size):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size
def __call__(self, img, mask):
return F.five_crop(img, self.size), F.five_crop(mask, self.size)
class TenCrop(object):
def __init__(self, size, vertical_flip=False):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size
self.vertical_flip = vertical_flip
def __call__(self, img, mask):
return F.ten_crop(img, self.size, self.vertical_flip), F.ten_crop(mask, self.size, self.vertical_flip)
class ColorJitter(object):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
clip_first_on_zero=False)
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
if isinstance(value, numbers.Number):
if value < 0:
raise ValueError("If {} is a single number, it must be non negative.".format(name))
value = [center - value, center + value]
if clip_first_on_zero:
value[0] = max(value[0], 0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError("{} values should be between {}".format(name, bound))
else:
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
return value
@staticmethod
def get_params(brightness, contrast, saturation, hue):
transforms = []
if brightness is not None:
brightness_factor = random.uniform(brightness[0], brightness[1])
transforms.append(Lambda_image(lambda img: F.adjust_brightness(img, brightness_factor)))
if contrast is not None:
contrast_factor = random.uniform(contrast[0], contrast[1])
transforms.append(Lambda_image(lambda img: F.adjust_contrast(img, contrast_factor)))
if saturation is not None:
saturation_factor = random.uniform(saturation[0], saturation[1])
transforms.append(Lambda_image(lambda img: F.adjust_saturation(img, saturation_factor)))
if hue is not None:
hue_factor = random.uniform(hue[0], hue[1])
transforms.append(Lambda_image(lambda img: F.adjust_hue(img, hue_factor)))
random.shuffle(transforms)
transform = Compose(transforms)
return transform
def __call__(self, img, mask):
transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)
return transform(img, mask)
class RandomRotation(object):
def __init__(self, degrees, resample=False, expand=False, center=None):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError("If degrees is a sequence, it must be of len 2.")
self.degrees = degrees
self.resample = resample
self.expand = expand
self.center = center
@staticmethod
def get_params(degrees):
angle = random.uniform(degrees[0], degrees[1])
return angle
def __call__(self, img, mask):
angle = self.get_params(self.degrees)
return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), \
F.rotate(mask, angle, Image.NEAREST, self.expand, self.center)
class RandomAffine(object):
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
"degrees should be a list or tuple and it must be of length 2."
self.degrees = degrees
if translate is not None:
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"translate should be a list or tuple and it must be of length 2."
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
"scale should be a list or tuple and it must be of length 2."
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError("If shear is a single number, it must be positive.")
self.shear = (-shear, shear)
else:
assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
"shear should be a list or tuple and it must be of length 2."
self.shear = shear
else:
self.shear = shear
self.resample = resample
self.fillcolor = fillcolor
@staticmethod
def get_params(degrees, translate, scale_ranges, shears, img_size):
angle = random.uniform(degrees[0], degrees[1])
if translate is not None:
max_dx = translate[0] * img_size[0]
max_dy = translate[1] * img_size[1]
translations = (np.round(random.uniform(-max_dx, max_dx)),
np.round(random.uniform(-max_dy, max_dy)))
else:
translations = (0, 0)
if scale_ranges is not None:
scale = random.uniform(scale_ranges[0], scale_ranges[1])
else:
scale = 1.0
if shears is not None:
shear = random.uniform(shears[0], shears[1])
else:
shear = 0.0
return angle, translations, scale, shear
def __call__(self, img, mask):
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
return F.affine(img, *ret, resample=Image.BILINEAR, fillcolor=self.fillcolor), \
F.affine(mask, *ret, resample=Image.NEAREST, fillcolor=self.fillcolor)
class RandomAffineFromSet(object):
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
assert isinstance(degrees, (tuple, list)), \
"degrees should be a list or tuple."
self.degrees = degrees
if translate is not None:
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"translate should be a list or tuple and it must be of length 2."
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
"scale should be a list or tuple and it must be of length 2."
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError("If shear is a single number, it must be positive.")
self.shear = (-shear, shear)
else:
assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
"shear should be a list or tuple and it must be of length 2."
self.shear = shear
else:
self.shear = shear
self.resample = resample
self.fillcolor = fillcolor
@staticmethod
def get_params(degrees, translate, scale_ranges, shears, img_size):
angle = random.choice(degrees)
if translate is not None:
max_dx = translate[0] * img_size[0]
max_dy = translate[1] * img_size[1]
translations = (np.round(random.uniform(-max_dx, max_dx)),
np.round(random.uniform(-max_dy, max_dy)))
else:
translations = (0, 0)
if scale_ranges is not None:
scale = random.uniform(scale_ranges[0], scale_ranges[1])
else:
scale = 1.0
if shears is not None:
shear = random.uniform(shears[0], shears[1])
else:
shear = 0.0
return angle, translations, scale, shear
def __call__(self, img, mask):
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
return F.affine(img, *ret, resample=Image.BILINEAR, fillcolor=self.fillcolor), \
F.affine(mask, *ret, resample=Image.NEAREST, fillcolor=self.fillcolor)
================================================
FILE: datasets/vaih.py
================================================
from pathlib import Path
import h5py
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from mpi4py import MPI
from torch.utils.data import Dataset, DataLoader
from datasets.transforms import \
Compose, ToPILImage, Resize, RandomHorizontalFlip, ToTensor, Normalize, \
RandomAffine, RandomVerticalFlip, ColorJitter
def load_data(
*, data_dir, batch_size, image_size, class_cond=False, deterministic=False
):
"""
For a dataset, create a generator over (images, kwargs) pairs.
Each images is an NCHW float tensor, and the kwargs dict contains zero or
more keys, each of which map to a batched Tensor of their own.
The kwargs dict can be used for class labels, in which case the key is "y"
and the values are integer tensors of class labels.
:param data_dir: a dataset directory.
:param batch_size: the batch size of each returned pair.
:param image_size: the size to which images are resized.
:param class_cond: if True, include a "y" key in returned dicts for class
label. If classes are not available and this is true, an
exception will be raised.
:param deterministic: if True, yield results in a deterministic order.
"""
dataset = VaihDataset(
mode='train',
image_size=image_size,
shard=MPI.COMM_WORLD.Get_rank(),
num_shards=MPI.COMM_WORLD.Get_size(),
)
if deterministic:
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True
)
else:
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True
)
while True:
yield from loader
class VaihDataset(Dataset):
CLASSES = ('building',)
PALETTE = [[255, 0, 0]]
def __init__(self, mode, std=np.array([0.22645572 * 255, 0.15276193 * 255, 0.140702 * 255]),
mean=np.array([0.47341759 * 255, 0.28791303 * 255, 0.2850705 * 255]), no_aug=False,
image_size=256, max_data_size=None, shard=0, num_shards=1, small_image_size=None):
self.mode = mode
self.mean = torch.from_numpy(mean)
self.std = torch.from_numpy(std)
if mode == 'train' and not no_aug:
self.transformations = Compose([ToPILImage(),
Resize(size=(image_size, image_size)),
RandomAffine(degrees=[0, 360], scale=(0.75, 1.5)),
ColorJitter(brightness=0.6,
contrast=0.5,
saturation=0.4,
hue=0.025),
RandomVerticalFlip(),
RandomHorizontalFlip(),
ToTensor(),
Normalize(self.mean, self.std)])
else:
self.transformations = Compose([ToPILImage(),
Resize(size=(image_size, image_size)),
ToTensor(),
Normalize(self.mean, self.std)])
if mode == 'train':
self.data_length = 100
else:
self.data_length = 68
if max_data_size is not None:
self.data_length = max_data_size
if self.mode == 'train':
self.data = h5py.File(
str(Path(__file__).absolute().parent.parent.parent / "data/Vaihingen/full_training_vaih.hdf5"), 'r')
else:
self.data = h5py.File(
str(Path(__file__).absolute().parent.parent.parent / "data/Vaihingen/full_test_vaih.hdf5"), 'r')
self.small_image_size = small_image_size
self.mask = self.data['mask_single']
self.imgs = self.data['imgs']
self.img_list = list(self.imgs)[shard::num_shards]
self.mask_list = list(self.mask)[shard::num_shards]
def __len__(self):
return len(self.img_list)
def __getitem__(self, item):
cimage = self.img_list[item]
img = np.array(self.imgs.get(cimage))
cmask = self.mask_list[item]
mask = np.array(self.mask.get(cmask))
img = img.astype(np.uint8)
mask = mask.astype(np.uint8)
img, mask = self.transformations(img, mask)
out_dict = {"conditioned_image": img}
mask = (2 * mask - 1.0).unsqueeze(0)
if self.small_image_size is not None:
out_dict["low_res"] = F.interpolate(mask.unsqueeze(0), self.small_image_size, mode="nearest").squeeze(0)
return mask, out_dict, str(Path(cimage).stem)
if __name__ == '__main__':
mean = np.array([0, 0, 0])
std = np.array([1, 1, 1])
dataset = VaihDataset('train', mean=mean, std=std, image_size=256)
dataset2 = VaihDataset('train', mean=mean, std=std, image_size=256, no_aug=True)
for i in range(10):
mask, out_dict, _ = dataset[0]
img = out_dict["conditioned_image"]
plt.imshow(img.permute(1,2,0).numpy().astype(np.uint8))
plt.show()
plt.imshow(mask.permute(1,2,0).numpy(), cmap='gray')
plt.show()
mask, out_dict, _ = dataset2[0]
img = out_dict["conditioned_image"]
plt.imshow(img.permute(1,2,0).numpy().astype(np.uint8))
plt.show()
================================================
FILE: environment.yml
================================================
name: segdiff
channels:
- anaconda
- pytorch
- conda-forge
- defaults
dependencies:
- python=3.8.12
- pip=21.2.4
- pytorch=1.9.0
- torchvision=0.10.0
- cudatoolkit=11.1
- mpi4py=3.1.2
- tqdm=4.62.3
- scikit-learn=0.24.2
- scikit-image=0.18.3
- matplotlib=3.4.3
- seaborn=0.11.2
- pip:
- opencv-python==4.5.1.48
- blobfile==1.2.3
- pycocotools==2.0.2
- gitpython==3.1.24
- kornia==0.5.11
- h5py==3.4.0
- imagecodecs==2021.11.20
================================================
FILE: image_sample_diff_city.py
================================================
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
import argparse
import datetime
import json
from pathlib import Path
import torch.distributed as dist
from improved_diffusion.sampling_util import sampling_major_vote_func
from improved_diffusion import dist_util, logger
from datasets.city import create_dataset
from improved_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
from improved_diffusion.utils import set_random_seed
import warnings
warnings.filterwarnings('ignore')
def main():
args = create_argparser().parse_args()
original_logs_path = Path(args.model_path).parent
logs_path = original_logs_path / f"{Path(args.model_path).stem}_major_vote"
args.__dict__.update(json.loads((original_logs_path / 'args.json').read_text()))
logger.info(args.__dict__)
dist_util.setup_dist()
logger.configure(dir=str(logs_path), log_suffix=f"val_{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}")
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
model.to(dist_util.dev())
model.eval()
test_dataset = create_dataset(
class_name=args.class_name,
mode='val',
expansion=args.expansion,
)
if args.__dict__.get("seed") is None:
seed = 1234
else:
seed = int(args.__dict__.get("seed"))
set_random_seed(seed, deterministic=True)
logger.log("sampling major vote val")
(logs_path / "major_vote").mkdir(exist_ok=True)
step = int(Path(args.model_path).stem.split("_")[-1])
sampling_major_vote_func(diffusion, model, str(logs_path / "major_vote"), test_dataset, logger, args.clip_denoised,
step=step, n_rounds=len(test_dataset))
dist.barrier()
logger.log("sampling complete")
def create_argparser():
defaults = dict(
clip_denoised=True,
num_samples=10000,
batch_size=16,
use_ddim=False,
model_path="",
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
================================================
FILE: image_sample_diff_medical.py
================================================
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
import argparse
import datetime
import json
from pathlib import Path
import torch.distributed as dist
from improved_diffusion import dist_util, logger
from datasets.monu import create_dataset
from improved_diffusion.sampling_util import sampling_major_vote_func
from improved_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
from improved_diffusion.utils import set_random_seed
import warnings
warnings.filterwarnings('ignore')
def main():
args = create_argparser().parse_args()
original_logs_path = Path(args.model_path).parent
logs_path = original_logs_path / f"{Path(args.model_path).stem}_major_vote"
args.__dict__.update(json.loads((original_logs_path / 'args.json').read_text()))
logger.info(args.__dict__)
dist_util.setup_dist()
logger.configure(dir=str(logs_path), log_suffix=f"val_{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}")
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
model.to(dist_util.dev())
model.eval()
test_dataset = create_dataset(
mode='val',
)
if args.__dict__.get("seed") is None:
seed = 1234
else:
seed = int(args.__dict__.get("seed"))
set_random_seed(seed, deterministic=True)
logger.log("sampling major vote val")
(logs_path / "major_vote").mkdir(exist_ok=True)
step = int(Path(args.model_path).stem.split("_")[-1])
sampling_major_vote_func(diffusion, model, str(logs_path / "major_vote"), test_dataset, logger, args.clip_denoised,
step=step, n_rounds=len(test_dataset))
dist.barrier()
logger.log("sampling complete")
def create_argparser():
defaults = dict(
clip_denoised=True,
num_samples=10000,
batch_size=16,
use_ddim=False,
model_path="",
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
================================================
FILE: image_sample_diff_vaih.py
================================================
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
import argparse
import datetime
import json
from pathlib import Path
import torch.distributed as dist
from mpi4py import MPI
from improved_diffusion import dist_util, logger
from improved_diffusion.sampling_util import sampling_major_vote_func
from improved_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
from improved_diffusion.utils import set_random_seed
from datasets.vaih import VaihDataset
import warnings
warnings.filterwarnings('ignore')
def main():
args = create_argparser().parse_args()
original_logs_path = Path(args.model_path).parent
logs_path = original_logs_path / f"{Path(args.model_path).stem}_major_vote"
args.__dict__.update(json.loads((original_logs_path / 'args.json').read_text()))
logger.info(args.__dict__)
dist_util.setup_dist()
logger.configure(dir=str(logs_path), log_suffix=f"val_{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}")
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
model.to(dist_util.dev())
model.eval()
test_dataset = VaihDataset(
mode='val',
image_size=args.image_size,
shard=MPI.COMM_WORLD.Get_rank(),
num_shards=MPI.COMM_WORLD.Get_size(),
)
if args.__dict__.get("seed") is None:
seed = 1234
else:
seed = int(args.__dict__.get("seed"))
set_random_seed(seed, deterministic=True)
logger.log("sampling major vote val")
(logs_path / "major_vote").mkdir(exist_ok=True)
step = int(Path(args.model_path).stem.split("_")[-1])
sampling_major_vote_func(diffusion, model, str(logs_path / "major_vote"), test_dataset, logger, args.clip_denoised,
step=step, n_rounds=len(test_dataset))
dist.barrier()
logger.log("sampling complete")
def create_argparser():
defaults = dict(
clip_denoised=True,
num_samples=10000,
batch_size=16,
use_ddim=False,
model_path="",
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
================================================
FILE: image_train_diff_city.py
================================================
"""
Train a diffusion model on images.
"""
import argparse
import datetime
import json
import os
from pathlib import Path
import git
from mpi4py import MPI
from improved_diffusion import dist_util, logger
from datasets.city import load_data, create_dataset
from improved_diffusion.resample import create_named_schedule_sampler
from improved_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
args_to_dict,
add_dict_to_argparser,
)
from improved_diffusion.train_util import TrainLoop
from improved_diffusion.utils import set_random_seed, set_random_seed_for_iterations
import warnings
warnings.filterwarnings('ignore')
def main():
args = create_argparser().parse_args()
args.use_fp16 = True
args.clip_denoised = False
args.learn_sigma = False
args.sigma_small = False
args.num_channels = 128
args.image_size = 128
args.num_res_blocks = 3
args.noise_schedule = "linear"
args.rescale_learned_sigmas = False
args.rescale_timesteps = False
args.use_scale_shift_norm = False
args.deeper_net = True
exp_name = f"city_{args.rrdb_blocks}_{args.lr}_{args.batch_size}_{args.diffusion_steps}_{str(args.dropout)}_{args.class_name}_{MPI.COMM_WORLD.Get_rank()}"
if args.expansion:
exp_name += "_expansion"
logs_root = Path(__file__).absolute().parent.parent / "logs"
log_path = logs_root / f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}_{exp_name}"
os.environ["OPENAI_LOGDIR"] = str(log_path)
set_random_seed(MPI.COMM_WORLD.Get_rank(), deterministic=True)
set_random_seed_for_iterations(MPI.COMM_WORLD.Get_rank())
dist_util.setup_dist()
logger.configure(dir=str(log_path))
if args.resume_checkpoint:
resumed_checkpoint_arg = args.resume_checkpoint
args.__dict__.update(json.loads((Path(args.resume_checkpoint) / 'args.json').read_text()))
args.resume_checkpoint = resumed_checkpoint_arg
logger.info(args.__dict__)
(Path(log_path) / 'args.json').write_text(json.dumps(args.__dict__, indent=4))
logger.info(f"log folder path: {Path(log_path).resolve()}")
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
logger.log(f"git commit hash {sha}")
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.to(dist_util.dev())
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
logger.log("creating data loader...")
data = load_data(
data_dir=args.data_dir,
batch_size=args.batch_size,
image_size=args.image_size,
class_cond=args.class_cond,
class_name=args.class_name,
expansion=args.expansion
)
val_dataset = create_dataset(
class_name=args.class_name,
mode='val',
expansion=args.expansion,
)
logger.log(f"gpu {MPI.COMM_WORLD.Get_rank()} / {MPI.COMM_WORLD.Get_size()} val length {len(val_dataset)}")
logger.log("training...")
TrainLoop(
model=model,
diffusion=diffusion,
data=data,
batch_size=args.batch_size,
microbatch=args.microbatch,
lr=args.lr,
ema_rate=args.ema_rate,
log_interval=args.log_interval,
save_interval=args.save_interval,
resume_checkpoint=args.resume_checkpoint,
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
clip_denoised=args.clip_denoised,
logger=logger,
image_size=args.image_size,
val_dataset=val_dataset,
run_without_test=args.run_without_test,
args=args
# dist_util=dist_util,
).run_loop(max_iter=300000, start_print_iter=args.start_print_iter)
def create_argparser():
defaults = dict(
data_dir="",
schedule_sampler="uniform",
lr=0.00002,
weight_decay=0.0,
lr_anneal_steps=0,
clip_denoised=False,
batch_size=4,
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
save_interval=5000,
start_print_iter=75000,
log_interval=200,
run_without_test=False,
resume_checkpoint="",
use_fp16=False,
fp16_scale_growth=1e-3,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
================================================
FILE: image_train_diff_medical.py
================================================
"""
Train a diffusion model on images.
"""
import argparse
import datetime
import json
import os
from pathlib import Path
import git
from mpi4py import MPI
from improved_diffusion import dist_util, logger
from datasets.monu import load_data, create_dataset
from improved_diffusion.resample import create_named_schedule_sampler
from improved_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
args_to_dict,
add_dict_to_argparser,
)
from improved_diffusion.train_util import TrainLoop
from improved_diffusion.utils import set_random_seed, set_random_seed_for_iterations
import warnings
warnings.filterwarnings('ignore')
def main():
args = create_argparser().parse_args()
args.use_fp16 = True
args.clip_denoised = False
args.learn_sigma = False
args.sigma_small = False
args.image_size = 256
args.num_res_blocks = 3
args.noise_schedule = "linear"
args.rescale_learned_sigmas = False
args.rescale_timesteps = False
args.use_scale_shift_norm = False
args.deeper_net = True
# args.start_print_iter = 4
# args.save_interval = 4
exp_name = f"monu_{args.rrdb_blocks}_{args.lr}_{args.batch_size}_{args.diffusion_steps}_{str(args.dropout)}_{MPI.COMM_WORLD.Get_rank()}"
logs_root = Path(__file__).absolute().parent.parent / "logs"
log_path = logs_root / f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}_{exp_name}"
os.environ["OPENAI_LOGDIR"] = str(log_path)
set_random_seed(MPI.COMM_WORLD.Get_rank(), deterministic=True)
set_random_seed_for_iterations(MPI.COMM_WORLD.Get_rank())
dist_util.setup_dist()
logger.configure(dir=str(log_path))
if args.resume_checkpoint:
resumed_checkpoint_arg = args.resume_checkpoint
args.__dict__.update(json.loads((Path(args.resume_checkpoint) / 'args.json').read_text()))
args.resume_checkpoint = resumed_checkpoint_arg
logger.info(args.__dict__)
(Path(log_path) / 'args.json').write_text(json.dumps(args.__dict__, indent=4))
logger.info(f"log folder path: {Path(log_path).resolve()}")
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
logger.log(f"git commit hash {sha}")
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.to(dist_util.dev())
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
logger.log("creating data loader...")
data = load_data(
data_dir=args.data_dir,
batch_size=args.batch_size,
image_size=args.image_size,
class_cond=args.class_cond,
class_name=args.class_name,
expansion=args.expansion
)
val_dataset = create_dataset(
mode='val',
image_size=args.image_size
)
logger.log(f"gpu {MPI.COMM_WORLD.Get_rank()} / {MPI.COMM_WORLD.Get_size()} val length {len(val_dataset)}")
logger.log("training...")
TrainLoop(
model=model,
diffusion=diffusion,
data=data,
batch_size=args.batch_size,
microbatch=args.microbatch,
lr=args.lr,
ema_rate=args.ema_rate,
log_interval=args.log_interval,
save_interval=args.save_interval,
resume_checkpoint=args.resume_checkpoint,
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
clip_denoised=args.clip_denoised,
logger=logger,
image_size=args.image_size,
val_dataset=val_dataset,
run_without_test=args.run_without_test,
args=args
# dist_util=dist_util,
).run_loop(max_iter=300000, start_print_iter=args.start_print_iter)
def create_argparser():
defaults = dict(
data_dir="",
schedule_sampler="uniform",
lr=0.00002,
weight_decay=0.0,
lr_anneal_steps=0,
clip_denoised=False,
batch_size=4,
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
save_interval=5000,
start_print_iter=75000,
log_interval=200,
run_without_test=False,
resume_checkpoint="",
use_fp16=False,
fp16_scale_growth=1e-3,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
================================================
FILE: image_train_diff_vaih.py
================================================
"""
Train a diffusion model on images.
"""
import argparse
import datetime
import json
import os
from pathlib import Path
import git
from mpi4py import MPI
from improved_diffusion import dist_util, logger
from datasets.vaih import load_data
from improved_diffusion.resample import create_named_schedule_sampler
from improved_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
args_to_dict,
add_dict_to_argparser,
)
from improved_diffusion.train_util import TrainLoop
from improved_diffusion.utils import set_random_seed, set_random_seed_for_iterations
from datasets.vaih import VaihDataset
import warnings
warnings.filterwarnings('ignore')
def main():
args = create_argparser().parse_args()
args.use_fp16 = True
args.clip_denoised = False
args.learn_sigma = False
args.sigma_small = False
args.num_channels = 128
args.image_size = 256
args.num_res_blocks = 3
args.noise_schedule = "linear"
args.rescale_learned_sigmas = False
args.rescale_timesteps = False
args.use_scale_shift_norm = False
args.deeper_net = True
exp_name = f"vaih_256_{args.rrdb_blocks}_{args.lr}_{args.batch_size}_{args.diffusion_steps}_{str(args.dropout)}_{MPI.COMM_WORLD.Get_rank()}"
logs_root = Path(__file__).absolute().parent.parent / "logs"
log_path = logs_root / f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}_{exp_name}"
os.environ["OPENAI_LOGDIR"] = str(log_path)
set_random_seed(MPI.COMM_WORLD.Get_rank(), deterministic=True)
set_random_seed_for_iterations(MPI.COMM_WORLD.Get_rank())
dist_util.setup_dist()
logger.configure(dir=str(log_path))
if args.resume_checkpoint:
resumed_checkpoint_arg = args.resume_checkpoint
args.__dict__.update(json.loads((Path(args.resume_checkpoint) / 'args.json').read_text()))
args.resume_checkpoint = resumed_checkpoint_arg
logger.info(args.__dict__)
(Path(log_path) / 'args.json').write_text(json.dumps(args.__dict__, indent=4))
logger.info(f"log folder path: {Path(log_path).resolve()}")
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
logger.log(f"git commit hash {sha}")
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.to(dist_util.dev())
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
logger.log("creating data loader...")
data = load_data(
data_dir=args.data_dir,
batch_size=args.batch_size,
image_size=args.image_size,
class_cond=args.class_cond
)
val_dataset = VaihDataset(
mode='val',
image_size=args.image_size,
shard=MPI.COMM_WORLD.Get_rank(),
num_shards=MPI.COMM_WORLD.Get_size(),
)
logger.log(f"gpu {MPI.COMM_WORLD.Get_rank()} / {MPI.COMM_WORLD.Get_size()} val length {len(val_dataset)}")
logger.log("training...")
TrainLoop(
model=model,
diffusion=diffusion,
data=data,
batch_size=args.batch_size,
microbatch=args.microbatch,
lr=args.lr,
ema_rate=args.ema_rate,
log_interval=args.log_interval,
save_interval=args.save_interval,
resume_checkpoint=args.resume_checkpoint,
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
clip_denoised=args.clip_denoised,
logger=logger,
image_size=args.image_size,
val_dataset=val_dataset,
run_without_test=args.run_without_test,
args=args
# dist_util=dist_util,
).run_loop(max_iter=300000, start_print_iter=args.start_print_iter)
def create_argparser():
defaults = dict(
data_dir="",
schedule_sampler="uniform",
lr=0.00002,
weight_decay=0.0,
lr_anneal_steps=0,
clip_denoised=False,
batch_size=4,
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
save_interval=5000,
start_print_iter=75000,
log_interval=200,
run_without_test=False,
resume_checkpoint="",
use_fp16=False,
fp16_scale_growth=1e-3,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
================================================
FILE: improved_diffusion/RRDB.py
================================================
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
def __init__(self, nf=1, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc=3, out_nc=128, nf=64, nb=3, gc=32):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
================================================
FILE: improved_diffusion/__init__.py
================================================
"""
Codebase for "Improved Denoising Diffusion Probabilistic Models".
"""
================================================
FILE: improved_diffusion/dist_util.py
================================================
"""
Helpers for distributed training.
"""
import io
import os
import socket
import blobfile as bf
from mpi4py import MPI
import torch as th
import torch.distributed as dist
# Change this to reflect your cluster layout.
# The GPU for a given rank is (rank % GPUS_PER_NODE).
GPUS_PER_NODE = 8
SETUP_RETRY_COUNT = 3
def setup_dist():
"""
Setup a distributed process group.
"""
if dist.is_initialized():
return
comm = MPI.COMM_WORLD
backend = "gloo" if not th.cuda.is_available() else "nccl"
if backend == "gloo":
hostname = "localhost"
else:
hostname = socket.gethostbyname(socket.getfqdn())
os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
os.environ["RANK"] = str(comm.rank)
os.environ["WORLD_SIZE"] = str(comm.size)
port = comm.bcast(_find_free_port(), root=0)
os.environ["MASTER_PORT"] = str(port)
dist.init_process_group(backend=backend, init_method="env://")
def dev():
"""
Get the device to use for torch.distributed.
"""
if th.cuda.is_available():
return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
return th.device("cpu")
def load_state_dict(path, **kwargs):
"""
Load a PyTorch file without redundant fetches across MPI ranks.
"""
if MPI.COMM_WORLD.Get_rank() == 0:
with bf.BlobFile(path, "rb") as f:
data = f.read()
else:
data = None
data = MPI.COMM_WORLD.bcast(data)
return th.load(io.BytesIO(data), **kwargs)
def sync_params(params):
"""
Synchronize a sequence of Tensors across ranks from rank 0.
"""
for p in params:
with th.no_grad():
dist.broadcast(p, 0)
def _find_free_port():
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
finally:
s.close()
================================================
FILE: improved_diffusion/fp16_util.py
================================================
"""
Helpers to train with 16-bit precision.
"""
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.half()
l.bias.data = l.bias.data.half()
def convert_module_to_f32(l):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.float()
l.bias.data = l.bias.data.float()
def make_master_params(model_params):
"""
Copy model parameters into a (differently-shaped) list of full-precision
parameters.
"""
master_params = _flatten_dense_tensors(
[param.detach().float() for param in model_params]
)
master_params = nn.Parameter(master_params)
master_params.requires_grad = True
return [master_params]
def model_grads_to_master_grads(model_params, master_params):
"""
Copy the gradients from the model parameters into the master parameters
from make_master_params().
"""
master_params[0].grad = _flatten_dense_tensors(
[param.grad.data.detach().float() for param in model_params]
)
def master_params_to_model_params(model_params, master_params):
"""
Copy the master parameter data back into the model parameters.
"""
# Without copying to a list, if a generator is passed, this will
# silently not copy any parameters.
model_params = list(model_params)
for param, master_param in zip(
model_params, unflatten_master_params(model_params, master_params)
):
param.detach().copy_(master_param)
def unflatten_master_params(model_params, master_params):
"""
Unflatten the master parameters to look like model_params.
"""
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
def zero_grad(model_params):
for param in model_params:
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()
================================================
FILE: improved_diffusion/gaussian_diffusion.py
================================================
"""
This code started out as a PyTorch port of Ho et al's diffusion models:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
"""
import enum
import math
import numpy as np
import torch as th
from .nn import mean_flat
from .losses import normal_kl, discretized_gaussian_log_likelihood
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif schedule_name == "cosine":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class ModelMeanType(enum.Enum):
"""
Which type of output the model predicts.
"""
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
START_X = enum.auto() # the model predicts x_0
EPSILON = enum.auto() # the model predicts epsilon
class ModelVarType(enum.Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
LEARNED = enum.auto()
FIXED_SMALL = enum.auto()
FIXED_LARGE = enum.auto()
LEARNED_RANGE = enum.auto()
class LossType(enum.Enum):
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = (
enum.auto()
) # use raw MSE loss (with RESCALED_KL when learning variances)
KL = enum.auto() # use the variational lower-bound
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
def is_vb(self):
return self == LossType.KL or self == LossType.RESCALED_KL
class GaussianDiffusion:
"""
Utilities for training and sampling diffusion models.
Ported directly from here, and then adapted over time to further experimentation.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
:param model_mean_type: a ModelMeanType determining what the model outputs.
:param model_var_type: a ModelVarType determining how variance is output.
:param loss_type: a LossType determining the loss function to use.
:param rescale_timesteps: if True, pass floating point timesteps into the
model so that they are always scaled like in the
original paper (0 to 1000).
"""
def __init__(
self,
*,
betas,
model_mean_type,
model_var_type,
loss_type,
rescale_timesteps=False,
):
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
self.loss_type = loss_type
self.rescale_timesteps = rescale_timesteps
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
self.betas = betas
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
self.num_timesteps = int(betas.shape[0])
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# log calculation clipped because the posterior variance is 0 at the
# beginning of the diffusion chain.
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* np.sqrt(alphas)
/ (1.0 - self.alphas_cumprod)
)
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
)
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
if self.model_var_type == ModelVarType.LEARNED:
model_log_variance = model_var_values
model_variance = th.exp(model_log_variance)
else:
min_log = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x.shape
)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = th.exp(model_log_variance)
else:
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
ModelVarType.FIXED_LARGE: (
np.append(self.posterior_variance[1], self.betas[1:]),
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
),
ModelVarType.FIXED_SMALL: (
self.posterior_variance,
self.posterior_log_variance_clipped,
),
}[self.model_var_type]
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)
if clip_denoised:
return x.clamp(-1, 1)
return x
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
pred_xstart = process_xstart(
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
)
model_mean = model_output
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output)
else:
pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)
else:
raise NotImplementedError(self.model_mean_type)
assert (
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
)
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
}
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_xstart_from_xprev(self, x_t, t, xprev):
assert x_t.shape == xprev.shape
return ( # (xprev - coef2*x_t) / coef1
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
- _extract_into_tensor(
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
)
* x_t
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _scale_timesteps(self, t):
if self.rescale_timesteps:
return t.float() * (1000.0 / self.num_timesteps)
return t
def p_sample(
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
noise = th.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def p_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final = None
for sample in self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
):
final = sample
return final["sample"]
def p_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape).to(device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
yield out
img = out["sample"]
def ddim_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = (
eta
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12.
noise = th.randn_like(x)
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def ddim_reverse_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""
assert eta == 0.0, "Reverse ODE only for deterministic path"
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ th.sqrt(1 - alpha_bar_next) * eps
)
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
def ddim_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final = None
for sample in self.ddim_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
eta=eta,
):
final = sample
return final["sample"]
def ddim_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape).to(device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.ddim_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
eta=eta,
)
yield out
img = out["sample"]
def _vb_terms_bpd(
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
)
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
if model_kwargs is None:
model_kwargs = {}
if noise is None:
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
terms["loss"] = self._vb_terms_bpd(
model=model,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
model_kwargs=model_kwargs,
)["output"]
if self.loss_type == LossType.RESCALED_KL:
terms["loss"] *= self.num_timesteps
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, C = x_t.shape[:2]
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"] = self._vb_terms_bpd(
model=lambda *args, r=frozen_out: r,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
)["output"]
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
terms["vb"] *= self.num_timesteps / 1000.0
target = {
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)[0],
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2)
terms["sum"] = (target - model_output).pow(2).sum(dim=(1, 2, 3))
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:
terms["loss"] = terms["sum"]
else:
raise NotImplementedError(self.loss_type)
return terms
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
)
return mean_flat(kl_prior) / np.log(2.0)
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
"""
Compute the entire variational lower-bound, measured in bits-per-dim,
as well as other related quantities.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param clip_denoised: if True, clip denoised samples.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- total_bpd: the total variational lower-bound, per batch element.
- prior_bpd: the prior term in the lower-bound.
- vb: an [N x T] tensor of terms in the lower-bound.
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
"""
device = x_start.device
batch_size = x_start.shape[0]
vb = []
xstart_mse = []
mse = []
for t in list(range(self.num_timesteps))[::-1]:
t_batch = th.tensor([t] * batch_size, device=device)
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
# Calculate VLB term at the current timestep
with th.no_grad():
out = self._vb_terms_bpd(
model,
x_start=x_start,
x_t=x_t,
t=t_batch,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
)
vb.append(out["output"])
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
mse.append(mean_flat((eps - noise) ** 2))
vb = th.stack(vb, dim=1)
xstart_mse = th.stack(xstart_mse, dim=1)
mse = th.stack(mse, dim=1)
prior_bpd = self._prior_bpd(x_start)
total_bpd = vb.sum(dim=1) + prior_bpd
return {
"total_bpd": total_bpd,
"prior_bpd": prior_bpd,
"vb": vb,
"xstart_mse": xstart_mse,
"mse": mse,
}
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
================================================
FILE: improved_diffusion/image_datasets.py
================================================
from PIL import Image
import blobfile as bf
from mpi4py import MPI
import numpy as np
from torch.utils.data import DataLoader, Dataset
def load_data(
*, data_dir, batch_size, image_size, class_cond=False, deterministic=False
):
"""
For a dataset, create a generator over (images, kwargs) pairs.
Each images is an NCHW float tensor, and the kwargs dict contains zero or
more keys, each of which map to a batched Tensor of their own.
The kwargs dict can be used for class labels, in which case the key is "y"
and the values are integer tensors of class labels.
:param data_dir: a dataset directory.
:param batch_size: the batch size of each returned pair.
:param image_size: the size to which images are resized.
:param class_cond: if True, include a "y" key in returned dicts for class
label. If classes are not available and this is true, an
exception will be raised.
:param deterministic: if True, yield results in a deterministic order.
"""
if not data_dir:
raise ValueError("unspecified data directory")
all_files = _list_image_files_recursively(data_dir)
classes = None
if class_cond:
# Assume classes are the first part of the filename,
# before an underscore.
class_names = [bf.basename(path).split("_")[0] for path in all_files]
sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
classes = [sorted_classes[x] for x in class_names]
dataset = ImageDataset(
image_size,
all_files,
classes=classes,
shard=MPI.COMM_WORLD.Get_rank(),
num_shards=MPI.COMM_WORLD.Get_size(),
)
if deterministic:
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True
)
else:
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True
)
while True:
yield from loader
def _list_image_files_recursively(data_dir):
results = []
for entry in sorted(bf.listdir(data_dir)):
full_path = bf.join(data_dir, entry)
ext = entry.split(".")[-1]
if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
results.append(full_path)
elif bf.isdir(full_path):
results.extend(_list_image_files_recursively(full_path))
return results
class ImageDataset(Dataset):
def __init__(self, resolution, image_paths, classes=None, shard=0, num_shards=1):
super().__init__()
self.resolution = resolution
self.local_images = image_paths[shard:][::num_shards]
self.local_classes = None if classes is None else classes[shard:][::num_shards]
def __len__(self):
return len(self.local_images)
def __getitem__(self, idx):
path = self.local_images[idx]
with bf.BlobFile(path, "rb") as f:
pil_image = Image.open(f)
pil_image.load()
# We are not on a new enough PIL to support the `reducing_gap`
# argument, which uses BOX downsampling at powers of two first.
# Thus, we do it by hand to improve downsample quality.
while min(*pil_image.size) >= 2 * self.resolution:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = self.resolution / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image.convert("RGB"))
crop_y = (arr.shape[0] - self.resolution) // 2
crop_x = (arr.shape[1] - self.resolution) // 2
arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution]
arr = arr.astype(np.float32) / 127.5 - 1
out_dict = {}
if self.local_classes is not None:
out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
return np.transpose(arr, [2, 0, 1]), out_dict
================================================
FILE: improved_diffusion/logger.py
================================================
"""
Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
"""
import os
import sys
import shutil
import os.path as osp
import json
import time
import datetime
import tempfile
import warnings
from collections import defaultdict
from contextlib import contextmanager
DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40
DISABLED = 50
class KVWriter(object):
def writekvs(self, kvs):
raise NotImplementedError
class SeqWriter(object):
def writeseq(self, seq):
raise NotImplementedError
class HumanOutputFormat(KVWriter, SeqWriter):
def __init__(self, filename_or_file):
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "wt")
self.own_file = True
else:
assert hasattr(filename_or_file, "read"), (
"expected file or str, got %s" % filename_or_file
)
self.file = filename_or_file
self.own_file = False
def writekvs(self, kvs):
# Create strings for printing
key2str = {}
for (key, val) in sorted(kvs.items()):
if hasattr(val, "__float__"):
valstr = "%-8.3g" % val
else:
valstr = str(val)
key2str[self._truncate(key)] = self._truncate(valstr)
# Find max widths
if len(key2str) == 0:
print("WARNING: tried to write empty key-value dict")
return
else:
keywidth = max(map(len, key2str.keys()))
valwidth = max(map(len, key2str.values()))
# Write out the data
dashes = "-" * (keywidth + valwidth + 7)
lines = [dashes]
for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
lines.append(
"| %s%s | %s%s |"
% (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
)
lines.append(dashes)
self.file.write("\n".join(lines) + "\n")
# Flush the output to the file
self.file.flush()
def _truncate(self, s):
maxlen = 30
return s[: maxlen - 3] + "..." if len(s) > maxlen else s
def writeseq(self, seq):
seq = list(seq)
for (i, elem) in enumerate(seq):
self.file.write(f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')} {elem}")
if i < len(seq) - 1: # add space unless this is the last one
self.file.write(" ")
self.file.write("\n")
self.file.flush()
def close(self):
if self.own_file:
self.file.close()
class JSONOutputFormat(KVWriter):
def __init__(self, filename):
self.file = open(filename, "wt")
def writekvs(self, kvs):
for k, v in sorted(kvs.items()):
if hasattr(v, "dtype"):
kvs[k] = float(v)
self.file.write(json.dumps(kvs) + "\n")
self.file.flush()
def close(self):
self.file.close()
class CSVOutputFormat(KVWriter):
def __init__(self, filename):
self.file = open(filename, "w+t")
self.keys = []
self.sep = ","
def writekvs(self, kvs):
# Add our current row to the history
extra_keys = list(kvs.keys() - self.keys)
extra_keys.sort()
if extra_keys:
self.keys.extend(extra_keys)
self.file.seek(0)
lines = self.file.readlines()
self.file.seek(0)
for (i, k) in enumerate(self.keys):
if i > 0:
self.file.write(",")
self.file.write(k)
self.file.write("\n")
for line in lines[1:]:
self.file.write(line[:-1])
self.file.write(self.sep * len(extra_keys))
self.file.write("\n")
for (i, k) in enumerate(self.keys):
if i > 0:
self.file.write(",")
v = kvs.get(k)
if v is not None:
self.file.write(str(v))
self.file.write("\n")
self.file.flush()
def close(self):
self.file.close()
class TensorBoardOutputFormat(KVWriter):
"""
Dumps key/value pairs into TensorBoard's numeric format.
"""
def __init__(self, dir):
os.makedirs(dir, exist_ok=True)
self.dir = dir
self.step = 1
prefix = "events"
path = osp.join(osp.abspath(dir), prefix)
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.core.util import event_pb2
from tensorflow.python.util import compat
self.tf = tf
self.event_pb2 = event_pb2
self.pywrap_tensorflow = pywrap_tensorflow
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
def writekvs(self, kvs):
def summary_val(k, v):
kwargs = {"tag": k, "simple_value": float(v)}
return self.tf.Summary.Value(**kwargs)
summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
event.step = (
self.step
) # is there any reason why you'd want to specify the step?
self.writer.WriteEvent(event)
self.writer.Flush()
self.step += 1
def close(self):
if self.writer:
self.writer.Close()
self.writer = None
def make_output_format(format, ev_dir, log_suffix=""):
os.makedirs(ev_dir, exist_ok=True)
if format == "stdout":
return HumanOutputFormat(sys.stdout)
elif format == "log":
return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
elif format == "json":
return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
elif format == "csv":
return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
elif format == "tensorboard":
return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
else:
raise ValueError("Unknown format specified: %s" % (format,))
# ================================================================
# API
# ================================================================
def logkv(key, val):
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
If called many times, last value will be used.
"""
get_current().logkv(key, val)
def logkv_mean(key, val):
"""
The same as logkv(), but if called many times, values averaged.
"""
get_current().logkv_mean(key, val)
def logkvs(d):
"""
Log a dictionary of key-value pairs
"""
for (k, v) in d.items():
logkv(k, v)
def dumpkvs():
"""
Write all of the diagnostics from the current iteration
"""
return get_current().dumpkvs()
def getkvs():
return get_current().name2val
def log(*args, level=INFO):
"""
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
"""
get_current().log(*args, level=level)
def debug(*args):
log(*args, level=DEBUG)
def info(*args):
log(*args, level=INFO)
def warn(*args):
log(*args, level=WARN)
def error(*args):
log(*args, level=ERROR)
def set_level(level):
"""
Set logging threshold on current logger.
"""
get_current().set_level(level)
def set_comm(comm):
get_current().set_comm(comm)
def get_dir():
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call start)
"""
return get_current().get_dir()
record_tabular = logkv
dump_tabular = dumpkvs
@contextmanager
def profile_kv(scopename):
logkey = "wait_" + scopename
tstart = time.time()
try:
yield
finally:
get_current().name2val[logkey] += time.time() - tstart
def profile(n):
"""
Usage:
@profile("my_func")
def my_func(): code
"""
def decorator_with_name(func):
def func_wrapper(*args, **kwargs):
with profile_kv(n):
return func(*args, **kwargs)
return func_wrapper
return decorator_with_name
# ================================================================
# Backend
# ================================================================
def get_current():
if Logger.CURRENT is None:
_configure_default_logger()
return Logger.CURRENT
class Logger(object):
DEFAULT = None # A logger with no output files. (See right below class definition)
# So that you can still log to the terminal without setting up any output files
CURRENT = None # Current logger being used by the free functions above
def __init__(self, dir, output_formats, comm=None):
self.name2val = defaultdict(float) # values this iteration
self.name2cnt = defaultdict(int)
self.level = INFO
self.dir = dir
self.output_formats = output_formats
self.comm = comm
# Logging API, forwarded
# ----------------------------------------
def logkv(self, key, val):
self.name2val[key] = val
def logkv_mean(self, key, val):
oldval, cnt = self.name2val[key], self.name2cnt[key]
self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
self.name2cnt[key] = cnt + 1
def dumpkvs(self):
if self.comm is None:
d = self.name2val
else:
d = mpi_weighted_mean(
self.comm,
{
name: (val, self.name2cnt.get(name, 1))
for (name, val) in self.name2val.items()
},
)
if self.comm.rank != 0:
d["dummy"] = 1 # so we don't get a warning about empty dict
out = d.copy() # Return the dict for unit testing purposes
for fmt in self.output_formats:
if isinstance(fmt, KVWriter):
fmt.writekvs(d)
self.name2val.clear()
self.name2cnt.clear()
return out
def log(self, *args, level=INFO):
if self.level <= level:
self._do_log(args)
# Configuration
# ----------------------------------------
def set_level(self, level):
self.level = level
def set_comm(self, comm):
self.comm = comm
def get_dir(self):
return self.dir
def close(self):
for fmt in self.output_formats:
fmt.close()
# Misc
# ----------------------------------------
def _do_log(self, args):
for fmt in self.output_formats:
if isinstance(fmt, SeqWriter):
fmt.writeseq(map(str, args))
def get_rank_without_mpi_import():
# check environment variables here instead of importing mpi4py
# to avoid calling MPI_Init() when this module is imported
for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
if varname in os.environ:
return int(os.environ[varname])
return 0
def mpi_weighted_mean(comm, local_name2valcount):
"""
Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
Perform a weighted average over dicts that are each on a different node
Input: local_name2valcount: dict mapping key -> (value, count)
Returns: key -> mean
"""
all_name2valcount = comm.gather(local_name2valcount)
if comm.rank == 0:
name2sum = defaultdict(float)
name2count = defaultdict(float)
for n2vc in all_name2valcount:
for (name, (val, count)) in n2vc.items():
try:
val = float(val)
except ValueError:
if comm.rank == 0:
warnings.warn(
"WARNING: tried to compute mean on non-float {}={}".format(
name, val
)
)
else:
name2sum[name] += val * count
name2count[name] += count
return {name: name2sum[name] / name2count[name] for name in name2sum}
else:
return {}
def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
"""
If comm is provided, average all numerical stats across that comm
"""
if dir is None:
dir = os.getenv("OPENAI_LOGDIR")
if dir is None:
dir = osp.join(
tempfile.gettempdir(),
datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
)
assert isinstance(dir, str)
dir = os.path.expanduser(dir)
os.makedirs(os.path.expanduser(dir), exist_ok=True)
rank = get_rank_without_mpi_import()
if rank > 0:
log_suffix = log_suffix + "-rank%03i" % rank
if format_strs is None:
if rank == 0:
format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
else:
format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
format_strs = filter(None, format_strs)
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
if output_formats:
log("Logging to %s" % dir)
def _configure_default_logger():
configure()
Logger.DEFAULT = Logger.CURRENT
def reset():
if Logger.CURRENT is not Logger.DEFAULT:
Logger.CURRENT.close()
Logger.CURRENT = Logger.DEFAULT
log("Reset logger")
@contextmanager
def scoped_configure(dir=None, format_strs=None, comm=None):
prevlogger = Logger.CURRENT
configure(dir=dir, format_strs=format_strs, comm=comm)
try:
yield
finally:
Logger.CURRENT.close()
Logger.CURRENT = prevlogger
================================================
FILE: improved_diffusion/losses.py
================================================
"""
Helpers for various likelihood-based losses. These are ported from the original
Ho et al. diffusion models codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
"""
import numpy as np
import torch as th
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, th.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
================================================
FILE: improved_diffusion/metrics.py
================================================
import numpy as np
from skimage.morphology import binary_dilation, disk
def WCov_metric(pred, gt_mask):
A1 = float(np.count_nonzero(pred))
A2 = float(np.count_nonzero(gt_mask))
if A1 >= A2: return A2 / A1
if A2 > A1: return A1 / A2
def FBound_metric(pred, gt_mask):
tmp1 = db_eval_boundary(pred, gt_mask, 1)[0]
tmp2 = db_eval_boundary(pred, gt_mask, 2)[0]
tmp3 = db_eval_boundary(pred, gt_mask, 3)[0]
tmp4 = db_eval_boundary(pred, gt_mask, 4)[0]
tmp5 = db_eval_boundary(pred, gt_mask, 5)[0]
return (tmp1 + tmp2 + tmp3 + tmp4 + tmp5) / 5.0
def db_eval_boundary(foreground_mask, gt_mask, bound_th):
"""
Compute mean,recall and decay from per-frame evaluation.
Calculates precision/recall for boundaries between foreground_mask and
gt_mask using morphological operators to speed it up.
Arguments:
foreground_mask (ndarray): binary segmentation image.
gt_mask (ndarray): binary annotated image.
Returns:
F (float): boundaries F-measure
P (float): boundaries precision
R (float): boundaries recall
"""
assert np.atleast_3d(foreground_mask).shape[2] == 1
bound_pix = bound_th if bound_th >= 1 else \
np.ceil(bound_th * np.linalg.norm(foreground_mask.shape))
# Get the pixel boundaries of both masks
fg_boundary = seg2bmap(foreground_mask)
gt_boundary = seg2bmap(gt_mask)
fg_dil = binary_dilation(fg_boundary, disk(bound_pix))
gt_dil = binary_dilation(gt_boundary, disk(bound_pix))
# Get the intersection
gt_match = gt_boundary * fg_dil
fg_match = fg_boundary * gt_dil
# Area of the intersection
n_fg = np.sum(fg_boundary)
n_gt = np.sum(gt_boundary)
# % Compute precision and recall
if n_fg == 0 and n_gt > 0:
precision = 1
recall = 0
elif n_fg > 0 and n_gt == 0:
precision = 0
recall = 1
elif n_fg == 0 and n_gt == 0:
precision = 1
recall = 1
else:
precision = np.sum(fg_match) / float(n_fg)
recall = np.sum(gt_match) / float(n_gt)
# Compute F measure
if precision + recall == 0:
F = 0
else:
F = 2 * precision * recall / (precision + recall)
return F, precision, recall, np.sum(fg_match), n_fg, np.sum(gt_match), n_gt
def seg2bmap(seg, width=None, height=None):
"""
From a segmentation, compute a binary boundary map with 1 pixel wide
boundaries. The boundary pixels are offset by 1/2 pixel towards the
origin from the actual segment boundary.
Arguments:
seg : Segments labeled from 1..k.
width : Width of desired bmap <= seg.shape[1]
height : Height of desired bmap <= seg.shape[0]
Returns:
bmap (ndarray): Binary boundary map.
David Martin <dmartin@eecs.berkeley.edu>
January 2003
"""
seg = seg.astype(bool)
seg[seg > 0] = 1
assert np.atleast_3d(seg).shape[2] == 1
width = seg.shape[1] if width is None else width
height = seg.shape[0] if height is None else height
h, w = seg.shape[:2]
ar1 = float(width) / float(height)
ar2 = float(w) / float(h)
assert not (width > w | height > h | abs(ar1 - ar2) > 0.01), \
'Can''t convert %dx%d seg to %dx%d bmap.' % (w, h, width, height)
e = np.zeros_like(seg)
s = np.zeros_like(seg)
se = np.zeros_like(seg)
e[:, :-1] = seg[:, 1:]
s[:-1, :] = seg[1:, :]
se[:-1, :-1] = seg[1:, 1:]
b = seg ^ e | seg ^ s | seg ^ se
b[-1, :] = seg[-1, :] ^ e[-1, :]
b[:, -1] = seg[:, -1] ^ s[:, -1]
b[-1, -1] = 0
if w == width and h == height:
bmap = b
else:
bmap = np.zeros((height, width))
for x in range(w):
for y in range(h):
if b[y, x]:
j = 1 + np.floor((y - 1) + height / h)
i = 1 + np.floor((x - 1) + width / h)
bmap[j, i] = 1
return bmap
================================================
FILE: improved_diffusion/nn.py
================================================
"""
Various utilities for neural networks.
"""
import math
import torch as th
import torch.nn as nn
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * th.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def swap_ema(target_params, source_params):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
"""
for targ, src in zip(target_params, source_params):
temp = targ.data.clone()
targ.data.copy_(src.data)
src.data.copy_(temp)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = th.exp(
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(th.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with th.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with th.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = th.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
================================================
FILE: improved_diffusion/resample.py
================================================
from abc import ABC, abstractmethod
import numpy as np
import torch as th
import torch.distributed as dist
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if name == "uniform":
return UniformSampler(diffusion)
elif name == "loss-second-moment":
return LossSecondMomentResampler(diffusion)
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(ABC):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
p = w / np.sum(w)
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
indices = th.from_numpy(indices_np).long().to(device)
weights_np = 1 / (len(p) * p[indices_np])
weights = th.from_numpy(weights_np).float().to(device)
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, diffusion):
self.diffusion = diffusion
self._weights = np.ones([diffusion.num_timesteps])
def weights(self):
return self._weights
class LossAwareSampler(ScheduleSampler):
def update_with_local_losses(self, local_ts, local_losses):
"""
Update the reweighting using losses from a model.
Call this method from each rank with a batch of timesteps and the
corresponding losses for each of those timesteps.
This method will perform synchronization to make sure all of the ranks
maintain the exact same reweighting.
:param local_ts: an integer Tensor of timesteps.
:param local_losses: a 1D Tensor of losses.
"""
batch_sizes = [
th.tensor([0], dtype=th.int32, device=local_ts.device)
for _ in range(dist.get_world_size())
]
dist.all_gather(
batch_sizes,
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
)
# Pad all_gather batches to be the maximum batch size.
batch_sizes = [x.item() for x in batch_sizes]
max_bs = max(batch_sizes)
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
dist.all_gather(timestep_batches, local_ts)
dist.all_gather(loss_batches, local_losses)
timesteps = [
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
]
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
self.update_with_all_losses(timesteps, losses)
@abstractmethod
def update_with_all_losses(self, ts, losses):
"""
Update the reweighting using losses from a model.
Sub-classes should override this method to update the reweighting
using losses from the model.
This method directly updates the reweighting without synchronizing
between workers. It is called by update_with_local_losses from all
ranks with identical arguments. Thus, it should have deterministic
behavior to maintain state across workers.
:param ts: a list of int timesteps.
:param losses: a list of float losses, one per timestep.
"""
class LossSecondMomentResampler(LossAwareSampler):
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
self.diffusion = diffusion
self.history_per_term = history_per_term
self.uniform_prob = uniform_prob
self._loss_history = np.zeros(
[diffusion.num_timesteps, history_per_term], dtype=np.float64
)
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
def weights(self):
if not self._warmed_up():
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
weights /= np.sum(weights)
weights *= 1 - self.uniform_prob
weights += self.uniform_prob / len(weights)
return weights
def update_with_all_losses(self, ts, losses):
for t, loss in zip(ts, losses):
if self._loss_counts[t] == self.history_per_term:
# Shift out the oldest loss term.
self._loss_history[t, :-1] = self._loss_history[t, 1:]
self._loss_history[t, -1] = loss
else:
self._loss_history[t, self._loss_counts[t]] = loss
self._loss_counts[t] += 1
def _warmed_up(self):
return (self._loss_counts == self.history_per_term).all()
================================================
FILE: improved_diffusion/respace.py
================================================
import numpy as np
import torch as th
from .gaussian_diffusion import GaussianDiffusion
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
class _WrappedModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
self.model = model
self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
if self.rescale_timesteps:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, new_ts, **kwargs)
================================================
FILE: improved_diffusion/sampling_util.py
================================================
import math
import os
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchvision.utils as tvu
from PIL import Image
from kornia import denormalize
from sklearn.metrics import f1_score, jaccard_score
from torch.utils.data import DataLoader
from tqdm import tqdm
from . import dist_util
from .metrics import FBound_metric, WCov_metric
from datasets.monu import MonuDataset
from .utils import set_random_seed_for_iterations
cityspallete = [
0, 0, 0,
128, 64, 128,
244, 35, 232,
70, 70, 70,
102, 102, 156,
190, 153, 153,
153, 153, 153,
250, 170, 30,
220, 220, 0,
107, 142, 35,
152, 251, 152,
0, 130, 180,
220, 20, 60,
255, 0, 0,
0, 0, 142,
0, 0, 70,
0, 60, 100,
0, 80, 100,
0, 0, 230,
119, 11, 32,
]
def calculate_metrics(x, gt):
predict = x.detach().cpu().numpy().astype('uint8')
target = gt.detach().cpu().numpy().astype('uint8')
return f1_score(target.flatten(), predict.flatten()), jaccard_score(target.flatten(), predict.flatten()), \
WCov_metric(predict, target), FBound_metric(predict, target)
def sampling_major_vote_func(diffusion_model, ddp_model, output_folder, dataset, logger, clip_denoised, step, n_rounds=3):
ddp_model.eval()
batch_size = 1
major_vote_number = 9
loader = DataLoader(dataset, batch_size=batch_size)
loader_iter = iter(loader)
f1_score_list = []
miou_list = []
fbound_list = []
wcov_list = []
with torch.no_grad():
for round_index in tqdm(
range(n_rounds), desc="Generating image samples for FID evaluation."
):
gt_mask, condition_on, name = next(loader_iter)
set_random_seed_for_iterations(step + int(name[0].split("_")[1]))
gt_mask = (gt_mask + 1.0) / 2.0
condition_on = condition_on["conditioned_image"]
former_frame_for_feature_extraction = condition_on.to(dist_util.dev())
for i in range(gt_mask.shape[0]):
gt_img = Image.fromarray(gt_mask[i][0].detach().cpu().numpy().astype('uint8'))
gt_img.putpalette(cityspallete)
gt_img.save(
os.path.join(output_folder, f"{name[i]}_gt_palette.png"))
gt_img = Image.fromarray((gt_mask[i][0].detach().cpu().numpy() - 1).astype(np.uint8))
gt_img.save(
os.path.join(output_folder, f"{name[i]}_gt.png"))
for i in range(condition_on.shape[0]):
denorm_condition_on = denormalize(condition_on.clone(), mean=dataset.mean, std=dataset.std)
tvu.save_image(
denorm_condition_on[i,] / 255.,
os.path.join(output_folder, f"{name[i]}_condition_on.png")
)
if isinstance(dataset, MonuDataset):
_, _, W, H = former_frame_for_feature_extraction.shape
kernel_size = dataset.image_size
stride = 256
patches = []
for y, x in np.ndindex((((W - kernel_size) // stride) + 1, ((H - kernel_size) // stride) + 1)):
y = y * stride
x = x * stride
patches.append(former_frame_for_feature_extraction[0,
:,
y: min(y + kernel_size, W),
x: min(x + kernel_size, H)])
patches = torch.stack(patches)
major_vote_list = []
for i in range(major_vote_number):
x_list = []
for index in range(math.ceil(patches.shape[0] / 4)):
model_kwargs = {"conditioned_image": patches[index * 4: min((index + 1) * 4, patches.shape[0])]}
x = diffusion_model.p_sample_loop(
ddp_model,
(model_kwargs["conditioned_image"].shape[0], gt_mask.shape[1], model_kwargs["conditioned_image"].shape[2], model_kwargs["conditioned_image"].shape[3]),
progress=True,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs
)
x_list.append(x)
out = torch.cat(x_list)
output = torch.zeros((former_frame_for_feature_extraction.shape[0], gt_mask.shape[1], former_frame_for_feature_extraction.shape[2], former_frame_for_feature_extraction.shape[3]))
idx_sum = torch.zeros((former_frame_for_feature_extraction.shape[0], gt_mask.shape[1], former_frame_for_feature_extraction.shape[2], former_frame_for_feature_extraction.shape[3]))
for index, val in enumerate(out):
y, x = np.unravel_index(index, (((W - kernel_size) // stride) + 1, ((H - kernel_size) // stride) + 1))
y = y * stride
x = x * stride
idx_sum[0,
:,
y: min(y + kernel_size, W),
x: min(x + kernel_size, H)] += 1
output[0,
:,
y: min(y + kernel_size, W),
x: min(x + kernel_size, H)] += val[:, :min(y + kernel_size, W) - y, :min(x + kernel_size, H) - x].cpu().data.numpy()
output = output / idx_sum
major_vote_list.append(output)
x = torch.cat(major_vote_list)
else:
model_kwargs = {
"conditioned_image": torch.cat([former_frame_for_feature_extraction] * major_vote_number)}
x = diffusion_model.p_sample_loop(
ddp_model,
(major_vote_number, gt_mask.shape[1], former_frame_for_feature_extraction.shape[2],
former_frame_for_feature_extraction.shape[3]),
progress=True,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs
)
x = (x + 1.0) / 2.0
if x.shape[2] != gt_mask.shape[2] or x.shape[3] != gt_mask.shape[3]:
x = F.interpolate(x, gt_mask.shape[2:], mode='bilinear')
x = torch.clamp(x, 0.0, 1.0)
# major vote result
x = x.mean(dim=0, keepdim=True).round()
for i in range(x.shape[0]):
# save as outer training ids
# current_output = x[i][0] + 1
# current_output[current_output == current_output.max()] = 0
out_img = Image.fromarray(x[i][0].detach().cpu().numpy().astype('uint8'))
out_img.putpalette(cityspallete)
out_img.save(
os.path.join(output_folder, f"{name[i]}_model_output_palette.png"))
out_img = Image.fromarray((x[i][0].detach().cpu().numpy() - 1).astype(np.uint8))
out_img.save(
os.path.join(output_folder, f"{name[i]}_model_output.png"))
for index, (gt_im, out_im) in enumerate(zip(gt_mask, x)):
f1, miou, wcov, fbound = calculate_metrics(out_im[0], gt_im[0])
f1_score_list.append(f1)
miou_list.append(miou)
wcov_list.append(wcov)
fbound_list.append(fbound)
logger.info(
f"{name[index]} iou {miou_list[-1]}, f1_Score {f1_score_list[-1]}, WCov {wcov_list[-1]}, boundF {fbound_list[-1]}")
my_length = len(miou_list)
length_of_data = torch.tensor(len(miou_list), device=dist_util.dev())
gathered_length_of_data = [torch.tensor(1, device=dist_util.dev()) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_length_of_data, length_of_data)
max_len = torch.max(torch.stack(gathered_length_of_data))
iou_tensor = torch.tensor(miou_list + [torch.tensor(-1)] * (max_len - my_length), device=dist_util.dev())
f1_tensor = torch.tensor(f1_score_list + [torch.tensor(-1)] * (max_len - my_length), device=dist_util.dev())
wcov_tensor = torch.tensor(wcov_list + [torch.tensor(-1)] * (max_len - my_length), device=dist_util.dev())
boundf_tensor = torch.tensor(fbound_list + [torch.tensor(-1)] * (max_len - my_length), device=dist_util.dev())
gathered_miou = [torch.ones_like(iou_tensor) * -1 for _ in range(dist.get_world_size())]
gathered_f1 = [torch.ones_like(f1_tensor) * -1 for _ in range(dist.get_world_size())]
gathered_wcov = [torch.ones_like(wcov_tensor) * -1 for _ in range(dist.get_world_size())]
gathered_boundf = [torch.ones_like(boundf_tensor) * -1 for _ in range(dist.get_world_size())]
dist.all_gather(gathered_miou, iou_tensor)
dist.all_gather(gathered_f1, f1_tensor)
dist.all_gather(gathered_wcov, wcov_tensor)
dist.all_gather(gathered_boundf, boundf_tensor)
# if dist.get_rank() == 0:
logger.info("measure total avg")
gathered_miou = torch.cat(gathered_miou)
gathered_miou = gathered_miou[gathered_miou != -1]
logger.info(f"mean iou {gathered_miou.mean()}")
gathered_f1 = torch.cat(gathered_f1)
gathered_f1 = gathered_f1[gathered_f1 != -1]
logger.info(f"mean f1 {gathered_f1.mean()}")
gathered_wcov = torch.cat(gathered_wcov)
gathered_wcov = gathered_wcov[gathered_wcov != -1]
logger.info(f"mean WCov {gathered_wcov.mean()}")
gathered_boundf = torch.cat(gathered_boundf)
gathered_boundf = gathered_boundf[gathered_boundf != -1]
logger.info(f"mean boundF {gathered_boundf.mean()}")
dist.barrier()
return gathered_miou.mean().item()
================================================
FILE: improved_diffusion/script_util.py
================================================
import argparse
import inspect
from . import gaussian_diffusion as gd
from .respace import SpacedDiffusion, space_timesteps
from .unet import SuperResModel, UNetModel
NUM_CLASSES = 1000
def model_and_diffusion_defaults():
"""
Defaults for image training.
"""
return dict(
image_size=64,
num_channels=128,
num_res_blocks=2,
num_heads=4,
num_heads_upsample=-1,
attention_resolutions="16,8",
dropout=0.0,
rrdb_blocks=10,
deeper_net=False,
learn_sigma=False,
sigma_small=False,
class_cond=False,
class_name="train",
expansion=False,
diffusion_steps=100,
noise_schedule="linear",
timestep_respacing="",
use_kl=False,
predict_xstart=False,
rescale_timesteps=True,
rescale_learned_sigmas=True,
use_checkpoint=False,
use_scale_shift_norm=True,
seed=None,
)
def create_model_and_diffusion(
image_size,
class_cond,
learn_sigma,
sigma_small,
num_channels,
num_res_blocks,
num_heads,
num_heads_upsample,
attention_resolutions,
dropout,
rrdb_blocks,
deeper_net,
class_name,
expansion,
diffusion_steps,
noise_schedule,
timestep_respacing,
use_kl,
predict_xstart,
rescale_timesteps,
rescale_learned_sigmas,
use_checkpoint,
use_scale_shift_norm,
seed,
):
_ = seed # hack to prevent unused variable
_ = expansion
_ = class_name
model = create_model(
image_size,
num_channels,
num_res_blocks,
learn_sigma=learn_sigma,
class_cond=class_cond,
use_checkpoint=use_checkpoint,
attention_resolutions=attention_resolutions,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
dropout=dropout,
rrdb_blocks=rrdb_blocks,
deeper_net=deeper_net
)
diffusion = create_gaussian_diffusion(
steps=diffusion_steps,
learn_sigma=learn_sigma,
sigma_small=sigma_small,
noise_schedule=noise_schedule,
use_kl=use_kl,
predict_xstart=predict_xstart,
rescale_timesteps=rescale_timesteps,
rescale_learned_sigmas=rescale_learned_sigmas,
timestep_respacing=timestep_respacing,
)
return model, diffusion
def create_model(
image_size,
num_channels,
num_res_blocks,
learn_sigma,
class_cond,
use_checkpoint,
attention_resolutions,
num_heads,
num_heads_upsample,
use_scale_shift_norm,
dropout,
rrdb_blocks,
deeper_net
):
if image_size == 256:
if deeper_net:
channel_mult = (1, 1, 1, 2, 2, 4, 4)
else:
channel_mult = (1, 1, 2, 2, 4, 4)
elif image_size == 128:
channel_mult = (1, 1, 2, 2, 4, 4)
elif image_size == 64:
channel_mult = (1, 2, 3, 4)
elif image_size == 32:
channel_mult = (1, 2, 2, 2)
else:
raise ValueError(f"unsupported image size: {image_size}")
attention_ds = []
for res in attention_resolutions.split(","):
attention_ds.append(image_size // int(res))
return UNetModel(
in_channels=1,
model_channels=num_channels,
out_channels=(1 if not learn_sigma else 2),
num_res_blocks=num_res_blocks,
attention_resolutions=tuple(attention_ds),
dropout=dropout,
channel_mult=channel_mult,
num_classes=(NUM_CLASSES if class_cond else None),
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
rrdb_blocks=rrdb_blocks
)
def sr_model_and_diffusion_defaults():
res = model_and_diffusion_defaults()
res["large_size"] = 256
res["small_size"] = 64
arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
for k in res.copy().keys():
if k not in arg_names:
del res[k]
return res
def sr_create_model_and_diffusion(
large_size,
small_size,
class_cond,
learn_sigma,
num_channels,
num_res_blocks,
num_heads,
num_heads_upsample,
attention_resolutions,
dropout,
rrdb_blocks,
deeper_net,
diffusion_steps,
noise_schedule,
timestep_respacing,
use_kl,
predict_xstart,
rescale_timesteps,
rescale_learned_sigmas,
use_checkpoint,
use_scale_shift_norm,
):
model = sr_create_model(
large_size,
small_size,
num_channels,
num_res_blocks,
learn_sigma=learn_sigma,
class_cond=class_cond,
use_checkpoint=use_checkpoint,
attention_resolutions=attention_resolutions,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
dropout=dropout,
rrdb_blocks=rrdb_blocks,
deeper_net=deeper_net,
)
diffusion = create_gaussian_diffusion(
steps=diffusion_steps,
learn_sigma=learn_sigma,
noise_schedule=noise_schedule,
use_kl=use_kl,
predict_xstart=predict_xstart,
rescale_timesteps=rescale_timesteps,
rescale_learned_sigmas=rescale_learned_sigmas,
timestep_respacing=timestep_respacing,
)
return model, diffusion
def sr_create_model(
large_size,
small_size,
num_channels,
num_res_blocks,
learn_sigma,
class_cond,
use_checkpoint,
attention_resolutions,
num_heads,
num_heads_upsample,
use_scale_shift_norm,
dropout,
rrdb_blocks,
deeper_net,
):
_ = small_size # hack to prevent unused variable
if large_size == 256:
if deeper_net:
channel_mult = (1, 1, 1, 2, 2, 4, 4)
else:
channel_mult = (1, 1, 2, 2, 4, 4)
elif large_size == 64:
channel_mult = (1, 2, 3, 4)
else:
raise ValueError(f"unsupported large size: {large_size}")
attention_ds = []
for res in attention_resolutions.split(","):
attention_ds.append(large_size // int(res))
return SuperResModel(
in_channels=1,
model_channels=num_channels,
out_channels=(1 if not learn_sigma else 2),
num_res_blocks=num_res_blocks,
attention_resolutions=tuple(attention_ds),
dropout=dropout,
channel_mult=channel_mult,
num_classes=(NUM_CLASSES if class_cond else None),
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
rrdb_blocks=rrdb_blocks,
)
def create_gaussian_diffusion(
*,
steps=1000,
learn_sigma=False,
sigma_small=False,
noise_schedule="linear",
use_kl=False,
predict_xstart=False,
rescale_timesteps=False,
rescale_learned_sigmas=False,
timestep_respacing="",
):
betas = gd.get_named_beta_schedule(noise_schedule, steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if not timestep_respacing:
timestep_respacing = [steps]
return SpacedDiffusion(
use_timesteps=space_timesteps(steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
),
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
rescale_timesteps=rescale_timesteps,
)
def add_dict_to_argparser(parser, default_dict):
for k, v in default_dict.items():
v_type = type(v)
if v is None:
v_type = str
elif isinstance(v, bool):
v_type = str2bool
parser.add_argument(f"--{k}", default=v, type=v_type)
def args_to_dict(args, keys):
return {k: getattr(args, k) for k in keys}
def str2bool(v):
"""
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("boolean value expected")
================================================
FILE: improved_diffusion/train_util.py
================================================
import copy
import functools
import os
from pathlib import Path
import blobfile as bf
import numpy as np
import torch as th
import torch.distributed as dist
from mpi4py import MPI
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW
from tqdm import tqdm
from . import dist_util, logger
from .fp16_util import (
make_master_params,
master_params_to_model_params,
model_grads_to_master_grads,
unflatten_master_params,
zero_grad,
)
from .nn import update_ema
from .resample import LossAwareSampler, UniformSampler
# For ImageNet experiments, this was a good default value.
# We found that the lg_loss_scale quickly climbed to
# 20-21 within the first ~1K steps of training.
from .sampling_util import sampling_major_vote_func
from .utils import set_random_seed_for_iterations
INITIAL_LOG_LOSS_SCALE = 20.0
class TrainLoop:
def __init__(
self,
*,
model,
diffusion,
data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
save_interval,
resume_checkpoint,
logger,
image_size,
val_dataset,
clip_denoised=True,
use_fp16=False,
fp16_scale_growth=1e-3,
schedule_sampler=None,
weight_decay=0.0,
lr_anneal_steps=0,
run_without_test=False,
args=None
):
self.model = model
self.diffusion = diffusion
self.data = data
self.batch_size = batch_size
self.microbatch = microbatch if microbatch > 0 else batch_size
self.lr = lr
self.args = args
self.ema_rate = (
[ema_rate]
if isinstance(ema_rate, float)
else [float(x) for x in ema_rate.split(",")]
)
self.log_interval = log_interval
self.save_interval = save_interval
self.resume_checkpoint = resume_checkpoint
self.use_fp16 = use_fp16
self.fp16_scale_growth = fp16_scale_growth
self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
self.weight_decay = weight_decay
self.lr_anneal_steps = lr_anneal_steps
self.step = 1
self.resume_step = 0
self.global_batch = self.batch_size * dist.get_world_size()
self.model_params = list(self.model.parameters())
self.master_params = self.model_params
self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
self.sync_cuda = th.cuda.is_available()
# if self.resume_checkpoint:
self._load_and_sync_parameters(self.resume_checkpoint)
if self.use_fp16:
self._setup_fp16()
self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
if self.resume_checkpoint:
self._load_optimizer_state(resume_checkpoint)
# Model was resumed, either due to a restart or a checkpoint
# being specified at the command line.
self.ema_params = [
self._load_ema_parameters(rate, resume_checkpoint) for rate in self.ema_rate
]
else:
self.ema_params = [
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
]
if th.cuda.is_available():
self.use_ddp = True
self.ddp_model = DDP(
self.model,
device_ids=[dist_util.dev()],
output_device=dist_util.dev(),
broadcast_buffers=False,
bucket_cap_mb=128,
find_unused_parameters=False,
)
self.ema_model = copy.deepcopy(self.model).to(th.device("cpu"))
else:
if dist.get_world_size() > 1:
logger.warn(
"Distributed training requires CUDA. "
"Gradients will not be synchronized properly!"
)
self.use_ddp = False
self.ddp_model = self.model
self.val_dataset = val_dataset
self.logger = logger
self.ema_val_best_iou = 0
self.val_best_iou = 0
self.clip_denoised = clip_denoised
self.val_current_model_name = ""
self.val_current_model_ema_name = ""
self.current_model_checkpoint_name = ""
self.run_without_test = run_without_test
def _load_and_sync_parameters(self, logs_path):
# resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
# model_checkpoint = bf.join(
# bf.dirname(logs_path), f"model.pt"
# )
logger.log(f"model folder path")
if logs_path:
if Path(logs_path).exists():
model_path = list(Path(logs_path).glob("model*.pt"))[0]
self.resume_step = parse_resume_step_from_filename(str(model_path))
self.step = self.resume_step
logger.log(f"loading model from checkpoint: {model_path} from step {self.step}...")
self.model.load_state_dict(
dist_util.load_state_dict(
str(model_path), map_location=dist_util.dev()
)
)
dist_util.sync_params(self.model.parameters())
def _load_ema_parameters(self, rate, logs_path):
ema_params = copy.deepcopy(self.master_params)
ema_checkpoint = Path(logs_path) / "ema.pt"
if ema_checkpoint.exists():
# if dist.get_rank() == 0:
logger.log(f"loading EMA from checkpoint: {str(ema_checkpoint)}...")
state_dict = dist_util.load_state_dict(
str(ema_checkpoint), map_location=dist_util.dev()
)
ema_params = self._state_dict_to_master_params(state_dict)
dist_util.sync_params(ema_params)
return ema_params
def _load_optimizer_state(self, logs_path):
opt_checkpoint = Path(logs_path) / "opt.pt"
if opt_checkpoint.exists():
logger.log(f"loading optimizer state from checkpoint: {str(opt_checkpoint)}")
state_dict = dist_util.load_state_dict(
str(opt_checkpoint), map_location=dist_util.dev()
)
self.opt.load_state_dict(state_dict)
def _setup_fp16(self):
self.master_params = make_master_params(self.model_params)
self.model.convert_to_fp16()
def run_loop(self, max_iter=250000, start_print_iter=100000, vis_batch_size=8, n_rounds=3):
if dist.get_rank() == 0:
pbar = tqdm()
while (
self.step < max_iter
):
self.ddp_model.train()
batch, cond, _ = next(self.data)
self.run_step(batch, cond)
if dist.get_rank() == 0:
pbar.update(1)
if self.step % self.log_interval == 0 and self.step != 0:
logger.log(f"interval")
logger.dumpkvs()
logger.log(f"class {self.args.class_name} lr {self.lr}, expansion {self.args.expansion}, "
f"rrdb blocks {self.args.rrdb_blocks} gpus {MPI.COMM_WORLD.Get_size()}")
if self.step % self.save_interval == 0:
logger.log(f"save model for checkpoint")
self.save_state_dict()
dist.barrier()
if self.step % self.save_interval == 0 and self.step >= start_print_iter or self.step == 60000:
if self.run_without_test:
if dist.get_rank() == 0:
self.save_checkpoint(self.ema_rate[0], self.ema_params[0], name=f"model")
else:
self.ddp_model.eval()
logger.log(f"ema sampling")
output_folder = os.path.join(os.environ["OPENAI_LOGDIR"], f"{self.step}_val_ema_major")
os.mkdir(output_folder)
self.ema_model = self.ema_model.to(dist_util.dev())
self.ema_model.load_state_dict(self._master_params_to_state_dict(self.ema_params[0]))
self.ema_model.eval()
ema_val_miou = sampling_major_vote_func(self.diffusion, self.ema_model, output_folder=output_folder,
dataset=self.val_dataset, logger=self.logger,
clip_denoised=self.clip_denoised, step=self.step, n_rounds=len(self.val_dataset))
self.ema_model = self.ema_model.to(th.device("cpu")) # release gpu memory
if dist.get_rank() == 0:
if self.ema_val_best_iou < ema_val_miou:
logger.log(f"best iou ema val: {ema_val_miou} step {self.step}")
self.ema_val_best_iou = ema_val_miou
ema_filename = self.save_checkpoint(self.ema_rate[0], self.ema_params[0], name=f"val_{ema_val_miou:.7f}")
if self.val_current_model_ema_name != "":
ckpt_path = bf.join(get_blob_logdir(), self.val_current_model_ema_name)
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
self.val_current_model_ema_name = ema_filename
set_random_seed_for_iterations(MPI.COMM_WORLD.Get_rank() + self.step)
dist.barrier()
self.step += 1
def run_step(self, batch, cond):
self.forward_backward(batch, cond)
if self.use_fp16:
self.optimize_fp16()
else:
self.optimize_normal()
self.log_step()
def forward_backward(self, batch, cond):
zero_grad(self.model_params)
for i in range(0, batch.shape[0], self.microbatch):
micro = batch[i : i + self.microbatch].to(dist_util.dev())
micro_cond = {
k: v[i : i + self.microbatch].to(dist_util.dev())
for k, v in cond.items()
}
last_batch = (i + self.microbatch) >= batch.shape[0]
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
compute_losses = functools.partial(
self.diffusion.training_losses,
self.ddp_model,
micro,
t,
model_kwargs=micro_cond,
)
if last_batch or not self.use_ddp:
losses = compute_losses()
else:
with self.ddp_model.no_sync():
losses = compute_losses()
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
t, losses["loss"].detach()
)
loss = (losses["loss"] * weights).mean()
log_loss_dict(
self.diffusion, t, {k: v * weights for k, v in losses.items()}
)
if self.use_fp16:
loss_scale = 2 ** self.lg_loss_scale
(loss * loss_scale).backward()
else:
loss.backward()
def optimize_fp16(self):
if any(not th.isfinite(p.grad).all() for p in self.model_params):
self.lg_loss_scale -= 1
logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
return
model_grads_to_master_grads(self.model_params, self.master_params)
self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
self._log_grad_norm()
self._anneal_lr()
self.opt.step()
for rate, params in zip(self.ema_rate, self.ema_params):
update_ema(params, self.master_params, rate=rate)
master_params_to_model_params(self.model_params, self.master_params)
self.lg_loss_scale += self.fp16_scale_growth
def optimize_normal(self):
self._log_grad_norm()
self._anneal_lr()
self.opt.step()
for rate, params in zip(self.ema_rate, self.ema_params):
update_ema(params, self.master_params, rate=rate)
def _log_grad_norm(self):
sqsum = 0.0
for p in self.master_params:
sqsum += (p.grad ** 2).sum().item()
logger.logkv_mean("grad_norm", np.sqrt(sqsum))
def _anneal_lr(self):
if not self.lr_anneal_steps:
return
frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
lr = self.lr * (1 - frac_done)
for param_group in self.opt.param_groups:
param_group["lr"] = lr
def log_step(self):
logger.logkv("step", self.step + self.resume_step)
logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
if self.use_fp16:
logger.logkv("lg_loss_scale", self.lg_loss_scale)
def save_checkpoint(self, rate, params, name):
state_dict = self._master_params_to_state_dict(params)
if dist.get_rank() == 0:
logger.log(f"saving model {rate}...")
if not rate:
filename = f"model_{name}_{(self.step+self.resume_step):06d}.pt"
else:
filename = f"ema_{name}_{rate}_{(self.step+self.resume_step):06d}.pt"
with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
th.save(state_dict, f)
return filename
def save_state_dict(self):
if dist.get_rank() == 0:
with bf.BlobFile(bf.join(get_blob_logdir(), f"opt.pt"), "wb",) as f:
th.save(self.opt.state_dict(), f)
with bf.BlobFile(bf.join(get_blob_logdir(), f"model{self.step}.pt"), "wb") as f:
th.save(self._master_params_to_state_dict(self.master_params), f)
if self.current_model_checkpoint_name != "":
ckpt_path = bf.join(get_blob_logdir(), self.current_model_checkpoint_name)
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
self.current_model_checkpoint_name = bf.join(get_blob_logdir(), f"model{self.step}.pt")
with bf.BlobFile(bf.join(get_blob_logdir(), f"ema.pt"), "wb") as f:
th.save(self._master_params_to_state_dict(self.ema_params[0]), f)
#
# checkpoint = {
# 'step': self.step,
# 'state_dict': self._master_params_to_state_dict(self.master_params),
# 'ema_state_dict': self._master_params_to_state_dict(self.ema_params[0]),
# 'optimizer': self.opt.state_dict()
# }
#
# current_model_checkpoint_name = bf.join(get_blob_logdir(), file_name)
# th.save(checkpoint, current_model_checkpoint_name)
#
# if self.current_model_checkpoint_name != "":
# ckpt_path = bf.join(get_blob_logdir(), self.current_model_checkpoint_name)
# if os.path.exists(ckpt_path):
# os.remove(ckpt_path)
#
# self.current_model_checkpoint_name = current_model_checkpoint_name
def save(self, name):
filename = self.save_checkpoint(0, self.master_params, name)
for rate, params in zip(self.ema_rate, self.ema_params):
filename_ema = self.save_checkpoint(rate, params, name)
# if dist.get_rank() == 0:
# with bf.BlobFile(
# bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
# "wb",
# ) as f:
# th.save(self.opt.state_dict(), f)
# dist.barrier()
return filename, filename_ema
def _master_params_to_state_dict(self, master_params):
if self.use_fp16:
master_params = unflatten_master_params(
list(self.model.parameters()), master_params
)
state_dict = self.model.state_dict()
for i, (name, _value) in enumerate(self.model.named_parameters()):
assert name in state_dict
state_dict[name] = master_params[i]
return state_dict
def _state_dict_to_master_params(self, state_dict):
params = [state_dict[name] for name, _ in self.model.named_parameters()]
if self.use_fp16:
return make_master_params(params)
else:
return params
def parse_resume_step_from_filename(filename):
"""
Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
checkpoint's number of steps.
"""
split = filename.split("model")
if len(split) < 2:
return 0
split1 = split[-1].split(".")[0]
try:
return int(split1)
except ValueError:
return 0
def get_blob_logdir():
return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir())
def find_resume_checkpoint():
# On your infrastructure, you may want to override this to automatically
# discover the latest checkpoint on your blob storage, etc.
return None
def find_ema_checkpoint(main_checkpoint, step, rate):
if main_checkpoint is None:
return None
filename = f"ema_{rate}_{(step):06d}.pt"
path = bf.join(bf.dirname(main_checkpoint), filename)
if bf.exists(path):
return path
return None
def log_loss_dict(diffusion, ts, losses):
for key, values in losses.items():
logger.logkv_mean(key, values.mean().item())
# Log the quantiles (four quartiles, in particular).
for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
quartile = int(4 * sub_t / diffusion.num_timesteps)
logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
================================================
FILE: improved_diffusion/unet.py
================================================
from abc import abstractmethod
import math
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from .RRDB import RRDBNet
from .fp16_util import convert_module_to_f16, convert_module_to_f32
from .nn import (
SiLU,
conv_nd,
linear,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
checkpoint,
)
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2):
super().__init__()
self.channels = channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, channels, channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2):
super().__init__()
self.channels = channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1)
else:
self.op = avg_pool_nd(stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(self, channels, num_heads=1, use_checkpoint=False):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttention()
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
h = self.attention(qkv)
h = h.reshape(b, -1, h.shape[-1])
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention.
"""
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
:return: an [N x C x T] tensor after attention.
"""
ch = qkv.shape[1] // 3
q, k, v = th.split(qkv, ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
return th.einsum("bts,bcs->bct", weight, v)
@staticmethod
def count_flops(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_c
gitextract_40nqd4bf/
├── .gitignore
├── README.md
├── datasets/
│ ├── city.py
│ ├── monu.py
│ ├── preprocess_vaihingen.py
│ ├── transforms.py
│ └── vaih.py
├── environment.yml
├── image_sample_diff_city.py
├── image_sample_diff_medical.py
├── image_sample_diff_vaih.py
├── image_train_diff_city.py
├── image_train_diff_medical.py
├── image_train_diff_vaih.py
└── improved_diffusion/
├── RRDB.py
├── __init__.py
├── dist_util.py
├── fp16_util.py
├── gaussian_diffusion.py
├── image_datasets.py
├── logger.py
├── losses.py
├── metrics.py
├── nn.py
├── resample.py
├── respace.py
├── sampling_util.py
├── script_util.py
├── train_util.py
├── unet.py
└── utils.py
SYMBOL INDEX (345 symbols across 27 files)
FILE: datasets/city.py
function create_dataset (line 21) | def create_dataset(mode="train", class_name="train", expansion=False):
function load_data (line 42) | def load_data(
class CityscapesInstances (line 76) | class CityscapesInstances(Dataset):
method __init__ (line 80) | def __init__(self,
method _poly2mask (line 168) | def _poly2mask(mask_ann, img_h, img_w):
method __len__ (line 183) | def __len__(self):
method __getitem__ (line 186) | def __getitem__(self, item):
function main (line 196) | def main():
FILE: datasets/monu.py
function cv2_loader (line 18) | def cv2_loader(path, is_mask):
function get_monu_transform (line 30) | def get_monu_transform(image_size):
function create_dataset (line 55) | def create_dataset(mode="train", image_size=256):
function load_data (line 65) | def load_data(
class MonuDataset (line 99) | class MonuDataset(torch.utils.data.Dataset):
method __init__ (line 100) | def __init__(self, root, transform=None, target_transform=None, train=...
method __getitem__ (line 134) | def __getitem__(self, index):
method __len__ (line 143) | def __len__(self):
FILE: datasets/preprocess_vaihingen.py
function get_img (line 10) | def get_img(cfile):
function get_mask (line 16) | def get_mask(cfile):
function main (line 24) | def main(args, out_path):
FILE: datasets/transforms.py
class Compose (line 43) | class Compose(object):
method __init__ (line 44) | def __init__(self, transforms):
method __call__ (line 47) | def __call__(self, img, mask):
class ToTensor (line 53) | class ToTensor(object):
method __call__ (line 54) | def __call__(self, img, mask):
class ToPILImage (line 61) | class ToPILImage(object):
method __init__ (line 62) | def __init__(self, mode=None):
method __call__ (line 65) | def __call__(self, img, mask):
class Normalize (line 69) | class Normalize(object):
method __init__ (line 70) | def __init__(self, mean, std, inplace=False):
method __call__ (line 75) | def __call__(self, img, mask):
class Resize (line 79) | class Resize(object):
method __init__ (line 80) | def __init__(self, size, interpolation=Image.BILINEAR, do_mask=True):
method __call__ (line 86) | def __call__(self, img, mask):
class CenterCrop (line 93) | class CenterCrop(object):
method __init__ (line 94) | def __init__(self, size):
method __call__ (line 100) | def __call__(self, img, mask):
class Pad (line 104) | class Pad(object):
method __init__ (line 105) | def __init__(self, padding, fill=0, padding_mode='constant'):
method __call__ (line 117) | def __call__(self, img, mask):
class Lambda (line 122) | class Lambda(object):
method __init__ (line 123) | def __init__(self, lambd):
method __call__ (line 127) | def __call__(self, img, mask):
class Lambda_image (line 131) | class Lambda_image(object):
method __init__ (line 132) | def __init__(self, lambd):
method __call__ (line 136) | def __call__(self, img, mask):
class RandomTransforms (line 140) | class RandomTransforms(object):
method __init__ (line 141) | def __init__(self, transforms):
method __call__ (line 145) | def __call__(self, *args, **kwargs):
class RandomApply (line 149) | class RandomApply(RandomTransforms):
method __init__ (line 150) | def __init__(self, transforms, p=0.5):
method __call__ (line 154) | def __call__(self, img, mask):
class RandomOrder (line 162) | class RandomOrder(RandomTransforms):
method __call__ (line 163) | def __call__(self, img, mask):
class RandomChoice (line 171) | class RandomChoice(RandomTransforms):
method __call__ (line 172) | def __call__(self, img, mask):
class RandomCrop (line 177) | class RandomCrop(object):
method __init__ (line 178) | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, pa...
method get_params (line 189) | def get_params(img, output_size):
method __call__ (line 199) | def __call__(self, img, mask):
class RandomHorizontalFlip (line 215) | class RandomHorizontalFlip(object):
method __init__ (line 216) | def __init__(self, p=0.5):
method __call__ (line 219) | def __call__(self, img, mask):
class RandomVerticalFlip (line 225) | class RandomVerticalFlip(object):
method __init__ (line 226) | def __init__(self, p=0.5):
method __call__ (line 229) | def __call__(self, img, mask):
class RandomPerspective (line 235) | class RandomPerspective(object):
method __init__ (line 236) | def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BI...
method __call__ (line 241) | def __call__(self, img, mask):
method get_params (line 253) | def get_params(width, height, distortion_scale):
class RandomResizedCrop (line 269) | class RandomResizedCrop(object):
method __init__ (line 270) | def __init__(self, size, mask_size, scale=(0.08, 1.0), ratio=(3. / 4.,...
method get_params (line 285) | def get_params(img, scale, ratio):
method __call__ (line 316) | def __call__(self, img, mask):
class FiveCrop (line 322) | class FiveCrop(object):
method __init__ (line 323) | def __init__(self, size):
method __call__ (line 331) | def __call__(self, img, mask):
class TenCrop (line 335) | class TenCrop(object):
method __init__ (line 336) | def __init__(self, size, vertical_flip=False):
method __call__ (line 345) | def __call__(self, img, mask):
class ColorJitter (line 349) | class ColorJitter(object):
method __init__ (line 350) | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
method _check_input (line 357) | def _check_input(self, value, name, center=1, bound=(0, float('inf')),...
method get_params (line 377) | def get_params(brightness, contrast, saturation, hue):
method __call__ (line 401) | def __call__(self, img, mask):
class RandomRotation (line 407) | class RandomRotation(object):
method __init__ (line 408) | def __init__(self, degrees, resample=False, expand=False, center=None):
method get_params (line 423) | def get_params(degrees):
method __call__ (line 428) | def __call__(self, img, mask):
class RandomAffine (line 435) | class RandomAffine(object):
method __init__ (line 436) | def __init__(self, degrees, translate=None, scale=None, shear=None, re...
method get_params (line 478) | def get_params(degrees, translate, scale_ranges, shears, img_size):
method __call__ (line 500) | def __call__(self, img, mask):
class RandomAffineFromSet (line 505) | class RandomAffineFromSet(object):
method __init__ (line 506) | def __init__(self, degrees, translate=None, scale=None, shear=None, re...
method get_params (line 543) | def get_params(degrees, translate, scale_ranges, shears, img_size):
method __call__ (line 565) | def __call__(self, img, mask):
FILE: datasets/vaih.py
function load_data (line 16) | def load_data(
class VaihDataset (line 54) | class VaihDataset(Dataset):
method __init__ (line 60) | def __init__(self, mode, std=np.array([0.22645572 * 255, 0.15276193 * ...
method __len__ (line 107) | def __len__(self):
method __getitem__ (line 110) | def __getitem__(self, item):
FILE: image_sample_diff_city.py
function main (line 27) | def main():
function create_argparser (line 70) | def create_argparser():
FILE: image_sample_diff_medical.py
function main (line 27) | def main():
function create_argparser (line 68) | def create_argparser():
FILE: image_sample_diff_vaih.py
function main (line 28) | def main():
function create_argparser (line 72) | def create_argparser():
FILE: image_train_diff_city.py
function main (line 29) | def main():
function create_argparser (line 121) | def create_argparser():
FILE: image_train_diff_medical.py
function main (line 29) | def main():
function create_argparser (line 119) | def create_argparser():
FILE: image_train_diff_vaih.py
function main (line 30) | def main():
function create_argparser (line 120) | def create_argparser():
FILE: improved_diffusion/RRDB.py
function make_layer (line 7) | def make_layer(block, n_layers):
class ResidualDenseBlock_5C (line 14) | class ResidualDenseBlock_5C(nn.Module):
method __init__ (line 15) | def __init__(self, nf=64, gc=32, bias=True):
method forward (line 28) | def forward(self, x):
class RRDB (line 37) | class RRDB(nn.Module):
method __init__ (line 40) | def __init__(self, nf=1, gc=32):
method forward (line 46) | def forward(self, x):
class RRDBNet (line 52) | class RRDBNet(nn.Module):
method __init__ (line 53) | def __init__(self, in_nc=3, out_nc=128, nf=64, nb=3, gc=32):
method forward (line 65) | def forward(self, x):
FILE: improved_diffusion/dist_util.py
function setup_dist (line 21) | def setup_dist():
function dev (line 44) | def dev():
function load_state_dict (line 53) | def load_state_dict(path, **kwargs):
function sync_params (line 66) | def sync_params(params):
function _find_free_port (line 75) | def _find_free_port():
FILE: improved_diffusion/fp16_util.py
function convert_module_to_f16 (line 9) | def convert_module_to_f16(l):
function convert_module_to_f32 (line 18) | def convert_module_to_f32(l):
function make_master_params (line 27) | def make_master_params(model_params):
function model_grads_to_master_grads (line 40) | def model_grads_to_master_grads(model_params, master_params):
function master_params_to_model_params (line 50) | def master_params_to_model_params(model_params, master_params):
function unflatten_master_params (line 64) | def unflatten_master_params(model_params, master_params):
function zero_grad (line 71) | def zero_grad(model_params):
FILE: improved_diffusion/gaussian_diffusion.py
function get_named_beta_schedule (line 18) | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
function betas_for_alpha_bar (line 45) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
class ModelMeanType (line 65) | class ModelMeanType(enum.Enum):
class ModelVarType (line 75) | class ModelVarType(enum.Enum):
class LossType (line 89) | class LossType(enum.Enum):
method is_vb (line 97) | def is_vb(self):
class GaussianDiffusion (line 101) | class GaussianDiffusion:
method __init__ (line 118) | def __init__(
method q_mean_variance (line 171) | def q_mean_variance(self, x_start, t):
method q_sample (line 188) | def q_sample(self, x_start, t, noise=None):
method q_posterior_mean_variance (line 208) | def q_posterior_mean_variance(self, x_start, x_t, t):
method p_mean_variance (line 232) | def p_mean_variance(
method _predict_xstart_from_eps (line 328) | def _predict_xstart_from_eps(self, x_t, t, eps):
method _predict_xstart_from_xprev (line 335) | def _predict_xstart_from_xprev(self, x_t, t, xprev):
method _predict_eps_from_xstart (line 345) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
method _scale_timesteps (line 351) | def _scale_timesteps(self, t):
method p_sample (line 356) | def p_sample(
method p_sample_loop (line 389) | def p_sample_loop(
method p_sample_loop_progressive (line 431) | def p_sample_loop_progressive(
method ddim_sample (line 479) | def ddim_sample(
method ddim_reverse_sample (line 524) | def ddim_reverse_sample(
method ddim_sample_loop (line 562) | def ddim_sample_loop(
method ddim_sample_loop_progressive (line 594) | def ddim_sample_loop_progressive(
method _vb_terms_bpd (line 642) | def _vb_terms_bpd(
method training_losses (line 677) | def training_losses(self, model, x_start, t, model_kwargs=None, noise=...
method _prior_bpd (line 753) | def _prior_bpd(self, x_start):
method calc_bpd_loop (line 771) | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwar...
function _extract_into_tensor (line 829) | def _extract_into_tensor(arr, timesteps, broadcast_shape):
FILE: improved_diffusion/image_datasets.py
function load_data (line 8) | def load_data(
function _list_image_files_recursively (line 56) | def _list_image_files_recursively(data_dir):
class ImageDataset (line 68) | class ImageDataset(Dataset):
method __init__ (line 69) | def __init__(self, resolution, image_paths, classes=None, shard=0, num...
method __len__ (line 75) | def __len__(self):
method __getitem__ (line 78) | def __getitem__(self, idx):
FILE: improved_diffusion/logger.py
class KVWriter (line 26) | class KVWriter(object):
method writekvs (line 27) | def writekvs(self, kvs):
class SeqWriter (line 31) | class SeqWriter(object):
method writeseq (line 32) | def writeseq(self, seq):
class HumanOutputFormat (line 36) | class HumanOutputFormat(KVWriter, SeqWriter):
method __init__ (line 37) | def __init__(self, filename_or_file):
method writekvs (line 48) | def writekvs(self, kvs):
method _truncate (line 80) | def _truncate(self, s):
method writeseq (line 84) | def writeseq(self, seq):
method close (line 93) | def close(self):
class JSONOutputFormat (line 98) | class JSONOutputFormat(KVWriter):
method __init__ (line 99) | def __init__(self, filename):
method writekvs (line 102) | def writekvs(self, kvs):
method close (line 109) | def close(self):
class CSVOutputFormat (line 113) | class CSVOutputFormat(KVWriter):
method __init__ (line 114) | def __init__(self, filename):
method writekvs (line 119) | def writekvs(self, kvs):
method close (line 146) | def close(self):
class TensorBoardOutputFormat (line 150) | class TensorBoardOutputFormat(KVWriter):
method __init__ (line 155) | def __init__(self, dir):
method writekvs (line 171) | def writekvs(self, kvs):
method close (line 185) | def close(self):
function make_output_format (line 191) | def make_output_format(format, ev_dir, log_suffix=""):
function logkv (line 212) | def logkv(key, val):
function logkv_mean (line 221) | def logkv_mean(key, val):
function logkvs (line 228) | def logkvs(d):
function dumpkvs (line 236) | def dumpkvs():
function getkvs (line 243) | def getkvs():
function log (line 247) | def log(*args, level=INFO):
function debug (line 254) | def debug(*args):
function info (line 258) | def info(*args):
function warn (line 262) | def warn(*args):
function error (line 266) | def error(*args):
function set_level (line 270) | def set_level(level):
function set_comm (line 277) | def set_comm(comm):
function get_dir (line 281) | def get_dir():
function profile_kv (line 294) | def profile_kv(scopename):
function profile (line 303) | def profile(n):
function get_current (line 325) | def get_current():
class Logger (line 332) | class Logger(object):
method __init__ (line 337) | def __init__(self, dir, output_formats, comm=None):
method logkv (line 347) | def logkv(self, key, val):
method logkv_mean (line 350) | def logkv_mean(self, key, val):
method dumpkvs (line 355) | def dumpkvs(self):
method log (line 376) | def log(self, *args, level=INFO):
method set_level (line 382) | def set_level(self, level):
method set_comm (line 385) | def set_comm(self, comm):
method get_dir (line 388) | def get_dir(self):
method close (line 391) | def close(self):
method _do_log (line 397) | def _do_log(self, args):
function get_rank_without_mpi_import (line 403) | def get_rank_without_mpi_import():
function mpi_weighted_mean (line 412) | def mpi_weighted_mean(comm, local_name2valcount):
function configure (line 442) | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
function _configure_default_logger (line 474) | def _configure_default_logger():
function reset (line 479) | def reset():
function scoped_configure (line 487) | def scoped_configure(dir=None, format_strs=None, comm=None):
FILE: improved_diffusion/losses.py
function normal_kl (line 12) | def normal_kl(mean1, logvar1, mean2, logvar2):
function approx_standard_normal_cdf (line 42) | def approx_standard_normal_cdf(x):
function discretized_gaussian_log_likelihood (line 50) | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
FILE: improved_diffusion/metrics.py
function WCov_metric (line 5) | def WCov_metric(pred, gt_mask):
function FBound_metric (line 12) | def FBound_metric(pred, gt_mask):
function db_eval_boundary (line 21) | def db_eval_boundary(foreground_mask, gt_mask, bound_th):
function seg2bmap (line 77) | def seg2bmap(seg, width=None, height=None):
FILE: improved_diffusion/nn.py
class SiLU (line 12) | class SiLU(nn.Module):
method forward (line 13) | def forward(self, x):
class GroupNorm32 (line 17) | class GroupNorm32(nn.GroupNorm):
method forward (line 18) | def forward(self, x):
function conv_nd (line 22) | def conv_nd(dims, *args, **kwargs):
function linear (line 35) | def linear(*args, **kwargs):
function avg_pool_nd (line 42) | def avg_pool_nd(dims, *args, **kwargs):
function update_ema (line 55) | def update_ema(target_params, source_params, rate=0.99):
function swap_ema (line 68) | def swap_ema(target_params, source_params):
function zero_module (line 82) | def zero_module(module):
function scale_module (line 91) | def scale_module(module, scale):
function mean_flat (line 100) | def mean_flat(tensor):
function normalization (line 107) | def normalization(channels):
function timestep_embedding (line 117) | def timestep_embedding(timesteps, dim, max_period=10000):
function checkpoint (line 138) | def checkpoint(func, inputs, params, flag):
class CheckpointFunction (line 156) | class CheckpointFunction(th.autograd.Function):
method forward (line 158) | def forward(ctx, run_function, length, *args):
method backward (line 167) | def backward(ctx, *output_grads):
FILE: improved_diffusion/resample.py
function create_named_schedule_sampler (line 8) | def create_named_schedule_sampler(name, diffusion):
class ScheduleSampler (line 23) | class ScheduleSampler(ABC):
method weights (line 35) | def weights(self):
method sample (line 42) | def sample(self, batch_size, device):
class UniformSampler (line 61) | class UniformSampler(ScheduleSampler):
method __init__ (line 62) | def __init__(self, diffusion):
method weights (line 66) | def weights(self):
class LossAwareSampler (line 70) | class LossAwareSampler(ScheduleSampler):
method update_with_local_losses (line 71) | def update_with_local_losses(self, local_ts, local_losses):
method update_with_all_losses (line 107) | def update_with_all_losses(self, ts, losses):
class LossSecondMomentResampler (line 124) | class LossSecondMomentResampler(LossAwareSampler):
method __init__ (line 125) | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
method weights (line 134) | def weights(self):
method update_with_all_losses (line 143) | def update_with_all_losses(self, ts, losses):
method _warmed_up (line 153) | def _warmed_up(self):
FILE: improved_diffusion/respace.py
function space_timesteps (line 7) | def space_timesteps(num_timesteps, section_counts):
class SpacedDiffusion (line 63) | class SpacedDiffusion(GaussianDiffusion):
method __init__ (line 72) | def __init__(self, use_timesteps, **kwargs):
method p_mean_variance (line 88) | def p_mean_variance(
method training_losses (line 93) | def training_losses(
method _wrap_model (line 98) | def _wrap_model(self, model):
method _scale_timesteps (line 105) | def _scale_timesteps(self, t):
class _WrappedModel (line 110) | class _WrappedModel:
method __init__ (line 111) | def __init__(self, model, timestep_map, rescale_timesteps, original_nu...
method __call__ (line 117) | def __call__(self, x, ts, **kwargs):
FILE: improved_diffusion/sampling_util.py
function calculate_metrics (line 44) | def calculate_metrics(x, gt):
function sampling_major_vote_func (line 51) | def sampling_major_vote_func(diffusion_model, ddp_model, output_folder, ...
FILE: improved_diffusion/script_util.py
function model_and_diffusion_defaults (line 11) | def model_and_diffusion_defaults():
function create_model_and_diffusion (line 43) | def create_model_and_diffusion(
function create_model (line 101) | def create_model(
function sr_model_and_diffusion_defaults (line 151) | def sr_model_and_diffusion_defaults():
function sr_create_model_and_diffusion (line 162) | def sr_create_model_and_diffusion(
function sr_create_model (line 214) | def sr_create_model(
function create_gaussian_diffusion (line 263) | def create_gaussian_diffusion(
function add_dict_to_argparser (line 304) | def add_dict_to_argparser(parser, default_dict):
function args_to_dict (line 314) | def args_to_dict(args, keys):
function str2bool (line 318) | def str2bool(v):
FILE: improved_diffusion/train_util.py
class TrainLoop (line 34) | class TrainLoop:
method __init__ (line 35) | def __init__(
method _load_and_sync_parameters (line 140) | def _load_and_sync_parameters(self, logs_path):
method _load_ema_parameters (line 163) | def _load_ema_parameters(self, rate, logs_path):
method _load_optimizer_state (line 179) | def _load_optimizer_state(self, logs_path):
method _setup_fp16 (line 190) | def _setup_fp16(self):
method run_loop (line 194) | def run_loop(self, max_iter=250000, start_print_iter=100000, vis_batch...
method run_step (line 252) | def run_step(self, batch, cond):
method forward_backward (line 260) | def forward_backward(self, batch, cond):
method optimize_fp16 (line 300) | def optimize_fp16(self):
method optimize_normal (line 316) | def optimize_normal(self):
method _log_grad_norm (line 323) | def _log_grad_norm(self):
method _anneal_lr (line 329) | def _anneal_lr(self):
method log_step (line 337) | def log_step(self):
method save_checkpoint (line 343) | def save_checkpoint(self, rate, params, name):
method save_state_dict (line 355) | def save_state_dict(self):
method save (line 391) | def save(self, name):
method _master_params_to_state_dict (line 408) | def _master_params_to_state_dict(self, master_params):
method _state_dict_to_master_params (line 419) | def _state_dict_to_master_params(self, state_dict):
function parse_resume_step_from_filename (line 427) | def parse_resume_step_from_filename(filename):
function get_blob_logdir (line 442) | def get_blob_logdir():
function find_resume_checkpoint (line 446) | def find_resume_checkpoint():
function find_ema_checkpoint (line 452) | def find_ema_checkpoint(main_checkpoint, step, rate):
function log_loss_dict (line 462) | def log_loss_dict(diffusion, ts, losses):
FILE: improved_diffusion/unet.py
class TimestepBlock (line 24) | class TimestepBlock(nn.Module):
method forward (line 30) | def forward(self, x, emb):
class TimestepEmbedSequential (line 36) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
method forward (line 42) | def forward(self, x, emb):
class Upsample (line 51) | class Upsample(nn.Module):
method __init__ (line 61) | def __init__(self, channels, use_conv, dims=2):
method forward (line 69) | def forward(self, x):
class Downsample (line 82) | class Downsample(nn.Module):
method __init__ (line 92) | def __init__(self, channels, use_conv, dims=2):
method forward (line 103) | def forward(self, x):
class ResBlock (line 108) | class ResBlock(TimestepBlock):
method __init__ (line 123) | def __init__(
method forward (line 173) | def forward(self, x, emb):
method _forward (line 185) | def _forward(self, x, emb):
class AttentionBlock (line 201) | class AttentionBlock(nn.Module):
method __init__ (line 209) | def __init__(self, channels, num_heads=1, use_checkpoint=False):
method forward (line 220) | def forward(self, x):
method _forward (line 223) | def _forward(self, x):
class QKVAttention (line 234) | class QKVAttention(nn.Module):
method forward (line 239) | def forward(self, qkv):
method count_flops (line 256) | def count_flops(model, _x, y):
class UNetModel (line 279) | class UNetModel(nn.Module):
method __init__ (line 302) | def __init__(
method convert_to_fp16 (line 441) | def convert_to_fp16(self):
method convert_to_fp32 (line 450) | def convert_to_fp32(self):
method inner_dtype (line 460) | def inner_dtype(self):
method forward (line 466) | def forward(self, x, timesteps, y=None, conditioned_image=None):
method get_feature_vectors (line 499) | def get_feature_vectors(self, x, timesteps, y=None):
class SuperResModel (line 532) | class SuperResModel(UNetModel):
method __init__ (line 539) | def __init__(self, in_channels, *args, **kwargs):
method forward (line 542) | def forward(self, x, timesteps, low_res=None, **kwargs):
method get_feature_vectors (line 548) | def get_feature_vectors(self, x, timesteps, low_res=None, **kwargs):
FILE: improved_diffusion/utils.py
function set_random_seed (line 7) | def set_random_seed(seed, deterministic=False):
function set_random_seed_for_iterations (line 25) | def set_random_seed_for_iterations(seed):
Condensed preview — 31 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (217K chars).
[
{
"path": ".gitignore",
"chars": 1334,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "README.md",
"chars": 4366,
"preview": "This is the official repository of the paper [SegDiff: Image Segmentation with Diffusion Probabilistic Models](https://a"
},
{
"path": "datasets/city.py",
"chars": 8035,
"preview": "import json\nimport os\nimport random\nfrom pathlib import Path\n\nimport h5py\nimport numpy as np\nimport pycocotools.mask as "
},
{
"path": "datasets/monu.py",
"chars": 6623,
"preview": "import os\nfrom pathlib import Path\n\nimport imageio\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport tifffile\nim"
},
{
"path": "datasets/preprocess_vaihingen.py",
"chars": 2446,
"preview": "from pathlib import Path\n\nimport h5py\nimport os\nimport cv2\nimport numpy as np\nfrom cv2 import resize\n\n\ndef get_img(cfile"
},
{
"path": "datasets/transforms.py",
"chars": 20786,
"preview": "from __future__ import division\nimport torch\nimport math\nimport sys\nimport random\nfrom PIL import Image\n\ntry:\n import"
},
{
"path": "datasets/vaih.py",
"chars": 5738,
"preview": "from pathlib import Path\n\nimport h5py\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom matplotlib im"
},
{
"path": "environment.yml",
"chars": 487,
"preview": "name: segdiff\nchannels:\n - anaconda\n - pytorch\n - conda-forge\n - defaults\ndependencies:\n - python=3.8.12\n - pip=21"
},
{
"path": "image_sample_diff_city.py",
"chars": 2528,
"preview": "\"\"\"\nGenerate a large batch of image samples from a model and save them as a large\nnumpy array. This can be used to produ"
},
{
"path": "image_sample_diff_medical.py",
"chars": 2458,
"preview": "\"\"\"\nGenerate a large batch of image samples from a model and save them as a large\nnumpy array. This can be used to produ"
},
{
"path": "image_sample_diff_vaih.py",
"chars": 2602,
"preview": "\"\"\"\nGenerate a large batch of image samples from a model and save them as a large\nnumpy array. This can be used to produ"
},
{
"path": "image_train_diff_city.py",
"chars": 4727,
"preview": "\"\"\"\nTrain a diffusion model on images.\n\"\"\"\n\nimport argparse\nimport datetime\nimport json\nimport os\nfrom pathlib import Pa"
},
{
"path": "image_train_diff_medical.py",
"chars": 4651,
"preview": "\"\"\"\nTrain a diffusion model on images.\n\"\"\"\n\nimport argparse\nimport datetime\nimport json\nimport os\nfrom pathlib import Pa"
},
{
"path": "image_train_diff_vaih.py",
"chars": 4660,
"preview": "\"\"\"\nTrain a diffusion model on images.\n\"\"\"\n\nimport argparse\nimport datetime\nimport json\nimport os\nfrom pathlib import Pa"
},
{
"path": "improved_diffusion/RRDB.py",
"chars": 2527,
"preview": "import functools\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef make_layer(block, n_layers):\n "
},
{
"path": "improved_diffusion/__init__.py",
"chars": 74,
"preview": "\"\"\"\nCodebase for \"Improved Denoising Diffusion Probabilistic Models\".\n\"\"\"\n"
},
{
"path": "improved_diffusion/dist_util.py",
"chars": 1959,
"preview": "\"\"\"\nHelpers for distributed training.\n\"\"\"\n\nimport io\nimport os\nimport socket\n\nimport blobfile as bf\nfrom mpi4py import M"
},
{
"path": "improved_diffusion/fp16_util.py",
"chars": 2282,
"preview": "\"\"\"\nHelpers to train with 16-bit precision.\n\"\"\"\n\nimport torch.nn as nn\nfrom torch._utils import _flatten_dense_tensors, "
},
{
"path": "improved_diffusion/gaussian_diffusion.py",
"chars": 32020,
"preview": "\"\"\"\nThis code started out as a PyTorch port of Ho et al's diffusion models:\nhttps://github.com/hojonathanho/diffusion/bl"
},
{
"path": "improved_diffusion/image_datasets.py",
"chars": 4118,
"preview": "from PIL import Image\nimport blobfile as bf\nfrom mpi4py import MPI\nimport numpy as np\nfrom torch.utils.data import DataL"
},
{
"path": "improved_diffusion/logger.py",
"chars": 14044,
"preview": "\"\"\"\nLogger copied from OpenAI baselines to avoid extra RL-based dependencies:\nhttps://github.com/openai/baselines/blob/e"
},
{
"path": "improved_diffusion/losses.py",
"chars": 2534,
"preview": "\"\"\"\nHelpers for various likelihood-based losses. These are ported from the original\nHo et al. diffusion models codebase:"
},
{
"path": "improved_diffusion/metrics.py",
"chars": 3970,
"preview": "import numpy as np\nfrom skimage.morphology import binary_dilation, disk\n\n\ndef WCov_metric(pred, gt_mask):\n A1 = float"
},
{
"path": "improved_diffusion/nn.py",
"chars": 5462,
"preview": "\"\"\"\nVarious utilities for neural networks.\n\"\"\"\n\nimport math\n\nimport torch as th\nimport torch.nn as nn\n\n\n# PyTorch 1.7 ha"
},
{
"path": "improved_diffusion/resample.py",
"chars": 5689,
"preview": "from abc import ABC, abstractmethod\n\nimport numpy as np\nimport torch as th\nimport torch.distributed as dist\n\n\ndef create"
},
{
"path": "improved_diffusion/respace.py",
"chars": 4913,
"preview": "import numpy as np\nimport torch as th\n\nfrom .gaussian_diffusion import GaussianDiffusion\n\n\ndef space_timesteps(num_times"
},
{
"path": "improved_diffusion/sampling_util.py",
"chars": 9754,
"preview": "import math\nimport os\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\n"
},
{
"path": "improved_diffusion/script_util.py",
"chars": 8684,
"preview": "import argparse\nimport inspect\n\nfrom . import gaussian_diffusion as gd\nfrom .respace import SpacedDiffusion, space_times"
},
{
"path": "improved_diffusion/train_util.py",
"chars": 17669,
"preview": "import copy\nimport functools\nimport os\nfrom pathlib import Path\n\nimport blobfile as bf\nimport numpy as np\nimport torch a"
},
{
"path": "improved_diffusion/unet.py",
"chars": 19220,
"preview": "from abc import abstractmethod\n\nimport math\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn"
},
{
"path": "improved_diffusion/utils.py",
"chars": 1124,
"preview": "import random\n\nimport numpy as np\nimport torch\n\n\ndef set_random_seed(seed, deterministic=False):\n \"\"\"Set random seed."
}
]
About this extraction
This page contains the full source code of the tomeramit/SegDiff GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 31 files (202.6 KB), approximately 50.8k tokens, and a symbol index with 345 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.