Repository: stevliu/self-conditioned-gan
Branch: master
Commit: a12fc3a99876
Files: 119
Total size: 375.7 KB
Directory structure:
gitextract_vn0la05j/
├── .gitignore
├── 2d_mix/
│ ├── .gitignore
│ ├── __init__.py
│ ├── config.py
│ ├── evaluation.py
│ ├── inputs.py
│ ├── models/
│ │ ├── __init__.py
│ │ └── cluster.py
│ ├── train.py
│ └── visualizations.py
├── LICENSE
├── README.md
├── cluster_metrics.py
├── clusterers/
│ ├── __init__.py
│ ├── base_clusterer.py
│ ├── kmeans.py
│ ├── online.py
│ ├── random_labels.py
│ └── selfcondgan.py
├── configs/
│ ├── cifar/
│ │ ├── conditional.yaml
│ │ ├── default.yaml
│ │ ├── selfcondgan.yaml
│ │ └── unconditional.yaml
│ ├── default.yaml
│ ├── imagenet/
│ │ ├── conditional.yaml
│ │ ├── default.yaml
│ │ ├── selfcondgan.yaml
│ │ └── unconditional.yaml
│ ├── places/
│ │ ├── conditional.yaml
│ │ ├── default.yaml
│ │ ├── selfcondgan.yaml
│ │ └── unconditional.yaml
│ ├── pretrained/
│ │ ├── imagenet/
│ │ │ ├── conditional.yaml
│ │ │ ├── selfcondgan.yaml
│ │ │ └── unconditional.yaml
│ │ └── places/
│ │ ├── conditional.yaml
│ │ ├── selfcondgan.yaml
│ │ └── unconditional.yaml
│ └── stacked_mnist/
│ ├── conditional.yaml
│ ├── default.yaml
│ ├── selfcondgan.yaml
│ └── unconditional.yaml
├── gan_training/
│ ├── __init__.py
│ ├── checkpoints.py
│ ├── config.py
│ ├── distributions.py
│ ├── eval.py
│ ├── inputs.py
│ ├── logger.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ ├── clustering_metrics.py
│ │ ├── fid.py
│ │ ├── inception_score.py
│ │ └── tf_is/
│ │ ├── LICENSE
│ │ ├── README.md
│ │ └── inception_score.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── blocks.py
│ │ ├── dcgan_deep.py
│ │ ├── dcgan_shallow.py
│ │ ├── resnet2.py
│ │ ├── resnet2s.py
│ │ └── resnet3.py
│ ├── train.py
│ └── utils.py
├── metrics.py
├── requirements.txt
├── seeded_sampler.py
├── seeing/
│ ├── frechet_distance.py
│ ├── fsd.py
│ ├── lightbox.html
│ ├── parallelfolder.py
│ ├── pbar.py
│ ├── pidfile.py
│ ├── sampler.py
│ ├── segmenter.py
│ ├── upsegmodel/
│ │ ├── __init__.py
│ │ ├── models.py
│ │ ├── prroi_pool/
│ │ │ ├── .gitignore
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── build.py
│ │ │ ├── functional.py
│ │ │ ├── prroi_pool.py
│ │ │ ├── src/
│ │ │ │ ├── prroi_pooling_gpu.c
│ │ │ │ ├── prroi_pooling_gpu.h
│ │ │ │ ├── prroi_pooling_gpu_impl.cu
│ │ │ │ └── prroi_pooling_gpu_impl.cuh
│ │ │ └── test_prroi_pooling2d.py
│ │ ├── resnet.py
│ │ └── resnext.py
│ ├── yz_dataset.py
│ └── zdataset.py
├── train.py
├── utils/
│ ├── classifiers/
│ │ ├── __init__.py
│ │ ├── cifar.py
│ │ ├── imagenet.py
│ │ ├── imagenet_class_index.json
│ │ ├── places.py
│ │ ├── pytorch_playground/
│ │ │ ├── .gitignore
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── cifar/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dataset.py
│ │ │ │ ├── model.py
│ │ │ │ └── train.py
│ │ │ ├── quantize.py
│ │ │ ├── requirements.txt
│ │ │ ├── roadmap_zh.md
│ │ │ ├── setup.py
│ │ │ └── utee/
│ │ │ ├── __init__.py
│ │ │ ├── misc.py
│ │ │ ├── quant.py
│ │ │ └── selector.py
│ │ └── stacked_mnist.py
│ ├── get_empirical_distribution.py
│ ├── get_gt_imgs.py
│ └── np_to_pt_img.py
└── visualize_clusters.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*/**/*pyc*
*/**/.DS_Store
.vscode
================================================
FILE: 2d_mix/.gitignore
================================================
**.png
**.pyc
**.pt
output/
================================================
FILE: 2d_mix/__init__.py
================================================
================================================
FILE: 2d_mix/config.py
================================================
import torch
from models import generator_dict, discriminator_dict
from torch import optim
import torch.utils.data as utils
def get_models(model_type, conditioning, k_value, d_act_dim, device):
G = generator_dict[model_type]
D = discriminator_dict[model_type]
generator = G(conditioning, k_value=k_value)
discriminator = D(conditioning, k_value=k_value, act_dim=d_act_dim)
generator.to(device)
discriminator.to(device)
return generator, discriminator
def get_optimizers(generator, discriminator, lr=1e-4, beta1=0.8, beta2=0.999):
g_optimizer = optim.Adam(generator.parameters(),
lr=lr,
betas=(beta1, beta2))
d_optimizer = optim.Adam(discriminator.parameters(),
lr=lr,
betas=(beta1, beta2))
return g_optimizer, d_optimizer
def get_test(get_data, batch_size, variance, k_value, device):
x_test, y_test = get_data(batch_size, var=variance)
x_test, y_test = torch.from_numpy(x_test).float().to(
device), torch.from_numpy(y_test).long().to(device)
return x_test, y_test
def get_dataset(get_data, batch_size, npts, variance, k_value):
samples, labels = get_data(npts, var=variance)
tensor_samples = torch.stack([torch.Tensor(x) for x in samples])
tensor_labels = torch.stack([torch.tensor(x) for x in labels])
dataset = utils.TensorDataset(tensor_samples, tensor_labels)
train_loader = utils.DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
pin_memory=True,
sampler=None,
drop_last=True)
return train_loader
================================================
FILE: 2d_mix/evaluation.py
================================================
def warn(*args, **kwargs):
pass
import warnings
warnings.warn = warn
import numpy as np
def percent_good_grid(x_fake, var=0.0025, nrows=5, ncols=5):
std = np.sqrt(var)
x = list(range(nrows))
y = list(range(ncols))
threshold = 3 * std
means = []
for i in x:
for j in y:
means.append(np.array([x[i] * 2 - 4, y[j] * 2 - 4]))
return percent_good_pts(x_fake, means, threshold)
def percent_good_ring(x_fake, var=0.0001, n_clusters=8, radius=2.0):
std = np.sqrt(var)
thetas = np.linspace(0, 2 * np.pi, n_clusters + 1)[:n_clusters]
x, y = radius * np.sin(thetas), radius * np.cos(thetas)
threshold = np.array([std * 3, std * 3])
means = []
for i in range(n_clusters):
means.append(np.array([x[i], y[i]]))
return percent_good_pts(x_fake, means, threshold)
def percent_good_pts(x_fake, means, threshold):
"""Calculate %good, #modes, kl
Keyword arguments:
x_fake -- detached generated samples
means -- true means
threshold -- good point if l_1 distance is within threshold
"""
count = 0
counts = np.zeros(len(means))
visited = set()
for point in x_fake:
minimum = 0
diff_minimum = [1e10, 1e10]
for i, mean in enumerate(means):
diff = np.abs(point - mean)
if np.all(diff < threshold):
visited.add(tuple(mean))
count += 1
break
for i, mean in enumerate(means):
diff = np.abs(point - mean)
if np.linalg.norm(diff) < np.linalg.norm(diff_minimum):
minimum = i
diff_minimum = diff
counts[minimum] += 1
kl = 0
counts = counts / len(x_fake)
for generated in counts:
if generated != 0:
kl += generated * np.log(len(means) * generated)
return count / len(x_fake), len(visited), kl
================================================
FILE: 2d_mix/inputs.py
================================================
import numpy as np
import random
mapping = list(range(25))
def map_labels(labels):
return np.array([mapping[label] for label in labels])
def get_data_ring(batch_size, radius=2.0, var=0.0001, n_clusters=8):
thetas = np.linspace(0, 2 * np.pi, n_clusters + 1)[:n_clusters]
xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
classes = np.random.multinomial(batch_size,
[1.0 / n_clusters] * n_clusters)
labels = [i for i in range(n_clusters) for _ in range(classes[i])]
random.shuffle(labels)
labels = np.array(labels)
samples = np.array([
np.random.multivariate_normal([xs[i], ys[i]], [[var, 0], [0, var]])
for i in labels
])
return samples, labels
def get_data_grid(batch_size, radius=2.0, var=0.0025, nrows=5, ncols=5):
samples = []
labels = []
for _ in range(batch_size):
i, j = random.randint(0, ncols - 1), random.randint(0, nrows - 1)
samples.append(
np.random.multivariate_normal([i * 2 - 4, j * 2 - 4],
[[var, 0], [0, var]]))
labels.append(5 * i + j)
return np.array(samples), map_labels(labels)
================================================
FILE: 2d_mix/models/__init__.py
================================================
from models import (cluster)
generator_dict = {'standard': cluster.G}
discriminator_dict = {'standard': cluster.D}
================================================
FILE: 2d_mix/models/cluster.py
================================================
import sys
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
sys.path.append('../gan_training/models')
from blocks import LatentEmbeddingConcat, Identity, LinearUnconditionalLogits, LinearConditionalMaskLogits
class G(nn.Module):
def __init__(self,
conditioning,
k_value,
z_dim=2,
embed_size=32,
act_dim=400,
x_dim=2):
super().__init__()
if conditioning == 'unconditional':
embed_size = 0
self.embedding = Identity()
elif conditioning == 'conditional':
self.embedding = LatentEmbeddingConcat(k_value, embed_size)
else:
raise NotImplementedError()
self.fc1 = nn.Sequential(nn.Linear(z_dim + embed_size, act_dim),
nn.BatchNorm1d(act_dim), nn.ReLU(True))
self.fc2 = nn.Sequential(nn.Linear(act_dim, act_dim),
nn.BatchNorm1d(act_dim), nn.ReLU(True))
self.fc3 = nn.Sequential(nn.Linear(act_dim, act_dim),
nn.BatchNorm1d(act_dim), nn.ReLU(True))
self.fc4 = nn.Sequential(nn.Linear(act_dim, act_dim),
nn.BatchNorm1d(act_dim), nn.ReLU(True))
self.fc_out = nn.Linear(act_dim, x_dim)
def forward(self, z, y=None):
out = self.fc1(self.embedding(z, y))
out = self.fc2(out)
out = self.fc3(out)
out = self.fc4(out)
out = self.fc_out(out)
return out
class D(nn.Module):
class Maxout(nn.Module):
# Taken from https://github.com/pytorch/pytorch/issues/805
def __init__(self, d_in, d_out, pool_size=5):
super().__init__()
self.d_in, self.d_out, self.pool_size = d_in, d_out, pool_size
self.lin = nn.Linear(d_in, d_out * pool_size)
def forward(self, inputs):
shape = list(inputs.size())
shape[-1] = self.d_out
shape.append(self.pool_size)
max_dim = len(shape) - 1
out = self.lin(inputs)
m, i = out.view(*shape).max(max_dim)
return m
def max(self, out, dim=5):
return out.view(out.size(0), -1, dim).max(2)[0]
def __init__(self, conditioning, k_value, act_dim=200, x_dim=2):
super().__init__()
self.fc1 = self.Maxout(x_dim, act_dim)
self.fc2 = self.Maxout(act_dim, act_dim)
self.fc3 = self.Maxout(act_dim, act_dim)
if conditioning == 'unconditional':
self.fc4 = LinearUnconditionalLogits(act_dim)
elif conditioning == 'conditional':
self.fc4 = LinearConditionalMaskLogits(act_dim, k_value)
else:
raise NotImplementedError()
def forward(self, x, y=None, get_features=False):
out = self.fc1(x)
out = self.fc2(out)
out = self.fc3(out)
if get_features: return out
return self.fc4(out, y, get_features=get_features)
================================================
FILE: 2d_mix/train.py
================================================
import argparse
import os
import sys
import torch
from torch import optim
from torch import distributions
from torch import nn
import torch.nn.functional as F
import numpy as np
import evaluation
import inputs
from config import get_models, get_optimizers, get_test, get_dataset
from visualizations import (visualize_generated, visualize_clusters)
sys.path.append('../')
from clusterers import clusterer_dict
from gan_training.train import Trainer
sys.path.append('../seeing/')
import pidfile
parser = argparse.ArgumentParser(description='2d dataset experiments')
parser.add_argument('--clusterer', help='type of clusterer to use. cluster specifies selfcondgan')
parser.add_argument('--data_type', help='either grid or ring')
parser.add_argument('--recluster_every', type=int, default=5000, help='how frequently to recluster')
parser.add_argument('--nruns', type=int, default=1, help='number of trials to do')
parser.add_argument('--burnin_time', type=int, default=0, help='wait this amount of iterations before clustering')
parser.add_argument('--variance', type=float, default=None, help='variance of the gaussians')
parser.add_argument('--model_type', type=str, default='standard', help='model architecture')
parser.add_argument('--num_clusters', type=int, default=50, help='number of clusters to use for selfcondgan')
parser.add_argument('--z_dim', type=int, default=2, help='G latent dim')
parser.add_argument('--d_act_dim', type=int, default=200, help='hidden layer width')
parser.add_argument('--npts', type=int, default=100000, help='number of points to use in dataset')
parser.add_argument('--train_batch_size', type=int, default=100, help='training time batch size')
parser.add_argument('--test_batch_size', type=int, default=50000, help='number of examples to get metrics with')
parser.add_argument('--nepochs', type=int, default=100, help='number of epochs to run')
parser.add_argument('--outdir', default='output')
args = parser.parse_args()
data_type = args.data_type
k_value = 8 if data_type == 'ring' else 25
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_clusters = k_value if args.clusterer == 'supervised' else args.num_clusters
exp_name = f'{args.data_type}_{args.clusterer}_{args.recluster_every}_{num_clusters}/'
if args.model_type != 'standard':
exp_name = f'{args.model_type}_{exp_name}'
if args.variance is not None:
exp_name = f'{args.variance}_{exp_name}'
if args.variance is None:
variance = 0.0025 if data_type == 'grid' else 0.0001
else:
variance = args.variance
nepochs = args.nepochs
z_dim = args.z_dim
test_batch_size = args.test_batch_size
train_batch_size = args.train_batch_size
npts = args.npts
def main(outdir):
for subdir in ['all', 'snapshots', 'clusters']:
if not os.path.exists(os.path.join(outdir, subdir)):
os.makedirs(os.path.join(outdir, subdir), exist_ok=True)
if data_type == 'grid':
get_data = inputs.get_data_grid
percent_good = evaluation.percent_good_grid
elif data_type == 'ring':
get_data = inputs.get_data_ring
percent_good = evaluation.percent_good_ring
else:
raise NotImplementedError()
zdist = distributions.Normal(torch.zeros(z_dim, device=device),
torch.ones(z_dim, device=device))
z_test = zdist.sample((test_batch_size, ))
x_test, y_test = get_test(get_data=get_data,
batch_size=test_batch_size,
variance=variance,
k_value=k_value,
device=device)
x_cluster, _ = get_test(get_data=get_data,
batch_size=10000,
variance=variance,
k_value=k_value,
device=device)
train_loader = get_dataset(get_data=get_data,
batch_size=train_batch_size,
npts=npts,
variance=variance,
k_value=k_value)
def train(trainer, g, d, clusterer, exp_dir):
it = 0
if os.path.exists(os.path.join(exp_dir, 'log.txt')):
os.remove(os.path.join(exp_dir, 'log.txt'))
for epoch in range(nepochs):
for x_real, y in train_loader:
z = zdist.sample((train_batch_size, ))
x_real, y = x_real.to(device), y.to(device)
y = clusterer.get_labels(x_real, y)
dloss, _ = trainer.discriminator_trainstep(x_real, y, z)
gloss = trainer.generator_trainstep(y, z)
if it % args.recluster_every == 0 and args.clusterer != 'supervised':
if args.clusterer != 'burnin' or it >= args.burnin_time:
clusterer.recluster(discriminator, x_batch=x_real)
if it % 1000 == 0:
x_fake = g(z_test, clusterer.get_labels(x_test, y_test)).detach().cpu().numpy()
visualize_generated(x_fake,
x_test.detach().cpu().numpy(), y, it,
exp_dir)
visualize_clusters(x_test.detach().cpu().numpy(),
clusterer.get_labels(x_test, y_test),
it, exp_dir)
torch.save(
{
'generator': g.state_dict(),
'discriminator': d.state_dict(),
'g_optimizer': g_optimizer.state_dict(),
'd_optimizer': d_optimizer.state_dict()
},
os.path.join(exp_dir, 'snapshots', 'model_%d.pt' % it))
if it % 1000 == 0:
g.eval()
d.eval()
x_fake = g(z_test, clusterer.get_labels(
x_test, y_test)).detach().cpu().numpy()
percent, modes, kl = percent_good(x_fake, var=variance)
log_message = f'[epoch {epoch} it {it}] dloss = {dloss}, gloss = {gloss}, prop_real = {percent}, modes = {modes}, kl = {kl}'
with open(os.path.join(exp_dir, 'log.txt'), 'a+') as f:
f.write(log_message + '\n')
print(log_message)
it += 1
# train a G/D from scratch
generator, discriminator = get_models(args.model_type, 'conditional', num_clusters, args.d_act_dim, device)
g_optimizer, d_optimizer = get_optimizers(generator, discriminator)
trainer = Trainer(generator, discriminator, g_optimizer, d_optimizer, gan_type='standard', reg_type='none', reg_param=0)
clusterer = clusterer_dict[args.clusterer](discriminator=discriminator,
k_value=num_clusters,
x_cluster=x_cluster)
clusterer.recluster(discriminator=discriminator)
train(trainer, generator, discriminator, clusterer, os.path.join(outdir))
if __name__ == '__main__':
outdir = os.path.join(args.outdir, exp_name)
pidfile.exit_if_job_done(outdir)
for run_number in range(args.nruns):
run_dir = f'{outdir}_run_{run_number}' if args.nruns > 1 else outdir
main(run_dir)
pidfile.mark_job_done(outdir)
================================================
FILE: 2d_mix/visualizations.py
================================================
import matplotlib
from matplotlib import pyplot
import os
COLORS = [
'purple',
'wheat',
'maroon',
'red',
'powderblue',
'dodgerblue',
'magenta',
'tan',
'aqua',
'yellow',
'slategray',
'blue',
'rosybrown',
'violet',
'lightseagreen',
'pink',
'darkorange',
'teal',
'royalblue',
'lawngreen',
'gold',
'navy',
'darkgreen',
'deeppink',
'palegreen',
'silver',
'saddlebrown',
'plum',
'peru',
'black',
]
assert (len(COLORS) == len(set(COLORS)))
def visualize_generated(fake, real, y, it, outdir):
pyplot.plot(real[:, 0], real[:, 1], 'r.')
pyplot.plot(fake[:, 0], fake[:, 1], 'b.')
pyplot.savefig(os.path.join(outdir, 'all', str(it) + '.png'))
pyplot.clf()
lim = 6
axes = pyplot.gca()
axes.set_aspect('equal', adjustable='box')
axes.set_xlim([-lim, lim])
axes.set_ylim([-lim, lim])
pyplot.locator_params(nbins=4)
pyplot.tight_layout()
pyplot.plot(fake[:, 0], fake[:, 1], 'b.', alpha=0.1)
pyplot.savefig(os.path.join(outdir, 'all',
str(it) + 'square.png'),
dpi=100,
bbox_inches='tight')
pyplot.clf()
def visualize_clusters(x, y, it, outdir):
y = y.detach().cpu().numpy()
for i in range(y.max()):
pyplot.plot(x[y == i, 0],
x[y == i, 1],
'.',
color=COLORS[i % len(COLORS)])
pyplot.savefig(os.path.join(outdir, 'clusters', str(it) + '.png'))
pyplot.clf()
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2020 Steven Liu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Diverse Image Generation via Self-Conditioned GANs
#### [Project](http://selfcondgan.csail.mit.edu/) | [Paper](http://selfcondgan.csail.mit.edu/preprint.pdf)
**Diverse Image Generation via Self-Conditioned GANs**
[Steven Liu](http://people.csail.mit.edu/stevenliu/),
[Tongzhou Wang](https://ssnl.github.io/),
[David Bau](http://people.csail.mit.edu/davidbau/home/),
[Jun-Yan Zhu](http://people.csail.mit.edu/junyanz/),
[Antonio Torralba](http://web.mit.edu/torralba/www/)
MIT, Adobe Research
in CVPR 2020.

Our proposed self-conditioned GAN model learns to perform clustering and image synthesis simultaneously. The model training
requires no manual annotation of object classes. Here, we visualize several discovered clusters for both Places365 (top) and ImageNet
(bottom). For each cluster, we show both real images and the generated samples conditioned on the cluster index.
## Getting Started
### Installation
- Clone this repo:
```bash
git clone https://github.com/stevliu/self-conditioned-gan.git
cd self-conditioned-gan
```
- Install the dependencies
```bash
conda create --name selfcondgan python=3.6
conda activate selfcondgan
conda install --file requirements.txt
conda install -c conda-forge tensorboardx
```
### Training and Evaluation
- Train a model on CIFAR:
```bash
python train.py configs/cifar/selfcondgan.yaml
```
- Visualize samples and inferred clusters:
```bash
python visualize_clusters.py configs/cifar/selfcondgan.yaml --show_clusters
```
The samples and clusters will be saved to `output/cifar/selfcondgan/clusters`. If this directory lies on an Apache server, you can open the URL to `output/cifar/selfcondgan/clusters/+lightbox.html` in the browser and visualize all samples and clusters in one webpage.
- Evaluate the model's FID:
You will need to first gather a set of ground truth train set images to compute metrics against.
```bash
python utils/get_gt_imgs.py --cifar
python metrics.py configs/cifar/selfcondgan.yaml --fid --every -1
```
You can also evaluate with other metrics by appending additional flags, such as Inception Score (`--inception`), the number of covered modes + reverse-KL divergence (`--modes`), and cluster metrics (`--cluster_metrics`).
## Pretrained Models
You can load and evaluate pretrained models on ImageNet and Places. If you have access to ImageNet or Places directories, first fill in paths to your ImageNet and/or Places dataset directories in `configs/imagenet/default.yaml` and `configs/places/default.yaml` respectively. You can use the following config files with the evaluation scripts, and the code will automatically download the appropriate models.
```bash
configs/pretrained/imagenet/selfcondgan.yaml
configs/pretrained/places/selfcondgan.yaml
configs/pretrained/imagenet/conditional.yaml
configs/pretrained/places/conditional.yaml
configs/pretrained/imagenet/baseline.yaml
configs/pretrained/places/baseline.yaml
```
## Evaluation
### Visualizations
To visualize generated samples and inferred clusters, run
```bash
python visualize_clusters.py config-file
```
You can set the flag `--show_clusters` to also visualize the real inferred clusters, but this requires that you have a path to training set images.
### Metrics
To obtain generation metrics, fill in paths to your ImageNet or Places dataset directories in `utils/get_gt_imgs.py` and then run
```bash
python utils/get_gt_imgs.py --imagenet --places
```
to precompute batches of GT images for FID/FSD evaluation.
Then, you can use
```bash
python metrics.py config-file
```
with the appropriate flags compute the FID (`--fid`), FSD (`--fsd`), IS (`--inception`), number of modes covered/ reverse-KL divergence (`--modes`) and clustering metrics (`--cluster_metrics`) for each of the checkpoints.
## Training models
To train a model, set up a configuration file (examples in `/configs`), and run
```bash
python train.py config-file
```
An example config of self-conditioned GAN on ImageNet is `config/imagenet/selfcondgan.yaml` and on Places is `config/places/selfcondgan.yaml`.
Some models may be too large to fit on one GPU, so you may want to add `--devices DEVICE_NUMBERS` as an additional flag to do multi GPU training.
## 2D-experiments
For synthetic dataset experiments, first go into the `2d_mix` directory.
To train a self-conditioned GAN on the 2D-ring and 2D-grid dataset, run
```bash
python train.py --clusterer selfcondgan --data_type ring
python train.py --clusterer selfcondgan --data_type grid
```
You can test several other configurations via the command line arguments.
## Acknowledgments
This code is heavily based on the [GAN-stability](https://github.com/LMescheder/GAN_stability) code base.
Our FSD code is taken from the [GANseeing](https://github.com/davidbau/ganseeing) work.
To compute inception score, we use the code provided from [Shichang Tang](https://github.com/tsc2017/Inception-Score.git).
To compute FID, we use the code provided from [TTUR](https://github.com/bioinf-jku/TTUR).
We also use pretrained classifiers given by the [pytorch-playground](https://github.com/aaron-xichen/pytorch-playground).
We thank all the authors for their useful code.
## Citation
If you use this code for your research, please cite the following work.
```
@inproceedings{liu2020selfconditioned,
title={Diverse Image Generation via Self-Conditioned GANs},
author={Liu, Steven and Wang, Tongzhou and Bau, David and Zhu, Jun-Yan and Torralba, Antonio},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2020}
}
```
================================================
FILE: cluster_metrics.py
================================================
import argparse
import os
from tqdm import tqdm
import torch
import numpy as np
from torch import nn
from gan_training import utils
from gan_training.inputs import get_dataset
from gan_training.checkpoints import CheckpointIO
from gan_training.config import load_config
from gan_training.metrics.clustering_metrics import (nmi, purity_score)
torch.backends.cudnn.benchmark = True
# Arguments
parser = argparse.ArgumentParser(description='Evaluate the clustering inferred by our method')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--model_it', type=str)
parser.add_argument('--random', action='store_true', help='Figure out if the clusters were randomly assigned')
args = parser.parse_args()
config = load_config(args.config, 'configs/default.yaml')
out_dir = config['training']['out_dir']
def main():
checkpoint_dir = os.path.join(out_dir, 'chkpts')
batch_size = config['training']['batch_size']
if 'cifar' in config['data']['train_dir'].lower():
name = 'cifar10'
elif 'stacked_mnist' == config['data']['type']:
name = 'stacked_mnist'
else:
name = 'image'
if os.path.exists(os.path.join(out_dir, 'cluster_preds.npz')):
# if we've already computed assignments, load them and move on
with np.load(os.path.join(out_dir, 'cluster_preds.npz')) as f:
y_reals = f['y_reals']
y_preds = f['y_preds']
else:
train_dataset, _ = get_dataset(
name=name,
data_dir=config['data']['train_dir'],
size=config['data']['img_size'])
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=config['training']['nworkers'],
shuffle=True,
pin_memory=True,
sampler=None,
drop_last=True)
checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)
print('Loading clusterer:')
most_recent = utils.get_most_recent(checkpoint_dir, 'model') if args.model_it is None else args.model_it
clusterer = checkpoint_io.load_clusterer(most_recent, load_samples=False, pretrained=config['pretrained'])
if isinstance(clusterer.discriminator, nn.DataParallel):
clusterer.discriminator = clusterer.discriminator.module
y_preds = []
y_reals = []
for batch_num, (x_real, y_real) in enumerate(tqdm(train_loader, total=len(train_loader))):
y_pred = clusterer.get_labels(x_real.cuda(), None)
y_preds.append(y_pred.detach().cpu())
y_reals.append(y_real)
y_reals = torch.cat(y_reals).numpy()
y_preds = torch.cat(y_preds).numpy()
np.savez(os.path.join(out_dir, 'cluster_preds.npz'), y_reals=y_reals, y_preds=y_preds)
if args.random:
y_preds = np.random.randint(0, 100, size=y_reals.shape)
nmi_score = nmi(y_preds, y_reals)
purity = purity_score(y_preds, y_reals)
print('nmi', nmi_score, 'purity', purity)
if __name__ == '__main__':
main()
================================================
FILE: clusterers/__init__.py
================================================
from clusterers import (base_clusterer, selfcondgan, random_labels, online)
clusterer_dict = {
'supervised': base_clusterer.BaseClusterer,
'selfcondgan': selfcondgan.Clusterer,
'online': online.Clusterer,
'random_labels': random_labels.Clusterer
}
================================================
FILE: clusterers/base_clusterer.py
================================================
import copy
import torch
import numpy as np
class BaseClusterer():
def __init__(self,
discriminator,
k_value=-1,
x_cluster=None,
batch_size=100,
**kwargs):
''' requires that self.x is not on the gpu, or else it hogs too much gpu memory '''
self.cluster_counts = [0] * k_value
self.discriminator = copy.deepcopy(discriminator)
self.discriminator.eval()
self.k = k_value
self.kmeans = None
self.x = x_cluster
self.x_labels = None
self.batch_size = batch_size
def get_labels(self, x, y):
return y
def recluster(self, discriminator, **kwargs):
return
def get_features(self, x):
''' by default gets discriminator, but you can use other things '''
return self.get_discriminator_output(x)
def get_cluster_batch_features(self):
''' returns the discriminator features for the batch self.x as a numpy array '''
with torch.no_grad():
outputs = []
x = self.x
for batch in range(x.size(0) // self.batch_size):
x_batch = x[batch * self.batch_size:(batch + 1) * self.batch_size].cuda()
outputs.append(self.get_features(x_batch).detach().cpu())
if (x.size(0) % self.batch_size != 0):
x_batch = x[x.size(0) // self.batch_size * self.batch_size:].cuda()
outputs.append(self.get_features(x_batch).detach().cpu())
result = torch.cat(outputs, dim=0).numpy()
return result
def get_discriminator_output(self, x):
'''returns discriminator features'''
self.discriminator.eval()
with torch.no_grad():
return self.discriminator(x, get_features=True)
def get_label_distribution(self, x=None):
'''returns the empirical distributon of clustering'''
y = self.x_labels if x is None else self.get_labels(x, None)
counts = [0] * self.k
for yi in y:
counts[yi] += 1
return counts
def sample_y(self, batch_size):
'''samples y according to the empirical distribution (not sure if used anymore)'''
distribution = self.get_label_distribution()
distribution = [i / sum(distribution) for i in distribution]
m = torch.distributions.Multinomial(batch_size,
torch.tensor(distribution))
return m.sample()
def print_label_distribution(self, x=None):
print(self.get_label_distribution(x))
================================================
FILE: clusterers/kmeans.py
================================================
import torch
import numpy as np
from sklearn.cluster import KMeans
from clusterers import base_clusterer
class Clusterer(base_clusterer.BaseClusterer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.mapping = list(range(self.k))
def kmeans_fit_predict(self, features, init='k-means++', n_init=10):
'''fits kmeans, and returns the predictions of the kmeans'''
print('Fitting k-means w data shape', features.shape)
self.kmeans = KMeans(init=init, n_clusters=self.k,
n_init=n_init).fit(features)
return self.kmeans.predict(features)
def get_labels(self, x, y):
d_features = self.get_features(x).detach().cpu().numpy()
np_prediction = self.kmeans.predict(d_features)
permuted_prediction = np.array([self.mapping[x] for x in np_prediction])
return torch.from_numpy(permuted_prediction).long().cuda()
================================================
FILE: clusterers/online.py
================================================
import copy, random
import torch
import numpy as np
from clusterers import kmeans
class Clusterer(kmeans.Clusterer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.burned_in = False
def get_initialization(self, features, labels):
'''given points (from new discriminator) and their old assignments as np arrays, compute the induced means as a np array'''
means = []
for i in range(self.k):
mask = (labels == i)
mean = np.zeros(features[0].shape)
numels = mask.astype(int).sum()
if numels > 0:
for index, equal in enumerate(mask):
if equal: mean += features[index]
means.append(mean / numels)
else:
# use kmeans++ init if cluster is starved
rand_point = random.randint(0, features.size(0) - 1)
means.append(features[rand_point])
result = np.array(means)
return result
def recluster(self, discriminator, x_batch=None, **kwargs):
if self.kmeans is None:
print('kmeans clustering as initialization')
self.discriminator = copy.deepcopy(discriminator)
features = self.get_cluster_batch_features()
self.x_labels = self.kmeans_fit_predict(features)
else:
self.discriminator = discriminator
if not self.burned_in:
print('Burned in: computing initialization for kmeans')
features = self.get_cluster_batch_features()
initialization = self.get_initialization(
features, self.x_labels)
self.kmeans_fit_predict(features, init=initialization)
self.burned_in = True
else:
assert x_batch is not None
self.discriminator = discriminator
features = self.get_features(x_batch).detach().cpu().numpy()
y_pred = self.kmeans.predict(features)
for xi, yi in zip(features, y_pred):
self.cluster_counts[yi] += 1
difference = xi - self.kmeans.cluster_centers_[yi]
step_size = 1.0 / self.cluster_counts[yi]
self.kmeans.cluster_centers_[
yi] = self.kmeans.cluster_centers_[yi] + step_size * (
difference)
================================================
FILE: clusterers/random_labels.py
================================================
import torch
from clusterers import base_clusterer
class Clusterer(base_clusterer.BaseClusterer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_labels(self, x, y):
return torch.randint(low=0, high=self.k, size=y.shape).long().cuda()
================================================
FILE: clusterers/selfcondgan.py
================================================
import copy, random
import torch
import numpy as np
from sklearn.utils.linear_assignment_ import linear_assignment
from clusterers import kmeans
class Clusterer(kmeans.Clusterer):
def __init__(self, initialization=True, matching=True, **kwargs):
self.initialization = initialization
self.matching = matching
super().__init__(**kwargs)
def get_initialization(self, features, labels):
'''given points (from new discriminator) and their old assignments as np arrays, compute the induced means as a np array'''
means = []
for i in range(self.k):
mask = (labels == i)
mean = np.zeros(features[0].shape)
numels = mask.astype(int).sum()
if numels > 0:
for index, equal in enumerate(mask):
if equal: mean += features[index]
means.append(mean / numels)
else:
# use kmeans++ init if cluster is starved
rand_point = random.randint(0, features.size(0) - 1)
means.append(features[rand_point])
result = np.array(means)
return result
def fit_means(self):
features = self.get_cluster_batch_features()
# if clustered already, use old assignments for the cluster mean
if self.x_labels is not None and self.initialization:
print('Initializing k-means with previous cluster assignments')
initialization = self.get_initialization(features, self.x_labels)
else:
initialization = 'k-means++'
new_classes = self.kmeans_fit_predict(features, init=initialization)
# we've clustered already, so compute the permutation
if self.x_labels is not None and self.matching:
print('Doing cluster matching')
matching = self.hungarian_match(new_classes, self.x_labels, self.k,
self.k)
self.mapping = [int(j) for i, j in sorted(matching)]
# recompute the fixed labels
self.x_labels = np.array([self.mapping[x] for x in new_classes])
def recluster(self, discriminator, **kwargs):
self.discriminator = copy.deepcopy(discriminator)
self.fit_means()
def hungarian_match(self, flat_preds, flat_targets, preds_k, targets_k):
'''takes in np arrays flat_preds, flat_targets of integers'''
num_samples = flat_targets.shape[0]
assert (preds_k == targets_k) # one to one
num_k = preds_k
num_correct = np.zeros((num_k, num_k))
for c1 in range(num_k):
for c2 in range(num_k):
votes = int(((flat_preds == c1) * (flat_targets == c2)).sum())
num_correct[c1, c2] = votes
# num_correct is small
match = linear_assignment(num_samples - num_correct)
# return as list of tuples, out_c to gt_c
res = []
for out_c, gt_c in match:
res.append((out_c, gt_c))
return res
================================================
FILE: configs/cifar/conditional.yaml
================================================
generator:
nlabels: 10
conditioning: embedding
discriminator:
nlabels: 10
conditioning: mask
inherit_from: configs/cifar/default.yaml
training:
out_dir: output/cifar/conditional
================================================
FILE: configs/cifar/default.yaml
================================================
data:
type: cifar10
train_dir: data/CIFAR
img_size: 32
nlabels: 10
generator:
name: dcgan_deep
nlabels: 1
conditioning: unconditional
kwargs:
placeholder: None
discriminator:
name: dcgan_deep
nlabels: 1
conditioning: unconditional
kwargs:
placeholder: None
z_dist:
type: gauss
dim: 128
clusterer:
name: supervised
nimgs: 25000
kwargs:
placeholder: None
training:
gan_type: standard
reg_type: none
reg_param: 0.
take_model_average: false
sample_nlabels: 20
log_every: 1000
inception_every: 10000
batch_size: 64
================================================
FILE: configs/cifar/selfcondgan.yaml
================================================
generator:
nlabels: 100
conditioning: embedding
discriminator:
nlabels: 100
conditioning: mask
clusterer:
name: selfcondgan
kwargs:
k_value: 100
inherit_from: configs/cifar/default.yaml
training:
out_dir: output/cifar/selfcondgan
recluster_every: 25000
================================================
FILE: configs/cifar/unconditional.yaml
================================================
inherit_from: configs/cifar/default.yaml
training:
out_dir: output/cifar/unconditional
================================================
FILE: configs/default.yaml
================================================
data:
type: lsun
train_dir: data/LSUN
deterministic: False
img_size: 128
nlabels: 1
generator:
name: resnet
nlabels: 1
conditioning: unconditional
kwargs:
placeholder: None
discriminator:
name: resnet
nlabels: 1
conditioning: unconditional
kwargs:
pack_size: 1
placeholder: None
clusterer:
name: supervised
nimgs: 100
kwargs:
num_components: -1
z_dist:
type: gauss
dim: 256
training:
out_dir: output/default
gan_type: standard
reg_type: real
reg_param: 10.
log_every: 1
batch_size: 128
ntest: 128
nworkers: 72
burnin_time: 0
take_model_average: true
model_average_beta: 0.999
monitoring: tensorboard
sample_every: 5000
sample_nlabels: 20
inception_every: 10000
inception_nsamples: 50000
backup_every: 10000
recluster_every: 10000
optimizer: adam
lr_g: 0.0001
lr_d: 0.0001
beta1: 0.0
beta2: 0.99
pretrained: {}
================================================
FILE: configs/imagenet/conditional.yaml
================================================
generator:
nlabels: 1000
conditioning: embedding
discriminator:
nlabels: 1000
conditioning: mask
inherit_from: configs/imagenet/default.yaml
training:
out_dir: output/imagenet/conditional
================================================
FILE: configs/imagenet/default.yaml
================================================
data:
type: image
train_dir: data/ImageNet/train
test_dir: data/ImageNet/val
img_size: 128
nlabels: 1000
generator:
name: resnet2
nlabels: 1
conditioning: unconditional
discriminator:
name: resnet2
nlabels: 1
conditioning: unconditional
z_dist:
type: gauss
dim: 256
clusterer:
name: supervised
training:
gan_type: standard
reg_type: real
reg_param: 10.
take_model_average: true
model_average_beta: 0.999
sample_nlabels: 20
log_every: 10
inception_every: 10000
backup_every: 5000
batch_size: 128
================================================
FILE: configs/imagenet/selfcondgan.yaml
================================================
generator:
nlabels: 100
conditioning: embedding
discriminator:
nlabels: 100
conditioning: mask
clusterer:
name: selfcondgan
nimgs: 50000
kwargs:
k_value: 100
inherit_from: configs/imagenet/default.yaml
training:
out_dir: output/imagenet/selfcondgan
recluster_every: 75000
reg_param: 0.1
================================================
FILE: configs/imagenet/unconditional.yaml
================================================
generator:
nlabels: 1
conditioning: unconditional
discriminator:
nlabels: 1
conditioning: unconditional
inherit_from: configs/imagenet/default.yaml
training:
out_dir: output/imagenet/unconditional
================================================
FILE: configs/places/conditional.yaml
================================================
generator:
nlabels: 365
conditioning: embedding
discriminator:
nlabels: 365
conditioning: mask
training:
out_dir: output/places/conditional
inherit_from: configs/places/default.yaml
================================================
FILE: configs/places/default.yaml
================================================
data:
type: image
train_dir: data/places365/train
test_dir: data/places365/val
img_size: 128
nlabels: 365
generator:
name: resnet2
nlabels: 1
conditioning: unconditional
discriminator:
name: resnet2
nlabels: 1
conditioning: unconditional
z_dist:
type: gauss
dim: 256
clusterer:
name: supervised
training:
gan_type: standard
reg_type: real
reg_param: 10.
take_model_average: true
model_average_beta: 0.999
sample_nlabels: 20
log_every: 10
inception_every: 10000
backup_every: 5000
batch_size: 128
================================================
FILE: configs/places/selfcondgan.yaml
================================================
generator:
nlabels: 100
conditioning: embedding
discriminator:
nlabels: 100
conditioning: mask
clusterer:
name: selfcondgan
nimgs: 50000
kwargs:
k_value: 100
inherit_from: configs/places/default.yaml
training:
out_dir: output/places/selfcondgan
recluster_every: 75000
reg_param: 0.1
================================================
FILE: configs/places/unconditional.yaml
================================================
generator:
nlabels: 1
conditioning: embedding
discriminator:
nlabels: 1
conditioning: mask
inherit_from: configs/places/default.yaml
training:
out_dir: output/places/unconditional
================================================
FILE: configs/pretrained/imagenet/conditional.yaml
================================================
generator:
nlabels: 1000
conditioning: embedding
discriminator:
nlabels: 1000
conditioning: mask
inherit_from: configs/imagenet/default.yaml
training:
out_dir: output/pretrained/imagenet/class_conditional
pretrained:
model: http://selfcondgan.csail.mit.edu/weights/classcondgan_i_model.pt
================================================
FILE: configs/pretrained/imagenet/selfcondgan.yaml
================================================
generator:
nlabels: 100
conditioning: embedding
discriminator:
nlabels: 100
conditioning: mask
clusterer:
name: selfcondgan
nimgs: 50000
kwargs:
k_value: 100
inherit_from: configs/imagenet/default.yaml
training:
out_dir: output/pretrained/imagenet/selfcondgan
recluster_every: 75000
reg_param: 0.1
pretrained:
model: http://selfcondgan.csail.mit.edu/weights/selfcondgan_i_model.pt
clusterer: http://selfcondgan.csail.mit.edu/weights/selfcondgan_i_clusterer.pkl
================================================
FILE: configs/pretrained/imagenet/unconditional.yaml
================================================
generator:
nlabels: 1
conditioning: unconditional
discriminator:
nlabels: 1
conditioning: unconditional
inherit_from: configs/imagenet/default.yaml
training:
out_dir: output/pretrained/imagenet/unconditional
pretrained:
model: http://selfcondgan.csail.mit.edu/weights/uncondgan_i_model.pt
================================================
FILE: configs/pretrained/places/conditional.yaml
================================================
generator:
nlabels: 365
conditioning: embedding
discriminator:
nlabels: 365
conditioning: mask
training:
out_dir: output/pretrained/places/class_conditional
inherit_from: configs/places/default.yaml
pretrained:
model: http://selfcondgan.csail.mit.edu/weights/classcondgan_p_model.pt
================================================
FILE: configs/pretrained/places/selfcondgan.yaml
================================================
generator:
nlabels: 100
conditioning: embedding
discriminator:
nlabels: 100
conditioning: mask
clusterer:
name: selfcondgan
nimgs: 50000
kwargs:
k_value: 100
inherit_from: configs/places/default.yaml
training:
out_dir: output/pretrained/places/selfcondgan
reg_param: 0.1
pretrained:
model: http://selfcondgan.csail.mit.edu/weights/selfcondgan_p_model.pt
clusterer: http://selfcondgan.csail.mit.edu/weights/selfcondgan_p_clusterer.pkl
================================================
FILE: configs/pretrained/places/unconditional.yaml
================================================
generator:
nlabels: 1
conditioning: embedding
discriminator:
nlabels: 1
conditioning: mask
inherit_from: configs/places/default.yaml
training:
out_dir: output/pretrained/places/unconditional
pretrained:
model: http://selfcondgan.csail.mit.edu/weights/uncondgan_p_model.pt
================================================
FILE: configs/stacked_mnist/conditional.yaml
================================================
generator:
nlabels: 1000
conditioning: embedding
discriminator:
nlabels: 1000
conditioning: mask
inherit_from: configs/stacked_mnist/default.yaml
training:
out_dir: output/stacked_mnist/conditional
================================================
FILE: configs/stacked_mnist/default.yaml
================================================
data:
type: stacked_mnist
train_dir: data/MNIST
img_size: 32
nlabels: 1000
generator:
name: dcgan_shallow
nlabels: 1
conditioning: unconditional
kwargs:
placeholder: None
discriminator:
name: dcgan_shallow
nlabels: 1
conditioning: unconditional
kwargs:
placeholder: None
z_dist:
type: gauss
dim: 128
clusterer:
name: supervised
nimgs: 25000
kwargs:
placeholder: None
training:
gan_type: standard
reg_type: none
reg_param: 0.
take_model_average: false
sample_nlabels: 20
log_every: 1000
backup_every: 5000
inception_every: 10000
batch_size: 64
================================================
FILE: configs/stacked_mnist/selfcondgan.yaml
================================================
generator:
nlabels: 100
conditioning: embedding
discriminator:
nlabels: 100
conditioning: mask
clusterer:
name: selfcondgan
kwargs:
k_value: 100
inherit_from: configs/stacked_mnist/default.yaml
training:
out_dir: output/stacked_mnist/selfcondgan
recluster_every: 25000
================================================
FILE: configs/stacked_mnist/unconditional.yaml
================================================
inherit_from: configs/stacked_mnist/default.yaml
training:
out_dir: output/stacked_mnist/unconditional
================================================
FILE: gan_training/__init__.py
================================================
================================================
FILE: gan_training/checkpoints.py
================================================
import os, pickle
import urllib
import torch
import numpy as np
from torch.utils import model_zoo
class CheckpointIO(object):
''' CheckpointIO class.
It handles saving and loading checkpoints.
Args:
checkpoint_dir (str): path where checkpoints are saved
'''
def __init__(self, checkpoint_dir='./chkpts', **kwargs):
self.module_dict = kwargs
self.checkpoint_dir = checkpoint_dir
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
def register_modules(self, **kwargs):
''' Registers modules in current module dictionary.
'''
self.module_dict.update(kwargs)
def save(self, filename, **kwargs):
''' Saves the current module dictionary.
Args:
filename (str): name of output file
'''
if not os.path.isabs(filename):
filename = os.path.join(self.checkpoint_dir, filename)
outdict = kwargs
for k, v in self.module_dict.items():
outdict[k] = v.state_dict()
torch.save(outdict, filename)
def load(self, filename, pretrained={}):
'''Loads a module dictionary from local file or url.
Args:
filename (str): name of saved module dictionary
'''
if 'model' in pretrained:
filename = pretrained['model']
if is_url(filename):
return self.load_url(filename)
else:
return self.load_file(filename)
def load_file(self, filename):
'''Loads a module dictionary from file.
Args:
filename (str): name of saved module dictionary
'''
if not os.path.isabs(filename):
filename = os.path.join(self.checkpoint_dir, filename)
if os.path.exists(filename):
print('=> Loading checkpoint from local file...', filename)
state_dict = torch.load(filename)
scalars = self.parse_state_dict(state_dict)
return scalars
else:
print('File not found', filename)
raise FileNotFoundError
def load_url(self, url):
'''Load a module dictionary from url.
Args:
url (str): url to saved model
'''
print('=> Loading checkpoint from url...', url)
state_dict = model_zoo.load_url(url, model_dir=self.checkpoint_dir, progress=True)
scalars = self.parse_state_dict(state_dict)
return scalars
def parse_state_dict(self, state_dict):
'''Parse state_dict of model and return scalars.
Args:
state_dict (dict): State dict of model
'''
for k, v in self.module_dict.items():
if k in state_dict:
v.load_state_dict(state_dict[k])
else:
print('Warning: Could not find %s in checkpoint!' % k)
scalars = {
k: v
for k, v in state_dict.items() if k not in self.module_dict
}
return scalars
def load_clusterer(self, it, load_samples, pretrained={}):
if 'clusterer' in pretrained:
pretrained_file = os.path.join(self.checkpoint_dir, 'pretrained_clusterer.pkl')
if not os.path.exists(pretrained_file):
import cloudpickle as cp
from urllib.request import urlopen
print('Loading pretrained clusterer from', pretrained['clusterer'])
clusterer = cp.load(urlopen(pretrained['clusterer']))
print('Saving pretrained clusterer to', pretrained_file)
with open(pretrained_file, 'wb') as f:
f.write(pickle.dumps(clusterer))
else:
with open(pretrained_file, 'rb') as f:
clusterer = pickle.load(f)
return clusterer
else:
print('Loading clusterer:')
with open(os.path.join(self.checkpoint_dir, f'clusterer{it}.pkl'), 'rb') as f:
clusterer = pickle.load(f)
if load_samples:
print('Loading cluster samples:')
with np.load(os.path.join(self.checkpoint_dir, 'cluster_samples.npz')) as f:
x = f['x']
clusterer.x = torch.from_numpy(x)
return clusterer
def load_models(self, it, pretrained={}, load_samples=False):
try:
load_dict = self.load('model_%08d.pt' % it, pretrained)
epoch_idx = load_dict.get('epoch_idx', -1)
except Exception as e: #models are not dataparallel modules
print('Trying again to load w/o data parallel modules')
try:
for name, module in self.module_dict.items():
if isinstance(module, torch.nn.DataParallel):
self.module_dict[name] = module.module
load_dict = self.load('model_%08d.pt' % it, pretrained)
epoch_idx = load_dict.get('epoch_idx', -1)
except FileNotFoundError as e:
print(e)
print("Models not found")
it = epoch_idx = -1
try:
clusterer = self.load_clusterer(it, load_samples, pretrained)
except FileNotFoundError as e:
clusterer = None
return it, epoch_idx, clusterer
def save_clusterer(self, clusterer, it):
with open(os.path.join(self.checkpoint_dir, f'clusterer{it}.pkl'), 'wb') as f:
#hack: only save changing data
x = clusterer.x
clusterer.x = None
pickle.dump(clusterer, f)
clusterer.x = x
def is_url(url):
scheme = urllib.parse.urlparse(url).scheme
return scheme in ('http', 'https')
================================================
FILE: gan_training/config.py
================================================
import yaml
from torch import optim
from os import path
from gan_training.models import generator_dict, discriminator_dict
from gan_training.train import toggle_grad
from clusterers import clusterer_dict
# General config
def load_config(path, default_path):
''' Loads config file.
Args:
path (str): path to config file
default_path (bool): whether to use default path
'''
# Load configuration from file itself
with open(path, 'r') as f:
cfg_special = yaml.load(f)
# Check if we should inherit from a config
inherit_from = cfg_special.get('inherit_from')
# If yes, load this config first as default
# If no, use the default_path
if inherit_from is not None:
cfg = load_config(inherit_from, default_path)
elif default_path is not None:
with open(default_path, 'r') as f:
cfg = yaml.load(f)
else:
cfg = dict()
# Include main configuration
update_recursive(cfg, cfg_special)
return cfg
def update_recursive(dict1, dict2):
''' Update two config dictionaries recursively.
Args:
dict1 (dict): first dictionary to be updated
dict2 (dict): second dictionary which entries should be used
'''
for k, v in dict2.items():
# Add item if not yet in dict1
if k not in dict1:
dict1[k] = None
# Update
if isinstance(dict1[k], dict):
update_recursive(dict1[k], v)
else:
dict1[k] = v
def get_clusterer(config):
return clusterer_dict[config['clusterer']['name']]
def build_models(config):
# Get classes
Generator = generator_dict[config['generator']['name']]
Discriminator = discriminator_dict[config['discriminator']['name']]
# Build models
generator = Generator(z_dim=config['z_dist']['dim'],
nlabels=config['generator']['nlabels'],
size=config['data']['img_size'],
conditioning=config['generator']['conditioning'],
**config['generator']['kwargs'])
discriminator = Discriminator(
nlabels=config['discriminator']['nlabels'],
conditioning=config['discriminator']['conditioning'],
size=config['data']['img_size'],
**config['discriminator']['kwargs'])
return generator, discriminator
def build_optimizers(generator, discriminator, config):
optimizer = config['training']['optimizer']
lr_g = config['training']['lr_g']
lr_d = config['training']['lr_d']
toggle_grad(generator, True)
toggle_grad(discriminator, True)
g_params = generator.parameters()
d_params = discriminator.parameters()
if optimizer == 'rmsprop':
g_optimizer = optim.RMSprop(g_params, lr=lr_g, alpha=0.99, eps=1e-8)
d_optimizer = optim.RMSprop(d_params, lr=lr_d, alpha=0.99, eps=1e-8)
elif optimizer == 'adam':
beta1 = config['training']['beta1']
beta2 = config['training']['beta2']
g_optimizer = optim.Adam(g_params, lr=lr_g, betas=(beta1, beta2), eps=1e-8)
d_optimizer = optim.Adam(d_params, lr=lr_d, betas=(beta1, beta2), eps=1e-8)
elif optimizer == 'sgd':
g_optimizer = optim.SGD(g_params, lr=lr_g, momentum=0.)
d_optimizer = optim.SGD(d_params, lr=lr_d, momentum=0.)
return g_optimizer, d_optimizer
# Some utility functions
def get_parameter_groups(parameters, gradient_scales, base_lr):
param_groups = []
for p in parameters:
c = gradient_scales.get(p, 1.)
param_groups.append({'params': [p], 'lr': c * base_lr})
return param_groups
================================================
FILE: gan_training/distributions.py
================================================
import torch
from torch import distributions
def get_zdist(dist_name, dim, device=None):
# Get distribution
if dist_name == 'uniform':
low = -torch.ones(dim, device=device)
high = torch.ones(dim, device=device)
zdist = distributions.Uniform(low, high)
elif dist_name == 'gauss':
mu = torch.zeros(dim, device=device)
scale = torch.ones(dim, device=device)
zdist = distributions.Normal(mu, scale)
else:
raise NotImplementedError
# Add dim attribute
zdist.dim = dim
return zdist
def get_ydist(nlabels, device=None):
logits = torch.zeros(nlabels, device=device)
ydist = distributions.categorical.Categorical(logits=logits)
# Add nlabels attribute
ydist.nlabels = nlabels
return ydist
def interpolate_sphere(z1, z2, t):
p = (z1 * z2).sum(dim=-1, keepdim=True)
p = p / z1.pow(2).sum(dim=-1, keepdim=True).sqrt()
p = p / z2.pow(2).sum(dim=-1, keepdim=True).sqrt()
omega = torch.acos(p)
s1 = torch.sin((1-t)*omega)/torch.sin(omega)
s2 = torch.sin(t*omega)/torch.sin(omega)
z = s1 * z1 + s2 * z2
return z
================================================
FILE: gan_training/eval.py
================================================
import numpy as np
import torch
from torch.nn import functional as F
from gan_training.metrics import inception_score
class Evaluator(object):
def __init__(self,
generator,
zdist,
ydist,
train_loader,
clusterer,
batch_size=64,
inception_nsamples=10000,
device=None):
self.generator = generator
self.clusterer = clusterer
self.train_loader = train_loader
self.zdist = zdist
self.ydist = ydist
self.inception_nsamples = inception_nsamples
self.batch_size = batch_size
self.device = device
def sample_z(self, batch_size):
return self.zdist.sample((batch_size, )).to(self.device)
def get_y(self, x, y):
return self.clusterer.get_labels(x, y).to(self.device)
def get_fake_real_samples(self, N):
''' returns N fake images and N real images in pytorch form'''
with torch.no_grad():
self.generator.eval()
fake_imgs = []
real_imgs = []
while len(fake_imgs) < N:
for x_real, y_gt in self.train_loader:
x_real = x_real.cuda()
z = self.sample_z(x_real.size(0))
y = self.get_y(x_real, y_gt)
samples = self.generator(z, y)
samples = [s.data.cpu() for s in samples]
fake_imgs.extend(samples)
real_batch = [img.data.cpu() for img in x_real]
real_imgs.extend(real_batch)
assert (len(real_imgs) == len(fake_imgs))
if len(fake_imgs) >= N:
fake_imgs = fake_imgs[:N]
real_imgs = real_imgs[:N]
return fake_imgs, real_imgs
def compute_inception_score(self):
imgs, _ = self.get_fake_real_samples(self.inception_nsamples)
imgs = [img.numpy() for img in imgs]
score, score_std = inception_score(imgs,
device=self.device,
resize=True,
splits=1)
return score, score_std
def create_samples(self, z, y=None):
self.generator.eval()
batch_size = z.size(0)
# Parse y
if y is None:
raise NotImplementedError()
elif isinstance(y, int):
y = torch.full((batch_size, ),
y,
device=self.device,
dtype=torch.int64)
# Sample x
with torch.no_grad():
x = self.generator(z, y)
return x
================================================
FILE: gan_training/inputs.py
================================================
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import os
import torch.utils.data as data
from torchvision.datasets.folder import default_loader
from PIL import Image
import random
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def get_dataset(name,
data_dir,
size=64,
lsun_categories=None,
deterministic=False,
transform=None):
transform = transforms.Compose([
t for t in [
transforms.Resize(size),
transforms.CenterCrop(size),
(not deterministic) and transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
(not deterministic) and
transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size())),
] if t is not False
]) if transform == None else transform
if name == 'image':
print('Using image labels')
dataset = datasets.ImageFolder(data_dir, transform)
nlabels = len(dataset.classes)
elif name == 'webp':
print('Using no labels from webp')
dataset = CachedImageFolder(data_dir, transform)
nlabels = len(dataset.classes)
elif name == 'npy':
# Only support normalization for now
dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
nlabels = len(dataset.classes)
elif name == 'cifar10':
dataset = datasets.CIFAR10(root=data_dir,
train=True,
download=True,
transform=transform)
nlabels = 10
elif name == 'stacked_mnist':
dataset = StackedMNIST(data_dir,
transform=transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, ))
]))
nlabels = 1000
elif name == 'lsun':
if lsun_categories is None:
lsun_categories = 'train'
dataset = datasets.LSUN(data_dir, lsun_categories, transform)
nlabels = len(dataset.classes)
elif name == 'lsun_class':
dataset = datasets.LSUNClass(data_dir,
transform,
target_transform=(lambda t: 0))
nlabels = 1
else:
raise NotImplemented
return dataset, nlabels
class CachedImageFolder(data.Dataset):
"""
A version of torchvision.dataset.ImageFolder that takes advantage
of cached filename lists.
photo/park/004234.jpg
photo/park/004236.jpg
photo/park/004237.jpg
"""
def __init__(self, root, transform=None, loader=default_loader):
classes, class_to_idx = find_classes(root)
self.imgs = make_class_dataset(root, class_to_idx)
if len(self.imgs) == 0:
raise RuntimeError("Found 0 images within: %s" % root)
self.root = root
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.loader = loader
def __getitem__(self, index):
path, classidx = self.imgs[index]
source = self.loader(path)
if self.transform is not None:
source = self.transform(source)
return source, classidx
def __len__(self):
return len(self.imgs)
class StackedMNIST(data.Dataset):
def __init__(self, data_dir, transform, batch_size=100000):
super().__init__()
self.channel1 = datasets.MNIST(data_dir,
transform=transform,
train=True,
download=True)
self.channel2 = datasets.MNIST(data_dir,
transform=transform,
train=True,
download=True)
self.channel3 = datasets.MNIST(data_dir,
transform=transform,
train=True,
download=True)
self.indices = {
k: (random.randint(0,
len(self.channel1) - 1),
random.randint(0,
len(self.channel1) - 1),
random.randint(0,
len(self.channel1) - 1))
for k in range(batch_size)
}
def __getitem__(self, index):
index1, index2, index3 = self.indices[index]
x1, y1 = self.channel1[index1]
x2, y2 = self.channel2[index2]
x3, y3 = self.channel3[index3]
return torch.cat([x1, x2, x3], dim=0), y1 * 100 + y2 * 10 + y3
def __len__(self):
return len(self.indices)
def is_npy_file(path):
return path.endswith('.npy') or path.endswith('.NPY')
def walk_image_files(rootdir):
print(rootdir)
if os.path.isfile('%s.txt' % rootdir):
print('Loading file list from %s.txt instead of scanning dir' %
rootdir)
basedir = os.path.dirname(rootdir)
with open('%s.txt' % rootdir) as f:
result = sorted([
os.path.join(basedir, line.strip()) for line in f.readlines()
])
import random
random.Random(1).shuffle(result)
return result
result = []
IMG_EXTENSIONS = [
'.jpg',
'.JPG',
'.jpeg',
'.JPEG',
'.png',
'.PNG',
'.ppm',
'.PPM',
'.bmp',
'.BMP',
]
for dirname, _, fnames in sorted(os.walk(rootdir)):
for fname in sorted(fnames):
if any(fname.endswith(extension)
for extension in IMG_EXTENSIONS) or is_npy_file(fname):
result.append(os.path.join(dirname, fname))
return result
def find_classes(dir):
classes = [
d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))
]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_class_dataset(source_root, class_to_idx):
"""
Returns (source, classnum, feature)
"""
imagepairs = []
source_root = os.path.expanduser(source_root)
for path in walk_image_files(source_root):
classname = os.path.basename(os.path.dirname(path))
imagepairs.append((path, 0))
return imagepairs
def npy_loader(path):
img = np.load(path)
if img.dtype == np.uint8:
img = img.astype(np.float32)
img = img / 127.5 - 1.
elif img.dtype == np.float32:
img = img * 2 - 1.
else:
raise NotImplementedError
img = torch.Tensor(img)
if len(img.size()) == 4:
img.squeeze_(0)
return img
================================================
FILE: gan_training/logger.py
================================================
import pickle
import os
import torchvision
import copy
class Logger(object):
def __init__(self,
log_dir='./logs',
img_dir='./imgs',
monitoring=None,
monitoring_dir=None):
self.stats = dict()
self.log_dir = log_dir
self.img_dir = img_dir
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(img_dir):
os.makedirs(img_dir)
if not (monitoring is None or monitoring == 'none'):
self.setup_monitoring(monitoring, monitoring_dir)
else:
self.monitoring = None
self.monitoring_dir = None
def setup_monitoring(self, monitoring, monitoring_dir=None):
self.monitoring = monitoring
self.monitoring_dir = monitoring_dir
if monitoring == 'telemetry':
import telemetry
self.tm = telemetry.ApplicationTelemetry()
if self.tm.get_status() == 0:
print('Telemetry successfully connected.')
elif monitoring == 'tensorboard':
import tensorboardX
self.tb = tensorboardX.SummaryWriter(monitoring_dir)
else:
raise NotImplementedError('Monitoring tool "%s" not supported!' %
monitoring)
def add(self, category, k, v, it):
if category not in self.stats:
self.stats[category] = {}
if k not in self.stats[category]:
self.stats[category][k] = []
self.stats[category][k].append((it, v))
k_name = '%s/%s' % (category, k)
if self.monitoring == 'telemetry':
self.tm.metric_push_async({'metric': k_name, 'value': v, 'it': it})
elif self.monitoring == 'tensorboard':
self.tb.add_scalar(k_name, v, it)
def add_imgs(self, imgs, class_name, it):
outdir = os.path.join(self.img_dir, class_name)
if not os.path.exists(outdir):
os.makedirs(outdir)
outfile = os.path.join(outdir, '%08d.png' % it)
imgs = imgs / 2 + 0.5
imgs = torchvision.utils.make_grid(imgs)
torchvision.utils.save_image(copy.deepcopy(imgs), outfile, nrow=8)
if self.monitoring == 'tensorboard':
self.tb.add_image(class_name, copy.deepcopy(imgs), it)
def get_last(self, category, k, default=0.):
if category not in self.stats:
return default
elif k not in self.stats[category]:
return default
else:
return self.stats[category][k][-1][1]
def save_stats(self, filename):
filename = os.path.join(self.log_dir, filename)
with open(filename, 'wb') as f:
pickle.dump(self.stats, f)
def load_stats(self, filename):
filename = os.path.join(self.log_dir, filename)
if not os.path.exists(filename):
print('Warning: file "%s" does not exist!' % filename)
return
try:
with open(filename, 'rb') as f:
self.stats = pickle.load(f)
except EOFError:
print('Warning: log file corrupted!')
================================================
FILE: gan_training/metrics/__init__.py
================================================
from gan_training.metrics.inception_score import inception_score
__all__ = [
inception_score
]
================================================
FILE: gan_training/metrics/clustering_metrics.py
================================================
def warn(*args, **kwargs):
pass
import warnings
warnings.warn = warn
from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score, homogeneity_score
from sklearn import metrics
import numpy as np
def nmi(inferred, gt):
return normalized_mutual_info_score(inferred, gt)
def acc(inferred, gt):
gt = gt.astype(np.int64)
assert inferred.size == gt.size
D = max(inferred.max(), gt.max()) + 1
w = np.zeros((D, D), dtype=np.int64)
for i in range(inferred.size):
w[inferred[i], gt[i]] += 1
from sklearn.utils.linear_assignment_ import linear_assignment
ind = linear_assignment(w.max() - w)
return sum([w[i, j] for i, j in ind]) * 1.0 / inferred.size
def purity_score(y_true, y_pred):
contingency_matrix = metrics.cluster.contingency_matrix(y_true, y_pred)
return np.sum(np.amax(contingency_matrix,
axis=0)) / np.sum(contingency_matrix)
def ari(inferred, gt):
return adjusted_rand_score(gt, inferred)
def homogeneity(inferred, gt):
return homogeneity_score(gt, inferred)
================================================
FILE: gan_training/metrics/fid.py
================================================
from __future__ import absolute_import, division, print_function
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from scipy import linalg
import pathlib
import urllib
from tqdm import tqdm
import warnings
def check_or_download_inception(inception_path):
''' Checks if the path to the inception file is valid, or downloads
the file if it is not present. '''
INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
if inception_path is None:
inception_path = '/tmp'
inception_path = pathlib.Path(inception_path)
model_file = inception_path / 'classify_image_graph_def.pb'
if not model_file.exists():
print("Downloading Inception model")
from urllib import request
import tarfile
fn, _ = request.urlretrieve(INCEPTION_URL)
with tarfile.open(fn, mode='r') as f:
f.extract('classify_image_graph_def.pb', str(model_file.parent))
return str(model_file)
def create_inception_graph(pth):
"""Creates a graph from saved GraphDef file."""
# Creates graph from saved graph_def.pb.
with tf.io.gfile.GFile(pth, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='FID_Inception_Net')
def calculate_activation_statistics(images,
sess,
batch_size=200,
verbose=False):
"""Calculation of the statistics used by the FID.
Params:
-- images : Numpy array of dimension (n_images, hi, wi, 3). The values
must lie between 0 and 255.
-- sess : current session
-- batch_size : the images numpy array is split into batches with batch size
batch_size. A reasonable batch size depends on the available hardware.
-- verbose : If set to True and parameter out_step is given, the number of calculated
batches is reported.
Returns:
-- mu : The mean over samples of the activations of the pool_3 layer of
the incption model.
-- sigma : The covariance matrix of the activations of the pool_3 layer of
the incption model.
"""
act = get_activations(images, sess, batch_size, verbose)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
# code for handling inception net derived from
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py
def _get_inception_layer(sess):
"""Prepares inception net for batched usage and returns pool_3 layer. """
layername = 'FID_Inception_Net/pool_3:0'
pool3 = sess.graph.get_tensor_by_name(layername)
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if shape._dims != []:
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
return pool3
#-------------------------------------------------------------------------------
def get_activations(images, sess, batch_size=200, verbose=False):
"""Calculates the activations of the pool_3 layer for all images.
Params:
-- images : Numpy array of dimension (n_images, hi, wi, 3). The values
must lie between 0 and 256.
-- sess : current session
-- batch_size : the images numpy array is split into batches with batch size
batch_size. A reasonable batch size depends on the disposable hardware.
-- verbose : If set to True and parameter out_step is given, the number of calculated
batches is reported.
Returns:
-- A numpy array of dimension (num images, 2048) that contains the
activations of the given tensor when feeding inception with the query tensor.
"""
inception_layer = _get_inception_layer(sess)
n_images = images.shape[0]
if batch_size > n_images:
print(
"warning: batch size is bigger than the data size. setting batch size to data size"
)
batch_size = n_images
n_batches = n_images // batch_size
pred_arr = np.empty((n_images, 2048))
for i in tqdm(range(n_batches)):
if verbose:
print("\rPropagating batch %d/%d" % (i + 1, n_batches),
end="",
flush=True)
start = i * batch_size
if start + batch_size < n_images:
end = start + batch_size
else:
end = n_images
batch = images[start:end]
pred = sess.run(inception_layer,
{'FID_Inception_Net/ExpandDims:0': batch})
pred_arr[start:end] = pred.reshape(batch_size, -1)
if verbose:
print(" done")
return pred_arr
#-------------------------------------------------------------------------------
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of the pool_3 layer of the
inception net ( like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations of the pool_3 layer, precalcualted
on an representive data set.
-- sigma1: The covariance matrix over activations of the pool_3 layer for
generated samples.
-- sigma2: The covariance matrix over activations of the pool_3 layer,
precalcualted on an representive data set.
Returns:
-- : The Frechet Distance.
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(
sigma2) - 2 * tr_covmean
def compute_fid_from_npz(path):
print(path)
with np.load(path) as data:
fake_imgs = data['fake']
name = None
for name in ['imagenet', 'cifar', 'places']:
if name in path:
real_imgs = name
break
print('Inferred name', name)
if name is None:
real_imgs = data['real']
if fake_imgs.shape[0] < 1000: return 0
inception_path = check_or_download_inception(None)
create_inception_graph(inception_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m1, s1 = calculate_activation_statistics(fake_imgs, sess)
if isinstance(real_imgs, str):
print(f'using cached image stats for {real_imgs}')
with np.load(precomputed_stats[real_imgs]) as data:
m2, s2 = data['m'], data['s']
else:
print('computing real images stats from scratch')
m2, s2 = calculate_activation_statistics(real_imgs, sess)
return calculate_frechet_distance(m1, s1, m2, s2)
precomputed_stats = {
'places':
'output/places_gt_stats.npz',
'imagenet':
'output/imagenet_gt_stats.npz',
'cifar':
'output/cifar_gt_stats.npz'
}
def compute_fid_from_imgs(fake_imgs, real_imgs):
inception_path = check_or_download_inception(None)
create_inception_graph(inception_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m1, s1 = calculate_activation_statistics(fake_imgs, sess)
if isinstance(real_imgs, str):
with np.load(precomputed_stats[real_imgs]) as data:
m2, s2 = data['m'], data['s']
else:
m2, s2 = calculate_activation_statistics(real_imgs, sess)
return calculate_frechet_distance(m1, s1, m2, s2)
def compute_stats(exp_path):
#TODO: a bit hacky
if 'places' in exp_path and not os.path.exists(precomputed_stats['places']):
with np.load('output/places_gt_imgs.npz') as data_real:
real_imgs = data_real['real']
print('loaded real places images', real_imgs.shape)
inception_path = check_or_download_inception(None)
create_inception_graph(inception_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m, s = calculate_activation_statistics(real_imgs, sess)
np.savez(precomputed_stats['places'], m=m, s=s)
if 'imagenet' in exp_path and not os.path.exists(precomputed_stats['imagenet']):
with np.load('output/imagenet_gt_imgs.npz') as data_real:
real_imgs = data_real['real']
print('loaded real imagenet images', real_imgs.shape)
inception_path = check_or_download_inception(None)
create_inception_graph(inception_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m, s = calculate_activation_statistics(real_imgs, sess)
np.savez(precomputed_stats['imagenet'], m=m, s=s)
if 'cifar' in exp_path and not os.path.exists(precomputed_stats['cifar']):
with np.load('output/cifar_gt_imgs.npz') as data_real:
real_imgs = data_real['real']
print('loaded real cifar images', real_imgs.shape)
inception_path = check_or_download_inception(None)
create_inception_graph(inception_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m, s = calculate_activation_statistics(real_imgs, sess)
np.savez(precomputed_stats['cifar'], m=m, s=s)
if __name__ == '__main__':
import argparse
import json
parser = argparse.ArgumentParser('compute TF FID')
parser.add_argument('--samples', help='path to samples')
parser.add_argument('--it', type=str, help='path to samples')
parser.add_argument('--results_dir', help='path to results_dir')
args = parser.parse_args()
it = args.it
results_dir = args.results_dir
compute_stats(args.samples)
mean = compute_fid_from_npz(args.samples)
print(f'FID: {mean}')
if args.results_dir is not None:
with open(os.path.join(args.results_dir, 'fid_results.json')) as f:
fid_results = json.load(f)
fid_results[it] = mean
print(f'{results_dir} iteration {it} FID: {mean}')
with open(os.path.join(args.results_dir, 'fid_results.json'), 'w') as f:
f.write(json.dumps(fid_results))
================================================
FILE: gan_training/metrics/inception_score.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data
from torchvision.models.inception import inception_v3
import numpy as np
from scipy.stats import entropy
def inception_score(imgs, device=None, batch_size=32, resize=False, splits=1):
"""Computes the inception score of the generated images imgs
Args:
imgs: Torch dataset of (3xHxW) numpy images normalized in the
range [-1, 1]
cuda: whether or not to run on GPU
batch_size: batch size for feeding into Inception v3
splits: number of splits
"""
N = len(imgs)
assert batch_size > 0
assert N > batch_size
# Set up dataloader
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
# Load inception model
inception_model = inception_v3(pretrained=True, transform_input=False)
inception_model = inception_model.to(device)
inception_model.eval()
up = nn.Upsample(size=(299, 299), mode='bilinear').to(device)
def get_pred(x):
with torch.no_grad():
if resize:
x = up(x)
x = inception_model(x)
out = F.softmax(x, dim=-1)
out = out.cpu().numpy()
return out
# Get predictions
preds = np.zeros((N, 1000))
for i, batch in enumerate(dataloader, 0):
batchv = batch.to(device)
batch_size_i = batch.size()[0]
preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)
# Now compute the mean kl-div
split_scores = []
for k in range(splits):
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
return np.mean(split_scores), np.std(split_scores)
================================================
FILE: gan_training/metrics/tf_is/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: gan_training/metrics/tf_is/README.md
================================================
Inception Score
=====================================
A new Tensorflow implementation of the "Inception Score" (IS) for the evaluation of generative models, with a bug raised in [https://github.com/openai/improved-gan/issues/29](https://github.com/openai/improved-gan/issues/29) fixed.
## Major Dependency
- `tensorflow >= 1.14`
## Features
- Fast, easy-to-use and memory-efficient, written in a way that is similar to the original implementation
- No prior knowledge about Tensorflow is necessary if your are using CPU or GPU
- Makes use of [TF-GAN](https://github.com/tensorflow/gan)
- Downloads InceptionV1 automatically
- Compatible with both Python 2 and Python 3
## Usage
- If you are working with GPU, use `inception_score.py`; if you are working with TPU, use `inception_score_tpu.py` and pass a Tensorflow Session and a [TPUStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/TPUStrategy) as additional arguments.
- Call `get_inception_score(images, splits=10)`, where `images` is a numpy array with values ranging from 0 to 255 and shape in the form `[N, 3, HEIGHT, WIDTH]` where `N`, `HEIGHT` and `WIDTH` can be arbitrary. `dtype` of the images is recommended to be `np.uint8` to save CPU memory.
- A smaller `BATCH_SIZE` reduces GPU/TPU memory usage, but at the cost of a slight slowdown.
- If you want to compute a general "Classifier Score" with probabilities `preds` from another classifier, call `preds2score(preds, splits=10)`. `preds` can be a numpy array of arbitrary shape `[N, num_classes]`.
## Links
- The Inception Score was proposed in the paper [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498)
- Code for the [Fréchet Inception Distance](https://github.com/tsc2017/Frechet-Inception-Distance)
================================================
FILE: gan_training/metrics/tf_is/inception_score.py
================================================
'''
From https://github.com/tsc2017/Inception-Score
Code derived from https://github.com/openai/improved-gan/blob/master/inception_score/model.py and https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
Usage:
Call get_inception_score(images, splits=10)
Args:
images: A numpy array with values ranging from 0 to 255 and shape in the form [N, 3, HEIGHT, WIDTH] where N, HEIGHT and WIDTH can be arbitrary. A dtype of np.uint8 is recommended to save CPU memory.
splits: The number of splits of the images, default is 10.
Returns:
Mean and standard deviation of the Inception Score across the splits.
'''
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import functools
import numpy as np
import time
from tqdm import tqdm
from tensorflow.python.ops import array_ops
tfgan = tf.contrib.gan
session=tf.compat.v1.InteractiveSession()
# A smaller BATCH_SIZE reduces GPU memory usage, but at the cost of a slight slowdown
BATCH_SIZE = 64
INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz'
INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb'
# Run images through Inception.
inception_images = tf.compat.v1.placeholder(tf.float32, [None, 3, None, None])
def inception_logits(images = inception_images, num_splits = 1):
images = tf.transpose(images, [0, 2, 3, 1])
size = 299
images = tf.compat.v1.image.resize_bilinear(images, [size, size])
generated_images_list = array_ops.split(images, num_or_size_splits = num_splits)
logits = tf.map_fn(
fn = functools.partial(
tfgan.eval.run_inception,
default_graph_def_fn = functools.partial(
tfgan.eval.get_graph_def_from_url_tarball,
INCEPTION_URL,
INCEPTION_FROZEN_GRAPH,
os.path.basename(INCEPTION_URL)),
output_tensor = 'logits:0'),
elems = array_ops.stack(generated_images_list),
parallel_iterations = 8,
back_prop = False,
swap_memory = True,
name = 'RunClassifier')
logits = array_ops.concat(array_ops.unstack(logits), 0)
return logits
logits=inception_logits()
def get_inception_probs(inps):
n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE))
preds = np.zeros([inps.shape[0], 1000], dtype = np.float32)
for i in tqdm(range(n_batches)):
inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] / 255. * 2 - 1
preds[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(logits,{inception_images: inp})[:, :1000]
preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True)
return preds
def preds2score(preds, splits=10):
scores = []
for i in range(splits):
part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return np.mean(scores), np.std(scores)
def get_inception_score(images, splits=10):
assert(type(images) == np.ndarray)
assert(len(images.shape) == 4)
assert(images.shape[1] == 3)
assert(np.min(images[0]) >= 0 and np.max(images[0]) > 10), 'Image values should be in the range [0, 255]'
print('Calculating Inception Score with %i images in %i splits' % (images.shape[0], splits))
start_time=time.time()
preds = get_inception_probs(images)
mean, std = preds2score(preds, splits)
print('Inception Score calculation time: %f s' % (time.time() - start_time))
return mean, std # Reference values: 11.38 for 50000 CIFAR-10 training set images, or mean=11.31, std=0.10 if in 10 splits.
def compute_is_from_npz(path):
with np.load(path) as data:
fake_imgs = data['fake']
fake_imgs = fake_imgs.transpose(0, 3, 1, 2)
print(fake_imgs.shape)
return get_inception_score(fake_imgs)
if __name__ == '__main__':
import argparse
import json
parser = argparse.ArgumentParser('compute TF IS')
parser.add_argument('--samples', help='path to samples')
parser.add_argument('--it', type=str, help='path to samples')
parser.add_argument('--results_dir', help='path to results_dir')
args = parser.parse_args()
it = args.it
results_dir = args.results_dir
mean, std = compute_is_from_npz(args.samples)
with open(os.path.join(args.results_dir, 'is_results.json')) as f:
is_results = json.load(f)
is_results[it] = float(mean)
print(f'{results_dir} iteration {it} IS: {mean}')
with open(os.path.join(args.results_dir, 'is_results.json'), 'w') as f:
f.write(json.dumps(is_results))
================================================
FILE: gan_training/models/__init__.py
================================================
from gan_training.models import (dcgan_deep, dcgan_shallow, resnet2)
generator_dict = {
'resnet2': resnet2.Generator,
'dcgan_deep': dcgan_deep.Generator,
'dcgan_shallow': dcgan_shallow.Generator
}
discriminator_dict = {
'resnet2': resnet2.Discriminator,
'dcgan_deep': dcgan_deep.Discriminator,
'dcgan_shallow': dcgan_shallow.Discriminator
}
================================================
FILE: gan_training/models/blocks.py
================================================
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
class ResnetBlock(nn.Module):
def __init__(self,
fin,
fout,
bn,
nclasses,
fhidden=None,
is_bias=True):
super().__init__()
# Attributes
self.is_bias = is_bias
self.learned_shortcut = (fin != fout)
self.fin = fin
self.fout = fout
if fhidden is None:
self.fhidden = min(fin, fout)
else:
self.fhidden = fhidden
# Submodules
self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
self.conv_1 = nn.Conv2d(self.fhidden,
self.fout,
3,
stride=1,
padding=1,
bias=is_bias)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(self.fin,
self.fout,
1,
stride=1,
padding=0,
bias=False)
self.bn0 = bn(self.fin, nclasses)
self.bn1 = bn(self.fhidden, nclasses)
def forward(self, x, y):
x_s = self._shortcut(x)
dx = self.conv_0(actvn(self.bn0(x, y)))
dx = self.conv_1(actvn(self.bn1(dx, y)))
out = x_s + 0.1 * dx
return out
def _shortcut(self, x):
if self.learned_shortcut:
x_s = self.conv_s(x)
else:
x_s = x
return x_s
def actvn(x):
out = F.leaky_relu(x, 2e-1)
return out
class LatentEmbeddingConcat(nn.Module):
''' projects class embedding onto hypersphere and returns the concat of the latent and the class embedding '''
def __init__(self, nlabels, embed_dim):
super().__init__()
self.embedding = nn.Embedding(nlabels, embed_dim)
def forward(self, z, y):
assert (y.size(0) == z.size(0))
yembed = self.embedding(y)
yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True)
yz = torch.cat([z, yembed], dim=1)
return yz
class NormalizeLinear(nn.Module):
def __init__(self, act_dim, k_value):
super().__init__()
self.lin = nn.Linear(act_dim, k_value)
def normalize(self):
self.lin.weight.data = F.normalize(self.lin.weight.data, p=2, dim=1)
def forward(self, x):
self.normalize()
return self.lin(x)
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, inp, *args, **kwargs):
return inp
class LinearConditionalMaskLogits(nn.Module):
''' runs activated logits through fc and masks out the appropriate discriminator score according to class number'''
def __init__(self, nc, nlabels):
super().__init__()
self.fc = nn.Linear(nc, nlabels)
def forward(self, inp, y=None, take_best=False, get_features=False):
out = self.fc(inp)
if get_features: return out
if not take_best:
y = y.view(-1)
index = Variable(torch.LongTensor(range(out.size(0))))
if y.is_cuda:
index = index.cuda()
return out[index, y]
else:
# high activation means real, so take the highest activations
best_logits, _ = out.max(dim=1)
return best_logits
class ProjectionDiscriminatorLogits(nn.Module):
''' takes in activated flattened logits before last linear layer and implements https://arxiv.org/pdf/1802.05637.pdf '''
def __init__(self, nc, nlabels):
super().__init__()
self.fc = nn.Linear(nc, 1)
self.embedding = nn.Embedding(nlabels, nc)
self.nlabels = nlabels
def forward(self, x, y, take_best=False):
output = self.fc(x)
if not take_best:
label_info = torch.sum(self.embedding(y) * x, dim=1, keepdim=True)
return (output + label_info).view(x.size(0))
else:
#TODO: this may be computationally expensive, maybe we want to do the global pooling first to reduce x's size
index = torch.LongTensor(range(self.nlabels)).cuda()
labels = index.repeat((x.size(0), ))
x = x.repeat_interleave(self.nlabels, dim=0)
label_info = torch.sum(self.embedding(labels) * x,
dim=1,
keepdim=True).view(output.size(0),
self.nlabels)
# high activation means real, so take the highest activations
best_logits, _ = label_info.max(dim=1)
return output.view(output.size(0)) + best_logits
class LinearUnconditionalLogits(nn.Module):
''' standard discriminator logit layer '''
def __init__(self, nc):
super().__init__()
self.fc = nn.Linear(nc, 1)
def forward(self, inp, y, take_best=False):
assert (take_best == False)
out = self.fc(inp)
return out.view(out.size(0))
class Reshape(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
batch_size = x.shape[0]
return x.view(*((batch_size, ) + self.shape))
class ConditionalBatchNorm2d(nn.Module):
''' from https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775 '''
def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm2d(num_features, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].normal_(
1, 0.02) # Initialize scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_(
) # Initialize bias at 0
def forward(self, x, y):
out = self.bn(x)
gamma, beta = self.embed(y).chunk(2, 1)
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(
-1, self.num_features, 1, 1)
return out
class BatchNorm2d(nn.Module):
''' identical to nn.BatchNorm2d but takes in y input that is ignored '''
def __init__(self, nc, nchannels, **kwargs):
super().__init__()
self.bn = nn.BatchNorm2d(nc)
def forward(self, x, y):
return self.bn(x)
================================================
FILE: gan_training/models/dcgan_deep.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data
import torch.utils.data.distributed
from gan_training.models import blocks
class Generator(nn.Module):
def __init__(self,
nlabels,
conditioning,
z_dim=128,
nc=3,
ngf=64,
embed_dim=256,
**kwargs):
super(Generator, self).__init__()
assert conditioning != 'unconditional' or nlabels == 1
if conditioning == 'embedding':
self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_dim)
self.fc = nn.Linear(z_dim + embed_dim, 4 * 4 * ngf * 8)
elif conditioning == 'unconditional':
self.get_latent = blocks.Identity()
self.fc = nn.Linear(z_dim, 4 * 4 * ngf * 8)
else:
raise NotImplementedError(
f"{conditioning} not implemented for generator")
bn = blocks.BatchNorm2d
self.nlabels = nlabels
self.conv1 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1)
self.bn1 = bn(ngf * 4, nlabels)
self.conv2 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1)
self.bn2 = bn(ngf * 2, nlabels)
self.conv3 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1)
self.bn3 = bn(ngf, nlabels)
self.conv_out = nn.Sequential(nn.Conv2d(ngf, nc, 3, 1, 1), nn.Tanh())
def forward(self, input, y):
y = y.clamp(None, self.nlabels - 1)
out = self.get_latent(input, y)
out = self.fc(out)
out = out.view(out.size(0), -1, 4, 4)
out = F.relu(self.bn1(self.conv1(out), y))
out = F.relu(self.bn2(self.conv2(out), y))
out = F.relu(self.bn3(self.conv3(out), y))
return self.conv_out(out)
class Discriminator(nn.Module):
def __init__(self,
nlabels,
conditioning,
nc=3,
ndf=64,
pack_size=1,
features='penultimate',
**kwargs):
super(Discriminator, self).__init__()
assert conditioning != 'unconditional' or nlabels == 1
self.nlabels = nlabels
self.conv1 = nn.Sequential(nn.Conv2d(nc * pack_size, ndf, 3, 1, 1), nn.LeakyReLU(0.1))
self.conv2 = nn.Sequential(nn.Conv2d(ndf, ndf, 4, 2, 1), nn.LeakyReLU(0.1))
self.conv3 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, 3, 1, 1), nn.LeakyReLU(0.1))
self.conv4 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 2, 4, 2, 1), nn.LeakyReLU(0.1))
self.conv5 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, 3, 1, 1), nn.LeakyReLU(0.1))
self.conv6 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1), nn.LeakyReLU(0.1))
self.conv7 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, 3, 1, 1), nn.LeakyReLU(0.1))
if conditioning == 'mask':
self.fc_out = blocks.LinearConditionalMaskLogits(
ndf * 8 * 4 * 4, nlabels)
elif conditioning == 'unconditional':
self.fc_out = blocks.LinearUnconditionalLogits(
ndf * 8 * 4 * 4)
else:
raise NotImplementedError(
f"{conditioning} not implemented for discriminator")
self.features = features
self.pack_size = pack_size
print(f'Getting features from {self.features}')
def stack(self, x):
#pacgan
nc = self.pack_size
assert (x.size(0) % nc == 0)
if nc == 1:
return x
x_new = []
for i in range(x.size(0) // nc):
imgs_to_stack = x[i * nc:(i + 1) * nc]
x_new.append(torch.cat([t for t in imgs_to_stack], dim=0))
return torch.stack(x_new)
def forward(self, input, y=None, get_features=False):
input = self.stack(input)
out = self.conv1(input)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
out = self.conv5(out)
out = self.conv6(out)
out = self.conv7(out)
if get_features and self.features == "penultimate":
return out.view(out.size(0), -1)
if get_features and self.features == "summed":
return out.view(out.size(0), out.size(1), -1).sum(dim=2)
out = out.view(out.size(0), -1)
y = y.clamp(None, self.nlabels - 1)
result = self.fc_out(out, y)
assert (len(result.shape) == 1)
return result
if __name__ == '__main__':
z = torch.zeros((1, 128))
g = Generator()
x = torch.zeros((1, 3, 32, 32))
d = Discriminator()
g(z)
d(g(z))
d(x)
================================================
FILE: gan_training/models/dcgan_shallow.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data
import torch.utils.data.distributed
from gan_training.models import blocks
class Generator(nn.Module):
def __init__(self,
nlabels,
conditioning,
z_dim=128,
nc=3,
ngf=64,
embed_dim=256,
**kwargs):
super(Generator, self).__init__()
assert conditioning != 'unconditional' or nlabels == 1
if conditioning == 'embedding':
self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_dim)
self.fc = nn.Linear(z_dim + embed_dim, 4 * 4 * ngf * 8)
elif conditioning == 'unconditional':
self.get_latent = blocks.Identity()
self.fc = nn.Linear(z_dim, 4 * 4 * ngf * 8)
else:
raise NotImplementedError(
f"{conditioning} not implemented for generator")
bn = blocks.BatchNorm2d
self.nlabels = nlabels
self.conv1 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1)
self.bn1 = bn(ngf * 4, nlabels)
self.conv2 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1)
self.bn2 = bn(ngf * 2, nlabels)
self.conv3 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1)
self.bn3 = bn(ngf, nlabels)
self.conv_out = nn.Sequential(nn.Conv2d(ngf, nc, 3, 1, 1), nn.Tanh())
def forward(self, input, y):
y = y.clamp(None, self.nlabels - 1)
out = self.get_latent(input, y)
out = self.fc(out)
out = out.view(out.size(0), -1, 4, 4)
out = F.relu(self.bn1(self.conv1(out), y))
out = F.relu(self.bn2(self.conv2(out), y))
out = F.relu(self.bn3(self.conv3(out), y))
return self.conv_out(out)
class Discriminator(nn.Module):
def __init__(self,
nlabels,
conditioning,
features='penultimate',
pack_size=1,
nc=3,
ndf=64,
**kwargs):
super(Discriminator, self).__init__()
assert conditioning != 'unconditional' or nlabels == 1
self.nlabels = nlabels
self.conv1 = nn.Sequential(nn.Conv2d(nc * pack_size, ndf, 4, 2, 1),
nn.BatchNorm2d(ndf),
nn.LeakyReLU(0.2, inplace=True))
self.conv2 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True))
self.conv3 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True))
self.conv4 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True))
if conditioning == 'mask':
self.fc_out = blocks.LinearConditionalMaskLogits(ndf * 8 * 4 , nlabels)
elif conditioning == 'unconditional':
self.fc_out = blocks.LinearUnconditionalLogits(ndf * 8 * 4)
else:
raise NotImplementedError(
f"{conditioning} not implemented for discriminator")
self.pack_size = pack_size
self.features = features
print(f'Getting features from {self.features}')
def stack(self, x):
#pacgan
nc = self.pack_size
if nc == 1:
return x
x_new = []
for i in range(x.size(0) // nc):
imgs_to_stack = x[i * nc:(i + 1) * nc]
x_new.append(torch.cat([t for t in imgs_to_stack], dim=0))
return torch.stack(x_new)
def forward(self, input, y=None, get_features=False):
input = self.stack(input)
out = self.conv1(input)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
out = out.view(out.size(0), -1)
if get_features: return out.view(out.size(0), -1)
y = y.clamp(None, self.nlabels - 1)
result = self.fc_out(out, y)
assert (len(result.shape) == 1)
return result
if __name__ == '__main__':
z = torch.zeros((1, 128))
g = Generator()
x = torch.zeros((1, 3, 32, 32))
d = Discriminator()
g(z)
d(g(z))
d(x)
================================================
FILE: gan_training/models/resnet2.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torch.utils.data
import torch.utils.data.distributed
from gan_training.models import blocks
from gan_training.models.blocks import ResnetBlock
from torch.nn.utils.spectral_norm import spectral_norm
class Generator(nn.Module):
def __init__(self,
z_dim,
nlabels,
size,
conditioning,
embed_size=256,
nfilter=64,
**kwargs):
super().__init__()
s0 = self.s0 = size // 32
nf = self.nf = nfilter
self.nlabels = nlabels
self.z_dim = z_dim
assert conditioning != 'unconditional' or nlabels == 1
if conditioning == 'embedding':
self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_size)
self.fc = nn.Linear(z_dim + embed_size, 16 * nf * s0 * s0)
elif conditioning == 'unconditional':
self.get_latent = blocks.Identity()
self.fc = nn.Linear(z_dim, 16 * nf * s0 * s0)
else:
raise NotImplementedError(
f"{conditioning} not implemented for generator")
#either use conditional batch norm, or use no batch norm
bn = blocks.Identity
self.resnet_0_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_0_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_1_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_1_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_2_0 = ResnetBlock(16 * nf, 8 * nf, bn, nlabels)
self.resnet_2_1 = ResnetBlock(8 * nf, 8 * nf, bn, nlabels)
self.resnet_3_0 = ResnetBlock(8 * nf, 4 * nf, bn, nlabels)
self.resnet_3_1 = ResnetBlock(4 * nf, 4 * nf, bn, nlabels)
self.resnet_4_0 = ResnetBlock(4 * nf, 2 * nf, bn, nlabels)
self.resnet_4_1 = ResnetBlock(2 * nf, 2 * nf, bn, nlabels)
self.resnet_5_0 = ResnetBlock(2 * nf, 1 * nf, bn, nlabels)
self.resnet_5_1 = ResnetBlock(1 * nf, 1 * nf, bn, nlabels)
self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)
def forward(self, z, y):
y = y.clamp(None, self.nlabels - 1)
out = self.get_latent(z, y)
out = self.fc(out)
out = out.view(z.size(0), 16 * self.nf, self.s0, self.s0)
out = self.resnet_0_0(out, y)
out = self.resnet_0_1(out, y)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_1_0(out, y)
out = self.resnet_1_1(out, y)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_2_0(out, y)
out = self.resnet_2_1(out, y)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_3_0(out, y)
out = self.resnet_3_1(out, y)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_4_0(out, y)
out = self.resnet_4_1(out, y)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_5_0(out, y)
out = self.resnet_5_1(out, y)
out = self.conv_img(actvn(out))
out = torch.tanh(out)
return out
class Discriminator(nn.Module):
def __init__(self,
nlabels,
size,
conditioning,
nfilter=64,
features='penultimate',
**kwargs):
super().__init__()
s0 = self.s0 = size // 32
nf = self.nf = nfilter
self.nlabels = nlabels
assert conditioning != 'unconditional' or nlabels == 1
bn = blocks.Identity
self.conv_img = nn.Conv2d(3, 1 * nf, 3, padding=1)
self.resnet_0_0 = ResnetBlock(1 * nf, 1 * nf, bn, nlabels)
self.resnet_0_1 = ResnetBlock(1 * nf, 2 * nf, bn, nlabels)
self.resnet_1_0 = ResnetBlock(2 * nf, 2 * nf, bn, nlabels)
self.resnet_1_1 = ResnetBlock(2 * nf, 4 * nf, bn, nlabels)
self.resnet_2_0 = ResnetBlock(4 * nf, 4 * nf, bn, nlabels)
self.resnet_2_1 = ResnetBlock(4 * nf, 8 * nf, bn, nlabels)
self.resnet_3_0 = ResnetBlock(8 * nf, 8 * nf, bn, nlabels)
self.resnet_3_1 = ResnetBlock(8 * nf, 16 * nf, bn, nlabels)
self.resnet_4_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_4_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_5_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_5_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
if conditioning == 'mask':
self.fc_out = blocks.LinearConditionalMaskLogits(
16 * nf * s0 * s0, nlabels)
elif conditioning == 'unconditional':
self.fc_out = blocks.LinearUnconditionalLogits(16 * nf * s0 * s0)
else:
raise NotImplementedError(
f"{conditioning} not implemented for discriminator")
self.features = features
def forward(self, x, y=None, get_features=False):
batch_size = x.size(0)
if y is not None:
y = y.clamp(None, self.nlabels - 1)
out = self.conv_img(x)
out = self.resnet_0_0(out, y)
out = self.resnet_0_1(out, y)
out = F.avg_pool2d(out, 3, stride=2, padding=1)
out = self.resnet_1_0(out, y)
out = self.resnet_1_1(out, y)
out = F.avg_pool2d(out, 3, stride=2, padding=1)
out = self.resnet_2_0(out, y)
out = self.resnet_2_1(out, y)
out = F.avg_pool2d(out, 3, stride=2, padding=1)
out = self.resnet_3_0(out, y)
out = self.resnet_3_1(out, y)
out = F.avg_pool2d(out, 3, stride=2, padding=1)
out = self.resnet_4_0(out, y)
out = self.resnet_4_1(out, y)
out = F.avg_pool2d(out, 3, stride=2, padding=1)
out = self.resnet_5_0(out, y)
out = self.resnet_5_1(out, y)
out = actvn(out)
if get_features and self.features == 'summed':
return out.view(out.size(0), out.size(1), -1).sum(dim=2)
out = out.view(batch_size, 16 * self.nf * self.s0 * self.s0)
if get_features: return out.view(batch_size, -1)
result = self.fc_out(out, y)
assert (len(result.shape) == 1)
return result
def actvn(x):
out = F.leaky_relu(x, 2e-1)
return out
================================================
FILE: gan_training/models/resnet2s.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torch.utils.data
import torch.utils.data.distributed
from collections import OrderedDict
class Reshape(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
batch_size = x.shape[0]
return x.view(*((batch_size, ) + self.shape))
class Generator(nn.Module):
'''
Perfectly equivalent to resnet2.Generator (can load state dicts
from that class), but organizes layers as a sequence for more
automatic inversion.
'''
def __init__(self,
z_dim,
nlabels,
size,
embed_size=256,
nfilter=64,
use_class_labels=False,
**kwargs):
super().__init__()
s0 = self.s0 = size // 32
nf = self.nf = nfilter
self.z_dim = z_dim
self.use_class_labels = use_class_labels
# Submodules
if use_class_labels:
self.condition = ConditionGen(z_dim, nlabels, embed_size)
latent_dim = self.condition.latent_dim
else:
latent_dim = z_dim
self.layers = nn.Sequential(
OrderedDict([('fc', nn.Linear(latent_dim, 16 * nf * s0 * s0)),
('reshape', Reshape(16 * self.nf, self.s0, self.s0)),
('resnet_0_0', ResnetBlock(16 * nf, 16 * nf)),
('resnet_0_1', ResnetBlock(16 * nf, 16 * nf)),
('upsample_1', nn.Upsample(scale_factor=2)),
('resnet_1_0', ResnetBlock(16 * nf, 16 * nf)),
('resnet_1_1', ResnetBlock(16 * nf, 16 * nf)),
('upsample_2', nn.Upsample(scale_factor=2)),
('resnet_2_0', ResnetBlock(16 * nf, 8 * nf)),
('resnet_2_1', ResnetBlock(8 * nf, 8 * nf)),
('upsample_3', nn.Upsample(scale_factor=2)),
('resnet_3_0', ResnetBlock(8 * nf, 4 * nf)),
('resnet_3_1', ResnetBlock(4 * nf, 4 * nf)),
('upsample_4', nn.Upsample(scale_factor=2)),
('resnet_4_0', ResnetBlock(4 * nf, 2 * nf)),
('resnet_4_1', ResnetBlock(2 * nf, 2 * nf)),
('upsample_5', nn.Upsample(scale_factor=2)),
('resnet_5_0', ResnetBlock(2 * nf, 1 * nf)),
('resnet_5_1', ResnetBlock(1 * nf, 1 * nf)),
('img_relu', nn.LeakyReLU(2e-1)),
('conv_img', nn.Conv2d(nf, 3, 3, padding=1)),
('tanh', nn.Tanh())]))
def forward(self, z, y=None):
assert (y is None or z.size(0) == y.size(0))
assert (not self.use_class_labels or y is not None)
batch_size = z.size(0)
if self.use_class_labels:
z = self.condition(z, y)
return self.layers(z)
def load_v2_state_dict(self, state_dict):
converted = {}
for k, v in state_dict.items():
if 'module.' in k: k = k.split('module.')[1]
if k.startswith('embedding'):
k = 'condition.' + k
elif k == 'get_latent.embedding.weight':
k = 'condition.embedding.weight'
else:
k = 'layers.' + k
converted[k] = v
self.load_state_dict(converted)
class ConditionGen(nn.Module):
def __init__(self, z_dim, nlabels, embed_size=256):
super().__init__()
self.embedding = nn.Embedding(nlabels, embed_size)
self.latent_dim = z_dim + embed_size
self.z_dim = z_dim
self.nlabels = nlabels
self.embed_size = embed_size
def forward(self, z, y):
assert (z.size(0) == y.size(0))
batch_size = z.size(0)
if y.dtype is torch.int64:
yembed = self.embedding(y)
else:
yembed = y
yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True)
return torch.cat([z, yembed], dim=1)
def convert_from_resnet2_generator(gen):
nlabels, embed_size = 0, 0
use_class_labels = False
if hasattr(gen, 'embedding'):
# new version does not have gen.use_class_labels..
nlabels = gen.embedding.num_embeddings
embed_size = gen.embedding.embedding_dim
use_class_labels = True
if hasattr(gen, 'get_latent'):
# new version does not have gen.use_class_labels..
nlabels = gen.get_latent.embedding.num_embeddings
embed_size = gen.get_latent.embedding.embedding_dim
use_class_labels = True
size = gen.s0 * 32
newgen = Generator(gen.z_dim, nlabels, size, embed_size, gen.nf,
use_class_labels)
newgen.load_v2_state_dict(gen.state_dict())
return newgen
class ResnetBlock(nn.Module):
def __init__(self, fin, fout, fhidden=None, is_bias=True):
super().__init__()
# Attributes
self.is_bias = is_bias
self.learned_shortcut = (fin != fout)
self.fin = fin
self.fout = fout
if fhidden is None:
self.fhidden = min(fin, fout)
else:
self.fhidden = fhidden
# Submodules
self.conv_0 = nn.Conv2d(self.fin,
self.fhidden,
kernel_size=3,
stride=1,
padding=1)
self.conv_1 = nn.Conv2d(self.fhidden,
self.fout,
kernel_size=3,
stride=1,
padding=1,
bias=is_bias)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(self.fin,
self.fout,
kernel_size=1,
stride=1,
padding=0,
bias=False)
def forward(self, x):
x_s = self._shortcut(x)
dx = self.conv_0(actvn(x))
dx = self.conv_1(actvn(dx))
out = x_s + 0.1 * dx
return out
def _shortcut(self, x):
if self.learned_shortcut:
x_s = self.conv_s(x)
else:
x_s = x
return x_s
def actvn(x):
out = F.leaky_relu(x, 2e-1)
return out
================================================
FILE: gan_training/models/resnet3.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torch.utils.data
import torch.utils.data.distributed
from collections import OrderedDict
class Generator(nn.Module):
'''
Perfectly equivalent to resnet2.Generator (can load state dicts
from that class), but organizes layers as a sequence for more
automatic inversion.
'''
def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64,
use_class_labels=False, **kwargs):
super().__init__()
s0 = self.s0 = size // 32
nf = self.nf = nfilter
self.z_dim = z_dim
self.use_class_labels = use_class_labels
# Submodules
if use_class_labels:
self.condition = ConditionGen(z_dim, nlabels, embed_size)
latent_dim = self.condition.latent_dim
else:
latent_dim = z_dim
self.layers = nn.Sequential(OrderedDict([
('fc', nn.Linear(latent_dim, 16*nf*s0*s0)),
('reshape', Reshape(16*self.nf, self.s0, self.s0)),
('resnet_0_0', ResnetBlock(16*nf, 16*nf)),
('resnet_0_1', ResnetBlock(16*nf, 16*nf)),
('upsample_1', nn.Upsample(scale_factor=2)),
('resnet_1_0', ResnetBlock(16*nf, 16*nf)),
('resnet_1_1', ResnetBlock(16*nf, 16*nf)),
('upsample_2', nn.Upsample(scale_factor=2)),
('resnet_2_0', ResnetBlock(16*nf, 8*nf)),
('resnet_2_1', ResnetBlock(8*nf, 8*nf)),
('upsample_3', nn.Upsample(scale_factor=2)),
('resnet_3_0', ResnetBlock(8*nf, 4*nf)),
('resnet_3_1', ResnetBlock(4*nf, 4*nf)),
('upsample_4', nn.Upsample(scale_factor=2)),
('resnet_4_0', ResnetBlock(4*nf, 2*nf)),
('resnet_4_1', ResnetBlock(2*nf, 2*nf)),
('upsample_5', nn.Upsample(scale_factor=2)),
('resnet_5_0', ResnetBlock(2*nf, 1*nf)),
('resnet_5_1', ResnetBlock(1*nf, 1*nf)),
('img_relu', nn.LeakyReLU(2e-1)),
('conv_img', nn.Conv2d(nf, 3, 3, padding=1)),
('tanh', nn.Tanh())
]))
def forward(self, z, y=None):
assert(y is None or z.size(0) == y.size(0))
assert(not self.use_class_labels or y is not None)
batch_size = z.size(0)
if self.use_class_labels:
z = self.condition(z, y)
return self.layers(z)
def load_v2_state_dict(self, state_dict):
converted = {}
for k, v in state_dict.items():
if k.startswith('embedding'):
k = 'condition.' + k
elif k == 'get_latent.embedding.weight':
k = 'condition.embedding.weight'
else:
k = 'layers.' + k
converted[k] = v
self.load_state_dict(converted)
class Reshape(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
batch_size = x.shape[0]
return x.view(*((batch_size,) + self.shape))
class ConditionGen(nn.Module):
def __init__(self, z_dim, nlabels, embed_size=256):
super().__init__()
self.embedding = nn.Embedding(nlabels, embed_size)
self.latent_dim = z_dim + embed_size
self.z_dim = z_dim
self.nlabels = nlabels
self.embed_size = embed_size
def forward(self, z, y):
assert(z.size(0) == y.size(0))
batch_size = z.size(0)
if y.dtype is torch.int64:
yembed = self.embedding(y)
else:
yembed = y
yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True)
return torch.cat([z, yembed], dim=1)
def convert_from_resnet2_generator(gen):
nlabels, embed_size = 0, 0
if hasattr(gen, 'get_latent'):
# new version does not have gen.use_class_labels..
nlabels = gen.get_latent.embedding.num_embeddings
embed_size = gen.get_latent.embedding.embedding_dim
use_class_labels = True
elif gen.use_class_labels:
nlabels = gen.embedding.num_embeddings
embed_size = gen.embedding.embedding_dim
use_class_labels = True
size = gen.s0 * 32
newgen = Generator(gen.z_dim, nlabels, size, embed_size, gen.nf, use_class_labels)
newgen.load_v2_state_dict(gen.state_dict())
return newgen
class ResnetBlock(nn.Module):
def __init__(self, fin, fout, fhidden=None, is_bias=True):
super().__init__()
# Attributes
self.is_bias = is_bias
self.learned_shortcut = (fin != fout)
self.fin = fin
self.fout = fout
if fhidden is None:
self.fhidden = min(fin, fout)
else:
self.fhidden = fhidden
# Submodules
self.conv_0 = nn.Conv2d(self.fin, self.fhidden,
kernel_size=3, stride=1, padding=1)
self.conv_1 = nn.Conv2d(self.fhidden, self.fout,
kernel_size=3, stride=1, padding=1, bias=is_bias)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(self.fin, self.fout,
kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, x):
x_s = self._shortcut(x)
dx = self.conv_0(actvn(x))
dx = self.conv_1(actvn(dx))
out = x_s + 0.1*dx
return out
def _shortcut(self, x):
if self.learned_shortcut:
x_s = self.conv_s(x)
else:
x_s = x
return x_s
def actvn(x):
out = F.leaky_relu(x, 2e-1)
return out
================================================
FILE: gan_training/train.py
================================================
# coding: utf-8
import torch
from torch.nn import functional as F
import torch.utils.data
import torch.utils.data.distributed
from torch import autograd
import numpy as np
class Trainer(object):
def __init__(self,
generator,
discriminator,
g_optimizer,
d_optimizer,
gan_type,
reg_type,
reg_param):
self.generator = generator
self.discriminator = discriminator
self.g_optimizer = g_optimizer
self.d_optimizer = d_optimizer
self.gan_type = gan_type
self.reg_type = reg_type
self.reg_param = reg_param
print('D reg gamma', self.reg_param)
def generator_trainstep(self, y, z):
assert (y.size(0) == z.size(0))
toggle_grad(self.generator, True)
toggle_grad(self.discriminator, False)
self.generator.train()
self.discriminator.train()
self.g_optimizer.zero_grad()
x_fake = self.generator(z, y)
d_fake = self.discriminator(x_fake, y)
gloss = self.compute_loss(d_fake, 1)
gloss.backward()
self.g_optimizer.step()
return gloss.item()
def discriminator_trainstep(self, x_real, y, z):
toggle_grad(self.generator, False)
toggle_grad(self.discriminator, True)
self.generator.train()
self.discriminator.train()
self.d_optimizer.zero_grad()
# On real data
x_real.requires_grad_()
d_real = self.discriminator(x_real, y)
dloss_real = self.compute_loss(d_real, 1)
if self.reg_type == 'real' or self.reg_type == 'real_fake':
dloss_real.backward(retain_graph=True)
reg = self.reg_param * compute_grad2(d_real, x_real).mean()
reg.backward()
else:
dloss_real.backward()
# On fake data
with torch.no_grad():
x_fake = self.generator(z, y)
x_fake.requires_grad_()
d_fake = self.discriminator(x_fake, y)
dloss_fake = self.compute_loss(d_fake, 0)
if self.reg_type == 'fake' or self.reg_type == 'real_fake':
dloss_fake.backward(retain_graph=True)
reg = self.reg_param * compute_grad2(d_fake, x_fake).mean()
reg.backward()
else:
dloss_fake.backward()
if self.reg_type == 'wgangp':
reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake, y)
reg.backward()
elif self.reg_type == 'wgangp0':
reg = self.reg_param * self.wgan_gp_reg(
x_real, x_fake, y, center=0.)
reg.backward()
self.d_optimizer.step()
dloss = (dloss_real + dloss_fake)
if self.reg_type == 'none':
reg = torch.tensor(0.)
return dloss.item(), reg.item()
def compute_loss(self, d_out, target):
targets = d_out.new_full(size=d_out.size(), fill_value=target)
if self.gan_type == 'standard':
loss = F.binary_cross_entropy_with_logits(d_out, targets)
elif self.gan_type == 'wgan':
loss = (2 * target - 1) * d_out.mean()
else:
raise NotImplementedError
return loss
def wgan_gp_reg(self, x_real, x_fake, y, center=1.):
batch_size = y.size(0)
eps = torch.rand(batch_size, device=y.device).view(batch_size, 1, 1, 1)
x_interp = (1 - eps) * x_real + eps * x_fake
x_interp = x_interp.detach()
x_interp.requires_grad_()
d_out = self.discriminator(x_interp, y)
reg = (compute_grad2(d_out, x_interp).sqrt() - center).pow(2).mean()
return reg
# Utility functions
def toggle_grad(model, requires_grad):
for p in model.parameters():
p.requires_grad_(requires_grad)
def compute_grad2(d_out, x_in):
batch_size = x_in.size(0)
grad_dout = autograd.grad(outputs=d_out.sum(),
inputs=x_in,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
grad_dout2 = grad_dout.pow(2)
assert (grad_dout2.size() == x_in.size())
reg = grad_dout2.view(batch_size, -1).sum(1)
return reg
def update_average(model_tgt, model_src, beta):
toggle_grad(model_src, False)
toggle_grad(model_tgt, False)
param_dict_src = dict(model_src.named_parameters())
for p_name, p_tgt in model_tgt.named_parameters():
p_src = param_dict_src[p_name]
assert (p_src is not p_tgt)
p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src)
================================================
FILE: gan_training/utils.py
================================================
import torch
import torch.utils.data
import torch.utils.data.distributed
import torchvision
import os
def save_images(imgs, outfile, nrow=8):
imgs = imgs / 2 + 0.5 # unnormalize
torchvision.utils.save_image(imgs, outfile, nrow=nrow)
def get_nsamples(data_loader, N):
x = []
y = []
n = 0
for x_next, y_next in data_loader:
x.append(x_next)
y.append(y_next)
n += x_next.size(0)
if n > N:
break
x = torch.cat(x, dim=0)[:N]
y = torch.cat(y, dim=0)[:N]
return x, y
def update_average(model_tgt, model_src, beta):
param_dict_src = dict(model_src.named_parameters())
for p_name, p_tgt in model_tgt.named_parameters():
p_src = param_dict_src[p_name]
assert (p_src is not p_tgt)
p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src)
def get_most_recent(d, ext):
if not os.path.exists(d):
print('Directory', d, 'does not exist')
return -1
its = []
for f in os.listdir(d):
try:
it = int(f.split(ext + "_")[1].split('.pt')[0])
its.append(it)
except Exception as e:
pass
if len(its) == 0:
print('Found no files with extension \"%s\" under %s' % (ext, d))
return -1
return max(its)
================================================
FILE: metrics.py
================================================
import argparse
import os
import json
from tqdm import tqdm
import numpy as np
import torch
from gan_training.config import load_config
from seeded_sampler import SeededSampler
parser = argparse.ArgumentParser('Computes numbers used in paper and caches them to a result files. Examples include FID, IS, reverse-KL, # modes, FSD, cluster NMI, Purity.')
parser.add_argument('paths', nargs='+', type=str, help='list of configs for each experiment')
parser.add_argument('--it', type=int, default=-1, help='If set, computes numbers only for that iteration')
parser.add_argument('--every', type=int, default=-1, help='skips some checkpoints and only computes those whose iteration number are divisible by every')
parser.add_argument('--fid', action='store_true', help='compute FID metric')
parser.add_argument('--inception', action='store_true', help='compute IS metric')
parser.add_argument('--modes', action='store_true', help='compute # modes and reverse-KL metric')
parser.add_argument('--fsd', action='store_true', help='compute FSD metric')
parser.add_argument('--cluster_metrics', action='store_true', help='compute clustering metrics (NMI, purity)')
parser.add_argument('--device', type=int, default=1, help='device to run the metrics on (can run into OOM issues if same as main device)')
args = parser.parse_args()
device = args.device
dirs = list(args.paths)
N = 50000
BS = 100
datasets = ['imagenet', 'cifar', 'stacked_mnist', 'places']
dataset_to_img = {
'places': 'output/places_gt_imgs.npz',
'imagenet': 'output/imagenet_gt_imgs.npz'}
def load_results(results_dir):
results = []
for results_file in ['fid_results.json', 'is_results.json', 'kl_results.json', 'nmodes_results.json', 'fsd_results.json', 'cluster_metrics.json']:
results_file = os.path.join(results_dir, results_file)
if not os.path.exists(results_file):
with open(results_file, 'w') as f:
f.write(json.dumps({}))
with open(results_file) as f:
results.append(json.load(f))
return results
def get_dataset_from_path(path):
for name in datasets:
if name in path:
print('Inferred dataset:', name)
return name
def pt_to_np(imgs):
'''normalizes pytorch image in [-1, 1] to [0, 255]'''
return (imgs.permute(0, 2, 3, 1).mul_(0.5).add_(0.5).mul_(255)).clamp_(0, 255).numpy()
def sample(sampler):
with torch.no_grad():
samples = []
for _ in tqdm(range(N // BS + 1)):
x_real = sampler.sample(BS)[0].detach().cpu()
x_real = [x.detach().cpu() for x in x_real]
samples.extend(x_real)
samples = torch.stack(samples[:N], dim=0)
return pt_to_np(samples)
root = './'
while len(dirs) > 0:
path = dirs.pop()
if os.path.isdir(path): # search down tree for config files
for d1 in os.listdir(path):
dirs.append(os.path.join(path, d1))
else:
if path.endswith('.yaml'):
config = load_config(path, default_path='configs/default.yaml')
outdir = config['training']['out_dir']
if not os.path.exists(outdir) and config['pretrained'] == {}:
print('Skipping', path, 'outdir', outdir)
continue
results_dir = os.path.join(outdir, 'results')
checkpoint_dir = os.path.join(outdir, 'chkpts')
os.makedirs(results_dir, exist_ok=True)
fid_results, is_results, kl_results, nmodes_results, fsd_results, cluster_results = load_results(results_dir)
checkpoint_files = os.listdir(checkpoint_dir) if os.path.exists(checkpoint_dir) else []
if config['pretrained'] != {}:
checkpoint_files = checkpoint_files + ['pretrained']
for checkpoint in checkpoint_files:
if (checkpoint.endswith('.pt') and checkpoint != 'model.pt') or checkpoint == 'pretrained':
print('Computing for', checkpoint)
if 'model' in checkpoint:
# infer iteration number from checkpoint file w/o loading it
if 'model_' in checkpoint:
it = int(checkpoint.split('model_')[1].split('.pt')[0])
else:
continue
if args.every != 0 and it % args.every != 0:
continue
# iteration 0 is often useless, skip it
if it == 0 or args.it != -1 and it != args.it:
continue
elif checkpoint == 'pretrained':
it = 'pretrained'
it = str(it)
clusterer_path = os.path.join(root, checkpoint_dir, f'clusterer{it}.pkl')
# don't save samples for each iteration for disk space
samples_path = os.path.join(outdir, 'results', 'samples.npz')
targets = []
if args.inception:
targets = targets + [is_results]
if args.fid:
targets = targets + [fid_results]
if args.modes:
targets = targets + [kl_results, nmodes_results]
if args.fsd:
targets = targets + [fsd_results]
if all([it in result for result in targets]):
print('Already generated', it, path)
else:
sampler = SeededSampler(path,
model_path=os.path.join(root, checkpoint_dir, checkpoint),
clusterer_path=clusterer_path,
pretrained=config['pretrained'])
samples = sample(sampler)
dataset_name = get_dataset_from_path(path)
np.savez(samples_path, fake=samples, real=dataset_name)
arguments = f'--samples {samples_path} --it {it} --results_dir {results_dir}'
if args.fid and it not in fid_results:
os.system(f'CUDA_VISIBLE_DEVICES={device} python gan_training/metrics/fid.py {arguments}')
if args.inception and it not in is_results:
os.system(f'CUDA_VISIBLE_DEVICES={device} python gan_training/metrics/tf_is/inception_score.py {arguments}')
if args.modes and (it not in kl_results or it not in nmodes_results):
os.system(f'CUDA_VISIBLE_DEVICES={device} python utils/get_empirical_distribution.py {arguments} --dataset {dataset_name}')
if args.cluster_metrics and it not in cluster_results:
os.system(f'CUDA_VISIBLE_DEVICES={device} python cluster_metrics.py {path} --model_it {it}')
if args.fsd and it not in fsd_results:
gt_path = dataset_to_img[dataset_name]
os.system(f'CUDA_VISIBLE_DEVICES={device} python -m seeing.fsd {gt_path} {samples_path} --it {it} --results_dir {results_dir}')
================================================
FILE: requirements.txt
================================================
pytorch-gpu==1.3.1
tensorflow-gpu==1.14.0
scikit-learn
scikit-image
torchvision
tqdm
pyyaml
cloudpickle
ipython
opencv
================================================
FILE: seeded_sampler.py
================================================
''' Samples from a (class-conditional) GAN, so that the samples can be reproduced '''
import os
import pickle
import random
import copy
import torch
from torch import nn
from gan_training.checkpoints import CheckpointIO
from gan_training.config import (load_config, build_models)
from seeing.yz_dataset import YZDataset
def get_most_recent(models):
model_numbers = [
int(model.split("model.pt")[0]) if model != "model.pt" else 0
for model in models
]
return str(max(model_numbers)) + "model.pt"
class SeededSampler():
def __init__(
self,
config_name, # name of experiment's config file
model_path="", # path to the model. empty string infers the most recent checkpoint
clusterer_path="", # path to the clusterer, ignored if gan type doesn't require a clusterer
pretrained={}, # urls to the pretrained models
rootdir='./',
device='cuda:0'):
self.config = load_config(os.path.join(rootdir, config_name), 'configs/default.yaml')
self.model_path = model_path
self.clusterer_path = clusterer_path
self.rootdir = rootdir
self.nlabels = self.config['generator']['nlabels']
self.device = device
self.pretrained = pretrained
self.generator = self.get_generator()
self.generator.eval()
self.yz_dist = self.get_yz_dist()
def sample(self, nimgs):
'''
samples an image using the generator, with z drawn from isotropic gaussian, and y drawn from self.yz_dist.
For baseline methods, y doesn't matter because y is ignored in the input
yz_dist is the empirical label distribution for the clustered gans.
returns the image, and the integer seed used to generate it. generated sample is in [-1, 1]
'''
self.generator.eval()
with torch.no_grad():
seeds = [random.randint(0, 1e8) for _ in range(nimgs)]
z, y = self.yz_dist(seeds)
return self.generator(z, y), seeds
def conditional_sample(self, yi, seed=None):
''' returns a generated sample, which is in [-1, 1], seed is an int'''
self.generator.eval()
with torch.no_grad():
if seed is None:
seed = [random.randint(0, 1e8)]
else:
seed = [seed]
z, _ = self.yz_dist(seed)
y = torch.LongTensor([yi]).to(self.device)
return self.generator(z, y)
def sample_with_seed(self, seeds):
''' returns a generated sample, which is in [-1, 1] '''
self.generator.eval()
z, y = self.yz_dist(seeds)
return self.generator(z, y)
def get_zy(self, seeds):
'''returns the batch of z, y corresponding to the seeds'''
return self.yz_dist(seeds)
def sample_with_zy(self, z, y):
''' returns a generated sample given z and y, which is in [-1, 1].'''
self.generator.eval()
return self.generator(z, y)
def get_generator(self):
''' loads a generator according to self.model_path '''
exp_out_dir = os.path.join(self.rootdir, self.config['training']['out_dir'])
# infer checkpoint if neeeded
checkpoint_dir = os.path.join(exp_out_dir, 'chkpts') if self.model_path == "" or 'model' in self.pretrained else "./"
model_name = get_most_recent(os.listdir(checkpoint_dir)) if self.model_path == "" else self.model_path
checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)
self.checkpoint_io = checkpoint_io
generator, _ = build_models(self.config)
generator = generator.to(self.device)
generator = nn.DataParallel(generator)
if self.config['training']['take_model_average']:
generator_test = copy.deepcopy(generator)
checkpoint_io.register_modules(generator_test=generator_test)
else:
generator_test = generator
checkpoint_io.register_modules(generator=generator)
try:
it = checkpoint_io.load(model_name, pretrained=self.pretrained)
assert (it != -1)
except Exception as e:
# try again without data parallel
print(e)
checkpoint_io.register_modules(generator=generator.module)
checkpoint_io.register_modules(generator_test=generator_test.module)
it = checkpoint_io.load(model_name, pretrained=self.pretrained)
assert (it != -1)
print('Loaded iteration:', it['it'])
return generator_test
def get_yz_dist(self):
'''loads the z and y dists used to sample from the generator.'''
if self.config['clusterer']['name'] != 'supervised':
if 'clusterer' in self.pretrained:
clusterer = self.checkpoint_io.load_clusterer('pretrained', load_samples=False, pretrained=self.pretrained)
elif os.path.exists(self.clusterer_path):
with open(self.clusterer_path, 'rb') as f:
clusterer = pickle.load(f)
if isinstance(clusterer.discriminator, nn.DataParallel):
clusterer.discriminator = clusterer.discriminator.module
if clusterer.kmeans is not None:
# use clusterer empirical distribution as sampling
print('Using k-means empirical distribution')
distribution = clusterer.get_label_distribution()
probs = [f / sum(distribution) for f in distribution]
else:
# otherwise, use a uniform distribution. this is not desired, unless it's a random label or unconditional GAN
print("Sampling with uniform distribution over", clusterer.k, "labels")
probs = [1. / clusterer.k for _ in range(clusterer.k)]
else:
# if it's supervised, then sample uniformly over all classes.
# this might not be the right thing to do, since datasets are usually imbalanced.
print("Sampling with uniform distribution over", self.nlabels,
"labels")
probs = [1. / self.nlabels for _ in range(self.nlabels)]
return YZDataset(zdim=self.config['z_dist']['dim'],
nlabels=len(probs),
distribution=probs,
device=self.device)
================================================
FILE: seeing/frechet_distance.py
================================================
#!/usr/bin/env python3
"""Calculates the Frechet Distance (FD) between two samples.
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
of Tensorflow
Copyright 2018 Institute of Bioinformatics, JKU Linz
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
import torch
from scipy import linalg
def sample_frechet_distance(sample1, sample2, eps=1e-6,
return_components=False):
'''
Both samples should be numpy arrays.
Returns the Frechet distance.
'''
(mu1, sigma1), (mu2, sigma2) = [calculate_activation_statistics(s)
for s in [sample1, sample2]]
return calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=eps,
return_components=return_components)
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6,
return_components=False):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations, precalculated on an
representative data set.
-- sigma1: The covariance matrix over activations for generated samples.
-- sigma2: The covariance matrix over activations, precalculated on an
representative data set.
Returns:
-- : The Frechet Distance.
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, \
'Training and test mean vectors have different lengths'
assert sigma1.shape == sigma2.shape, \
'Training and test covariances have different dimensions'
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
meandiff = diff.dot(diff)
covdiff = np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
if return_components:
return (meandiff + covdiff, meandiff, covdiff)
else:
return meandiff + covdiff
def calculate_activation_statistics(act):
"""Calculation of the statistics used by the FID.
Params:
-- files : List of image files paths
-- model : Instance of inception model
-- batch_size : The images numpy array is split into batches with
batch size batch_size. A reasonable batch size
depends on the hardware.
-- dims : Dimensionality of features returned by Inception
-- cuda : If set to True, use GPU
-- verbose : If set to True and parameter out_step is given, the
number of calculated batches is reported.
Returns:
-- mu : The mean over samples of the activations of the pool_3 layer of
the inception model.
-- sigma : The covariance matrix of the activations of the pool_3 layer of
the inception model.
"""
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
================================================
FILE: seeing/fsd.py
================================================
import torch, argparse, sys, os, numpy
from .sampler import FixedRandomSubsetSampler, FixedSubsetSampler
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from torchvision import transforms, utils
from . import pbar, zdataset, segmenter, frechet_distance, parallelfolder
NUM_OBJECTS = 336
def main():
parser = argparse.ArgumentParser(description='Net dissect utility')
parser.add_argument('true_dir')
parser.add_argument('gen_dir')
parser.add_argument('--size', type=int, default=10000)
parser.add_argument('--cachedir', default='results/fsd/cache')
parser.add_argument('--histout', default=None)
parser.add_argument('--maxscale', type=float, default=50)
parser.add_argument('--labelcount', type=int, default=30)
parser.add_argument('--dpi', type=float, default=100)
parser.add_argument('--it', type=str, default="-1")
parser.add_argument('--results_dir', default=None, help='path to results_dir')
args = parser.parse_args()
if len(sys.argv) == 1:
parser.print_usage(sys.stderr)
sys.exit(1)
args = parser.parse_args()
print(args.true_dir, args.gen_dir)
true_dir, gen_dir = args.true_dir, args.gen_dir
seed1, seed2 = [1, 1 if true_dir != gen_dir else 2]
true_tally, gen_tally = [
cached_tally_directory(d,
size=args.size,
cachedir=args.cachedir,
seed=seed)
for d, seed in [(true_dir, seed1), (gen_dir, seed2)]
]
fsd, meandiff, covdiff = frechet_distance.sample_frechet_distance(
true_tally * 100, gen_tally * 100, return_components=True)
print('fsd: %f; meandiff: %f; covdiff: %f' % (fsd, meandiff, covdiff))
if args.histout is not None:
diff_figure(true_tally * 100,
gen_tally * 100,
labelcount=args.labelcount,
maxscale=args.maxscale,
dpi=args.dpi).savefig(args.histout)
if args.results_dir is not None:
import json
it = args.it
results_dir = args.results_dir
with open(os.path.join(args.results_dir, 'fsd_results.json')) as f:
fsd_results = json.load(f)
fsd_results[it] = (fsd, meandiff, covdiff)
with open(os.path.join(args.results_dir, 'fsd_results.json'), 'w') as f:
f.write(json.dumps(fsd_results))
diff_figure(true_tally * 100,
gen_tally * 100,
labelcount=args.labelcount,
maxscale=args.maxscale,
dpi=args.dpi).savefig(os.path.join(args.results_dir, f'fsd_{it}.png'))
def cached_tally_directory(directory, size=10000, cachedir=None, seed=1):
filename = '%s_segtally_%d.npy' % (directory, size)
if seed != 1:
filename = '%d_%s' % (seed, filename)
if cachedir is not None:
filename = os.path.join(cachedir, filename.replace('/', '_'))
#load only if gt stats, or image directory
if os.path.isfile(filename) and (not directory.endswith('.npz') or 'gt' in directory):
return numpy.load(filename)
os.makedirs(cachedir, exist_ok=True)
result = tally_directory(directory, size, seed=seed)
numpy.save(filename, result)
return result
def tally_directory(directory, size=10000, seed=1):
if directory.endswith('.npz'):
with np.load(directory) as f:
images = torch.from_numpy(f['fake'])
images = images.permute(0, 3, 1, 2) #BHWC -> BCHW
images = (images/127.5) - 1 #normalize in [-1, 1]
images = torch.nn.functional.interpolate(images, size=(256, 256))
print(images.shape, images.max(), images.min())
dataset = TensorDataset(images)
else:
dataset = parallelfolder.ParallelImageFolders(
[directory],
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
loader = DataLoader(dataset,
sampler=FixedRandomSubsetSampler(dataset,
end=size,
seed=seed),
batch_size=10,
pin_memory=True)
upp = segmenter.UnifiedParsingSegmenter()
labelnames, catnames = upp.get_label_and_category_names()
result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
batch_result = torch.zeros(loader.batch_size,
NUM_OBJECTS,
dtype=torch.float).cuda()
with torch.no_grad():
batch_index = 0
for [batch] in pbar(loader):
seg_result = upp.segment_batch(batch.cuda())
for i in range(len(batch)):
batch_result[i] = (seg_result[i, 0].view(-1).bincount(
minlength=NUM_OBJECTS).float() /
(seg_result.shape[2] * seg_result.shape[3]))
result[batch_index:batch_index +
len(batch)] = (batch_result.cpu().numpy())
batch_index += len(batch)
return result
def tally_dataset_objects(dataset, size=10000):
loader = DataLoader(dataset,
sampler=FixedRandomSubsetSampler(dataset, end=size),
batch_size=10,
pin_memory=True)
upp = segmenter.UnifiedParsingSegmenter()
labelnames, catnames = upp.get_label_and_category_names()
result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
batch_result = torch.zeros(loader.batch_size,
NUM_OBJECTS,
dtype=torch.float).cuda()
with torch.no_grad():
batch_index = 0
for [batch] in pbar(loader):
seg_result = upp.segment_batch(batch.cuda())
for i in range(len(batch)):
batch_result[i] = (seg_result[i, 0].view(-1).bincount(
minlength=NUM_OBJECTS).float() /
(seg_result.shape[2] * seg_result.shape[3]))
result[batch_index:batch_index +
len(batch)] = (batch_result.cpu().numpy())
batch_index += len(batch)
return result
def tally_generated_objects(model, size=10000):
zds = zdataset.z_dataset_for_model(model, size)
loader = DataLoader(zds, batch_size=10, pin_memory=True)
upp = segmenter.UnifiedParsingSegmenter()
labelnames, catnames = upp.get_label_and_category_names()
result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
batch_result = torch.zeros(loader.batch_size,
NUM_OBJECTS,
dtype=torch.float).cuda()
with torch.no_grad():
batch_index = 0
for [zbatch] in pbar(loader):
img = model(zbatch.cuda())
seg_result = upp.segment_batch(img)
for i in range(len(zbatch)):
batch_result[i] = (seg_result[i, 0].view(-1).bincount(
minlength=NUM_OBJECTS).float() /
(seg_result.shape[2] * seg_result.shape[3]))
result[batch_index:batch_index +
len(zbatch)] = (batch_result.cpu().numpy())
batch_index += len(zbatch)
return result
def diff_figure(ttally,
gtally,
labelcount=30,
labelleft=True,
dpi=100,
maxscale=50.0,
legend=False):
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
tresult, gresult = [t.mean(0) for t in [ttally, gtally]]
upp = segmenter.UnifiedParsingSegmenter()
labelnames, catnames = upp.get_label_and_category_names()
x = []
labels = []
gen_amount = []
change_frac = []
true_amount = []
for label in numpy.argsort(-tresult):
if label == 0 or labelnames[label][1] == 'material':
continue
if tresult[label] == 0:
break
x.append(len(x))
labels.append(labelnames[label][0].split()[0])
true_amount.append(tresult[label].item())
gen_amount.append(gresult[label].item())
change_frac.append(
(float(gresult[label] - tresult[label]) / tresult[label]))
if len(x) >= labelcount:
break
fig = Figure(dpi=dpi, figsize=(1.4 + 5.0 * labelcount / 30, 4.0))
FigureCanvas(fig)
a1, a0 = fig.subplots(2, 1, gridspec_kw={'height_ratios': [1, 2]})
a0.bar(x, change_frac, label='relative delta')
a0.set_xticks(x)
a0.set_xticklabels(labels, rotation='vertical')
if labelleft:
a0.set_ylabel('relative delta\n(gen - train) / train')
a0.set_xlim(-1.0, len(x))
a0.set_ylim([-1, 1.1])
a0.grid(axis='y', antialiased=False, alpha=0.25)
if legend:
a0.legend(loc=2)
prev_high = None
for ix, cf in enumerate(change_frac):
if cf > 1.15:
if prev_high == (ix - 1):
offset = 0.1
else:
offset = 0.0
prev_high = ix
a0.text(ix,
1.15 + offset,
'%.1f' % cf,
horizontalalignment='center',
size=6)
a1.bar(x, true_amount, label='training')
a1.plot(x, gen_amount, linewidth=3, color='red', label='generated')
a1.set_yscale('log')
a1.set_xlim(-1.0, len(x))
a1.set_ylim(maxscale / 5000, maxscale)
from matplotlib.ticker import LogLocator
# a1.yaxis.set_major_locator(LogLocator(subs=(1,)))
# a1.yaxis.set_minor_locator(LogLocator(subs=(1,), numdecs=10))
# a1.yaxis.set_minor_locator(LogLocator(subs=(1,2,3,4,5,6,7,8,9)))
# a1.yaxis.set_minor_locator(yminor_locator)
if labelleft:
a1.set_ylabel('mean area\nlog scale')
if legend:
a1.legend()
a1.set_yticks([1e-2, 1e-1, 1.0, 1e+1])
a1.set_yticks([
a * b for a in [1e-2, 1e-1, 1.0, 1e+1]
for b in range(1, 10) if maxscale / 5000 <= a * b <= maxscale
], True) # minor ticks.
a1.set_xticks([])
fig.tight_layout()
return fig
if __name__ == '__main__':
main()
================================================
FILE: seeing/lightbox.html
================================================
