Showing preview only (407K chars total). Download the full file or copy to clipboard to get everything.
Repository: stevliu/self-conditioned-gan
Branch: master
Commit: a12fc3a99876
Files: 119
Total size: 375.7 KB
Directory structure:
gitextract_vn0la05j/
├── .gitignore
├── 2d_mix/
│ ├── .gitignore
│ ├── __init__.py
│ ├── config.py
│ ├── evaluation.py
│ ├── inputs.py
│ ├── models/
│ │ ├── __init__.py
│ │ └── cluster.py
│ ├── train.py
│ └── visualizations.py
├── LICENSE
├── README.md
├── cluster_metrics.py
├── clusterers/
│ ├── __init__.py
│ ├── base_clusterer.py
│ ├── kmeans.py
│ ├── online.py
│ ├── random_labels.py
│ └── selfcondgan.py
├── configs/
│ ├── cifar/
│ │ ├── conditional.yaml
│ │ ├── default.yaml
│ │ ├── selfcondgan.yaml
│ │ └── unconditional.yaml
│ ├── default.yaml
│ ├── imagenet/
│ │ ├── conditional.yaml
│ │ ├── default.yaml
│ │ ├── selfcondgan.yaml
│ │ └── unconditional.yaml
│ ├── places/
│ │ ├── conditional.yaml
│ │ ├── default.yaml
│ │ ├── selfcondgan.yaml
│ │ └── unconditional.yaml
│ ├── pretrained/
│ │ ├── imagenet/
│ │ │ ├── conditional.yaml
│ │ │ ├── selfcondgan.yaml
│ │ │ └── unconditional.yaml
│ │ └── places/
│ │ ├── conditional.yaml
│ │ ├── selfcondgan.yaml
│ │ └── unconditional.yaml
│ └── stacked_mnist/
│ ├── conditional.yaml
│ ├── default.yaml
│ ├── selfcondgan.yaml
│ └── unconditional.yaml
├── gan_training/
│ ├── __init__.py
│ ├── checkpoints.py
│ ├── config.py
│ ├── distributions.py
│ ├── eval.py
│ ├── inputs.py
│ ├── logger.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ ├── clustering_metrics.py
│ │ ├── fid.py
│ │ ├── inception_score.py
│ │ └── tf_is/
│ │ ├── LICENSE
│ │ ├── README.md
│ │ └── inception_score.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── blocks.py
│ │ ├── dcgan_deep.py
│ │ ├── dcgan_shallow.py
│ │ ├── resnet2.py
│ │ ├── resnet2s.py
│ │ └── resnet3.py
│ ├── train.py
│ └── utils.py
├── metrics.py
├── requirements.txt
├── seeded_sampler.py
├── seeing/
│ ├── frechet_distance.py
│ ├── fsd.py
│ ├── lightbox.html
│ ├── parallelfolder.py
│ ├── pbar.py
│ ├── pidfile.py
│ ├── sampler.py
│ ├── segmenter.py
│ ├── upsegmodel/
│ │ ├── __init__.py
│ │ ├── models.py
│ │ ├── prroi_pool/
│ │ │ ├── .gitignore
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── build.py
│ │ │ ├── functional.py
│ │ │ ├── prroi_pool.py
│ │ │ ├── src/
│ │ │ │ ├── prroi_pooling_gpu.c
│ │ │ │ ├── prroi_pooling_gpu.h
│ │ │ │ ├── prroi_pooling_gpu_impl.cu
│ │ │ │ └── prroi_pooling_gpu_impl.cuh
│ │ │ └── test_prroi_pooling2d.py
│ │ ├── resnet.py
│ │ └── resnext.py
│ ├── yz_dataset.py
│ └── zdataset.py
├── train.py
├── utils/
│ ├── classifiers/
│ │ ├── __init__.py
│ │ ├── cifar.py
│ │ ├── imagenet.py
│ │ ├── imagenet_class_index.json
│ │ ├── places.py
│ │ ├── pytorch_playground/
│ │ │ ├── .gitignore
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── cifar/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dataset.py
│ │ │ │ ├── model.py
│ │ │ │ └── train.py
│ │ │ ├── quantize.py
│ │ │ ├── requirements.txt
│ │ │ ├── roadmap_zh.md
│ │ │ ├── setup.py
│ │ │ └── utee/
│ │ │ ├── __init__.py
│ │ │ ├── misc.py
│ │ │ ├── quant.py
│ │ │ └── selector.py
│ │ └── stacked_mnist.py
│ ├── get_empirical_distribution.py
│ ├── get_gt_imgs.py
│ └── np_to_pt_img.py
└── visualize_clusters.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*/**/*pyc*
*/**/.DS_Store
.vscode
================================================
FILE: 2d_mix/.gitignore
================================================
**.png
**.pyc
**.pt
output/
================================================
FILE: 2d_mix/__init__.py
================================================
================================================
FILE: 2d_mix/config.py
================================================
import torch
from models import generator_dict, discriminator_dict
from torch import optim
import torch.utils.data as utils
def get_models(model_type, conditioning, k_value, d_act_dim, device):
G = generator_dict[model_type]
D = discriminator_dict[model_type]
generator = G(conditioning, k_value=k_value)
discriminator = D(conditioning, k_value=k_value, act_dim=d_act_dim)
generator.to(device)
discriminator.to(device)
return generator, discriminator
def get_optimizers(generator, discriminator, lr=1e-4, beta1=0.8, beta2=0.999):
g_optimizer = optim.Adam(generator.parameters(),
lr=lr,
betas=(beta1, beta2))
d_optimizer = optim.Adam(discriminator.parameters(),
lr=lr,
betas=(beta1, beta2))
return g_optimizer, d_optimizer
def get_test(get_data, batch_size, variance, k_value, device):
x_test, y_test = get_data(batch_size, var=variance)
x_test, y_test = torch.from_numpy(x_test).float().to(
device), torch.from_numpy(y_test).long().to(device)
return x_test, y_test
def get_dataset(get_data, batch_size, npts, variance, k_value):
samples, labels = get_data(npts, var=variance)
tensor_samples = torch.stack([torch.Tensor(x) for x in samples])
tensor_labels = torch.stack([torch.tensor(x) for x in labels])
dataset = utils.TensorDataset(tensor_samples, tensor_labels)
train_loader = utils.DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
pin_memory=True,
sampler=None,
drop_last=True)
return train_loader
================================================
FILE: 2d_mix/evaluation.py
================================================
def warn(*args, **kwargs):
pass
import warnings
warnings.warn = warn
import numpy as np
def percent_good_grid(x_fake, var=0.0025, nrows=5, ncols=5):
std = np.sqrt(var)
x = list(range(nrows))
y = list(range(ncols))
threshold = 3 * std
means = []
for i in x:
for j in y:
means.append(np.array([x[i] * 2 - 4, y[j] * 2 - 4]))
return percent_good_pts(x_fake, means, threshold)
def percent_good_ring(x_fake, var=0.0001, n_clusters=8, radius=2.0):
std = np.sqrt(var)
thetas = np.linspace(0, 2 * np.pi, n_clusters + 1)[:n_clusters]
x, y = radius * np.sin(thetas), radius * np.cos(thetas)
threshold = np.array([std * 3, std * 3])
means = []
for i in range(n_clusters):
means.append(np.array([x[i], y[i]]))
return percent_good_pts(x_fake, means, threshold)
def percent_good_pts(x_fake, means, threshold):
"""Calculate %good, #modes, kl
Keyword arguments:
x_fake -- detached generated samples
means -- true means
threshold -- good point if l_1 distance is within threshold
"""
count = 0
counts = np.zeros(len(means))
visited = set()
for point in x_fake:
minimum = 0
diff_minimum = [1e10, 1e10]
for i, mean in enumerate(means):
diff = np.abs(point - mean)
if np.all(diff < threshold):
visited.add(tuple(mean))
count += 1
break
for i, mean in enumerate(means):
diff = np.abs(point - mean)
if np.linalg.norm(diff) < np.linalg.norm(diff_minimum):
minimum = i
diff_minimum = diff
counts[minimum] += 1
kl = 0
counts = counts / len(x_fake)
for generated in counts:
if generated != 0:
kl += generated * np.log(len(means) * generated)
return count / len(x_fake), len(visited), kl
================================================
FILE: 2d_mix/inputs.py
================================================
import numpy as np
import random
mapping = list(range(25))
def map_labels(labels):
return np.array([mapping[label] for label in labels])
def get_data_ring(batch_size, radius=2.0, var=0.0001, n_clusters=8):
thetas = np.linspace(0, 2 * np.pi, n_clusters + 1)[:n_clusters]
xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
classes = np.random.multinomial(batch_size,
[1.0 / n_clusters] * n_clusters)
labels = [i for i in range(n_clusters) for _ in range(classes[i])]
random.shuffle(labels)
labels = np.array(labels)
samples = np.array([
np.random.multivariate_normal([xs[i], ys[i]], [[var, 0], [0, var]])
for i in labels
])
return samples, labels
def get_data_grid(batch_size, radius=2.0, var=0.0025, nrows=5, ncols=5):
samples = []
labels = []
for _ in range(batch_size):
i, j = random.randint(0, ncols - 1), random.randint(0, nrows - 1)
samples.append(
np.random.multivariate_normal([i * 2 - 4, j * 2 - 4],
[[var, 0], [0, var]]))
labels.append(5 * i + j)
return np.array(samples), map_labels(labels)
================================================
FILE: 2d_mix/models/__init__.py
================================================
from models import (cluster)
generator_dict = {'standard': cluster.G}
discriminator_dict = {'standard': cluster.D}
================================================
FILE: 2d_mix/models/cluster.py
================================================
import sys
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
sys.path.append('../gan_training/models')
from blocks import LatentEmbeddingConcat, Identity, LinearUnconditionalLogits, LinearConditionalMaskLogits
class G(nn.Module):
def __init__(self,
conditioning,
k_value,
z_dim=2,
embed_size=32,
act_dim=400,
x_dim=2):
super().__init__()
if conditioning == 'unconditional':
embed_size = 0
self.embedding = Identity()
elif conditioning == 'conditional':
self.embedding = LatentEmbeddingConcat(k_value, embed_size)
else:
raise NotImplementedError()
self.fc1 = nn.Sequential(nn.Linear(z_dim + embed_size, act_dim),
nn.BatchNorm1d(act_dim), nn.ReLU(True))
self.fc2 = nn.Sequential(nn.Linear(act_dim, act_dim),
nn.BatchNorm1d(act_dim), nn.ReLU(True))
self.fc3 = nn.Sequential(nn.Linear(act_dim, act_dim),
nn.BatchNorm1d(act_dim), nn.ReLU(True))
self.fc4 = nn.Sequential(nn.Linear(act_dim, act_dim),
nn.BatchNorm1d(act_dim), nn.ReLU(True))
self.fc_out = nn.Linear(act_dim, x_dim)
def forward(self, z, y=None):
out = self.fc1(self.embedding(z, y))
out = self.fc2(out)
out = self.fc3(out)
out = self.fc4(out)
out = self.fc_out(out)
return out
class D(nn.Module):
class Maxout(nn.Module):
# Taken from https://github.com/pytorch/pytorch/issues/805
def __init__(self, d_in, d_out, pool_size=5):
super().__init__()
self.d_in, self.d_out, self.pool_size = d_in, d_out, pool_size
self.lin = nn.Linear(d_in, d_out * pool_size)
def forward(self, inputs):
shape = list(inputs.size())
shape[-1] = self.d_out
shape.append(self.pool_size)
max_dim = len(shape) - 1
out = self.lin(inputs)
m, i = out.view(*shape).max(max_dim)
return m
def max(self, out, dim=5):
return out.view(out.size(0), -1, dim).max(2)[0]
def __init__(self, conditioning, k_value, act_dim=200, x_dim=2):
super().__init__()
self.fc1 = self.Maxout(x_dim, act_dim)
self.fc2 = self.Maxout(act_dim, act_dim)
self.fc3 = self.Maxout(act_dim, act_dim)
if conditioning == 'unconditional':
self.fc4 = LinearUnconditionalLogits(act_dim)
elif conditioning == 'conditional':
self.fc4 = LinearConditionalMaskLogits(act_dim, k_value)
else:
raise NotImplementedError()
def forward(self, x, y=None, get_features=False):
out = self.fc1(x)
out = self.fc2(out)
out = self.fc3(out)
if get_features: return out
return self.fc4(out, y, get_features=get_features)
================================================
FILE: 2d_mix/train.py
================================================
import argparse
import os
import sys
import torch
from torch import optim
from torch import distributions
from torch import nn
import torch.nn.functional as F
import numpy as np
import evaluation
import inputs
from config import get_models, get_optimizers, get_test, get_dataset
from visualizations import (visualize_generated, visualize_clusters)
sys.path.append('../')
from clusterers import clusterer_dict
from gan_training.train import Trainer
sys.path.append('../seeing/')
import pidfile
parser = argparse.ArgumentParser(description='2d dataset experiments')
parser.add_argument('--clusterer', help='type of clusterer to use. cluster specifies selfcondgan')
parser.add_argument('--data_type', help='either grid or ring')
parser.add_argument('--recluster_every', type=int, default=5000, help='how frequently to recluster')
parser.add_argument('--nruns', type=int, default=1, help='number of trials to do')
parser.add_argument('--burnin_time', type=int, default=0, help='wait this amount of iterations before clustering')
parser.add_argument('--variance', type=float, default=None, help='variance of the gaussians')
parser.add_argument('--model_type', type=str, default='standard', help='model architecture')
parser.add_argument('--num_clusters', type=int, default=50, help='number of clusters to use for selfcondgan')
parser.add_argument('--z_dim', type=int, default=2, help='G latent dim')
parser.add_argument('--d_act_dim', type=int, default=200, help='hidden layer width')
parser.add_argument('--npts', type=int, default=100000, help='number of points to use in dataset')
parser.add_argument('--train_batch_size', type=int, default=100, help='training time batch size')
parser.add_argument('--test_batch_size', type=int, default=50000, help='number of examples to get metrics with')
parser.add_argument('--nepochs', type=int, default=100, help='number of epochs to run')
parser.add_argument('--outdir', default='output')
args = parser.parse_args()
data_type = args.data_type
k_value = 8 if data_type == 'ring' else 25
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_clusters = k_value if args.clusterer == 'supervised' else args.num_clusters
exp_name = f'{args.data_type}_{args.clusterer}_{args.recluster_every}_{num_clusters}/'
if args.model_type != 'standard':
exp_name = f'{args.model_type}_{exp_name}'
if args.variance is not None:
exp_name = f'{args.variance}_{exp_name}'
if args.variance is None:
variance = 0.0025 if data_type == 'grid' else 0.0001
else:
variance = args.variance
nepochs = args.nepochs
z_dim = args.z_dim
test_batch_size = args.test_batch_size
train_batch_size = args.train_batch_size
npts = args.npts
def main(outdir):
for subdir in ['all', 'snapshots', 'clusters']:
if not os.path.exists(os.path.join(outdir, subdir)):
os.makedirs(os.path.join(outdir, subdir), exist_ok=True)
if data_type == 'grid':
get_data = inputs.get_data_grid
percent_good = evaluation.percent_good_grid
elif data_type == 'ring':
get_data = inputs.get_data_ring
percent_good = evaluation.percent_good_ring
else:
raise NotImplementedError()
zdist = distributions.Normal(torch.zeros(z_dim, device=device),
torch.ones(z_dim, device=device))
z_test = zdist.sample((test_batch_size, ))
x_test, y_test = get_test(get_data=get_data,
batch_size=test_batch_size,
variance=variance,
k_value=k_value,
device=device)
x_cluster, _ = get_test(get_data=get_data,
batch_size=10000,
variance=variance,
k_value=k_value,
device=device)
train_loader = get_dataset(get_data=get_data,
batch_size=train_batch_size,
npts=npts,
variance=variance,
k_value=k_value)
def train(trainer, g, d, clusterer, exp_dir):
it = 0
if os.path.exists(os.path.join(exp_dir, 'log.txt')):
os.remove(os.path.join(exp_dir, 'log.txt'))
for epoch in range(nepochs):
for x_real, y in train_loader:
z = zdist.sample((train_batch_size, ))
x_real, y = x_real.to(device), y.to(device)
y = clusterer.get_labels(x_real, y)
dloss, _ = trainer.discriminator_trainstep(x_real, y, z)
gloss = trainer.generator_trainstep(y, z)
if it % args.recluster_every == 0 and args.clusterer != 'supervised':
if args.clusterer != 'burnin' or it >= args.burnin_time:
clusterer.recluster(discriminator, x_batch=x_real)
if it % 1000 == 0:
x_fake = g(z_test, clusterer.get_labels(x_test, y_test)).detach().cpu().numpy()
visualize_generated(x_fake,
x_test.detach().cpu().numpy(), y, it,
exp_dir)
visualize_clusters(x_test.detach().cpu().numpy(),
clusterer.get_labels(x_test, y_test),
it, exp_dir)
torch.save(
{
'generator': g.state_dict(),
'discriminator': d.state_dict(),
'g_optimizer': g_optimizer.state_dict(),
'd_optimizer': d_optimizer.state_dict()
},
os.path.join(exp_dir, 'snapshots', 'model_%d.pt' % it))
if it % 1000 == 0:
g.eval()
d.eval()
x_fake = g(z_test, clusterer.get_labels(
x_test, y_test)).detach().cpu().numpy()
percent, modes, kl = percent_good(x_fake, var=variance)
log_message = f'[epoch {epoch} it {it}] dloss = {dloss}, gloss = {gloss}, prop_real = {percent}, modes = {modes}, kl = {kl}'
with open(os.path.join(exp_dir, 'log.txt'), 'a+') as f:
f.write(log_message + '\n')
print(log_message)
it += 1
# train a G/D from scratch
generator, discriminator = get_models(args.model_type, 'conditional', num_clusters, args.d_act_dim, device)
g_optimizer, d_optimizer = get_optimizers(generator, discriminator)
trainer = Trainer(generator, discriminator, g_optimizer, d_optimizer, gan_type='standard', reg_type='none', reg_param=0)
clusterer = clusterer_dict[args.clusterer](discriminator=discriminator,
k_value=num_clusters,
x_cluster=x_cluster)
clusterer.recluster(discriminator=discriminator)
train(trainer, generator, discriminator, clusterer, os.path.join(outdir))
if __name__ == '__main__':
outdir = os.path.join(args.outdir, exp_name)
pidfile.exit_if_job_done(outdir)
for run_number in range(args.nruns):
run_dir = f'{outdir}_run_{run_number}' if args.nruns > 1 else outdir
main(run_dir)
pidfile.mark_job_done(outdir)
================================================
FILE: 2d_mix/visualizations.py
================================================
import matplotlib
from matplotlib import pyplot
import os
COLORS = [
'purple',
'wheat',
'maroon',
'red',
'powderblue',
'dodgerblue',
'magenta',
'tan',
'aqua',
'yellow',
'slategray',
'blue',
'rosybrown',
'violet',
'lightseagreen',
'pink',
'darkorange',
'teal',
'royalblue',
'lawngreen',
'gold',
'navy',
'darkgreen',
'deeppink',
'palegreen',
'silver',
'saddlebrown',
'plum',
'peru',
'black',
]
assert (len(COLORS) == len(set(COLORS)))
def visualize_generated(fake, real, y, it, outdir):
pyplot.plot(real[:, 0], real[:, 1], 'r.')
pyplot.plot(fake[:, 0], fake[:, 1], 'b.')
pyplot.savefig(os.path.join(outdir, 'all', str(it) + '.png'))
pyplot.clf()
lim = 6
axes = pyplot.gca()
axes.set_aspect('equal', adjustable='box')
axes.set_xlim([-lim, lim])
axes.set_ylim([-lim, lim])
pyplot.locator_params(nbins=4)
pyplot.tight_layout()
pyplot.plot(fake[:, 0], fake[:, 1], 'b.', alpha=0.1)
pyplot.savefig(os.path.join(outdir, 'all',
str(it) + 'square.png'),
dpi=100,
bbox_inches='tight')
pyplot.clf()
def visualize_clusters(x, y, it, outdir):
y = y.detach().cpu().numpy()
for i in range(y.max()):
pyplot.plot(x[y == i, 0],
x[y == i, 1],
'.',
color=COLORS[i % len(COLORS)])
pyplot.savefig(os.path.join(outdir, 'clusters', str(it) + '.png'))
pyplot.clf()
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2020 Steven Liu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Diverse Image Generation via Self-Conditioned GANs
#### [Project](http://selfcondgan.csail.mit.edu/) | [Paper](http://selfcondgan.csail.mit.edu/preprint.pdf)
**Diverse Image Generation via Self-Conditioned GANs** <br>
[Steven Liu](http://people.csail.mit.edu/stevenliu/),
[Tongzhou Wang](https://ssnl.github.io/),
[David Bau](http://people.csail.mit.edu/davidbau/home/),
[Jun-Yan Zhu](http://people.csail.mit.edu/junyanz/),
[Antonio Torralba](http://web.mit.edu/torralba/www/) <br>
MIT, Adobe Research<br>
in CVPR 2020.

Our proposed self-conditioned GAN model learns to perform clustering and image synthesis simultaneously. The model training
requires no manual annotation of object classes. Here, we visualize several discovered clusters for both Places365 (top) and ImageNet
(bottom). For each cluster, we show both real images and the generated samples conditioned on the cluster index.
## Getting Started
### Installation
- Clone this repo:
```bash
git clone https://github.com/stevliu/self-conditioned-gan.git
cd self-conditioned-gan
```
- Install the dependencies
```bash
conda create --name selfcondgan python=3.6
conda activate selfcondgan
conda install --file requirements.txt
conda install -c conda-forge tensorboardx
```
### Training and Evaluation
- Train a model on CIFAR:
```bash
python train.py configs/cifar/selfcondgan.yaml
```
- Visualize samples and inferred clusters:
```bash
python visualize_clusters.py configs/cifar/selfcondgan.yaml --show_clusters
```
The samples and clusters will be saved to `output/cifar/selfcondgan/clusters`. If this directory lies on an Apache server, you can open the URL to `output/cifar/selfcondgan/clusters/+lightbox.html` in the browser and visualize all samples and clusters in one webpage.
- Evaluate the model's FID:
You will need to first gather a set of ground truth train set images to compute metrics against.
```bash
python utils/get_gt_imgs.py --cifar
python metrics.py configs/cifar/selfcondgan.yaml --fid --every -1
```
You can also evaluate with other metrics by appending additional flags, such as Inception Score (`--inception`), the number of covered modes + reverse-KL divergence (`--modes`), and cluster metrics (`--cluster_metrics`).
## Pretrained Models
You can load and evaluate pretrained models on ImageNet and Places. If you have access to ImageNet or Places directories, first fill in paths to your ImageNet and/or Places dataset directories in `configs/imagenet/default.yaml` and `configs/places/default.yaml` respectively. You can use the following config files with the evaluation scripts, and the code will automatically download the appropriate models.
```bash
configs/pretrained/imagenet/selfcondgan.yaml
configs/pretrained/places/selfcondgan.yaml
configs/pretrained/imagenet/conditional.yaml
configs/pretrained/places/conditional.yaml
configs/pretrained/imagenet/baseline.yaml
configs/pretrained/places/baseline.yaml
```
## Evaluation
### Visualizations
To visualize generated samples and inferred clusters, run
```bash
python visualize_clusters.py config-file
```
You can set the flag `--show_clusters` to also visualize the real inferred clusters, but this requires that you have a path to training set images.
### Metrics
To obtain generation metrics, fill in paths to your ImageNet or Places dataset directories in `utils/get_gt_imgs.py` and then run
```bash
python utils/get_gt_imgs.py --imagenet --places
```
to precompute batches of GT images for FID/FSD evaluation.
Then, you can use
```bash
python metrics.py config-file
```
with the appropriate flags compute the FID (`--fid`), FSD (`--fsd`), IS (`--inception`), number of modes covered/ reverse-KL divergence (`--modes`) and clustering metrics (`--cluster_metrics`) for each of the checkpoints.
## Training models
To train a model, set up a configuration file (examples in `/configs`), and run
```bash
python train.py config-file
```
An example config of self-conditioned GAN on ImageNet is `config/imagenet/selfcondgan.yaml` and on Places is `config/places/selfcondgan.yaml`.
Some models may be too large to fit on one GPU, so you may want to add `--devices DEVICE_NUMBERS` as an additional flag to do multi GPU training.
## 2D-experiments
For synthetic dataset experiments, first go into the `2d_mix` directory.
To train a self-conditioned GAN on the 2D-ring and 2D-grid dataset, run
```bash
python train.py --clusterer selfcondgan --data_type ring
python train.py --clusterer selfcondgan --data_type grid
```
You can test several other configurations via the command line arguments.
## Acknowledgments
This code is heavily based on the [GAN-stability](https://github.com/LMescheder/GAN_stability) code base.
Our FSD code is taken from the [GANseeing](https://github.com/davidbau/ganseeing) work.
To compute inception score, we use the code provided from [Shichang Tang](https://github.com/tsc2017/Inception-Score.git).
To compute FID, we use the code provided from [TTUR](https://github.com/bioinf-jku/TTUR).
We also use pretrained classifiers given by the [pytorch-playground](https://github.com/aaron-xichen/pytorch-playground).
We thank all the authors for their useful code.
## Citation
If you use this code for your research, please cite the following work.
```
@inproceedings{liu2020selfconditioned,
title={Diverse Image Generation via Self-Conditioned GANs},
author={Liu, Steven and Wang, Tongzhou and Bau, David and Zhu, Jun-Yan and Torralba, Antonio},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2020}
}
```
================================================
FILE: cluster_metrics.py
================================================
import argparse
import os
from tqdm import tqdm
import torch
import numpy as np
from torch import nn
from gan_training import utils
from gan_training.inputs import get_dataset
from gan_training.checkpoints import CheckpointIO
from gan_training.config import load_config
from gan_training.metrics.clustering_metrics import (nmi, purity_score)
torch.backends.cudnn.benchmark = True
# Arguments
parser = argparse.ArgumentParser(description='Evaluate the clustering inferred by our method')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--model_it', type=str)
parser.add_argument('--random', action='store_true', help='Figure out if the clusters were randomly assigned')
args = parser.parse_args()
config = load_config(args.config, 'configs/default.yaml')
out_dir = config['training']['out_dir']
def main():
checkpoint_dir = os.path.join(out_dir, 'chkpts')
batch_size = config['training']['batch_size']
if 'cifar' in config['data']['train_dir'].lower():
name = 'cifar10'
elif 'stacked_mnist' == config['data']['type']:
name = 'stacked_mnist'
else:
name = 'image'
if os.path.exists(os.path.join(out_dir, 'cluster_preds.npz')):
# if we've already computed assignments, load them and move on
with np.load(os.path.join(out_dir, 'cluster_preds.npz')) as f:
y_reals = f['y_reals']
y_preds = f['y_preds']
else:
train_dataset, _ = get_dataset(
name=name,
data_dir=config['data']['train_dir'],
size=config['data']['img_size'])
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=config['training']['nworkers'],
shuffle=True,
pin_memory=True,
sampler=None,
drop_last=True)
checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)
print('Loading clusterer:')
most_recent = utils.get_most_recent(checkpoint_dir, 'model') if args.model_it is None else args.model_it
clusterer = checkpoint_io.load_clusterer(most_recent, load_samples=False, pretrained=config['pretrained'])
if isinstance(clusterer.discriminator, nn.DataParallel):
clusterer.discriminator = clusterer.discriminator.module
y_preds = []
y_reals = []
for batch_num, (x_real, y_real) in enumerate(tqdm(train_loader, total=len(train_loader))):
y_pred = clusterer.get_labels(x_real.cuda(), None)
y_preds.append(y_pred.detach().cpu())
y_reals.append(y_real)
y_reals = torch.cat(y_reals).numpy()
y_preds = torch.cat(y_preds).numpy()
np.savez(os.path.join(out_dir, 'cluster_preds.npz'), y_reals=y_reals, y_preds=y_preds)
if args.random:
y_preds = np.random.randint(0, 100, size=y_reals.shape)
nmi_score = nmi(y_preds, y_reals)
purity = purity_score(y_preds, y_reals)
print('nmi', nmi_score, 'purity', purity)
if __name__ == '__main__':
main()
================================================
FILE: clusterers/__init__.py
================================================
from clusterers import (base_clusterer, selfcondgan, random_labels, online)
clusterer_dict = {
'supervised': base_clusterer.BaseClusterer,
'selfcondgan': selfcondgan.Clusterer,
'online': online.Clusterer,
'random_labels': random_labels.Clusterer
}
================================================
FILE: clusterers/base_clusterer.py
================================================
import copy
import torch
import numpy as np
class BaseClusterer():
def __init__(self,
discriminator,
k_value=-1,
x_cluster=None,
batch_size=100,
**kwargs):
''' requires that self.x is not on the gpu, or else it hogs too much gpu memory '''
self.cluster_counts = [0] * k_value
self.discriminator = copy.deepcopy(discriminator)
self.discriminator.eval()
self.k = k_value
self.kmeans = None
self.x = x_cluster
self.x_labels = None
self.batch_size = batch_size
def get_labels(self, x, y):
return y
def recluster(self, discriminator, **kwargs):
return
def get_features(self, x):
''' by default gets discriminator, but you can use other things '''
return self.get_discriminator_output(x)
def get_cluster_batch_features(self):
''' returns the discriminator features for the batch self.x as a numpy array '''
with torch.no_grad():
outputs = []
x = self.x
for batch in range(x.size(0) // self.batch_size):
x_batch = x[batch * self.batch_size:(batch + 1) * self.batch_size].cuda()
outputs.append(self.get_features(x_batch).detach().cpu())
if (x.size(0) % self.batch_size != 0):
x_batch = x[x.size(0) // self.batch_size * self.batch_size:].cuda()
outputs.append(self.get_features(x_batch).detach().cpu())
result = torch.cat(outputs, dim=0).numpy()
return result
def get_discriminator_output(self, x):
'''returns discriminator features'''
self.discriminator.eval()
with torch.no_grad():
return self.discriminator(x, get_features=True)
def get_label_distribution(self, x=None):
'''returns the empirical distributon of clustering'''
y = self.x_labels if x is None else self.get_labels(x, None)
counts = [0] * self.k
for yi in y:
counts[yi] += 1
return counts
def sample_y(self, batch_size):
'''samples y according to the empirical distribution (not sure if used anymore)'''
distribution = self.get_label_distribution()
distribution = [i / sum(distribution) for i in distribution]
m = torch.distributions.Multinomial(batch_size,
torch.tensor(distribution))
return m.sample()
def print_label_distribution(self, x=None):
print(self.get_label_distribution(x))
================================================
FILE: clusterers/kmeans.py
================================================
import torch
import numpy as np
from sklearn.cluster import KMeans
from clusterers import base_clusterer
class Clusterer(base_clusterer.BaseClusterer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.mapping = list(range(self.k))
def kmeans_fit_predict(self, features, init='k-means++', n_init=10):
'''fits kmeans, and returns the predictions of the kmeans'''
print('Fitting k-means w data shape', features.shape)
self.kmeans = KMeans(init=init, n_clusters=self.k,
n_init=n_init).fit(features)
return self.kmeans.predict(features)
def get_labels(self, x, y):
d_features = self.get_features(x).detach().cpu().numpy()
np_prediction = self.kmeans.predict(d_features)
permuted_prediction = np.array([self.mapping[x] for x in np_prediction])
return torch.from_numpy(permuted_prediction).long().cuda()
================================================
FILE: clusterers/online.py
================================================
import copy, random
import torch
import numpy as np
from clusterers import kmeans
class Clusterer(kmeans.Clusterer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.burned_in = False
def get_initialization(self, features, labels):
'''given points (from new discriminator) and their old assignments as np arrays, compute the induced means as a np array'''
means = []
for i in range(self.k):
mask = (labels == i)
mean = np.zeros(features[0].shape)
numels = mask.astype(int).sum()
if numels > 0:
for index, equal in enumerate(mask):
if equal: mean += features[index]
means.append(mean / numels)
else:
# use kmeans++ init if cluster is starved
rand_point = random.randint(0, features.size(0) - 1)
means.append(features[rand_point])
result = np.array(means)
return result
def recluster(self, discriminator, x_batch=None, **kwargs):
if self.kmeans is None:
print('kmeans clustering as initialization')
self.discriminator = copy.deepcopy(discriminator)
features = self.get_cluster_batch_features()
self.x_labels = self.kmeans_fit_predict(features)
else:
self.discriminator = discriminator
if not self.burned_in:
print('Burned in: computing initialization for kmeans')
features = self.get_cluster_batch_features()
initialization = self.get_initialization(
features, self.x_labels)
self.kmeans_fit_predict(features, init=initialization)
self.burned_in = True
else:
assert x_batch is not None
self.discriminator = discriminator
features = self.get_features(x_batch).detach().cpu().numpy()
y_pred = self.kmeans.predict(features)
for xi, yi in zip(features, y_pred):
self.cluster_counts[yi] += 1
difference = xi - self.kmeans.cluster_centers_[yi]
step_size = 1.0 / self.cluster_counts[yi]
self.kmeans.cluster_centers_[
yi] = self.kmeans.cluster_centers_[yi] + step_size * (
difference)
================================================
FILE: clusterers/random_labels.py
================================================
import torch
from clusterers import base_clusterer
class Clusterer(base_clusterer.BaseClusterer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_labels(self, x, y):
return torch.randint(low=0, high=self.k, size=y.shape).long().cuda()
================================================
FILE: clusterers/selfcondgan.py
================================================
import copy, random
import torch
import numpy as np
from sklearn.utils.linear_assignment_ import linear_assignment
from clusterers import kmeans
class Clusterer(kmeans.Clusterer):
def __init__(self, initialization=True, matching=True, **kwargs):
self.initialization = initialization
self.matching = matching
super().__init__(**kwargs)
def get_initialization(self, features, labels):
'''given points (from new discriminator) and their old assignments as np arrays, compute the induced means as a np array'''
means = []
for i in range(self.k):
mask = (labels == i)
mean = np.zeros(features[0].shape)
numels = mask.astype(int).sum()
if numels > 0:
for index, equal in enumerate(mask):
if equal: mean += features[index]
means.append(mean / numels)
else:
# use kmeans++ init if cluster is starved
rand_point = random.randint(0, features.size(0) - 1)
means.append(features[rand_point])
result = np.array(means)
return result
def fit_means(self):
features = self.get_cluster_batch_features()
# if clustered already, use old assignments for the cluster mean
if self.x_labels is not None and self.initialization:
print('Initializing k-means with previous cluster assignments')
initialization = self.get_initialization(features, self.x_labels)
else:
initialization = 'k-means++'
new_classes = self.kmeans_fit_predict(features, init=initialization)
# we've clustered already, so compute the permutation
if self.x_labels is not None and self.matching:
print('Doing cluster matching')
matching = self.hungarian_match(new_classes, self.x_labels, self.k,
self.k)
self.mapping = [int(j) for i, j in sorted(matching)]
# recompute the fixed labels
self.x_labels = np.array([self.mapping[x] for x in new_classes])
def recluster(self, discriminator, **kwargs):
self.discriminator = copy.deepcopy(discriminator)
self.fit_means()
def hungarian_match(self, flat_preds, flat_targets, preds_k, targets_k):
'''takes in np arrays flat_preds, flat_targets of integers'''
num_samples = flat_targets.shape[0]
assert (preds_k == targets_k) # one to one
num_k = preds_k
num_correct = np.zeros((num_k, num_k))
for c1 in range(num_k):
for c2 in range(num_k):
votes = int(((flat_preds == c1) * (flat_targets == c2)).sum())
num_correct[c1, c2] = votes
# num_correct is small
match = linear_assignment(num_samples - num_correct)
# return as list of tuples, out_c to gt_c
res = []
for out_c, gt_c in match:
res.append((out_c, gt_c))
return res
================================================
FILE: configs/cifar/conditional.yaml
================================================
generator:
nlabels: 10
conditioning: embedding
discriminator:
nlabels: 10
conditioning: mask
inherit_from: configs/cifar/default.yaml
training:
out_dir: output/cifar/conditional
================================================
FILE: configs/cifar/default.yaml
================================================
data:
type: cifar10
train_dir: data/CIFAR
img_size: 32
nlabels: 10
generator:
name: dcgan_deep
nlabels: 1
conditioning: unconditional
kwargs:
placeholder: None
discriminator:
name: dcgan_deep
nlabels: 1
conditioning: unconditional
kwargs:
placeholder: None
z_dist:
type: gauss
dim: 128
clusterer:
name: supervised
nimgs: 25000
kwargs:
placeholder: None
training:
gan_type: standard
reg_type: none
reg_param: 0.
take_model_average: false
sample_nlabels: 20
log_every: 1000
inception_every: 10000
batch_size: 64
================================================
FILE: configs/cifar/selfcondgan.yaml
================================================
generator:
nlabels: 100
conditioning: embedding
discriminator:
nlabels: 100
conditioning: mask
clusterer:
name: selfcondgan
kwargs:
k_value: 100
inherit_from: configs/cifar/default.yaml
training:
out_dir: output/cifar/selfcondgan
recluster_every: 25000
================================================
FILE: configs/cifar/unconditional.yaml
================================================
inherit_from: configs/cifar/default.yaml
training:
out_dir: output/cifar/unconditional
================================================
FILE: configs/default.yaml
================================================
data:
type: lsun
train_dir: data/LSUN
deterministic: False
img_size: 128
nlabels: 1
generator:
name: resnet
nlabels: 1
conditioning: unconditional
kwargs:
placeholder: None
discriminator:
name: resnet
nlabels: 1
conditioning: unconditional
kwargs:
pack_size: 1
placeholder: None
clusterer:
name: supervised
nimgs: 100
kwargs:
num_components: -1
z_dist:
type: gauss
dim: 256
training:
out_dir: output/default
gan_type: standard
reg_type: real
reg_param: 10.
log_every: 1
batch_size: 128
ntest: 128
nworkers: 72
burnin_time: 0
take_model_average: true
model_average_beta: 0.999
monitoring: tensorboard
sample_every: 5000
sample_nlabels: 20
inception_every: 10000
inception_nsamples: 50000
backup_every: 10000
recluster_every: 10000
optimizer: adam
lr_g: 0.0001
lr_d: 0.0001
beta1: 0.0
beta2: 0.99
pretrained: {}
================================================
FILE: configs/imagenet/conditional.yaml
================================================
generator:
nlabels: 1000
conditioning: embedding
discriminator:
nlabels: 1000
conditioning: mask
inherit_from: configs/imagenet/default.yaml
training:
out_dir: output/imagenet/conditional
================================================
FILE: configs/imagenet/default.yaml
================================================
data:
type: image
train_dir: data/ImageNet/train
test_dir: data/ImageNet/val
img_size: 128
nlabels: 1000
generator:
name: resnet2
nlabels: 1
conditioning: unconditional
discriminator:
name: resnet2
nlabels: 1
conditioning: unconditional
z_dist:
type: gauss
dim: 256
clusterer:
name: supervised
training:
gan_type: standard
reg_type: real
reg_param: 10.
take_model_average: true
model_average_beta: 0.999
sample_nlabels: 20
log_every: 10
inception_every: 10000
backup_every: 5000
batch_size: 128
================================================
FILE: configs/imagenet/selfcondgan.yaml
================================================
generator:
nlabels: 100
conditioning: embedding
discriminator:
nlabels: 100
conditioning: mask
clusterer:
name: selfcondgan
nimgs: 50000
kwargs:
k_value: 100
inherit_from: configs/imagenet/default.yaml
training:
out_dir: output/imagenet/selfcondgan
recluster_every: 75000
reg_param: 0.1
================================================
FILE: configs/imagenet/unconditional.yaml
================================================
generator:
nlabels: 1
conditioning: unconditional
discriminator:
nlabels: 1
conditioning: unconditional
inherit_from: configs/imagenet/default.yaml
training:
out_dir: output/imagenet/unconditional
================================================
FILE: configs/places/conditional.yaml
================================================
generator:
nlabels: 365
conditioning: embedding
discriminator:
nlabels: 365
conditioning: mask
training:
out_dir: output/places/conditional
inherit_from: configs/places/default.yaml
================================================
FILE: configs/places/default.yaml
================================================
data:
type: image
train_dir: data/places365/train
test_dir: data/places365/val
img_size: 128
nlabels: 365
generator:
name: resnet2
nlabels: 1
conditioning: unconditional
discriminator:
name: resnet2
nlabels: 1
conditioning: unconditional
z_dist:
type: gauss
dim: 256
clusterer:
name: supervised
training:
gan_type: standard
reg_type: real
reg_param: 10.
take_model_average: true
model_average_beta: 0.999
sample_nlabels: 20
log_every: 10
inception_every: 10000
backup_every: 5000
batch_size: 128
================================================
FILE: configs/places/selfcondgan.yaml
================================================
generator:
nlabels: 100
conditioning: embedding
discriminator:
nlabels: 100
conditioning: mask
clusterer:
name: selfcondgan
nimgs: 50000
kwargs:
k_value: 100
inherit_from: configs/places/default.yaml
training:
out_dir: output/places/selfcondgan
recluster_every: 75000
reg_param: 0.1
================================================
FILE: configs/places/unconditional.yaml
================================================
generator:
nlabels: 1
conditioning: embedding
discriminator:
nlabels: 1
conditioning: mask
inherit_from: configs/places/default.yaml
training:
out_dir: output/places/unconditional
================================================
FILE: configs/pretrained/imagenet/conditional.yaml
================================================
generator:
nlabels: 1000
conditioning: embedding
discriminator:
nlabels: 1000
conditioning: mask
inherit_from: configs/imagenet/default.yaml
training:
out_dir: output/pretrained/imagenet/class_conditional
pretrained:
model: http://selfcondgan.csail.mit.edu/weights/classcondgan_i_model.pt
================================================
FILE: configs/pretrained/imagenet/selfcondgan.yaml
================================================
generator:
nlabels: 100
conditioning: embedding
discriminator:
nlabels: 100
conditioning: mask
clusterer:
name: selfcondgan
nimgs: 50000
kwargs:
k_value: 100
inherit_from: configs/imagenet/default.yaml
training:
out_dir: output/pretrained/imagenet/selfcondgan
recluster_every: 75000
reg_param: 0.1
pretrained:
model: http://selfcondgan.csail.mit.edu/weights/selfcondgan_i_model.pt
clusterer: http://selfcondgan.csail.mit.edu/weights/selfcondgan_i_clusterer.pkl
================================================
FILE: configs/pretrained/imagenet/unconditional.yaml
================================================
generator:
nlabels: 1
conditioning: unconditional
discriminator:
nlabels: 1
conditioning: unconditional
inherit_from: configs/imagenet/default.yaml
training:
out_dir: output/pretrained/imagenet/unconditional
pretrained:
model: http://selfcondgan.csail.mit.edu/weights/uncondgan_i_model.pt
================================================
FILE: configs/pretrained/places/conditional.yaml
================================================
generator:
nlabels: 365
conditioning: embedding
discriminator:
nlabels: 365
conditioning: mask
training:
out_dir: output/pretrained/places/class_conditional
inherit_from: configs/places/default.yaml
pretrained:
model: http://selfcondgan.csail.mit.edu/weights/classcondgan_p_model.pt
================================================
FILE: configs/pretrained/places/selfcondgan.yaml
================================================
generator:
nlabels: 100
conditioning: embedding
discriminator:
nlabels: 100
conditioning: mask
clusterer:
name: selfcondgan
nimgs: 50000
kwargs:
k_value: 100
inherit_from: configs/places/default.yaml
training:
out_dir: output/pretrained/places/selfcondgan
reg_param: 0.1
pretrained:
model: http://selfcondgan.csail.mit.edu/weights/selfcondgan_p_model.pt
clusterer: http://selfcondgan.csail.mit.edu/weights/selfcondgan_p_clusterer.pkl
================================================
FILE: configs/pretrained/places/unconditional.yaml
================================================
generator:
nlabels: 1
conditioning: embedding
discriminator:
nlabels: 1
conditioning: mask
inherit_from: configs/places/default.yaml
training:
out_dir: output/pretrained/places/unconditional
pretrained:
model: http://selfcondgan.csail.mit.edu/weights/uncondgan_p_model.pt
================================================
FILE: configs/stacked_mnist/conditional.yaml
================================================
generator:
nlabels: 1000
conditioning: embedding
discriminator:
nlabels: 1000
conditioning: mask
inherit_from: configs/stacked_mnist/default.yaml
training:
out_dir: output/stacked_mnist/conditional
================================================
FILE: configs/stacked_mnist/default.yaml
================================================
data:
type: stacked_mnist
train_dir: data/MNIST
img_size: 32
nlabels: 1000
generator:
name: dcgan_shallow
nlabels: 1
conditioning: unconditional
kwargs:
placeholder: None
discriminator:
name: dcgan_shallow
nlabels: 1
conditioning: unconditional
kwargs:
placeholder: None
z_dist:
type: gauss
dim: 128
clusterer:
name: supervised
nimgs: 25000
kwargs:
placeholder: None
training:
gan_type: standard
reg_type: none
reg_param: 0.
take_model_average: false
sample_nlabels: 20
log_every: 1000
backup_every: 5000
inception_every: 10000
batch_size: 64
================================================
FILE: configs/stacked_mnist/selfcondgan.yaml
================================================
generator:
nlabels: 100
conditioning: embedding
discriminator:
nlabels: 100
conditioning: mask
clusterer:
name: selfcondgan
kwargs:
k_value: 100
inherit_from: configs/stacked_mnist/default.yaml
training:
out_dir: output/stacked_mnist/selfcondgan
recluster_every: 25000
================================================
FILE: configs/stacked_mnist/unconditional.yaml
================================================
inherit_from: configs/stacked_mnist/default.yaml
training:
out_dir: output/stacked_mnist/unconditional
================================================
FILE: gan_training/__init__.py
================================================
================================================
FILE: gan_training/checkpoints.py
================================================
import os, pickle
import urllib
import torch
import numpy as np
from torch.utils import model_zoo
class CheckpointIO(object):
''' CheckpointIO class.
It handles saving and loading checkpoints.
Args:
checkpoint_dir (str): path where checkpoints are saved
'''
def __init__(self, checkpoint_dir='./chkpts', **kwargs):
self.module_dict = kwargs
self.checkpoint_dir = checkpoint_dir
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
def register_modules(self, **kwargs):
''' Registers modules in current module dictionary.
'''
self.module_dict.update(kwargs)
def save(self, filename, **kwargs):
''' Saves the current module dictionary.
Args:
filename (str): name of output file
'''
if not os.path.isabs(filename):
filename = os.path.join(self.checkpoint_dir, filename)
outdict = kwargs
for k, v in self.module_dict.items():
outdict[k] = v.state_dict()
torch.save(outdict, filename)
def load(self, filename, pretrained={}):
'''Loads a module dictionary from local file or url.
Args:
filename (str): name of saved module dictionary
'''
if 'model' in pretrained:
filename = pretrained['model']
if is_url(filename):
return self.load_url(filename)
else:
return self.load_file(filename)
def load_file(self, filename):
'''Loads a module dictionary from file.
Args:
filename (str): name of saved module dictionary
'''
if not os.path.isabs(filename):
filename = os.path.join(self.checkpoint_dir, filename)
if os.path.exists(filename):
print('=> Loading checkpoint from local file...', filename)
state_dict = torch.load(filename)
scalars = self.parse_state_dict(state_dict)
return scalars
else:
print('File not found', filename)
raise FileNotFoundError
def load_url(self, url):
'''Load a module dictionary from url.
Args:
url (str): url to saved model
'''
print('=> Loading checkpoint from url...', url)
state_dict = model_zoo.load_url(url, model_dir=self.checkpoint_dir, progress=True)
scalars = self.parse_state_dict(state_dict)
return scalars
def parse_state_dict(self, state_dict):
'''Parse state_dict of model and return scalars.
Args:
state_dict (dict): State dict of model
'''
for k, v in self.module_dict.items():
if k in state_dict:
v.load_state_dict(state_dict[k])
else:
print('Warning: Could not find %s in checkpoint!' % k)
scalars = {
k: v
for k, v in state_dict.items() if k not in self.module_dict
}
return scalars
def load_clusterer(self, it, load_samples, pretrained={}):
if 'clusterer' in pretrained:
pretrained_file = os.path.join(self.checkpoint_dir, 'pretrained_clusterer.pkl')
if not os.path.exists(pretrained_file):
import cloudpickle as cp
from urllib.request import urlopen
print('Loading pretrained clusterer from', pretrained['clusterer'])
clusterer = cp.load(urlopen(pretrained['clusterer']))
print('Saving pretrained clusterer to', pretrained_file)
with open(pretrained_file, 'wb') as f:
f.write(pickle.dumps(clusterer))
else:
with open(pretrained_file, 'rb') as f:
clusterer = pickle.load(f)
return clusterer
else:
print('Loading clusterer:')
with open(os.path.join(self.checkpoint_dir, f'clusterer{it}.pkl'), 'rb') as f:
clusterer = pickle.load(f)
if load_samples:
print('Loading cluster samples:')
with np.load(os.path.join(self.checkpoint_dir, 'cluster_samples.npz')) as f:
x = f['x']
clusterer.x = torch.from_numpy(x)
return clusterer
def load_models(self, it, pretrained={}, load_samples=False):
try:
load_dict = self.load('model_%08d.pt' % it, pretrained)
epoch_idx = load_dict.get('epoch_idx', -1)
except Exception as e: #models are not dataparallel modules
print('Trying again to load w/o data parallel modules')
try:
for name, module in self.module_dict.items():
if isinstance(module, torch.nn.DataParallel):
self.module_dict[name] = module.module
load_dict = self.load('model_%08d.pt' % it, pretrained)
epoch_idx = load_dict.get('epoch_idx', -1)
except FileNotFoundError as e:
print(e)
print("Models not found")
it = epoch_idx = -1
try:
clusterer = self.load_clusterer(it, load_samples, pretrained)
except FileNotFoundError as e:
clusterer = None
return it, epoch_idx, clusterer
def save_clusterer(self, clusterer, it):
with open(os.path.join(self.checkpoint_dir, f'clusterer{it}.pkl'), 'wb') as f:
#hack: only save changing data
x = clusterer.x
clusterer.x = None
pickle.dump(clusterer, f)
clusterer.x = x
def is_url(url):
scheme = urllib.parse.urlparse(url).scheme
return scheme in ('http', 'https')
================================================
FILE: gan_training/config.py
================================================
import yaml
from torch import optim
from os import path
from gan_training.models import generator_dict, discriminator_dict
from gan_training.train import toggle_grad
from clusterers import clusterer_dict
# General config
def load_config(path, default_path):
''' Loads config file.
Args:
path (str): path to config file
default_path (bool): whether to use default path
'''
# Load configuration from file itself
with open(path, 'r') as f:
cfg_special = yaml.load(f)
# Check if we should inherit from a config
inherit_from = cfg_special.get('inherit_from')
# If yes, load this config first as default
# If no, use the default_path
if inherit_from is not None:
cfg = load_config(inherit_from, default_path)
elif default_path is not None:
with open(default_path, 'r') as f:
cfg = yaml.load(f)
else:
cfg = dict()
# Include main configuration
update_recursive(cfg, cfg_special)
return cfg
def update_recursive(dict1, dict2):
''' Update two config dictionaries recursively.
Args:
dict1 (dict): first dictionary to be updated
dict2 (dict): second dictionary which entries should be used
'''
for k, v in dict2.items():
# Add item if not yet in dict1
if k not in dict1:
dict1[k] = None
# Update
if isinstance(dict1[k], dict):
update_recursive(dict1[k], v)
else:
dict1[k] = v
def get_clusterer(config):
return clusterer_dict[config['clusterer']['name']]
def build_models(config):
# Get classes
Generator = generator_dict[config['generator']['name']]
Discriminator = discriminator_dict[config['discriminator']['name']]
# Build models
generator = Generator(z_dim=config['z_dist']['dim'],
nlabels=config['generator']['nlabels'],
size=config['data']['img_size'],
conditioning=config['generator']['conditioning'],
**config['generator']['kwargs'])
discriminator = Discriminator(
nlabels=config['discriminator']['nlabels'],
conditioning=config['discriminator']['conditioning'],
size=config['data']['img_size'],
**config['discriminator']['kwargs'])
return generator, discriminator
def build_optimizers(generator, discriminator, config):
optimizer = config['training']['optimizer']
lr_g = config['training']['lr_g']
lr_d = config['training']['lr_d']
toggle_grad(generator, True)
toggle_grad(discriminator, True)
g_params = generator.parameters()
d_params = discriminator.parameters()
if optimizer == 'rmsprop':
g_optimizer = optim.RMSprop(g_params, lr=lr_g, alpha=0.99, eps=1e-8)
d_optimizer = optim.RMSprop(d_params, lr=lr_d, alpha=0.99, eps=1e-8)
elif optimizer == 'adam':
beta1 = config['training']['beta1']
beta2 = config['training']['beta2']
g_optimizer = optim.Adam(g_params, lr=lr_g, betas=(beta1, beta2), eps=1e-8)
d_optimizer = optim.Adam(d_params, lr=lr_d, betas=(beta1, beta2), eps=1e-8)
elif optimizer == 'sgd':
g_optimizer = optim.SGD(g_params, lr=lr_g, momentum=0.)
d_optimizer = optim.SGD(d_params, lr=lr_d, momentum=0.)
return g_optimizer, d_optimizer
# Some utility functions
def get_parameter_groups(parameters, gradient_scales, base_lr):
param_groups = []
for p in parameters:
c = gradient_scales.get(p, 1.)
param_groups.append({'params': [p], 'lr': c * base_lr})
return param_groups
================================================
FILE: gan_training/distributions.py
================================================
import torch
from torch import distributions
def get_zdist(dist_name, dim, device=None):
# Get distribution
if dist_name == 'uniform':
low = -torch.ones(dim, device=device)
high = torch.ones(dim, device=device)
zdist = distributions.Uniform(low, high)
elif dist_name == 'gauss':
mu = torch.zeros(dim, device=device)
scale = torch.ones(dim, device=device)
zdist = distributions.Normal(mu, scale)
else:
raise NotImplementedError
# Add dim attribute
zdist.dim = dim
return zdist
def get_ydist(nlabels, device=None):
logits = torch.zeros(nlabels, device=device)
ydist = distributions.categorical.Categorical(logits=logits)
# Add nlabels attribute
ydist.nlabels = nlabels
return ydist
def interpolate_sphere(z1, z2, t):
p = (z1 * z2).sum(dim=-1, keepdim=True)
p = p / z1.pow(2).sum(dim=-1, keepdim=True).sqrt()
p = p / z2.pow(2).sum(dim=-1, keepdim=True).sqrt()
omega = torch.acos(p)
s1 = torch.sin((1-t)*omega)/torch.sin(omega)
s2 = torch.sin(t*omega)/torch.sin(omega)
z = s1 * z1 + s2 * z2
return z
================================================
FILE: gan_training/eval.py
================================================
import numpy as np
import torch
from torch.nn import functional as F
from gan_training.metrics import inception_score
class Evaluator(object):
def __init__(self,
generator,
zdist,
ydist,
train_loader,
clusterer,
batch_size=64,
inception_nsamples=10000,
device=None):
self.generator = generator
self.clusterer = clusterer
self.train_loader = train_loader
self.zdist = zdist
self.ydist = ydist
self.inception_nsamples = inception_nsamples
self.batch_size = batch_size
self.device = device
def sample_z(self, batch_size):
return self.zdist.sample((batch_size, )).to(self.device)
def get_y(self, x, y):
return self.clusterer.get_labels(x, y).to(self.device)
def get_fake_real_samples(self, N):
''' returns N fake images and N real images in pytorch form'''
with torch.no_grad():
self.generator.eval()
fake_imgs = []
real_imgs = []
while len(fake_imgs) < N:
for x_real, y_gt in self.train_loader:
x_real = x_real.cuda()
z = self.sample_z(x_real.size(0))
y = self.get_y(x_real, y_gt)
samples = self.generator(z, y)
samples = [s.data.cpu() for s in samples]
fake_imgs.extend(samples)
real_batch = [img.data.cpu() for img in x_real]
real_imgs.extend(real_batch)
assert (len(real_imgs) == len(fake_imgs))
if len(fake_imgs) >= N:
fake_imgs = fake_imgs[:N]
real_imgs = real_imgs[:N]
return fake_imgs, real_imgs
def compute_inception_score(self):
imgs, _ = self.get_fake_real_samples(self.inception_nsamples)
imgs = [img.numpy() for img in imgs]
score, score_std = inception_score(imgs,
device=self.device,
resize=True,
splits=1)
return score, score_std
def create_samples(self, z, y=None):
self.generator.eval()
batch_size = z.size(0)
# Parse y
if y is None:
raise NotImplementedError()
elif isinstance(y, int):
y = torch.full((batch_size, ),
y,
device=self.device,
dtype=torch.int64)
# Sample x
with torch.no_grad():
x = self.generator(z, y)
return x
================================================
FILE: gan_training/inputs.py
================================================
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import os
import torch.utils.data as data
from torchvision.datasets.folder import default_loader
from PIL import Image
import random
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def get_dataset(name,
data_dir,
size=64,
lsun_categories=None,
deterministic=False,
transform=None):
transform = transforms.Compose([
t for t in [
transforms.Resize(size),
transforms.CenterCrop(size),
(not deterministic) and transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
(not deterministic) and
transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size())),
] if t is not False
]) if transform == None else transform
if name == 'image':
print('Using image labels')
dataset = datasets.ImageFolder(data_dir, transform)
nlabels = len(dataset.classes)
elif name == 'webp':
print('Using no labels from webp')
dataset = CachedImageFolder(data_dir, transform)
nlabels = len(dataset.classes)
elif name == 'npy':
# Only support normalization for now
dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
nlabels = len(dataset.classes)
elif name == 'cifar10':
dataset = datasets.CIFAR10(root=data_dir,
train=True,
download=True,
transform=transform)
nlabels = 10
elif name == 'stacked_mnist':
dataset = StackedMNIST(data_dir,
transform=transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, ))
]))
nlabels = 1000
elif name == 'lsun':
if lsun_categories is None:
lsun_categories = 'train'
dataset = datasets.LSUN(data_dir, lsun_categories, transform)
nlabels = len(dataset.classes)
elif name == 'lsun_class':
dataset = datasets.LSUNClass(data_dir,
transform,
target_transform=(lambda t: 0))
nlabels = 1
else:
raise NotImplemented
return dataset, nlabels
class CachedImageFolder(data.Dataset):
"""
A version of torchvision.dataset.ImageFolder that takes advantage
of cached filename lists.
photo/park/004234.jpg
photo/park/004236.jpg
photo/park/004237.jpg
"""
def __init__(self, root, transform=None, loader=default_loader):
classes, class_to_idx = find_classes(root)
self.imgs = make_class_dataset(root, class_to_idx)
if len(self.imgs) == 0:
raise RuntimeError("Found 0 images within: %s" % root)
self.root = root
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.loader = loader
def __getitem__(self, index):
path, classidx = self.imgs[index]
source = self.loader(path)
if self.transform is not None:
source = self.transform(source)
return source, classidx
def __len__(self):
return len(self.imgs)
class StackedMNIST(data.Dataset):
def __init__(self, data_dir, transform, batch_size=100000):
super().__init__()
self.channel1 = datasets.MNIST(data_dir,
transform=transform,
train=True,
download=True)
self.channel2 = datasets.MNIST(data_dir,
transform=transform,
train=True,
download=True)
self.channel3 = datasets.MNIST(data_dir,
transform=transform,
train=True,
download=True)
self.indices = {
k: (random.randint(0,
len(self.channel1) - 1),
random.randint(0,
len(self.channel1) - 1),
random.randint(0,
len(self.channel1) - 1))
for k in range(batch_size)
}
def __getitem__(self, index):
index1, index2, index3 = self.indices[index]
x1, y1 = self.channel1[index1]
x2, y2 = self.channel2[index2]
x3, y3 = self.channel3[index3]
return torch.cat([x1, x2, x3], dim=0), y1 * 100 + y2 * 10 + y3
def __len__(self):
return len(self.indices)
def is_npy_file(path):
return path.endswith('.npy') or path.endswith('.NPY')
def walk_image_files(rootdir):
print(rootdir)
if os.path.isfile('%s.txt' % rootdir):
print('Loading file list from %s.txt instead of scanning dir' %
rootdir)
basedir = os.path.dirname(rootdir)
with open('%s.txt' % rootdir) as f:
result = sorted([
os.path.join(basedir, line.strip()) for line in f.readlines()
])
import random
random.Random(1).shuffle(result)
return result
result = []
IMG_EXTENSIONS = [
'.jpg',
'.JPG',
'.jpeg',
'.JPEG',
'.png',
'.PNG',
'.ppm',
'.PPM',
'.bmp',
'.BMP',
]
for dirname, _, fnames in sorted(os.walk(rootdir)):
for fname in sorted(fnames):
if any(fname.endswith(extension)
for extension in IMG_EXTENSIONS) or is_npy_file(fname):
result.append(os.path.join(dirname, fname))
return result
def find_classes(dir):
classes = [
d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))
]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_class_dataset(source_root, class_to_idx):
"""
Returns (source, classnum, feature)
"""
imagepairs = []
source_root = os.path.expanduser(source_root)
for path in walk_image_files(source_root):
classname = os.path.basename(os.path.dirname(path))
imagepairs.append((path, 0))
return imagepairs
def npy_loader(path):
img = np.load(path)
if img.dtype == np.uint8:
img = img.astype(np.float32)
img = img / 127.5 - 1.
elif img.dtype == np.float32:
img = img * 2 - 1.
else:
raise NotImplementedError
img = torch.Tensor(img)
if len(img.size()) == 4:
img.squeeze_(0)
return img
================================================
FILE: gan_training/logger.py
================================================
import pickle
import os
import torchvision
import copy
class Logger(object):
def __init__(self,
log_dir='./logs',
img_dir='./imgs',
monitoring=None,
monitoring_dir=None):
self.stats = dict()
self.log_dir = log_dir
self.img_dir = img_dir
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(img_dir):
os.makedirs(img_dir)
if not (monitoring is None or monitoring == 'none'):
self.setup_monitoring(monitoring, monitoring_dir)
else:
self.monitoring = None
self.monitoring_dir = None
def setup_monitoring(self, monitoring, monitoring_dir=None):
self.monitoring = monitoring
self.monitoring_dir = monitoring_dir
if monitoring == 'telemetry':
import telemetry
self.tm = telemetry.ApplicationTelemetry()
if self.tm.get_status() == 0:
print('Telemetry successfully connected.')
elif monitoring == 'tensorboard':
import tensorboardX
self.tb = tensorboardX.SummaryWriter(monitoring_dir)
else:
raise NotImplementedError('Monitoring tool "%s" not supported!' %
monitoring)
def add(self, category, k, v, it):
if category not in self.stats:
self.stats[category] = {}
if k not in self.stats[category]:
self.stats[category][k] = []
self.stats[category][k].append((it, v))
k_name = '%s/%s' % (category, k)
if self.monitoring == 'telemetry':
self.tm.metric_push_async({'metric': k_name, 'value': v, 'it': it})
elif self.monitoring == 'tensorboard':
self.tb.add_scalar(k_name, v, it)
def add_imgs(self, imgs, class_name, it):
outdir = os.path.join(self.img_dir, class_name)
if not os.path.exists(outdir):
os.makedirs(outdir)
outfile = os.path.join(outdir, '%08d.png' % it)
imgs = imgs / 2 + 0.5
imgs = torchvision.utils.make_grid(imgs)
torchvision.utils.save_image(copy.deepcopy(imgs), outfile, nrow=8)
if self.monitoring == 'tensorboard':
self.tb.add_image(class_name, copy.deepcopy(imgs), it)
def get_last(self, category, k, default=0.):
if category not in self.stats:
return default
elif k not in self.stats[category]:
return default
else:
return self.stats[category][k][-1][1]
def save_stats(self, filename):
filename = os.path.join(self.log_dir, filename)
with open(filename, 'wb') as f:
pickle.dump(self.stats, f)
def load_stats(self, filename):
filename = os.path.join(self.log_dir, filename)
if not os.path.exists(filename):
print('Warning: file "%s" does not exist!' % filename)
return
try:
with open(filename, 'rb') as f:
self.stats = pickle.load(f)
except EOFError:
print('Warning: log file corrupted!')
================================================
FILE: gan_training/metrics/__init__.py
================================================
from gan_training.metrics.inception_score import inception_score
__all__ = [
inception_score
]
================================================
FILE: gan_training/metrics/clustering_metrics.py
================================================
def warn(*args, **kwargs):
pass
import warnings
warnings.warn = warn
from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score, homogeneity_score
from sklearn import metrics
import numpy as np
def nmi(inferred, gt):
return normalized_mutual_info_score(inferred, gt)
def acc(inferred, gt):
gt = gt.astype(np.int64)
assert inferred.size == gt.size
D = max(inferred.max(), gt.max()) + 1
w = np.zeros((D, D), dtype=np.int64)
for i in range(inferred.size):
w[inferred[i], gt[i]] += 1
from sklearn.utils.linear_assignment_ import linear_assignment
ind = linear_assignment(w.max() - w)
return sum([w[i, j] for i, j in ind]) * 1.0 / inferred.size
def purity_score(y_true, y_pred):
contingency_matrix = metrics.cluster.contingency_matrix(y_true, y_pred)
return np.sum(np.amax(contingency_matrix,
axis=0)) / np.sum(contingency_matrix)
def ari(inferred, gt):
return adjusted_rand_score(gt, inferred)
def homogeneity(inferred, gt):
return homogeneity_score(gt, inferred)
================================================
FILE: gan_training/metrics/fid.py
================================================
from __future__ import absolute_import, division, print_function
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from scipy import linalg
import pathlib
import urllib
from tqdm import tqdm
import warnings
def check_or_download_inception(inception_path):
''' Checks if the path to the inception file is valid, or downloads
the file if it is not present. '''
INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
if inception_path is None:
inception_path = '/tmp'
inception_path = pathlib.Path(inception_path)
model_file = inception_path / 'classify_image_graph_def.pb'
if not model_file.exists():
print("Downloading Inception model")
from urllib import request
import tarfile
fn, _ = request.urlretrieve(INCEPTION_URL)
with tarfile.open(fn, mode='r') as f:
f.extract('classify_image_graph_def.pb', str(model_file.parent))
return str(model_file)
def create_inception_graph(pth):
"""Creates a graph from saved GraphDef file."""
# Creates graph from saved graph_def.pb.
with tf.io.gfile.GFile(pth, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='FID_Inception_Net')
def calculate_activation_statistics(images,
sess,
batch_size=200,
verbose=False):
"""Calculation of the statistics used by the FID.
Params:
-- images : Numpy array of dimension (n_images, hi, wi, 3). The values
must lie between 0 and 255.
-- sess : current session
-- batch_size : the images numpy array is split into batches with batch size
batch_size. A reasonable batch size depends on the available hardware.
-- verbose : If set to True and parameter out_step is given, the number of calculated
batches is reported.
Returns:
-- mu : The mean over samples of the activations of the pool_3 layer of
the incption model.
-- sigma : The covariance matrix of the activations of the pool_3 layer of
the incption model.
"""
act = get_activations(images, sess, batch_size, verbose)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
# code for handling inception net derived from
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py
def _get_inception_layer(sess):
"""Prepares inception net for batched usage and returns pool_3 layer. """
layername = 'FID_Inception_Net/pool_3:0'
pool3 = sess.graph.get_tensor_by_name(layername)
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if shape._dims != []:
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
return pool3
#-------------------------------------------------------------------------------
def get_activations(images, sess, batch_size=200, verbose=False):
"""Calculates the activations of the pool_3 layer for all images.
Params:
-- images : Numpy array of dimension (n_images, hi, wi, 3). The values
must lie between 0 and 256.
-- sess : current session
-- batch_size : the images numpy array is split into batches with batch size
batch_size. A reasonable batch size depends on the disposable hardware.
-- verbose : If set to True and parameter out_step is given, the number of calculated
batches is reported.
Returns:
-- A numpy array of dimension (num images, 2048) that contains the
activations of the given tensor when feeding inception with the query tensor.
"""
inception_layer = _get_inception_layer(sess)
n_images = images.shape[0]
if batch_size > n_images:
print(
"warning: batch size is bigger than the data size. setting batch size to data size"
)
batch_size = n_images
n_batches = n_images // batch_size
pred_arr = np.empty((n_images, 2048))
for i in tqdm(range(n_batches)):
if verbose:
print("\rPropagating batch %d/%d" % (i + 1, n_batches),
end="",
flush=True)
start = i * batch_size
if start + batch_size < n_images:
end = start + batch_size
else:
end = n_images
batch = images[start:end]
pred = sess.run(inception_layer,
{'FID_Inception_Net/ExpandDims:0': batch})
pred_arr[start:end] = pred.reshape(batch_size, -1)
if verbose:
print(" done")
return pred_arr
#-------------------------------------------------------------------------------
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of the pool_3 layer of the
inception net ( like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations of the pool_3 layer, precalcualted
on an representive data set.
-- sigma1: The covariance matrix over activations of the pool_3 layer for
generated samples.
-- sigma2: The covariance matrix over activations of the pool_3 layer,
precalcualted on an representive data set.
Returns:
-- : The Frechet Distance.
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(
sigma2) - 2 * tr_covmean
def compute_fid_from_npz(path):
print(path)
with np.load(path) as data:
fake_imgs = data['fake']
name = None
for name in ['imagenet', 'cifar', 'places']:
if name in path:
real_imgs = name
break
print('Inferred name', name)
if name is None:
real_imgs = data['real']
if fake_imgs.shape[0] < 1000: return 0
inception_path = check_or_download_inception(None)
create_inception_graph(inception_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m1, s1 = calculate_activation_statistics(fake_imgs, sess)
if isinstance(real_imgs, str):
print(f'using cached image stats for {real_imgs}')
with np.load(precomputed_stats[real_imgs]) as data:
m2, s2 = data['m'], data['s']
else:
print('computing real images stats from scratch')
m2, s2 = calculate_activation_statistics(real_imgs, sess)
return calculate_frechet_distance(m1, s1, m2, s2)
precomputed_stats = {
'places':
'output/places_gt_stats.npz',
'imagenet':
'output/imagenet_gt_stats.npz',
'cifar':
'output/cifar_gt_stats.npz'
}
def compute_fid_from_imgs(fake_imgs, real_imgs):
inception_path = check_or_download_inception(None)
create_inception_graph(inception_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m1, s1 = calculate_activation_statistics(fake_imgs, sess)
if isinstance(real_imgs, str):
with np.load(precomputed_stats[real_imgs]) as data:
m2, s2 = data['m'], data['s']
else:
m2, s2 = calculate_activation_statistics(real_imgs, sess)
return calculate_frechet_distance(m1, s1, m2, s2)
def compute_stats(exp_path):
#TODO: a bit hacky
if 'places' in exp_path and not os.path.exists(precomputed_stats['places']):
with np.load('output/places_gt_imgs.npz') as data_real:
real_imgs = data_real['real']
print('loaded real places images', real_imgs.shape)
inception_path = check_or_download_inception(None)
create_inception_graph(inception_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m, s = calculate_activation_statistics(real_imgs, sess)
np.savez(precomputed_stats['places'], m=m, s=s)
if 'imagenet' in exp_path and not os.path.exists(precomputed_stats['imagenet']):
with np.load('output/imagenet_gt_imgs.npz') as data_real:
real_imgs = data_real['real']
print('loaded real imagenet images', real_imgs.shape)
inception_path = check_or_download_inception(None)
create_inception_graph(inception_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m, s = calculate_activation_statistics(real_imgs, sess)
np.savez(precomputed_stats['imagenet'], m=m, s=s)
if 'cifar' in exp_path and not os.path.exists(precomputed_stats['cifar']):
with np.load('output/cifar_gt_imgs.npz') as data_real:
real_imgs = data_real['real']
print('loaded real cifar images', real_imgs.shape)
inception_path = check_or_download_inception(None)
create_inception_graph(inception_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m, s = calculate_activation_statistics(real_imgs, sess)
np.savez(precomputed_stats['cifar'], m=m, s=s)
if __name__ == '__main__':
import argparse
import json
parser = argparse.ArgumentParser('compute TF FID')
parser.add_argument('--samples', help='path to samples')
parser.add_argument('--it', type=str, help='path to samples')
parser.add_argument('--results_dir', help='path to results_dir')
args = parser.parse_args()
it = args.it
results_dir = args.results_dir
compute_stats(args.samples)
mean = compute_fid_from_npz(args.samples)
print(f'FID: {mean}')
if args.results_dir is not None:
with open(os.path.join(args.results_dir, 'fid_results.json')) as f:
fid_results = json.load(f)
fid_results[it] = mean
print(f'{results_dir} iteration {it} FID: {mean}')
with open(os.path.join(args.results_dir, 'fid_results.json'), 'w') as f:
f.write(json.dumps(fid_results))
================================================
FILE: gan_training/metrics/inception_score.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data
from torchvision.models.inception import inception_v3
import numpy as np
from scipy.stats import entropy
def inception_score(imgs, device=None, batch_size=32, resize=False, splits=1):
"""Computes the inception score of the generated images imgs
Args:
imgs: Torch dataset of (3xHxW) numpy images normalized in the
range [-1, 1]
cuda: whether or not to run on GPU
batch_size: batch size for feeding into Inception v3
splits: number of splits
"""
N = len(imgs)
assert batch_size > 0
assert N > batch_size
# Set up dataloader
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
# Load inception model
inception_model = inception_v3(pretrained=True, transform_input=False)
inception_model = inception_model.to(device)
inception_model.eval()
up = nn.Upsample(size=(299, 299), mode='bilinear').to(device)
def get_pred(x):
with torch.no_grad():
if resize:
x = up(x)
x = inception_model(x)
out = F.softmax(x, dim=-1)
out = out.cpu().numpy()
return out
# Get predictions
preds = np.zeros((N, 1000))
for i, batch in enumerate(dataloader, 0):
batchv = batch.to(device)
batch_size_i = batch.size()[0]
preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)
# Now compute the mean kl-div
split_scores = []
for k in range(splits):
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
return np.mean(split_scores), np.std(split_scores)
================================================
FILE: gan_training/metrics/tf_is/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: gan_training/metrics/tf_is/README.md
================================================
Inception Score
=====================================
A new Tensorflow implementation of the "Inception Score" (IS) for the evaluation of generative models, with a bug raised in [https://github.com/openai/improved-gan/issues/29](https://github.com/openai/improved-gan/issues/29) fixed.
## Major Dependency
- `tensorflow >= 1.14`
## Features
- Fast, easy-to-use and memory-efficient, written in a way that is similar to the original implementation
- No prior knowledge about Tensorflow is necessary if your are using CPU or GPU
- Makes use of [TF-GAN](https://github.com/tensorflow/gan)
- Downloads InceptionV1 automatically
- Compatible with both Python 2 and Python 3
## Usage
- If you are working with GPU, use `inception_score.py`; if you are working with TPU, use `inception_score_tpu.py` and pass a Tensorflow Session and a [TPUStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/TPUStrategy) as additional arguments.
- Call `get_inception_score(images, splits=10)`, where `images` is a numpy array with values ranging from 0 to 255 and shape in the form `[N, 3, HEIGHT, WIDTH]` where `N`, `HEIGHT` and `WIDTH` can be arbitrary. `dtype` of the images is recommended to be `np.uint8` to save CPU memory.
- A smaller `BATCH_SIZE` reduces GPU/TPU memory usage, but at the cost of a slight slowdown.
- If you want to compute a general "Classifier Score" with probabilities `preds` from another classifier, call `preds2score(preds, splits=10)`. `preds` can be a numpy array of arbitrary shape `[N, num_classes]`.
## Links
- The Inception Score was proposed in the paper [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498)
- Code for the [Fréchet Inception Distance](https://github.com/tsc2017/Frechet-Inception-Distance)
================================================
FILE: gan_training/metrics/tf_is/inception_score.py
================================================
'''
From https://github.com/tsc2017/Inception-Score
Code derived from https://github.com/openai/improved-gan/blob/master/inception_score/model.py and https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
Usage:
Call get_inception_score(images, splits=10)
Args:
images: A numpy array with values ranging from 0 to 255 and shape in the form [N, 3, HEIGHT, WIDTH] where N, HEIGHT and WIDTH can be arbitrary. A dtype of np.uint8 is recommended to save CPU memory.
splits: The number of splits of the images, default is 10.
Returns:
Mean and standard deviation of the Inception Score across the splits.
'''
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import functools
import numpy as np
import time
from tqdm import tqdm
from tensorflow.python.ops import array_ops
tfgan = tf.contrib.gan
session=tf.compat.v1.InteractiveSession()
# A smaller BATCH_SIZE reduces GPU memory usage, but at the cost of a slight slowdown
BATCH_SIZE = 64
INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz'
INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb'
# Run images through Inception.
inception_images = tf.compat.v1.placeholder(tf.float32, [None, 3, None, None])
def inception_logits(images = inception_images, num_splits = 1):
images = tf.transpose(images, [0, 2, 3, 1])
size = 299
images = tf.compat.v1.image.resize_bilinear(images, [size, size])
generated_images_list = array_ops.split(images, num_or_size_splits = num_splits)
logits = tf.map_fn(
fn = functools.partial(
tfgan.eval.run_inception,
default_graph_def_fn = functools.partial(
tfgan.eval.get_graph_def_from_url_tarball,
INCEPTION_URL,
INCEPTION_FROZEN_GRAPH,
os.path.basename(INCEPTION_URL)),
output_tensor = 'logits:0'),
elems = array_ops.stack(generated_images_list),
parallel_iterations = 8,
back_prop = False,
swap_memory = True,
name = 'RunClassifier')
logits = array_ops.concat(array_ops.unstack(logits), 0)
return logits
logits=inception_logits()
def get_inception_probs(inps):
n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE))
preds = np.zeros([inps.shape[0], 1000], dtype = np.float32)
for i in tqdm(range(n_batches)):
inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] / 255. * 2 - 1
preds[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(logits,{inception_images: inp})[:, :1000]
preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True)
return preds
def preds2score(preds, splits=10):
scores = []
for i in range(splits):
part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return np.mean(scores), np.std(scores)
def get_inception_score(images, splits=10):
assert(type(images) == np.ndarray)
assert(len(images.shape) == 4)
assert(images.shape[1] == 3)
assert(np.min(images[0]) >= 0 and np.max(images[0]) > 10), 'Image values should be in the range [0, 255]'
print('Calculating Inception Score with %i images in %i splits' % (images.shape[0], splits))
start_time=time.time()
preds = get_inception_probs(images)
mean, std = preds2score(preds, splits)
print('Inception Score calculation time: %f s' % (time.time() - start_time))
return mean, std # Reference values: 11.38 for 50000 CIFAR-10 training set images, or mean=11.31, std=0.10 if in 10 splits.
def compute_is_from_npz(path):
with np.load(path) as data:
fake_imgs = data['fake']
fake_imgs = fake_imgs.transpose(0, 3, 1, 2)
print(fake_imgs.shape)
return get_inception_score(fake_imgs)
if __name__ == '__main__':
import argparse
import json
parser = argparse.ArgumentParser('compute TF IS')
parser.add_argument('--samples', help='path to samples')
parser.add_argument('--it', type=str, help='path to samples')
parser.add_argument('--results_dir', help='path to results_dir')
args = parser.parse_args()
it = args.it
results_dir = args.results_dir
mean, std = compute_is_from_npz(args.samples)
with open(os.path.join(args.results_dir, 'is_results.json')) as f:
is_results = json.load(f)
is_results[it] = float(mean)
print(f'{results_dir} iteration {it} IS: {mean}')
with open(os.path.join(args.results_dir, 'is_results.json'), 'w') as f:
f.write(json.dumps(is_results))
================================================
FILE: gan_training/models/__init__.py
================================================
from gan_training.models import (dcgan_deep, dcgan_shallow, resnet2)
generator_dict = {
'resnet2': resnet2.Generator,
'dcgan_deep': dcgan_deep.Generator,
'dcgan_shallow': dcgan_shallow.Generator
}
discriminator_dict = {
'resnet2': resnet2.Discriminator,
'dcgan_deep': dcgan_deep.Discriminator,
'dcgan_shallow': dcgan_shallow.Discriminator
}
================================================
FILE: gan_training/models/blocks.py
================================================
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
class ResnetBlock(nn.Module):
def __init__(self,
fin,
fout,
bn,
nclasses,
fhidden=None,
is_bias=True):
super().__init__()
# Attributes
self.is_bias = is_bias
self.learned_shortcut = (fin != fout)
self.fin = fin
self.fout = fout
if fhidden is None:
self.fhidden = min(fin, fout)
else:
self.fhidden = fhidden
# Submodules
self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
self.conv_1 = nn.Conv2d(self.fhidden,
self.fout,
3,
stride=1,
padding=1,
bias=is_bias)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(self.fin,
self.fout,
1,
stride=1,
padding=0,
bias=False)
self.bn0 = bn(self.fin, nclasses)
self.bn1 = bn(self.fhidden, nclasses)
def forward(self, x, y):
x_s = self._shortcut(x)
dx = self.conv_0(actvn(self.bn0(x, y)))
dx = self.conv_1(actvn(self.bn1(dx, y)))
out = x_s + 0.1 * dx
return out
def _shortcut(self, x):
if self.learned_shortcut:
x_s = self.conv_s(x)
else:
x_s = x
return x_s
def actvn(x):
out = F.leaky_relu(x, 2e-1)
return out
class LatentEmbeddingConcat(nn.Module):
''' projects class embedding onto hypersphere and returns the concat of the latent and the class embedding '''
def __init__(self, nlabels, embed_dim):
super().__init__()
self.embedding = nn.Embedding(nlabels, embed_dim)
def forward(self, z, y):
assert (y.size(0) == z.size(0))
yembed = self.embedding(y)
yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True)
yz = torch.cat([z, yembed], dim=1)
return yz
class NormalizeLinear(nn.Module):
def __init__(self, act_dim, k_value):
super().__init__()
self.lin = nn.Linear(act_dim, k_value)
def normalize(self):
self.lin.weight.data = F.normalize(self.lin.weight.data, p=2, dim=1)
def forward(self, x):
self.normalize()
return self.lin(x)
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, inp, *args, **kwargs):
return inp
class LinearConditionalMaskLogits(nn.Module):
''' runs activated logits through fc and masks out the appropriate discriminator score according to class number'''
def __init__(self, nc, nlabels):
super().__init__()
self.fc = nn.Linear(nc, nlabels)
def forward(self, inp, y=None, take_best=False, get_features=False):
out = self.fc(inp)
if get_features: return out
if not take_best:
y = y.view(-1)
index = Variable(torch.LongTensor(range(out.size(0))))
if y.is_cuda:
index = index.cuda()
return out[index, y]
else:
# high activation means real, so take the highest activations
best_logits, _ = out.max(dim=1)
return best_logits
class ProjectionDiscriminatorLogits(nn.Module):
''' takes in activated flattened logits before last linear layer and implements https://arxiv.org/pdf/1802.05637.pdf '''
def __init__(self, nc, nlabels):
super().__init__()
self.fc = nn.Linear(nc, 1)
self.embedding = nn.Embedding(nlabels, nc)
self.nlabels = nlabels
def forward(self, x, y, take_best=False):
output = self.fc(x)
if not take_best:
label_info = torch.sum(self.embedding(y) * x, dim=1, keepdim=True)
return (output + label_info).view(x.size(0))
else:
#TODO: this may be computationally expensive, maybe we want to do the global pooling first to reduce x's size
index = torch.LongTensor(range(self.nlabels)).cuda()
labels = index.repeat((x.size(0), ))
x = x.repeat_interleave(self.nlabels, dim=0)
label_info = torch.sum(self.embedding(labels) * x,
dim=1,
keepdim=True).view(output.size(0),
self.nlabels)
# high activation means real, so take the highest activations
best_logits, _ = label_info.max(dim=1)
return output.view(output.size(0)) + best_logits
class LinearUnconditionalLogits(nn.Module):
''' standard discriminator logit layer '''
def __init__(self, nc):
super().__init__()
self.fc = nn.Linear(nc, 1)
def forward(self, inp, y, take_best=False):
assert (take_best == False)
out = self.fc(inp)
return out.view(out.size(0))
class Reshape(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
batch_size = x.shape[0]
return x.view(*((batch_size, ) + self.shape))
class ConditionalBatchNorm2d(nn.Module):
''' from https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775 '''
def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm2d(num_features, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].normal_(
1, 0.02) # Initialize scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_(
) # Initialize bias at 0
def forward(self, x, y):
out = self.bn(x)
gamma, beta = self.embed(y).chunk(2, 1)
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(
-1, self.num_features, 1, 1)
return out
class BatchNorm2d(nn.Module):
''' identical to nn.BatchNorm2d but takes in y input that is ignored '''
def __init__(self, nc, nchannels, **kwargs):
super().__init__()
self.bn = nn.BatchNorm2d(nc)
def forward(self, x, y):
return self.bn(x)
================================================
FILE: gan_training/models/dcgan_deep.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data
import torch.utils.data.distributed
from gan_training.models import blocks
class Generator(nn.Module):
def __init__(self,
nlabels,
conditioning,
z_dim=128,
nc=3,
ngf=64,
embed_dim=256,
**kwargs):
super(Generator, self).__init__()
assert conditioning != 'unconditional' or nlabels == 1
if conditioning == 'embedding':
self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_dim)
self.fc = nn.Linear(z_dim + embed_dim, 4 * 4 * ngf * 8)
elif conditioning == 'unconditional':
self.get_latent = blocks.Identity()
self.fc = nn.Linear(z_dim, 4 * 4 * ngf * 8)
else:
raise NotImplementedError(
f"{conditioning} not implemented for generator")
bn = blocks.BatchNorm2d
self.nlabels = nlabels
self.conv1 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1)
self.bn1 = bn(ngf * 4, nlabels)
self.conv2 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1)
self.bn2 = bn(ngf * 2, nlabels)
self.conv3 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1)
self.bn3 = bn(ngf, nlabels)
self.conv_out = nn.Sequential(nn.Conv2d(ngf, nc, 3, 1, 1), nn.Tanh())
def forward(self, input, y):
y = y.clamp(None, self.nlabels - 1)
out = self.get_latent(input, y)
out = self.fc(out)
out = out.view(out.size(0), -1, 4, 4)
out = F.relu(self.bn1(self.conv1(out), y))
out = F.relu(self.bn2(self.conv2(out), y))
out = F.relu(self.bn3(self.conv3(out), y))
return self.conv_out(out)
class Discriminator(nn.Module):
def __init__(self,
nlabels,
conditioning,
nc=3,
ndf=64,
pack_size=1,
features='penultimate',
**kwargs):
super(Discriminator, self).__init__()
assert conditioning != 'unconditional' or nlabels == 1
self.nlabels = nlabels
self.conv1 = nn.Sequential(nn.Conv2d(nc * pack_size, ndf, 3, 1, 1), nn.LeakyReLU(0.1))
self.conv2 = nn.Sequential(nn.Conv2d(ndf, ndf, 4, 2, 1), nn.LeakyReLU(0.1))
self.conv3 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, 3, 1, 1), nn.LeakyReLU(0.1))
self.conv4 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 2, 4, 2, 1), nn.LeakyReLU(0.1))
self.conv5 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, 3, 1, 1), nn.LeakyReLU(0.1))
self.conv6 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1), nn.LeakyReLU(0.1))
self.conv7 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, 3, 1, 1), nn.LeakyReLU(0.1))
if conditioning == 'mask':
self.fc_out = blocks.LinearConditionalMaskLogits(
ndf * 8 * 4 * 4, nlabels)
elif conditioning == 'unconditional':
self.fc_out = blocks.LinearUnconditionalLogits(
ndf * 8 * 4 * 4)
else:
raise NotImplementedError(
f"{conditioning} not implemented for discriminator")
self.features = features
self.pack_size = pack_size
print(f'Getting features from {self.features}')
def stack(self, x):
#pacgan
nc = self.pack_size
assert (x.size(0) % nc == 0)
if nc == 1:
return x
x_new = []
for i in range(x.size(0) // nc):
imgs_to_stack = x[i * nc:(i + 1) * nc]
x_new.append(torch.cat([t for t in imgs_to_stack], dim=0))
return torch.stack(x_new)
def forward(self, input, y=None, get_features=False):
input = self.stack(input)
out = self.conv1(input)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
out = self.conv5(out)
out = self.conv6(out)
out = self.conv7(out)
if get_features and self.features == "penultimate":
return out.view(out.size(0), -1)
if get_features and self.features == "summed":
return out.view(out.size(0), out.size(1), -1).sum(dim=2)
out = out.view(out.size(0), -1)
y = y.clamp(None, self.nlabels - 1)
result = self.fc_out(out, y)
assert (len(result.shape) == 1)
return result
if __name__ == '__main__':
z = torch.zeros((1, 128))
g = Generator()
x = torch.zeros((1, 3, 32, 32))
d = Discriminator()
g(z)
d(g(z))
d(x)
================================================
FILE: gan_training/models/dcgan_shallow.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data
import torch.utils.data.distributed
from gan_training.models import blocks
class Generator(nn.Module):
def __init__(self,
nlabels,
conditioning,
z_dim=128,
nc=3,
ngf=64,
embed_dim=256,
**kwargs):
super(Generator, self).__init__()
assert conditioning != 'unconditional' or nlabels == 1
if conditioning == 'embedding':
self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_dim)
self.fc = nn.Linear(z_dim + embed_dim, 4 * 4 * ngf * 8)
elif conditioning == 'unconditional':
self.get_latent = blocks.Identity()
self.fc = nn.Linear(z_dim, 4 * 4 * ngf * 8)
else:
raise NotImplementedError(
f"{conditioning} not implemented for generator")
bn = blocks.BatchNorm2d
self.nlabels = nlabels
self.conv1 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1)
self.bn1 = bn(ngf * 4, nlabels)
self.conv2 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1)
self.bn2 = bn(ngf * 2, nlabels)
self.conv3 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1)
self.bn3 = bn(ngf, nlabels)
self.conv_out = nn.Sequential(nn.Conv2d(ngf, nc, 3, 1, 1), nn.Tanh())
def forward(self, input, y):
y = y.clamp(None, self.nlabels - 1)
out = self.get_latent(input, y)
out = self.fc(out)
out = out.view(out.size(0), -1, 4, 4)
out = F.relu(self.bn1(self.conv1(out), y))
out = F.relu(self.bn2(self.conv2(out), y))
out = F.relu(self.bn3(self.conv3(out), y))
return self.conv_out(out)
class Discriminator(nn.Module):
def __init__(self,
nlabels,
conditioning,
features='penultimate',
pack_size=1,
nc=3,
ndf=64,
**kwargs):
super(Discriminator, self).__init__()
assert conditioning != 'unconditional' or nlabels == 1
self.nlabels = nlabels
self.conv1 = nn.Sequential(nn.Conv2d(nc * pack_size, ndf, 4, 2, 1),
nn.BatchNorm2d(ndf),
nn.LeakyReLU(0.2, inplace=True))
self.conv2 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True))
self.conv3 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True))
self.conv4 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True))
if conditioning == 'mask':
self.fc_out = blocks.LinearConditionalMaskLogits(ndf * 8 * 4 , nlabels)
elif conditioning == 'unconditional':
self.fc_out = blocks.LinearUnconditionalLogits(ndf * 8 * 4)
else:
raise NotImplementedError(
f"{conditioning} not implemented for discriminator")
self.pack_size = pack_size
self.features = features
print(f'Getting features from {self.features}')
def stack(self, x):
#pacgan
nc = self.pack_size
if nc == 1:
return x
x_new = []
for i in range(x.size(0) // nc):
imgs_to_stack = x[i * nc:(i + 1) * nc]
x_new.append(torch.cat([t for t in imgs_to_stack], dim=0))
return torch.stack(x_new)
def forward(self, input, y=None, get_features=False):
input = self.stack(input)
out = self.conv1(input)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
out = out.view(out.size(0), -1)
if get_features: return out.view(out.size(0), -1)
y = y.clamp(None, self.nlabels - 1)
result = self.fc_out(out, y)
assert (len(result.shape) == 1)
return result
if __name__ == '__main__':
z = torch.zeros((1, 128))
g = Generator()
x = torch.zeros((1, 3, 32, 32))
d = Discriminator()
g(z)
d(g(z))
d(x)
================================================
FILE: gan_training/models/resnet2.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torch.utils.data
import torch.utils.data.distributed
from gan_training.models import blocks
from gan_training.models.blocks import ResnetBlock
from torch.nn.utils.spectral_norm import spectral_norm
class Generator(nn.Module):
def __init__(self,
z_dim,
nlabels,
size,
conditioning,
embed_size=256,
nfilter=64,
**kwargs):
super().__init__()
s0 = self.s0 = size // 32
nf = self.nf = nfilter
self.nlabels = nlabels
self.z_dim = z_dim
assert conditioning != 'unconditional' or nlabels == 1
if conditioning == 'embedding':
self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_size)
self.fc = nn.Linear(z_dim + embed_size, 16 * nf * s0 * s0)
elif conditioning == 'unconditional':
self.get_latent = blocks.Identity()
self.fc = nn.Linear(z_dim, 16 * nf * s0 * s0)
else:
raise NotImplementedError(
f"{conditioning} not implemented for generator")
#either use conditional batch norm, or use no batch norm
bn = blocks.Identity
self.resnet_0_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_0_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_1_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_1_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_2_0 = ResnetBlock(16 * nf, 8 * nf, bn, nlabels)
self.resnet_2_1 = ResnetBlock(8 * nf, 8 * nf, bn, nlabels)
self.resnet_3_0 = ResnetBlock(8 * nf, 4 * nf, bn, nlabels)
self.resnet_3_1 = ResnetBlock(4 * nf, 4 * nf, bn, nlabels)
self.resnet_4_0 = ResnetBlock(4 * nf, 2 * nf, bn, nlabels)
self.resnet_4_1 = ResnetBlock(2 * nf, 2 * nf, bn, nlabels)
self.resnet_5_0 = ResnetBlock(2 * nf, 1 * nf, bn, nlabels)
self.resnet_5_1 = ResnetBlock(1 * nf, 1 * nf, bn, nlabels)
self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)
def forward(self, z, y):
y = y.clamp(None, self.nlabels - 1)
out = self.get_latent(z, y)
out = self.fc(out)
out = out.view(z.size(0), 16 * self.nf, self.s0, self.s0)
out = self.resnet_0_0(out, y)
out = self.resnet_0_1(out, y)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_1_0(out, y)
out = self.resnet_1_1(out, y)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_2_0(out, y)
out = self.resnet_2_1(out, y)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_3_0(out, y)
out = self.resnet_3_1(out, y)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_4_0(out, y)
out = self.resnet_4_1(out, y)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_5_0(out, y)
out = self.resnet_5_1(out, y)
out = self.conv_img(actvn(out))
out = torch.tanh(out)
return out
class Discriminator(nn.Module):
def __init__(self,
nlabels,
size,
conditioning,
nfilter=64,
features='penultimate',
**kwargs):
super().__init__()
s0 = self.s0 = size // 32
nf = self.nf = nfilter
self.nlabels = nlabels
assert conditioning != 'unconditional' or nlabels == 1
bn = blocks.Identity
self.conv_img = nn.Conv2d(3, 1 * nf, 3, padding=1)
self.resnet_0_0 = ResnetBlock(1 * nf, 1 * nf, bn, nlabels)
self.resnet_0_1 = ResnetBlock(1 * nf, 2 * nf, bn, nlabels)
self.resnet_1_0 = ResnetBlock(2 * nf, 2 * nf, bn, nlabels)
self.resnet_1_1 = ResnetBlock(2 * nf, 4 * nf, bn, nlabels)
self.resnet_2_0 = ResnetBlock(4 * nf, 4 * nf, bn, nlabels)
self.resnet_2_1 = ResnetBlock(4 * nf, 8 * nf, bn, nlabels)
self.resnet_3_0 = ResnetBlock(8 * nf, 8 * nf, bn, nlabels)
self.resnet_3_1 = ResnetBlock(8 * nf, 16 * nf, bn, nlabels)
self.resnet_4_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_4_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_5_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
self.resnet_5_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)
if conditioning == 'mask':
self.fc_out = blocks.LinearConditionalMaskLogits(
16 * nf * s0 * s0, nlabels)
elif conditioning == 'unconditional':
self.fc_out = blocks.LinearUnconditionalLogits(16 * nf * s0 * s0)
else:
raise NotImplementedError(
f"{conditioning} not implemented for discriminator")
self.features = features
def forward(self, x, y=None, get_features=False):
batch_size = x.size(0)
if y is not None:
y = y.clamp(None, self.nlabels - 1)
out = self.conv_img(x)
out = self.resnet_0_0(out, y)
out = self.resnet_0_1(out, y)
out = F.avg_pool2d(out, 3, stride=2, padding=1)
out = self.resnet_1_0(out, y)
out = self.resnet_1_1(out, y)
out = F.avg_pool2d(out, 3, stride=2, padding=1)
out = self.resnet_2_0(out, y)
out = self.resnet_2_1(out, y)
out = F.avg_pool2d(out, 3, stride=2, padding=1)
out = self.resnet_3_0(out, y)
out = self.resnet_3_1(out, y)
out = F.avg_pool2d(out, 3, stride=2, padding=1)
out = self.resnet_4_0(out, y)
out = self.resnet_4_1(out, y)
out = F.avg_pool2d(out, 3, stride=2, padding=1)
out = self.resnet_5_0(out, y)
out = self.resnet_5_1(out, y)
out = actvn(out)
if get_features and self.features == 'summed':
return out.view(out.size(0), out.size(1), -1).sum(dim=2)
out = out.view(batch_size, 16 * self.nf * self.s0 * self.s0)
if get_features: return out.view(batch_size, -1)
result = self.fc_out(out, y)
assert (len(result.shape) == 1)
return result
def actvn(x):
out = F.leaky_relu(x, 2e-1)
return out
================================================
FILE: gan_training/models/resnet2s.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torch.utils.data
import torch.utils.data.distributed
from collections import OrderedDict
class Reshape(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
batch_size = x.shape[0]
return x.view(*((batch_size, ) + self.shape))
class Generator(nn.Module):
'''
Perfectly equivalent to resnet2.Generator (can load state dicts
from that class), but organizes layers as a sequence for more
automatic inversion.
'''
def __init__(self,
z_dim,
nlabels,
size,
embed_size=256,
nfilter=64,
use_class_labels=False,
**kwargs):
super().__init__()
s0 = self.s0 = size // 32
nf = self.nf = nfilter
self.z_dim = z_dim
self.use_class_labels = use_class_labels
# Submodules
if use_class_labels:
self.condition = ConditionGen(z_dim, nlabels, embed_size)
latent_dim = self.condition.latent_dim
else:
latent_dim = z_dim
self.layers = nn.Sequential(
OrderedDict([('fc', nn.Linear(latent_dim, 16 * nf * s0 * s0)),
('reshape', Reshape(16 * self.nf, self.s0, self.s0)),
('resnet_0_0', ResnetBlock(16 * nf, 16 * nf)),
('resnet_0_1', ResnetBlock(16 * nf, 16 * nf)),
('upsample_1', nn.Upsample(scale_factor=2)),
('resnet_1_0', ResnetBlock(16 * nf, 16 * nf)),
('resnet_1_1', ResnetBlock(16 * nf, 16 * nf)),
('upsample_2', nn.Upsample(scale_factor=2)),
('resnet_2_0', ResnetBlock(16 * nf, 8 * nf)),
('resnet_2_1', ResnetBlock(8 * nf, 8 * nf)),
('upsample_3', nn.Upsample(scale_factor=2)),
('resnet_3_0', ResnetBlock(8 * nf, 4 * nf)),
('resnet_3_1', ResnetBlock(4 * nf, 4 * nf)),
('upsample_4', nn.Upsample(scale_factor=2)),
('resnet_4_0', ResnetBlock(4 * nf, 2 * nf)),
('resnet_4_1', ResnetBlock(2 * nf, 2 * nf)),
('upsample_5', nn.Upsample(scale_factor=2)),
('resnet_5_0', ResnetBlock(2 * nf, 1 * nf)),
('resnet_5_1', ResnetBlock(1 * nf, 1 * nf)),
('img_relu', nn.LeakyReLU(2e-1)),
('conv_img', nn.Conv2d(nf, 3, 3, padding=1)),
('tanh', nn.Tanh())]))
def forward(self, z, y=None):
assert (y is None or z.size(0) == y.size(0))
assert (not self.use_class_labels or y is not None)
batch_size = z.size(0)
if self.use_class_labels:
z = self.condition(z, y)
return self.layers(z)
def load_v2_state_dict(self, state_dict):
converted = {}
for k, v in state_dict.items():
if 'module.' in k: k = k.split('module.')[1]
if k.startswith('embedding'):
k = 'condition.' + k
elif k == 'get_latent.embedding.weight':
k = 'condition.embedding.weight'
else:
k = 'layers.' + k
converted[k] = v
self.load_state_dict(converted)
class ConditionGen(nn.Module):
def __init__(self, z_dim, nlabels, embed_size=256):
super().__init__()
self.embedding = nn.Embedding(nlabels, embed_size)
self.latent_dim = z_dim + embed_size
self.z_dim = z_dim
self.nlabels = nlabels
self.embed_size = embed_size
def forward(self, z, y):
assert (z.size(0) == y.size(0))
batch_size = z.size(0)
if y.dtype is torch.int64:
yembed = self.embedding(y)
else:
yembed = y
yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True)
return torch.cat([z, yembed], dim=1)
def convert_from_resnet2_generator(gen):
nlabels, embed_size = 0, 0
use_class_labels = False
if hasattr(gen, 'embedding'):
# new version does not have gen.use_class_labels..
nlabels = gen.embedding.num_embeddings
embed_size = gen.embedding.embedding_dim
use_class_labels = True
if hasattr(gen, 'get_latent'):
# new version does not have gen.use_class_labels..
nlabels = gen.get_latent.embedding.num_embeddings
embed_size = gen.get_latent.embedding.embedding_dim
use_class_labels = True
size = gen.s0 * 32
newgen = Generator(gen.z_dim, nlabels, size, embed_size, gen.nf,
use_class_labels)
newgen.load_v2_state_dict(gen.state_dict())
return newgen
class ResnetBlock(nn.Module):
def __init__(self, fin, fout, fhidden=None, is_bias=True):
super().__init__()
# Attributes
self.is_bias = is_bias
self.learned_shortcut = (fin != fout)
self.fin = fin
self.fout = fout
if fhidden is None:
self.fhidden = min(fin, fout)
else:
self.fhidden = fhidden
# Submodules
self.conv_0 = nn.Conv2d(self.fin,
self.fhidden,
kernel_size=3,
stride=1,
padding=1)
self.conv_1 = nn.Conv2d(self.fhidden,
self.fout,
kernel_size=3,
stride=1,
padding=1,
bias=is_bias)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(self.fin,
self.fout,
kernel_size=1,
stride=1,
padding=0,
bias=False)
def forward(self, x):
x_s = self._shortcut(x)
dx = self.conv_0(actvn(x))
dx = self.conv_1(actvn(dx))
out = x_s + 0.1 * dx
return out
def _shortcut(self, x):
if self.learned_shortcut:
x_s = self.conv_s(x)
else:
x_s = x
return x_s
def actvn(x):
out = F.leaky_relu(x, 2e-1)
return out
================================================
FILE: gan_training/models/resnet3.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torch.utils.data
import torch.utils.data.distributed
from collections import OrderedDict
class Generator(nn.Module):
'''
Perfectly equivalent to resnet2.Generator (can load state dicts
from that class), but organizes layers as a sequence for more
automatic inversion.
'''
def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64,
use_class_labels=False, **kwargs):
super().__init__()
s0 = self.s0 = size // 32
nf = self.nf = nfilter
self.z_dim = z_dim
self.use_class_labels = use_class_labels
# Submodules
if use_class_labels:
self.condition = ConditionGen(z_dim, nlabels, embed_size)
latent_dim = self.condition.latent_dim
else:
latent_dim = z_dim
self.layers = nn.Sequential(OrderedDict([
('fc', nn.Linear(latent_dim, 16*nf*s0*s0)),
('reshape', Reshape(16*self.nf, self.s0, self.s0)),
('resnet_0_0', ResnetBlock(16*nf, 16*nf)),
('resnet_0_1', ResnetBlock(16*nf, 16*nf)),
('upsample_1', nn.Upsample(scale_factor=2)),
('resnet_1_0', ResnetBlock(16*nf, 16*nf)),
('resnet_1_1', ResnetBlock(16*nf, 16*nf)),
('upsample_2', nn.Upsample(scale_factor=2)),
('resnet_2_0', ResnetBlock(16*nf, 8*nf)),
('resnet_2_1', ResnetBlock(8*nf, 8*nf)),
('upsample_3', nn.Upsample(scale_factor=2)),
('resnet_3_0', ResnetBlock(8*nf, 4*nf)),
('resnet_3_1', ResnetBlock(4*nf, 4*nf)),
('upsample_4', nn.Upsample(scale_factor=2)),
('resnet_4_0', ResnetBlock(4*nf, 2*nf)),
('resnet_4_1', ResnetBlock(2*nf, 2*nf)),
('upsample_5', nn.Upsample(scale_factor=2)),
('resnet_5_0', ResnetBlock(2*nf, 1*nf)),
('resnet_5_1', ResnetBlock(1*nf, 1*nf)),
('img_relu', nn.LeakyReLU(2e-1)),
('conv_img', nn.Conv2d(nf, 3, 3, padding=1)),
('tanh', nn.Tanh())
]))
def forward(self, z, y=None):
assert(y is None or z.size(0) == y.size(0))
assert(not self.use_class_labels or y is not None)
batch_size = z.size(0)
if self.use_class_labels:
z = self.condition(z, y)
return self.layers(z)
def load_v2_state_dict(self, state_dict):
converted = {}
for k, v in state_dict.items():
if k.startswith('embedding'):
k = 'condition.' + k
elif k == 'get_latent.embedding.weight':
k = 'condition.embedding.weight'
else:
k = 'layers.' + k
converted[k] = v
self.load_state_dict(converted)
class Reshape(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
batch_size = x.shape[0]
return x.view(*((batch_size,) + self.shape))
class ConditionGen(nn.Module):
def __init__(self, z_dim, nlabels, embed_size=256):
super().__init__()
self.embedding = nn.Embedding(nlabels, embed_size)
self.latent_dim = z_dim + embed_size
self.z_dim = z_dim
self.nlabels = nlabels
self.embed_size = embed_size
def forward(self, z, y):
assert(z.size(0) == y.size(0))
batch_size = z.size(0)
if y.dtype is torch.int64:
yembed = self.embedding(y)
else:
yembed = y
yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True)
return torch.cat([z, yembed], dim=1)
def convert_from_resnet2_generator(gen):
nlabels, embed_size = 0, 0
if hasattr(gen, 'get_latent'):
# new version does not have gen.use_class_labels..
nlabels = gen.get_latent.embedding.num_embeddings
embed_size = gen.get_latent.embedding.embedding_dim
use_class_labels = True
elif gen.use_class_labels:
nlabels = gen.embedding.num_embeddings
embed_size = gen.embedding.embedding_dim
use_class_labels = True
size = gen.s0 * 32
newgen = Generator(gen.z_dim, nlabels, size, embed_size, gen.nf, use_class_labels)
newgen.load_v2_state_dict(gen.state_dict())
return newgen
class ResnetBlock(nn.Module):
def __init__(self, fin, fout, fhidden=None, is_bias=True):
super().__init__()
# Attributes
self.is_bias = is_bias
self.learned_shortcut = (fin != fout)
self.fin = fin
self.fout = fout
if fhidden is None:
self.fhidden = min(fin, fout)
else:
self.fhidden = fhidden
# Submodules
self.conv_0 = nn.Conv2d(self.fin, self.fhidden,
kernel_size=3, stride=1, padding=1)
self.conv_1 = nn.Conv2d(self.fhidden, self.fout,
kernel_size=3, stride=1, padding=1, bias=is_bias)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(self.fin, self.fout,
kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, x):
x_s = self._shortcut(x)
dx = self.conv_0(actvn(x))
dx = self.conv_1(actvn(dx))
out = x_s + 0.1*dx
return out
def _shortcut(self, x):
if self.learned_shortcut:
x_s = self.conv_s(x)
else:
x_s = x
return x_s
def actvn(x):
out = F.leaky_relu(x, 2e-1)
return out
================================================
FILE: gan_training/train.py
================================================
# coding: utf-8
import torch
from torch.nn import functional as F
import torch.utils.data
import torch.utils.data.distributed
from torch import autograd
import numpy as np
class Trainer(object):
def __init__(self,
generator,
discriminator,
g_optimizer,
d_optimizer,
gan_type,
reg_type,
reg_param):
self.generator = generator
self.discriminator = discriminator
self.g_optimizer = g_optimizer
self.d_optimizer = d_optimizer
self.gan_type = gan_type
self.reg_type = reg_type
self.reg_param = reg_param
print('D reg gamma', self.reg_param)
def generator_trainstep(self, y, z):
assert (y.size(0) == z.size(0))
toggle_grad(self.generator, True)
toggle_grad(self.discriminator, False)
self.generator.train()
self.discriminator.train()
self.g_optimizer.zero_grad()
x_fake = self.generator(z, y)
d_fake = self.discriminator(x_fake, y)
gloss = self.compute_loss(d_fake, 1)
gloss.backward()
self.g_optimizer.step()
return gloss.item()
def discriminator_trainstep(self, x_real, y, z):
toggle_grad(self.generator, False)
toggle_grad(self.discriminator, True)
self.generator.train()
self.discriminator.train()
self.d_optimizer.zero_grad()
# On real data
x_real.requires_grad_()
d_real = self.discriminator(x_real, y)
dloss_real = self.compute_loss(d_real, 1)
if self.reg_type == 'real' or self.reg_type == 'real_fake':
dloss_real.backward(retain_graph=True)
reg = self.reg_param * compute_grad2(d_real, x_real).mean()
reg.backward()
else:
dloss_real.backward()
# On fake data
with torch.no_grad():
x_fake = self.generator(z, y)
x_fake.requires_grad_()
d_fake = self.discriminator(x_fake, y)
dloss_fake = self.compute_loss(d_fake, 0)
if self.reg_type == 'fake' or self.reg_type == 'real_fake':
dloss_fake.backward(retain_graph=True)
reg = self.reg_param * compute_grad2(d_fake, x_fake).mean()
reg.backward()
else:
dloss_fake.backward()
if self.reg_type == 'wgangp':
reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake, y)
reg.backward()
elif self.reg_type == 'wgangp0':
reg = self.reg_param * self.wgan_gp_reg(
x_real, x_fake, y, center=0.)
reg.backward()
self.d_optimizer.step()
dloss = (dloss_real + dloss_fake)
if self.reg_type == 'none':
reg = torch.tensor(0.)
return dloss.item(), reg.item()
def compute_loss(self, d_out, target):
targets = d_out.new_full(size=d_out.size(), fill_value=target)
if self.gan_type == 'standard':
loss = F.binary_cross_entropy_with_logits(d_out, targets)
elif self.gan_type == 'wgan':
loss = (2 * target - 1) * d_out.mean()
else:
raise NotImplementedError
return loss
def wgan_gp_reg(self, x_real, x_fake, y, center=1.):
batch_size = y.size(0)
eps = torch.rand(batch_size, device=y.device).view(batch_size, 1, 1, 1)
x_interp = (1 - eps) * x_real + eps * x_fake
x_interp = x_interp.detach()
x_interp.requires_grad_()
d_out = self.discriminator(x_interp, y)
reg = (compute_grad2(d_out, x_interp).sqrt() - center).pow(2).mean()
return reg
# Utility functions
def toggle_grad(model, requires_grad):
for p in model.parameters():
p.requires_grad_(requires_grad)
def compute_grad2(d_out, x_in):
batch_size = x_in.size(0)
grad_dout = autograd.grad(outputs=d_out.sum(),
inputs=x_in,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
grad_dout2 = grad_dout.pow(2)
assert (grad_dout2.size() == x_in.size())
reg = grad_dout2.view(batch_size, -1).sum(1)
return reg
def update_average(model_tgt, model_src, beta):
toggle_grad(model_src, False)
toggle_grad(model_tgt, False)
param_dict_src = dict(model_src.named_parameters())
for p_name, p_tgt in model_tgt.named_parameters():
p_src = param_dict_src[p_name]
assert (p_src is not p_tgt)
p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src)
================================================
FILE: gan_training/utils.py
================================================
import torch
import torch.utils.data
import torch.utils.data.distributed
import torchvision
import os
def save_images(imgs, outfile, nrow=8):
imgs = imgs / 2 + 0.5 # unnormalize
torchvision.utils.save_image(imgs, outfile, nrow=nrow)
def get_nsamples(data_loader, N):
x = []
y = []
n = 0
for x_next, y_next in data_loader:
x.append(x_next)
y.append(y_next)
n += x_next.size(0)
if n > N:
break
x = torch.cat(x, dim=0)[:N]
y = torch.cat(y, dim=0)[:N]
return x, y
def update_average(model_tgt, model_src, beta):
param_dict_src = dict(model_src.named_parameters())
for p_name, p_tgt in model_tgt.named_parameters():
p_src = param_dict_src[p_name]
assert (p_src is not p_tgt)
p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src)
def get_most_recent(d, ext):
if not os.path.exists(d):
print('Directory', d, 'does not exist')
return -1
its = []
for f in os.listdir(d):
try:
it = int(f.split(ext + "_")[1].split('.pt')[0])
its.append(it)
except Exception as e:
pass
if len(its) == 0:
print('Found no files with extension \"%s\" under %s' % (ext, d))
return -1
return max(its)
================================================
FILE: metrics.py
================================================
import argparse
import os
import json
from tqdm import tqdm
import numpy as np
import torch
from gan_training.config import load_config
from seeded_sampler import SeededSampler
parser = argparse.ArgumentParser('Computes numbers used in paper and caches them to a result files. Examples include FID, IS, reverse-KL, # modes, FSD, cluster NMI, Purity.')
parser.add_argument('paths', nargs='+', type=str, help='list of configs for each experiment')
parser.add_argument('--it', type=int, default=-1, help='If set, computes numbers only for that iteration')
parser.add_argument('--every', type=int, default=-1, help='skips some checkpoints and only computes those whose iteration number are divisible by every')
parser.add_argument('--fid', action='store_true', help='compute FID metric')
parser.add_argument('--inception', action='store_true', help='compute IS metric')
parser.add_argument('--modes', action='store_true', help='compute # modes and reverse-KL metric')
parser.add_argument('--fsd', action='store_true', help='compute FSD metric')
parser.add_argument('--cluster_metrics', action='store_true', help='compute clustering metrics (NMI, purity)')
parser.add_argument('--device', type=int, default=1, help='device to run the metrics on (can run into OOM issues if same as main device)')
args = parser.parse_args()
device = args.device
dirs = list(args.paths)
N = 50000
BS = 100
datasets = ['imagenet', 'cifar', 'stacked_mnist', 'places']
dataset_to_img = {
'places': 'output/places_gt_imgs.npz',
'imagenet': 'output/imagenet_gt_imgs.npz'}
def load_results(results_dir):
results = []
for results_file in ['fid_results.json', 'is_results.json', 'kl_results.json', 'nmodes_results.json', 'fsd_results.json', 'cluster_metrics.json']:
results_file = os.path.join(results_dir, results_file)
if not os.path.exists(results_file):
with open(results_file, 'w') as f:
f.write(json.dumps({}))
with open(results_file) as f:
results.append(json.load(f))
return results
def get_dataset_from_path(path):
for name in datasets:
if name in path:
print('Inferred dataset:', name)
return name
def pt_to_np(imgs):
'''normalizes pytorch image in [-1, 1] to [0, 255]'''
return (imgs.permute(0, 2, 3, 1).mul_(0.5).add_(0.5).mul_(255)).clamp_(0, 255).numpy()
def sample(sampler):
with torch.no_grad():
samples = []
for _ in tqdm(range(N // BS + 1)):
x_real = sampler.sample(BS)[0].detach().cpu()
x_real = [x.detach().cpu() for x in x_real]
samples.extend(x_real)
samples = torch.stack(samples[:N], dim=0)
return pt_to_np(samples)
root = './'
while len(dirs) > 0:
path = dirs.pop()
if os.path.isdir(path): # search down tree for config files
for d1 in os.listdir(path):
dirs.append(os.path.join(path, d1))
else:
if path.endswith('.yaml'):
config = load_config(path, default_path='configs/default.yaml')
outdir = config['training']['out_dir']
if not os.path.exists(outdir) and config['pretrained'] == {}:
print('Skipping', path, 'outdir', outdir)
continue
results_dir = os.path.join(outdir, 'results')
checkpoint_dir = os.path.join(outdir, 'chkpts')
os.makedirs(results_dir, exist_ok=True)
fid_results, is_results, kl_results, nmodes_results, fsd_results, cluster_results = load_results(results_dir)
checkpoint_files = os.listdir(checkpoint_dir) if os.path.exists(checkpoint_dir) else []
if config['pretrained'] != {}:
checkpoint_files = checkpoint_files + ['pretrained']
for checkpoint in checkpoint_files:
if (checkpoint.endswith('.pt') and checkpoint != 'model.pt') or checkpoint == 'pretrained':
print('Computing for', checkpoint)
if 'model' in checkpoint:
# infer iteration number from checkpoint file w/o loading it
if 'model_' in checkpoint:
it = int(checkpoint.split('model_')[1].split('.pt')[0])
else:
continue
if args.every != 0 and it % args.every != 0:
continue
# iteration 0 is often useless, skip it
if it == 0 or args.it != -1 and it != args.it:
continue
elif checkpoint == 'pretrained':
it = 'pretrained'
it = str(it)
clusterer_path = os.path.join(root, checkpoint_dir, f'clusterer{it}.pkl')
# don't save samples for each iteration for disk space
samples_path = os.path.join(outdir, 'results', 'samples.npz')
targets = []
if args.inception:
targets = targets + [is_results]
if args.fid:
targets = targets + [fid_results]
if args.modes:
targets = targets + [kl_results, nmodes_results]
if args.fsd:
targets = targets + [fsd_results]
if all([it in result for result in targets]):
print('Already generated', it, path)
else:
sampler = SeededSampler(path,
model_path=os.path.join(root, checkpoint_dir, checkpoint),
clusterer_path=clusterer_path,
pretrained=config['pretrained'])
samples = sample(sampler)
dataset_name = get_dataset_from_path(path)
np.savez(samples_path, fake=samples, real=dataset_name)
arguments = f'--samples {samples_path} --it {it} --results_dir {results_dir}'
if args.fid and it not in fid_results:
os.system(f'CUDA_VISIBLE_DEVICES={device} python gan_training/metrics/fid.py {arguments}')
if args.inception and it not in is_results:
os.system(f'CUDA_VISIBLE_DEVICES={device} python gan_training/metrics/tf_is/inception_score.py {arguments}')
if args.modes and (it not in kl_results or it not in nmodes_results):
os.system(f'CUDA_VISIBLE_DEVICES={device} python utils/get_empirical_distribution.py {arguments} --dataset {dataset_name}')
if args.cluster_metrics and it not in cluster_results:
os.system(f'CUDA_VISIBLE_DEVICES={device} python cluster_metrics.py {path} --model_it {it}')
if args.fsd and it not in fsd_results:
gt_path = dataset_to_img[dataset_name]
os.system(f'CUDA_VISIBLE_DEVICES={device} python -m seeing.fsd {gt_path} {samples_path} --it {it} --results_dir {results_dir}')
================================================
FILE: requirements.txt
================================================
pytorch-gpu==1.3.1
tensorflow-gpu==1.14.0
scikit-learn
scikit-image
torchvision
tqdm
pyyaml
cloudpickle
ipython
opencv
================================================
FILE: seeded_sampler.py
================================================
''' Samples from a (class-conditional) GAN, so that the samples can be reproduced '''
import os
import pickle
import random
import copy
import torch
from torch import nn
from gan_training.checkpoints import CheckpointIO
from gan_training.config import (load_config, build_models)
from seeing.yz_dataset import YZDataset
def get_most_recent(models):
model_numbers = [
int(model.split("model.pt")[0]) if model != "model.pt" else 0
for model in models
]
return str(max(model_numbers)) + "model.pt"
class SeededSampler():
def __init__(
self,
config_name, # name of experiment's config file
model_path="", # path to the model. empty string infers the most recent checkpoint
clusterer_path="", # path to the clusterer, ignored if gan type doesn't require a clusterer
pretrained={}, # urls to the pretrained models
rootdir='./',
device='cuda:0'):
self.config = load_config(os.path.join(rootdir, config_name), 'configs/default.yaml')
self.model_path = model_path
self.clusterer_path = clusterer_path
self.rootdir = rootdir
self.nlabels = self.config['generator']['nlabels']
self.device = device
self.pretrained = pretrained
self.generator = self.get_generator()
self.generator.eval()
self.yz_dist = self.get_yz_dist()
def sample(self, nimgs):
'''
samples an image using the generator, with z drawn from isotropic gaussian, and y drawn from self.yz_dist.
For baseline methods, y doesn't matter because y is ignored in the input
yz_dist is the empirical label distribution for the clustered gans.
returns the image, and the integer seed used to generate it. generated sample is in [-1, 1]
'''
self.generator.eval()
with torch.no_grad():
seeds = [random.randint(0, 1e8) for _ in range(nimgs)]
z, y = self.yz_dist(seeds)
return self.generator(z, y), seeds
def conditional_sample(self, yi, seed=None):
''' returns a generated sample, which is in [-1, 1], seed is an int'''
self.generator.eval()
with torch.no_grad():
if seed is None:
seed = [random.randint(0, 1e8)]
else:
seed = [seed]
z, _ = self.yz_dist(seed)
y = torch.LongTensor([yi]).to(self.device)
return self.generator(z, y)
def sample_with_seed(self, seeds):
''' returns a generated sample, which is in [-1, 1] '''
self.generator.eval()
z, y = self.yz_dist(seeds)
return self.generator(z, y)
def get_zy(self, seeds):
'''returns the batch of z, y corresponding to the seeds'''
return self.yz_dist(seeds)
def sample_with_zy(self, z, y):
''' returns a generated sample given z and y, which is in [-1, 1].'''
self.generator.eval()
return self.generator(z, y)
def get_generator(self):
''' loads a generator according to self.model_path '''
exp_out_dir = os.path.join(self.rootdir, self.config['training']['out_dir'])
# infer checkpoint if neeeded
checkpoint_dir = os.path.join(exp_out_dir, 'chkpts') if self.model_path == "" or 'model' in self.pretrained else "./"
model_name = get_most_recent(os.listdir(checkpoint_dir)) if self.model_path == "" else self.model_path
checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)
self.checkpoint_io = checkpoint_io
generator, _ = build_models(self.config)
generator = generator.to(self.device)
generator = nn.DataParallel(generator)
if self.config['training']['take_model_average']:
generator_test = copy.deepcopy(generator)
checkpoint_io.register_modules(generator_test=generator_test)
else:
generator_test = generator
checkpoint_io.register_modules(generator=generator)
try:
it = checkpoint_io.load(model_name, pretrained=self.pretrained)
assert (it != -1)
except Exception as e:
# try again without data parallel
print(e)
checkpoint_io.register_modules(generator=generator.module)
checkpoint_io.register_modules(generator_test=generator_test.module)
it = checkpoint_io.load(model_name, pretrained=self.pretrained)
assert (it != -1)
print('Loaded iteration:', it['it'])
return generator_test
def get_yz_dist(self):
'''loads the z and y dists used to sample from the generator.'''
if self.config['clusterer']['name'] != 'supervised':
if 'clusterer' in self.pretrained:
clusterer = self.checkpoint_io.load_clusterer('pretrained', load_samples=False, pretrained=self.pretrained)
elif os.path.exists(self.clusterer_path):
with open(self.clusterer_path, 'rb') as f:
clusterer = pickle.load(f)
if isinstance(clusterer.discriminator, nn.DataParallel):
clusterer.discriminator = clusterer.discriminator.module
if clusterer.kmeans is not None:
# use clusterer empirical distribution as sampling
print('Using k-means empirical distribution')
distribution = clusterer.get_label_distribution()
probs = [f / sum(distribution) for f in distribution]
else:
# otherwise, use a uniform distribution. this is not desired, unless it's a random label or unconditional GAN
print("Sampling with uniform distribution over", clusterer.k, "labels")
probs = [1. / clusterer.k for _ in range(clusterer.k)]
else:
# if it's supervised, then sample uniformly over all classes.
# this might not be the right thing to do, since datasets are usually imbalanced.
print("Sampling with uniform distribution over", self.nlabels,
"labels")
probs = [1. / self.nlabels for _ in range(self.nlabels)]
return YZDataset(zdim=self.config['z_dist']['dim'],
nlabels=len(probs),
distribution=probs,
device=self.device)
================================================
FILE: seeing/frechet_distance.py
================================================
#!/usr/bin/env python3
"""Calculates the Frechet Distance (FD) between two samples.
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
of Tensorflow
Copyright 2018 Institute of Bioinformatics, JKU Linz
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
import torch
from scipy import linalg
def sample_frechet_distance(sample1, sample2, eps=1e-6,
return_components=False):
'''
Both samples should be numpy arrays.
Returns the Frechet distance.
'''
(mu1, sigma1), (mu2, sigma2) = [calculate_activation_statistics(s)
for s in [sample1, sample2]]
return calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=eps,
return_components=return_components)
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6,
return_components=False):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations, precalculated on an
representative data set.
-- sigma1: The covariance matrix over activations for generated samples.
-- sigma2: The covariance matrix over activations, precalculated on an
representative data set.
Returns:
-- : The Frechet Distance.
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, \
'Training and test mean vectors have different lengths'
assert sigma1.shape == sigma2.shape, \
'Training and test covariances have different dimensions'
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
meandiff = diff.dot(diff)
covdiff = np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
if return_components:
return (meandiff + covdiff, meandiff, covdiff)
else:
return meandiff + covdiff
def calculate_activation_statistics(act):
"""Calculation of the statistics used by the FID.
Params:
-- files : List of image files paths
-- model : Instance of inception model
-- batch_size : The images numpy array is split into batches with
batch size batch_size. A reasonable batch size
depends on the hardware.
-- dims : Dimensionality of features returned by Inception
-- cuda : If set to True, use GPU
-- verbose : If set to True and parameter out_step is given, the
number of calculated batches is reported.
Returns:
-- mu : The mean over samples of the activations of the pool_3 layer of
the inception model.
-- sigma : The covariance matrix of the activations of the pool_3 layer of
the inception model.
"""
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
================================================
FILE: seeing/fsd.py
================================================
import torch, argparse, sys, os, numpy
from .sampler import FixedRandomSubsetSampler, FixedSubsetSampler
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from torchvision import transforms, utils
from . import pbar, zdataset, segmenter, frechet_distance, parallelfolder
NUM_OBJECTS = 336
def main():
parser = argparse.ArgumentParser(description='Net dissect utility')
parser.add_argument('true_dir')
parser.add_argument('gen_dir')
parser.add_argument('--size', type=int, default=10000)
parser.add_argument('--cachedir', default='results/fsd/cache')
parser.add_argument('--histout', default=None)
parser.add_argument('--maxscale', type=float, default=50)
parser.add_argument('--labelcount', type=int, default=30)
parser.add_argument('--dpi', type=float, default=100)
parser.add_argument('--it', type=str, default="-1")
parser.add_argument('--results_dir', default=None, help='path to results_dir')
args = parser.parse_args()
if len(sys.argv) == 1:
parser.print_usage(sys.stderr)
sys.exit(1)
args = parser.parse_args()
print(args.true_dir, args.gen_dir)
true_dir, gen_dir = args.true_dir, args.gen_dir
seed1, seed2 = [1, 1 if true_dir != gen_dir else 2]
true_tally, gen_tally = [
cached_tally_directory(d,
size=args.size,
cachedir=args.cachedir,
seed=seed)
for d, seed in [(true_dir, seed1), (gen_dir, seed2)]
]
fsd, meandiff, covdiff = frechet_distance.sample_frechet_distance(
true_tally * 100, gen_tally * 100, return_components=True)
print('fsd: %f; meandiff: %f; covdiff: %f' % (fsd, meandiff, covdiff))
if args.histout is not None:
diff_figure(true_tally * 100,
gen_tally * 100,
labelcount=args.labelcount,
maxscale=args.maxscale,
dpi=args.dpi).savefig(args.histout)
if args.results_dir is not None:
import json
it = args.it
results_dir = args.results_dir
with open(os.path.join(args.results_dir, 'fsd_results.json')) as f:
fsd_results = json.load(f)
fsd_results[it] = (fsd, meandiff, covdiff)
with open(os.path.join(args.results_dir, 'fsd_results.json'), 'w') as f:
f.write(json.dumps(fsd_results))
diff_figure(true_tally * 100,
gen_tally * 100,
labelcount=args.labelcount,
maxscale=args.maxscale,
dpi=args.dpi).savefig(os.path.join(args.results_dir, f'fsd_{it}.png'))
def cached_tally_directory(directory, size=10000, cachedir=None, seed=1):
filename = '%s_segtally_%d.npy' % (directory, size)
if seed != 1:
filename = '%d_%s' % (seed, filename)
if cachedir is not None:
filename = os.path.join(cachedir, filename.replace('/', '_'))
#load only if gt stats, or image directory
if os.path.isfile(filename) and (not directory.endswith('.npz') or 'gt' in directory):
return numpy.load(filename)
os.makedirs(cachedir, exist_ok=True)
result = tally_directory(directory, size, seed=seed)
numpy.save(filename, result)
return result
def tally_directory(directory, size=10000, seed=1):
if directory.endswith('.npz'):
with np.load(directory) as f:
images = torch.from_numpy(f['fake'])
images = images.permute(0, 3, 1, 2) #BHWC -> BCHW
images = (images/127.5) - 1 #normalize in [-1, 1]
images = torch.nn.functional.interpolate(images, size=(256, 256))
print(images.shape, images.max(), images.min())
dataset = TensorDataset(images)
else:
dataset = parallelfolder.ParallelImageFolders(
[directory],
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
loader = DataLoader(dataset,
sampler=FixedRandomSubsetSampler(dataset,
end=size,
seed=seed),
batch_size=10,
pin_memory=True)
upp = segmenter.UnifiedParsingSegmenter()
labelnames, catnames = upp.get_label_and_category_names()
result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
batch_result = torch.zeros(loader.batch_size,
NUM_OBJECTS,
dtype=torch.float).cuda()
with torch.no_grad():
batch_index = 0
for [batch] in pbar(loader):
seg_result = upp.segment_batch(batch.cuda())
for i in range(len(batch)):
batch_result[i] = (seg_result[i, 0].view(-1).bincount(
minlength=NUM_OBJECTS).float() /
(seg_result.shape[2] * seg_result.shape[3]))
result[batch_index:batch_index +
len(batch)] = (batch_result.cpu().numpy())
batch_index += len(batch)
return result
def tally_dataset_objects(dataset, size=10000):
loader = DataLoader(dataset,
sampler=FixedRandomSubsetSampler(dataset, end=size),
batch_size=10,
pin_memory=True)
upp = segmenter.UnifiedParsingSegmenter()
labelnames, catnames = upp.get_label_and_category_names()
result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
batch_result = torch.zeros(loader.batch_size,
NUM_OBJECTS,
dtype=torch.float).cuda()
with torch.no_grad():
batch_index = 0
for [batch] in pbar(loader):
seg_result = upp.segment_batch(batch.cuda())
for i in range(len(batch)):
batch_result[i] = (seg_result[i, 0].view(-1).bincount(
minlength=NUM_OBJECTS).float() /
(seg_result.shape[2] * seg_result.shape[3]))
result[batch_index:batch_index +
len(batch)] = (batch_result.cpu().numpy())
batch_index += len(batch)
return result
def tally_generated_objects(model, size=10000):
zds = zdataset.z_dataset_for_model(model, size)
loader = DataLoader(zds, batch_size=10, pin_memory=True)
upp = segmenter.UnifiedParsingSegmenter()
labelnames, catnames = upp.get_label_and_category_names()
result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)
batch_result = torch.zeros(loader.batch_size,
NUM_OBJECTS,
dtype=torch.float).cuda()
with torch.no_grad():
batch_index = 0
for [zbatch] in pbar(loader):
img = model(zbatch.cuda())
seg_result = upp.segment_batch(img)
for i in range(len(zbatch)):
batch_result[i] = (seg_result[i, 0].view(-1).bincount(
minlength=NUM_OBJECTS).float() /
(seg_result.shape[2] * seg_result.shape[3]))
result[batch_index:batch_index +
len(zbatch)] = (batch_result.cpu().numpy())
batch_index += len(zbatch)
return result
def diff_figure(ttally,
gtally,
labelcount=30,
labelleft=True,
dpi=100,
maxscale=50.0,
legend=False):
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
tresult, gresult = [t.mean(0) for t in [ttally, gtally]]
upp = segmenter.UnifiedParsingSegmenter()
labelnames, catnames = upp.get_label_and_category_names()
x = []
labels = []
gen_amount = []
change_frac = []
true_amount = []
for label in numpy.argsort(-tresult):
if label == 0 or labelnames[label][1] == 'material':
continue
if tresult[label] == 0:
break
x.append(len(x))
labels.append(labelnames[label][0].split()[0])
true_amount.append(tresult[label].item())
gen_amount.append(gresult[label].item())
change_frac.append(
(float(gresult[label] - tresult[label]) / tresult[label]))
if len(x) >= labelcount:
break
fig = Figure(dpi=dpi, figsize=(1.4 + 5.0 * labelcount / 30, 4.0))
FigureCanvas(fig)
a1, a0 = fig.subplots(2, 1, gridspec_kw={'height_ratios': [1, 2]})
a0.bar(x, change_frac, label='relative delta')
a0.set_xticks(x)
a0.set_xticklabels(labels, rotation='vertical')
if labelleft:
a0.set_ylabel('relative delta\n(gen - train) / train')
a0.set_xlim(-1.0, len(x))
a0.set_ylim([-1, 1.1])
a0.grid(axis='y', antialiased=False, alpha=0.25)
if legend:
a0.legend(loc=2)
prev_high = None
for ix, cf in enumerate(change_frac):
if cf > 1.15:
if prev_high == (ix - 1):
offset = 0.1
else:
offset = 0.0
prev_high = ix
a0.text(ix,
1.15 + offset,
'%.1f' % cf,
horizontalalignment='center',
size=6)
a1.bar(x, true_amount, label='training')
a1.plot(x, gen_amount, linewidth=3, color='red', label='generated')
a1.set_yscale('log')
a1.set_xlim(-1.0, len(x))
a1.set_ylim(maxscale / 5000, maxscale)
from matplotlib.ticker import LogLocator
# a1.yaxis.set_major_locator(LogLocator(subs=(1,)))
# a1.yaxis.set_minor_locator(LogLocator(subs=(1,), numdecs=10))
# a1.yaxis.set_minor_locator(LogLocator(subs=(1,2,3,4,5,6,7,8,9)))
# a1.yaxis.set_minor_locator(yminor_locator)
if labelleft:
a1.set_ylabel('mean area\nlog scale')
if legend:
a1.legend()
a1.set_yticks([1e-2, 1e-1, 1.0, 1e+1])
a1.set_yticks([
a * b for a in [1e-2, 1e-1, 1.0, 1e+1]
for b in range(1, 10) if maxscale / 5000 <= a * b <= maxscale
], True) # minor ticks.
a1.set_xticks([])
fig.tight_layout()
return fig
if __name__ == '__main__':
main()
================================================
FILE: seeing/lightbox.html
================================================
<!DOCTYPE html>
<html>
<!--
+lightbox.html, a page for automatically showing all images in a
directory on an Apache server. Just copy it into the directory.
Works by scraping the default directory HTML at "./" - David Bau.
-->
<head>
<script src="https://cdn.jsdelivr.net/npm/vue@2.5.16/dist/vue.js"
integrity="sha256-CMMTrj5gGwOAXBeFi7kNokqowkzbeL8ydAJy39ewjkQ=" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/lodash@4.17.10/lodash.js"
integrity="sha256-qwbDmNVLiCqkqRBpF46q5bjYH11j5cd+K+Y6D3/ja28=" crossorigin="anonymous"></script>
<script src="https://code.jquery.com/jquery-3.3.1.js" integrity="sha256-2Kok7MbOyxpgUVvAk/HJ2jigOSYS2auK4Pfzbm7uH60="
crossorigin="anonymous"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/lity/2.3.1/lity.js"
integrity="sha256-28JiZvE/RethQIYCwkMdtSMHgI//KoTLeB2tSm10trs=" crossorigin="anonymous"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/lity/2.3.1/lity.css"
integrity="sha256-76wKiAXVBs5Kyj7j0T43nlBCbvR6pqdeeZmXI4ATnY0=" crossorigin="anonymous" />
<style>
h3 {
font-family: sans-serif;
font-size: 18px;
}
.thumb,
.filter {
font-family: sans-serif;
font-size: 12px;
}
.filter {
padding-bottom: 10px;
}
.thumb {
display: inline-block;
margin: 1px;
text-align: center;
}
.thumb img,
.thumb div {
max-width: 150px;
word-break: break-all;
}
</style>
</head>
<body>
<div id="app" v-if="images">
<h3>Images in <a :href="directory">{{ directory }}</a></h3>
<div class="filter">
Filter: <input v-model="pattern" placeholder="regexp">
</div>
<div v-for="r in images" class="thumb" v-if="patternRe.test(r)">
<div>{{ r }}</div>
<a :href="r" data-lity><img :src="r"></a>
</div>
</div>
<!--app-->
</body>
<script>
var theapp = new Vue({
el: '#app',
data: {
directory: window.location.pathname.replace(/[^\/]*$/, ''),
images: null,
pattern: '',
},
created: function () {
var self = this;
$.get('./?' + Math.random(), function (d) {
var imgurls = $.map($(d).find('a'),
x => x.href).filter(
x => x.match(/\.(jpg|jpeg|png|gif|svg)$/i)).map(
x => x.replace(/.*\//, ''));
self.images = imgurls;
}, 'html');
},
computed: {
patternRe: function () {
try {
return RegExp(this.pattern);
} catch (e) {
return /.*/;
}
}
},
})
</script>
</html>
================================================
FILE: seeing/parallelfolder.py
================================================
'''
Variants of pytorch's ImageFolder for loading image datasets with more
information, such as parallel feature channels in separate files,
cached files with lists of filenames, etc.
'''
import os, torch, re, random, numpy, itertools
import torch.utils.data as data
from torchvision.datasets.folder import default_loader as tv_default_loader
from PIL import Image
from collections import OrderedDict
from . import pbar
def grayscale_loader(path):
with open(path, 'rb') as f:
return Image.open(f).convert('L')
class ndarray(numpy.ndarray):
'''
Wrapper to make ndarrays into heap objects so that shared_state can
be attached as an attribute.
'''
pass
def default_loader(filename):
'''
Handles both numpy files and image formats.
'''
if filename.endswith('.npy'):
return numpy.load(filename).view(ndarray)
elif filename.endswith('.npz'):
return numpy.load(filename)
else:
return tv_default_loader(filename)
class ParallelImageFolders(data.Dataset):
"""
A data loader that looks for parallel image filenames, for example
photo1/park/004234.jpg
photo1/park/004236.jpg
photo1/park/004237.jpg
photo2/park/004234.png
photo2/park/004236.png
photo2/park/004237.png
"""
def __init__(self, image_roots,
transform=None,
loader=default_loader,
stacker=None,
classification=False,
intersection=False,
filter_tuples=None,
verbose=None,
size=None,
shuffle=None,
lazy_init=True):
self.image_roots = image_roots
if transform is not None and not hasattr(transform, '__iter__'):
transform = [transform for _ in image_roots]
self.transforms = transform
self.stacker = stacker
self.loader = loader
def do_lazy_init():
self.images, self.classes, self.class_to_idx = (
make_parallel_dataset(image_roots,
classification=classification,
intersection=intersection,
filter_tuples=filter_tuples,
verbose=verbose))
if len(self.images) == 0:
raise RuntimeError("Found 0 images within: %s" % image_roots)
if shuffle is not None:
random.Random(shuffle).shuffle(self.images)
if size is not None:
self.image = self.images[:size]
self._do_lazy_init = None
# Do slow initialization lazily.
if lazy_init:
self._do_lazy_init = do_lazy_init
else:
do_lazy_init()
def __getattr__(self, attr):
if self._do_lazy_init is not None:
self._do_lazy_init()
return getattr(self, attr)
raise AttributeError()
def __getitem__(self, index):
if self._do_lazy_init is not None:
self._do_lazy_init()
paths = self.images[index]
if self.classes is not None:
classidx = paths[-1]
paths = paths[:-1]
sources = [self.loader(path) for path in paths]
# Add a common shared state dict to allow random crops/flips to be
# coordinated.
shared_state = {}
for s in sources:
try:
s.shared_state = shared_state
except:
pass
if self.transforms is not None:
sources = [transform(source) if transform is not None else source
for source, transform
in itertools.zip_longest(sources, self.transforms)]
if self.stacker is not None:
sources = self.stacker(sources)
if self.classes is not None:
sources = (sources, classidx)
else:
if self.classes is not None:
sources.append(classidx)
sources = tuple(sources)
return sources
def __len__(self):
if self._do_lazy_init is not None:
self._do_lazy_init()
return len(self.images)
def is_npy_file(path):
return path.endswith('.npy') or path.endswith('.NPY')
def is_image_file(path):
return None != re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE)
def walk_image_files(rootdir, verbose=None):
indexfile = '%s.txt' % rootdir
if os.path.isfile(indexfile):
basedir = os.path.dirname(rootdir)
with open(indexfile) as f:
result = sorted([os.path.join(basedir, line.strip())
for line in f.readlines()])
return result
result = []
for dirname, _, fnames in sorted(pbar(os.walk(rootdir),
desc='Walking %s' % os.path.basename(rootdir))):
for fname in sorted(fnames):
if is_image_file(fname) or is_npy_file(fname):
result.append(os.path.join(dirname, fname))
return result
def make_parallel_dataset(image_roots, classification=False,
intersection=False, filter_tuples=None, verbose=None):
"""
Returns ([(img1, img2, clsid), (img1, img2, clsid)..],
classes, class_to_idx)
"""
image_roots = [os.path.expanduser(d) for d in image_roots]
image_sets = OrderedDict()
for j, root in enumerate(image_roots):
for path in walk_image_files(root, verbose=verbose):
key = os.path.splitext(os.path.relpath(path, root))[0]
if key not in image_sets:
image_sets[key] = []
if not intersection and len(image_sets[key]) != j:
raise RuntimeError(
'Images not parallel: %s missing from one dir' % (key))
image_sets[key].append(path)
if classification:
classes = sorted(set([os.path.basename(os.path.dirname(k))
for k in image_sets.keys()]))
class_to_idx = dict({k: v for v, k in enumerate(classes)})
for k, v in image_sets.items():
v.append(class_to_idx[os.path.basename(os.path.dirname(k))])
else:
classes, class_to_idx = None, None
tuples = []
for key, value in image_sets.items():
if len(value) != len(image_roots) + (1 if classification else 0):
if intersection:
continue
else:
raise RuntimeError(
'Images not parallel: %s missing from one dir' % (key))
value = tuple(value)
if filter_tuples and not filter_tuples(value):
continue
tuples.append(value)
return tuples, classes, class_to_idx
================================================
FILE: seeing/pbar.py
================================================
'''
Utilities for showing progress bars, controlling default verbosity, etc.
'''
# If the tqdm package is not available, then do not show progress bars;
# just connect print_progress to print.
import sys, types, builtins
try:
from tqdm import tqdm, tqdm_notebook
except:
tqdm = None
default_verbosity = True
next_description = None
python_print = builtins.print
def post(**kwargs):
'''
When within a progress loop, pbar.post(k=str) will display
the given k=str status on the right-hand-side of the progress
status bar. If not within a visible progress bar, does nothing.
'''
innermost = innermost_tqdm()
if innermost is not None:
innermost.set_postfix(**kwargs)
def desc(desc):
'''
When within a progress loop, pbar.desc(str) changes the
left-hand-side description of the loop toe the given description.
'''
innermost = innermost_tqdm()
if innermost is not None:
innermost.set_description(str(desc))
def descnext(desc):
'''
Called before starting a progress loop, pbar.descnext(str)
sets the description text that will be used in the following loop.
'''
global next_description
if not default_verbosity or tqdm is None:
return
next_description = desc
def print(*args):
'''
When within a progress loop, will print above the progress loop.
'''
global next_description
next_description = None
if default_verbosity:
msg = ' '.join(str(s) for s in args)
if tqdm is None:
python_print(msg)
else:
tqdm.write(msg)
def tqdm_terminal(it, *args, **kwargs):
'''
Some settings for tqdm that make it run better in resizable terminals.
'''
return tqdm(it, *args, dynamic_ncols=True, ascii=True,
leave=(innermost_tqdm() is not None), **kwargs)
def in_notebook():
'''
True if running inside a Jupyter notebook.
'''
# From https://stackoverflow.com/a/39662359/265298
try:
shell = get_ipython().__class__.__name__
if shell == 'ZMQInteractiveShell':
return True # Jupyter notebook or qtconsole
elif shell == 'TerminalInteractiveShell':
return False # Terminal running IPython
else:
return False # Other type (?)
except NameError:
return False # Probably standard Python interpreter
def innermost_tqdm():
'''
Returns the innermost active tqdm progress loop on the stack.
'''
if hasattr(tqdm, '_instances') and len(tqdm._instances) > 0:
return max(tqdm._instances, key=lambda x: x.pos)
else:
return None
def reporthook(*args, **kwargs):
'''
For use with urllib.request.urlretrieve.
with pbar.reporthook() as hook:
urllib.request.urlretrieve(url, filename, reporthook=hook)
'''
kwargs2 = dict(unit_scale=True, miniters=1)
kwargs2.update(kwargs)
bar = __call__(None, *args, **kwargs2)
class ReportHook(object):
def __init__(self, t):
self.t = t
def __call__(self, b=1, bsize=1, tsize=None):
if hasattr(self.t, 'total'):
if tsize is not None:
self.t.total = tsize
if hasattr(self.t, 'update'):
self.t.update(b * bsize - self.t.n)
def __enter__(self):
return self
def __exit__(self, *exc):
if hasattr(self.t, '__exit__'):
self.t.__exit__(*exc)
return ReportHook(bar)
def __call__(x, *args, **kwargs):
'''
Invokes a progress function that can wrap iterators to print
progress messages, if verbose is True.
If verbose is False or tqdm is unavailable, then a quiet
non-printing identity function is used.
verbose can also be set to a spefific progress function rather
than True, and that function will be used.
'''
global default_verbosity, next_description
if not default_verbosity or tqdm is None:
return x
if default_verbosity == True:
fn = tqdm_notebook if in_notebook() else tqdm_terminal
else:
fn = default_verbosity
if next_description is not None:
kwargs = dict(kwargs)
kwargs['desc'] = next_description
next_description = None
return fn(x, *args, **kwargs)
class VerboseContextManager():
def __init__(self, v, entered=False):
self.v, self.entered, self.saved = v, False, []
if entered:
self.__enter__()
self.entered = True
def __enter__(self):
global default_verbosity
if self.entered:
self.entered = False
else:
self.saved.append(default_verbosity)
default_verbosity = self.v
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
global default_verbosity
default_verbosity = self.saved.pop()
def __call__(self, v=True):
'''
Calling the context manager makes a new context that is
pre-entered, so it works as both a plain function and as a
factory for a context manager.
'''
new_v = v if self.v else not v
cm = VerboseContextManager(new_v, entered=True)
default_verbosity = new_v
return cm
# Use as either "with pbar.verbose:" or "pbar.verbose(False)", or also
# "with pbar.verbose(False):"
verbose = VerboseContextManager(True)
# Use as either "with @pbar.quiet" or "pbar.quiet(True)". or also
# "with pbar.quiet(True):"
quiet = VerboseContextManager(False)
class CallableModule(types.ModuleType):
def __init__(self):
# or super().__init__(__name__) for Python 3
types.ModuleType.__init__(self, __name__)
self.__dict__.update(sys.modules[__name__].__dict__)
def __call__(self, x, *args, **kwargs):
return __call__(x, *args, **kwargs)
sys.modules[__name__] = CallableModule()
================================================
FILE: seeing/pidfile.py
================================================
'''
Utility for simple distribution of work on multiple processes, by
making sure only one process is working on a job at once.
'''
import os, errno, socket, atexit, time, sys
def exit_if_job_done(directory, redo=False, force=False, verbose=True):
if pidfile_taken(os.path.join(directory, 'lockfile.pid'),
force=force, verbose=verbose):
sys.exit(0)
donefile = os.path.join(directory, 'done.txt')
if os.path.isfile(donefile):
with open(donefile) as f:
msg = f.read()
if redo or force:
if verbose:
print('Removing %s %s' % (donefile, msg))
os.remove(donefile)
else:
if verbose:
print('%s %s' % (donefile, msg))
sys.exit(0)
def mark_job_done(directory):
with open(os.path.join(directory, 'done.txt'), 'w') as f:
f.write('done by %d@%s %s at %s' %
(os.getpid(), socket.gethostname(),
os.getenv('STY', ''),
time.strftime('%c')))
def pidfile_taken(path, verbose=False, force=False):
'''
Usage. To grab an exclusive lock for the remaining duration of the
current process (and exit if another process already has the lock),
do this:
if pidfile_taken('job_423/lockfile.pid', verbose=True):
sys.exit(0)
To do a batch of jobs, just run a script that does them all on
each available machine, sharing a network filesystem. When each
job grabs a lock, then this will automatically distribute the
jobs so that each one is done just once on one machine.
'''
# Try to create the file exclusively and write my pid into it.
try:
os.makedirs(os.path.dirname(path), exist_ok=True)
fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR)
except OSError as e:
if e.errno == errno.EEXIST:
# If we cannot because there was a race, yield the conflicter.
conflicter = 'race'
try:
with open(path, 'r') as lockfile:
conflicter = lockfile.read().strip() or 'empty'
except:
pass
# Force is for manual one-time use, for deleting stale lockfiles.
if force:
if verbose:
print('Removing %s from %s' % (path, conflicter))
os.remove(path)
return pidfile_taken(path, verbose=verbose, force=False)
if verbose:
print('%s held by %s' % (path, conflicter))
return conflicter
else:
# Other problems get an exception.
raise
# Register to delete this file on exit.
lockfile = os.fdopen(fd, 'r+')
atexit.register(delete_pidfile, lockfile, path)
# Write my pid into the open file.
lockfile.write('%d@%s %s\n' % (os.getpid(), socket.gethostname(),
os.getenv('STY', '')))
lockfile.flush()
os.fsync(lockfile)
# Return 'None' to say there was not a conflict.
return None
def delete_pidfile(lockfile, path):
'''
Runs at exit after pidfile_taken succeeds.
'''
if lockfile is not None:
try:
lockfile.close()
except:
pass
try:
os.unlink(path)
except:
pass
================================================
FILE: seeing/sampler.py
================================================
'''
A sampler is just a list of integer listing the indexes of the
inputs in a data set to sample. For reproducibility, the
FixedRandomSubsetSampler uses a seeded prng to produce the same
sequence always. FixedSubsetSampler is just a wrapper for an
explicit list of integers.
coordinate_sample solves another sampling problem: when testing
convolutional outputs, we can reduce data explosing by sampling
random points of the feature map rather than the entire feature map.
coordinate_sample does this in a deterministic way that is also
resolution-independent.
'''
import numpy
import random
from torch.utils.data.sampler import Sampler
class FixedSubsetSampler(Sampler):
"""Represents a fixed sequence of data set indices.
Subsets can be created by specifying a subset of output indexes.
"""
def __init__(self, samples):
self.samples = samples
def __iter__(self):
return iter(self.samples)
def __len__(self):
return len(self.samples)
def __getitem__(self, key):
return self.samples[key]
def subset(self, new_subset):
return FixedSubsetSampler(self.dereference(new_subset))
def dereference(self, indices):
'''
Translate output sample indices (small numbers indexing the sample)
to input sample indices (larger number indexing the original full set)
'''
return [self.samples[i] for i in indices]
class FixedRandomSubsetSampler(FixedSubsetSampler):
"""Samples a fixed number of samples from the dataset, deterministically.
Arguments:
data_source,
sa
gitextract_vn0la05j/ ├── .gitignore ├── 2d_mix/ │ ├── .gitignore │ ├── __init__.py │ ├── config.py │ ├── evaluation.py │ ├── inputs.py │ ├── models/ │ │ ├── __init__.py │ │ └── cluster.py │ ├── train.py │ └── visualizations.py ├── LICENSE ├── README.md ├── cluster_metrics.py ├── clusterers/ │ ├── __init__.py │ ├── base_clusterer.py │ ├── kmeans.py │ ├── online.py │ ├── random_labels.py │ └── selfcondgan.py ├── configs/ │ ├── cifar/ │ │ ├── conditional.yaml │ │ ├── default.yaml │ │ ├── selfcondgan.yaml │ │ └── unconditional.yaml │ ├── default.yaml │ ├── imagenet/ │ │ ├── conditional.yaml │ │ ├── default.yaml │ │ ├── selfcondgan.yaml │ │ └── unconditional.yaml │ ├── places/ │ │ ├── conditional.yaml │ │ ├── default.yaml │ │ ├── selfcondgan.yaml │ │ └── unconditional.yaml │ ├── pretrained/ │ │ ├── imagenet/ │ │ │ ├── conditional.yaml │ │ │ ├── selfcondgan.yaml │ │ │ └── unconditional.yaml │ │ └── places/ │ │ ├── conditional.yaml │ │ ├── selfcondgan.yaml │ │ └── unconditional.yaml │ └── stacked_mnist/ │ ├── conditional.yaml │ ├── default.yaml │ ├── selfcondgan.yaml │ └── unconditional.yaml ├── gan_training/ │ ├── __init__.py │ ├── checkpoints.py │ ├── config.py │ ├── distributions.py │ ├── eval.py │ ├── inputs.py │ ├── logger.py │ ├── metrics/ │ │ ├── __init__.py │ │ ├── clustering_metrics.py │ │ ├── fid.py │ │ ├── inception_score.py │ │ └── tf_is/ │ │ ├── LICENSE │ │ ├── README.md │ │ └── inception_score.py │ ├── models/ │ │ ├── __init__.py │ │ ├── blocks.py │ │ ├── dcgan_deep.py │ │ ├── dcgan_shallow.py │ │ ├── resnet2.py │ │ ├── resnet2s.py │ │ └── resnet3.py │ ├── train.py │ └── utils.py ├── metrics.py ├── requirements.txt ├── seeded_sampler.py ├── seeing/ │ ├── frechet_distance.py │ ├── fsd.py │ ├── lightbox.html │ ├── parallelfolder.py │ ├── pbar.py │ ├── pidfile.py │ ├── sampler.py │ ├── segmenter.py │ ├── upsegmodel/ │ │ ├── __init__.py │ │ ├── models.py │ │ ├── prroi_pool/ │ │ │ ├── .gitignore │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ ├── functional.py │ │ │ ├── prroi_pool.py │ │ │ ├── src/ │ │ │ │ ├── prroi_pooling_gpu.c │ │ │ │ ├── prroi_pooling_gpu.h │ │ │ │ ├── prroi_pooling_gpu_impl.cu │ │ │ │ └── prroi_pooling_gpu_impl.cuh │ │ │ └── test_prroi_pooling2d.py │ │ ├── resnet.py │ │ └── resnext.py │ ├── yz_dataset.py │ └── zdataset.py ├── train.py ├── utils/ │ ├── classifiers/ │ │ ├── __init__.py │ │ ├── cifar.py │ │ ├── imagenet.py │ │ ├── imagenet_class_index.json │ │ ├── places.py │ │ ├── pytorch_playground/ │ │ │ ├── .gitignore │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── cifar/ │ │ │ │ ├── __init__.py │ │ │ │ ├── dataset.py │ │ │ │ ├── model.py │ │ │ │ └── train.py │ │ │ ├── quantize.py │ │ │ ├── requirements.txt │ │ │ ├── roadmap_zh.md │ │ │ ├── setup.py │ │ │ └── utee/ │ │ │ ├── __init__.py │ │ │ ├── misc.py │ │ │ ├── quant.py │ │ │ └── selector.py │ │ └── stacked_mnist.py │ ├── get_empirical_distribution.py │ ├── get_gt_imgs.py │ └── np_to_pt_img.py └── visualize_clusters.py
SYMBOL INDEX (462 symbols across 63 files)
FILE: 2d_mix/config.py
function get_models (line 8) | def get_models(model_type, conditioning, k_value, d_act_dim, device):
function get_optimizers (line 20) | def get_optimizers(generator, discriminator, lr=1e-4, beta1=0.8, beta2=0...
function get_test (line 30) | def get_test(get_data, batch_size, variance, k_value, device):
function get_dataset (line 37) | def get_dataset(get_data, batch_size, npts, variance, k_value):
FILE: 2d_mix/evaluation.py
function warn (line 1) | def warn(*args, **kwargs):
function percent_good_grid (line 11) | def percent_good_grid(x_fake, var=0.0025, nrows=5, ncols=5):
function percent_good_ring (line 24) | def percent_good_ring(x_fake, var=0.0001, n_clusters=8, radius=2.0):
function percent_good_pts (line 35) | def percent_good_pts(x_fake, means, threshold):
FILE: 2d_mix/inputs.py
function map_labels (line 6) | def map_labels(labels):
function get_data_ring (line 10) | def get_data_ring(batch_size, radius=2.0, var=0.0001, n_clusters=8):
function get_data_grid (line 25) | def get_data_grid(batch_size, radius=2.0, var=0.0025, nrows=5, ncols=5):
FILE: 2d_mix/models/cluster.py
class G (line 13) | class G(nn.Module):
method __init__ (line 14) | def __init__(self,
method forward (line 41) | def forward(self, z, y=None):
class D (line 50) | class D(nn.Module):
class Maxout (line 51) | class Maxout(nn.Module):
method __init__ (line 53) | def __init__(self, d_in, d_out, pool_size=5):
method forward (line 58) | def forward(self, inputs):
method max (line 67) | def max(self, out, dim=5):
method __init__ (line 70) | def __init__(self, conditioning, k_value, act_dim=200, x_dim=2):
method forward (line 83) | def forward(self, x, y=None, get_features=False):
FILE: 2d_mix/train.py
function main (line 66) | def main(outdir):
FILE: 2d_mix/visualizations.py
function visualize_generated (line 40) | def visualize_generated(fake, real, y, it, outdir):
function visualize_clusters (line 63) | def visualize_clusters(x, y, it, outdir):
FILE: cluster_metrics.py
function main (line 28) | def main():
FILE: clusterers/base_clusterer.py
class BaseClusterer (line 6) | class BaseClusterer():
method __init__ (line 7) | def __init__(self,
method get_labels (line 23) | def get_labels(self, x, y):
method recluster (line 26) | def recluster(self, discriminator, **kwargs):
method get_features (line 29) | def get_features(self, x):
method get_cluster_batch_features (line 33) | def get_cluster_batch_features(self):
method get_discriminator_output (line 47) | def get_discriminator_output(self, x):
method get_label_distribution (line 53) | def get_label_distribution(self, x=None):
method sample_y (line 61) | def sample_y(self, batch_size):
method print_label_distribution (line 69) | def print_label_distribution(self, x=None):
FILE: clusterers/kmeans.py
class Clusterer (line 7) | class Clusterer(base_clusterer.BaseClusterer):
method __init__ (line 8) | def __init__(self, **kwargs):
method kmeans_fit_predict (line 12) | def kmeans_fit_predict(self, features, init='k-means++', n_init=10):
method get_labels (line 19) | def get_labels(self, x, y):
FILE: clusterers/online.py
class Clusterer (line 9) | class Clusterer(kmeans.Clusterer):
method __init__ (line 10) | def __init__(self, **kwargs):
method get_initialization (line 14) | def get_initialization(self, features, labels):
method recluster (line 32) | def recluster(self, discriminator, x_batch=None, **kwargs):
FILE: clusterers/random_labels.py
class Clusterer (line 5) | class Clusterer(base_clusterer.BaseClusterer):
method __init__ (line 6) | def __init__(self, **kwargs):
method get_labels (line 9) | def get_labels(self, x, y):
FILE: clusterers/selfcondgan.py
class Clusterer (line 10) | class Clusterer(kmeans.Clusterer):
method __init__ (line 11) | def __init__(self, initialization=True, matching=True, **kwargs):
method get_initialization (line 17) | def get_initialization(self, features, labels):
method fit_means (line 35) | def fit_means(self):
method recluster (line 57) | def recluster(self, discriminator, **kwargs):
method hungarian_match (line 61) | def hungarian_match(self, flat_preds, flat_targets, preds_k, targets_k):
FILE: gan_training/checkpoints.py
class CheckpointIO (line 8) | class CheckpointIO(object):
method __init__ (line 17) | def __init__(self, checkpoint_dir='./chkpts', **kwargs):
method register_modules (line 23) | def register_modules(self, **kwargs):
method save (line 28) | def save(self, filename, **kwargs):
method load (line 42) | def load(self, filename, pretrained={}):
method load_file (line 55) | def load_file(self, filename):
method load_url (line 74) | def load_url(self, url):
method parse_state_dict (line 85) | def parse_state_dict(self, state_dict):
method load_clusterer (line 102) | def load_clusterer(self, it, load_samples, pretrained={}):
method load_models (line 129) | def load_models(self, it, pretrained={}, load_samples=False):
method save_clusterer (line 153) | def save_clusterer(self, clusterer, it):
function is_url (line 161) | def is_url(url):
FILE: gan_training/config.py
function load_config (line 10) | def load_config(path, default_path):
function update_recursive (line 40) | def update_recursive(dict1, dict2):
function get_clusterer (line 59) | def get_clusterer(config):
function build_models (line 63) | def build_models(config):
function build_optimizers (line 83) | def build_optimizers(generator, discriminator, config):
function get_parameter_groups (line 111) | def get_parameter_groups(parameters, gradient_scales, base_lr):
FILE: gan_training/distributions.py
function get_zdist (line 5) | def get_zdist(dist_name, dim, device=None):
function get_ydist (line 24) | def get_ydist(nlabels, device=None):
function interpolate_sphere (line 34) | def interpolate_sphere(z1, z2, t):
FILE: gan_training/eval.py
class Evaluator (line 7) | class Evaluator(object):
method __init__ (line 8) | def __init__(self,
method sample_z (line 26) | def sample_z(self, batch_size):
method get_y (line 29) | def get_y(self, x, y):
method get_fake_real_samples (line 32) | def get_fake_real_samples(self, N):
method compute_inception_score (line 54) | def compute_inception_score(self):
method create_samples (line 64) | def create_samples(self, z, y=None):
FILE: gan_training/inputs.py
function get_dataset (line 15) | def get_dataset(name,
class CachedImageFolder (line 75) | class CachedImageFolder(data.Dataset):
method __init__ (line 84) | def __init__(self, root, transform=None, loader=default_loader):
method __getitem__ (line 95) | def __getitem__(self, index):
method __len__ (line 102) | def __len__(self):
class StackedMNIST (line 105) | class StackedMNIST(data.Dataset):
method __init__ (line 106) | def __init__(self, data_dir, transform, batch_size=100000):
method __getitem__ (line 130) | def __getitem__(self, index):
method __len__ (line 137) | def __len__(self):
function is_npy_file (line 141) | def is_npy_file(path):
function walk_image_files (line 145) | def walk_image_files(rootdir):
function find_classes (line 181) | def find_classes(dir):
function make_class_dataset (line 190) | def make_class_dataset(source_root, class_to_idx):
function npy_loader (line 202) | def npy_loader(path):
FILE: gan_training/logger.py
class Logger (line 7) | class Logger(object):
method __init__ (line 8) | def __init__(self,
method setup_monitoring (line 29) | def setup_monitoring(self, monitoring, monitoring_dir=None):
method add (line 45) | def add(self, category, k, v, it):
method add_imgs (line 60) | def add_imgs(self, imgs, class_name, it):
method get_last (line 73) | def get_last(self, category, k, default=0.):
method save_stats (line 81) | def save_stats(self, filename):
method load_stats (line 86) | def load_stats(self, filename):
FILE: gan_training/metrics/clustering_metrics.py
function warn (line 1) | def warn(*args, **kwargs):
function nmi (line 14) | def nmi(inferred, gt):
function acc (line 18) | def acc(inferred, gt):
function purity_score (line 30) | def purity_score(y_true, y_pred):
function ari (line 36) | def ari(inferred, gt):
function homogeneity (line 40) | def homogeneity(inferred, gt):
FILE: gan_training/metrics/fid.py
function check_or_download_inception (line 13) | def check_or_download_inception(inception_path):
function create_inception_graph (line 31) | def create_inception_graph(pth):
function calculate_activation_statistics (line 40) | def calculate_activation_statistics(images,
function _get_inception_layer (line 67) | def _get_inception_layer(sess):
function get_activations (line 90) | def get_activations(images, sess, batch_size=200, verbose=False):
function calculate_frechet_distance (line 137) | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
function compute_fid_from_npz (line 190) | def compute_fid_from_npz(path):
function compute_fid_from_imgs (line 231) | def compute_fid_from_imgs(fake_imgs, real_imgs):
function compute_stats (line 244) | def compute_stats(exp_path):
FILE: gan_training/metrics/inception_score.py
function inception_score (line 12) | def inception_score(imgs, device=None, batch_size=32, resize=False, spli...
FILE: gan_training/metrics/tf_is/inception_score.py
function inception_logits (line 33) | def inception_logits(images = inception_images, num_splits = 1):
function get_inception_probs (line 57) | def get_inception_probs(inps):
function preds2score (line 66) | def preds2score(preds, splits=10):
function get_inception_score (line 75) | def get_inception_score(images, splits=10):
function compute_is_from_npz (line 87) | def compute_is_from_npz(path):
FILE: gan_training/models/blocks.py
class ResnetBlock (line 7) | class ResnetBlock(nn.Module):
method __init__ (line 8) | def __init__(self,
method forward (line 43) | def forward(self, x, y):
method _shortcut (line 51) | def _shortcut(self, x):
function actvn (line 59) | def actvn(x):
class LatentEmbeddingConcat (line 64) | class LatentEmbeddingConcat(nn.Module):
method __init__ (line 67) | def __init__(self, nlabels, embed_dim):
method forward (line 71) | def forward(self, z, y):
class NormalizeLinear (line 79) | class NormalizeLinear(nn.Module):
method __init__ (line 80) | def __init__(self, act_dim, k_value):
method normalize (line 84) | def normalize(self):
method forward (line 87) | def forward(self, x):
class Identity (line 92) | class Identity(nn.Module):
method __init__ (line 93) | def __init__(self, *args, **kwargs):
method forward (line 96) | def forward(self, inp, *args, **kwargs):
class LinearConditionalMaskLogits (line 100) | class LinearConditionalMaskLogits(nn.Module):
method __init__ (line 103) | def __init__(self, nc, nlabels):
method forward (line 107) | def forward(self, inp, y=None, take_best=False, get_features=False):
class ProjectionDiscriminatorLogits (line 123) | class ProjectionDiscriminatorLogits(nn.Module):
method __init__ (line 126) | def __init__(self, nc, nlabels):
method forward (line 132) | def forward(self, x, y, take_best=False):
class LinearUnconditionalLogits (line 152) | class LinearUnconditionalLogits(nn.Module):
method __init__ (line 155) | def __init__(self, nc):
method forward (line 159) | def forward(self, inp, y, take_best=False):
class Reshape (line 166) | class Reshape(nn.Module):
method __init__ (line 167) | def __init__(self, *shape):
method forward (line 171) | def forward(self, x):
class ConditionalBatchNorm2d (line 176) | class ConditionalBatchNorm2d(nn.Module):
method __init__ (line 179) | def __init__(self, num_features, num_classes):
method forward (line 189) | def forward(self, x, y):
class BatchNorm2d (line 197) | class BatchNorm2d(nn.Module):
method __init__ (line 200) | def __init__(self, nc, nchannels, **kwargs):
method forward (line 204) | def forward(self, x, y):
FILE: gan_training/models/dcgan_deep.py
class Generator (line 9) | class Generator(nn.Module):
method __init__ (line 10) | def __init__(self,
method forward (line 47) | def forward(self, input, y):
class Discriminator (line 59) | class Discriminator(nn.Module):
method __init__ (line 60) | def __init__(self,
method stack (line 97) | def stack(self, x):
method forward (line 109) | def forward(self, input, y=None, get_features=False):
FILE: gan_training/models/dcgan_shallow.py
class Generator (line 9) | class Generator(nn.Module):
method __init__ (line 10) | def __init__(self,
method forward (line 47) | def forward(self, input, y):
class Discriminator (line 60) | class Discriminator(nn.Module):
method __init__ (line 61) | def __init__(self,
method stack (line 100) | def stack(self, x):
method forward (line 111) | def forward(self, input, y=None, get_features=False):
FILE: gan_training/models/resnet2.py
class Generator (line 13) | class Generator(nn.Module):
method __init__ (line 14) | def __init__(self,
method forward (line 63) | def forward(self, z, y):
class Discriminator (line 100) | class Discriminator(nn.Module):
method __init__ (line 101) | def __init__(self,
method forward (line 147) | def forward(self, x, y=None, get_features=False):
function actvn (line 184) | def actvn(x):
FILE: gan_training/models/resnet2s.py
class Reshape (line 10) | class Reshape(nn.Module):
method __init__ (line 11) | def __init__(self, *shape):
method forward (line 15) | def forward(self, x):
class Generator (line 20) | class Generator(nn.Module):
method __init__ (line 27) | def __init__(self,
method forward (line 71) | def forward(self, z, y=None):
method load_v2_state_dict (line 79) | def load_v2_state_dict(self, state_dict):
class ConditionGen (line 93) | class ConditionGen(nn.Module):
method __init__ (line 94) | def __init__(self, z_dim, nlabels, embed_size=256):
method forward (line 102) | def forward(self, z, y):
function convert_from_resnet2_generator (line 113) | def convert_from_resnet2_generator(gen):
class ResnetBlock (line 133) | class ResnetBlock(nn.Module):
method __init__ (line 134) | def __init__(self, fin, fout, fhidden=None, is_bias=True):
method forward (line 166) | def forward(self, x):
method _shortcut (line 174) | def _shortcut(self, x):
function actvn (line 182) | def actvn(x):
FILE: gan_training/models/resnet3.py
class Generator (line 9) | class Generator(nn.Module):
method __init__ (line 15) | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64,
method forward (line 55) | def forward(self, z, y=None):
method load_v2_state_dict (line 63) | def load_v2_state_dict(self, state_dict):
class Reshape (line 75) | class Reshape(nn.Module):
method __init__ (line 76) | def __init__(self, *shape):
method forward (line 79) | def forward(self, x):
class ConditionGen (line 83) | class ConditionGen(nn.Module):
method __init__ (line 84) | def __init__(self, z_dim, nlabels, embed_size=256):
method forward (line 92) | def forward(self, z, y):
function convert_from_resnet2_generator (line 102) | def convert_from_resnet2_generator(gen):
class ResnetBlock (line 121) | class ResnetBlock(nn.Module):
method __init__ (line 122) | def __init__(self, fin, fout, fhidden=None, is_bias=True):
method forward (line 143) | def forward(self, x):
method _shortcut (line 151) | def _shortcut(self, x):
function actvn (line 159) | def actvn(x):
FILE: gan_training/train.py
class Trainer (line 10) | class Trainer(object):
method __init__ (line 11) | def __init__(self,
method generator_trainstep (line 30) | def generator_trainstep(self, y, z):
method discriminator_trainstep (line 48) | def discriminator_trainstep(self, x_real, y, z):
method compute_loss (line 99) | def compute_loss(self, d_out, target):
method wgan_gp_reg (line 111) | def wgan_gp_reg(self, x_real, x_fake, y, center=1.):
function toggle_grad (line 125) | def toggle_grad(model, requires_grad):
function compute_grad2 (line 130) | def compute_grad2(d_out, x_in):
function update_average (line 143) | def update_average(model_tgt, model_src, beta):
FILE: gan_training/utils.py
function save_images (line 9) | def save_images(imgs, outfile, nrow=8):
function get_nsamples (line 14) | def get_nsamples(data_loader, N):
function update_average (line 29) | def update_average(model_tgt, model_src, beta):
function get_most_recent (line 38) | def get_most_recent(d, ext):
FILE: metrics.py
function load_results (line 37) | def load_results(results_dir):
function get_dataset_from_path (line 49) | def get_dataset_from_path(path):
function pt_to_np (line 56) | def pt_to_np(imgs):
function sample (line 61) | def sample(sampler):
FILE: seeded_sampler.py
function get_most_recent (line 16) | def get_most_recent(models):
class SeededSampler (line 24) | class SeededSampler():
method __init__ (line 25) | def __init__(
method sample (line 45) | def sample(self, nimgs):
method conditional_sample (line 59) | def conditional_sample(self, yi, seed=None):
method sample_with_seed (line 71) | def sample_with_seed(self, seeds):
method get_zy (line 77) | def get_zy(self, seeds):
method sample_with_zy (line 81) | def sample_with_zy(self, z, y):
method get_generator (line 86) | def get_generator(self):
method get_yz_dist (line 123) | def get_yz_dist(self):
FILE: seeing/frechet_distance.py
function sample_frechet_distance (line 25) | def sample_frechet_distance(sample1, sample2, eps=1e-6,
function calculate_frechet_distance (line 36) | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6,
function calculate_activation_statistics (line 98) | def calculate_activation_statistics(act):
FILE: seeing/fsd.py
function main (line 11) | def main():
function cached_tally_directory (line 68) | def cached_tally_directory(directory, size=10000, cachedir=None, seed=1):
function tally_directory (line 83) | def tally_directory(directory, size=10000, seed=1):
function tally_dataset_objects (line 127) | def tally_dataset_objects(dataset, size=10000):
function tally_generated_objects (line 152) | def tally_generated_objects(model, size=10000):
function diff_figure (line 176) | def diff_figure(ttally,
FILE: seeing/parallelfolder.py
function grayscale_loader (line 14) | def grayscale_loader(path):
class ndarray (line 18) | class ndarray(numpy.ndarray):
function default_loader (line 25) | def default_loader(filename):
class ParallelImageFolders (line 36) | class ParallelImageFolders(data.Dataset):
method __init__ (line 48) | def __init__(self, image_roots,
method __getattr__ (line 85) | def __getattr__(self, attr):
method __getitem__ (line 91) | def __getitem__(self, index):
method __len__ (line 121) | def __len__(self):
function is_npy_file (line 126) | def is_npy_file(path):
function is_image_file (line 129) | def is_image_file(path):
function walk_image_files (line 132) | def walk_image_files(rootdir, verbose=None):
function make_parallel_dataset (line 148) | def make_parallel_dataset(image_roots, classification=False,
FILE: seeing/pbar.py
function post (line 17) | def post(**kwargs):
function desc (line 27) | def desc(desc):
function descnext (line 36) | def descnext(desc):
function print (line 46) | def print(*args):
function tqdm_terminal (line 59) | def tqdm_terminal(it, *args, **kwargs):
function in_notebook (line 66) | def in_notebook():
function innermost_tqdm (line 82) | def innermost_tqdm():
function reporthook (line 91) | def reporthook(*args, **kwargs):
function __call__ (line 117) | def __call__(x, *args, **kwargs):
class VerboseContextManager (line 141) | class VerboseContextManager():
method __init__ (line 142) | def __init__(self, v, entered=False):
method __enter__ (line 147) | def __enter__(self):
method __exit__ (line 155) | def __exit__(self, exc_type, exc_value, exc_traceback):
method __call__ (line 158) | def __call__(self, v=True):
class CallableModule (line 177) | class CallableModule(types.ModuleType):
method __init__ (line 178) | def __init__(self):
method __call__ (line 182) | def __call__(self, x, *args, **kwargs):
FILE: seeing/pidfile.py
function exit_if_job_done (line 8) | def exit_if_job_done(directory, redo=False, force=False, verbose=True):
function mark_job_done (line 25) | def mark_job_done(directory):
function pidfile_taken (line 32) | def pidfile_taken(path, verbose=False, force=False):
function delete_pidfile (line 83) | def delete_pidfile(lockfile, path):
FILE: seeing/sampler.py
class FixedSubsetSampler (line 19) | class FixedSubsetSampler(Sampler):
method __init__ (line 23) | def __init__(self, samples):
method __iter__ (line 26) | def __iter__(self):
method __len__ (line 29) | def __len__(self):
method __getitem__ (line 32) | def __getitem__(self, key):
method subset (line 35) | def subset(self, new_subset):
method dereference (line 38) | def dereference(self, indices):
class FixedRandomSubsetSampler (line 46) | class FixedRandomSubsetSampler(FixedSubsetSampler):
method __init__ (line 53) | def __init__(self, data_source, start=None, end=None, seed=1):
method class_subset (line 60) | def class_subset(self, class_filter):
function coordinate_sample (line 71) | def coordinate_sample(shape, sample_size, seeds, grid=13, seed=1, flat=F...
function main (line 104) | def main():
function test (line 139) | def test():
FILE: seeing/segmenter.py
class BaseSegmenter (line 9) | class BaseSegmenter:
method get_label_and_category_names (line 10) | def get_label_and_category_names(self):
method segment_batch (line 21) | def segment_batch(self, tensor_images, downsample=1):
class UnifiedParsingSegmenter (line 34) | class UnifiedParsingSegmenter(BaseSegmenter):
method __init__ (line 44) | def __init__(self, segsizes=None):
method get_label_and_category_names (line 92) | def get_label_and_category_names(self, dataset=None):
method raw_seg_prediction (line 114) | def raw_seg_prediction(self, tensor_images, downsample=1):
method segment_batch (line 151) | def segment_batch(self, tensor_images, downsample=1):
function load_unified_parsing_segmentation_model (line 186) | def load_unified_parsing_segmentation_model(segmodel_arch, segvocab, epo...
function ensure_upp_segmenter_downloaded (line 212) | def ensure_upp_segmenter_downloaded(directory):
function test_main (line 226) | def test_main():
FILE: seeing/upsegmodel/models.py
class SegmentationModuleBase (line 12) | class SegmentationModuleBase(nn.Module):
method __init__ (line 13) | def __init__(self):
method pixel_acc (line 17) | def pixel_acc(pred, label, ignore_index=-1):
method part_pixel_acc (line 26) | def part_pixel_acc(pred_part, gt_seg_part, gt_seg_object, object_label...
method part_loss (line 37) | def part_loss(pred_part, gt_seg_part, gt_seg_object, object_label, val...
class SegmentationModule (line 48) | class SegmentationModule(SegmentationModuleBase):
method __init__ (line 49) | def __init__(self, net_enc, net_dec, labeldata, loss_scale=None):
method forward (line 75) | def forward(self, feed_dict, *, seg_size=None):
function conv3x3 (line 136) | def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
function conv3x3_bn_relu (line 142) | def conv3x3_bn_relu(in_planes, out_planes, stride=1):
class ModelBuilder (line 150) | class ModelBuilder:
method __init__ (line 151) | def __init__(self):
method weights_init (line 156) | def weights_init(m):
method build_encoder (line 166) | def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=...
method build_decoder (line 187) | def build_decoder(self, nr_classes,
class Resnet (line 213) | class Resnet(nn.Module):
method __init__ (line 214) | def __init__(self, orig_resnet):
method forward (line 233) | def forward(self, x, return_feature_maps=False):
class UPerNet (line 252) | class UPerNet(nn.Module):
method __init__ (line 253) | def __init__(self, nr_classes, fc_dim=4096,
method forward (line 325) | def forward(self, conv_out, output_switch=None, seg_size=None):
FILE: seeing/upsegmodel/prroi_pool/functional.py
class PrRoIPool2DFunction (line 30) | class PrRoIPool2DFunction(ag.Function):
method forward (line 32) | def forward(ctx, features, rois, pooled_height, pooled_width, spatial_...
method backward (line 55) | def backward(ctx, grad_output):
FILE: seeing/upsegmodel/prroi_pool/prroi_pool.py
class PrRoIPool2D (line 19) | class PrRoIPool2D(nn.Module):
method __init__ (line 20) | def __init__(self, pooled_height, pooled_width, spatial_scale):
method forward (line 27) | def forward(self, features, rois):
FILE: seeing/upsegmodel/prroi_pool/src/prroi_pooling_gpu.c
function output (line 28) | auto output = at::zeros({nr_rois, nr_channels, pooled_height, pooled_wid...
FILE: seeing/upsegmodel/prroi_pool/test_prroi_pooling2d.py
class TestPrRoIPool2D (line 20) | class TestPrRoIPool2D(TorchTestCase):
method test_forward (line 21) | def test_forward(self):
method test_backward_shapeonly (line 37) | def test_backward_shapeonly(self):
FILE: seeing/upsegmodel/resnet.py
function conv3x3 (line 26) | def conv3x3(in_planes, out_planes, stride=1):
class BasicBlock (line 32) | class BasicBlock(nn.Module):
method __init__ (line 35) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 45) | def forward(self, x):
class Bottleneck (line 64) | class Bottleneck(nn.Module):
method __init__ (line 67) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 80) | def forward(self, x):
class ResNet (line 103) | class ResNet(nn.Module):
method __init__ (line 105) | def __init__(self, block, layers, num_classes=1000):
method _make_layer (line 134) | def _make_layer(self, block, planes, blocks, stride=1):
method forward (line 151) | def forward(self, x):
function resnet50 (line 193) | def resnet50(pretrained=False, **kwargs):
function resnet101 (line 205) | def resnet101(pretrained=False, **kwargs):
function load_url (line 227) | def load_url(url, model_dir='./pretrained', map_location=None):
FILE: seeing/upsegmodel/resnext.py
function conv3x3 (line 26) | def conv3x3(in_planes, out_planes, stride=1):
class GroupBottleneck (line 32) | class GroupBottleneck(nn.Module):
method __init__ (line 35) | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=No...
method forward (line 48) | def forward(self, x):
class ResNeXt (line 71) | class ResNeXt(nn.Module):
method __init__ (line 73) | def __init__(self, block, layers, groups=32, num_classes=1000):
method _make_layer (line 102) | def _make_layer(self, block, planes, blocks, stride=1, groups=1):
method forward (line 119) | def forward(self, x):
function resnext101 (line 151) | def resnext101(pretrained=False, **kwargs):
function load_url (line 175) | def load_url(url, model_dir='./pretrained', map_location=None):
FILE: seeing/yz_dataset.py
class YZDataset (line 4) | class YZDataset():
method __init__ (line 5) | def __init__(self, zdim=256, nlabels=1, distribution=[1.], device='cpu'):
method __call__ (line 12) | def __call__(self, seeds):
FILE: seeing/zdataset.py
function z_dataset_for_model (line 4) | def z_dataset_for_model(model, size=100, seed=1):
function z_sample_for_model (line 7) | def z_sample_for_model(model, size=100, seed=1):
function standard_z_sample (line 26) | def standard_z_sample(size, depth, seed=1, device=None):
FILE: train.py
function main (line 37) | def main():
FILE: utils/classifiers/cifar.py
class Classifier (line 6) | class Classifier():
method __init__ (line 7) | def __init__(self):
method get_predictions (line 10) | def get_predictions(self, x):
FILE: utils/classifiers/imagenet.py
class Classifier (line 8) | class Classifier():
method __init__ (line 9) | def __init__(self):
method transform (line 21) | def transform(self, x):
method get_name (line 26) | def get_name(self, class_id):
method get_predictions_and_confidence (line 29) | def get_predictions_and_confidence(self, x):
method get_predictions (line 35) | def get_predictions(self, x):
FILE: utils/classifiers/places.py
class Classifier (line 8) | class Classifier():
method __init__ (line 9) | def __init__(self):
method get_name (line 45) | def get_name(self, id):
method transform (line 48) | def transform(self, x):
method get_predictions_and_confidence (line 53) | def get_predictions_and_confidence(self, x):
method get_predictions (line 59) | def get_predictions(self, x):
FILE: utils/classifiers/pytorch_playground/cifar/dataset.py
function get10 (line 6) | def get10(batch_size, data_root='/tmp/public_dataset/pytorch', train=Tru...
function get100 (line 38) | def get100(batch_size, data_root='/tmp/public_dataset/pytorch', train=Tr...
FILE: utils/classifiers/pytorch_playground/cifar/model.py
class CIFAR (line 13) | class CIFAR(nn.Module):
method __init__ (line 14) | def __init__(self, features, n_channel, num_classes):
method forward (line 24) | def forward(self, x):
function make_layers (line 30) | def make_layers(cfg, batch_norm=False):
function cifar10 (line 47) | def cifar10(n_channel=128):
FILE: utils/classifiers/pytorch_playground/quantize.py
function main (line 8) | def main():
FILE: utils/classifiers/pytorch_playground/utee/misc.py
class Logger (line 11) | class Logger(object):
method __init__ (line 12) | def __init__(self):
method init (line 15) | def init(self, logdir, name='log'):
method info (line 30) | def info(self, str_info):
function ensure_dir (line 36) | def ensure_dir(path, erase=False):
function load_pickle (line 44) | def load_pickle(path):
function dump_pickle (line 52) | def dump_pickle(obj, path):
function auto_select_gpu (line 57) | def auto_select_gpu(mem_bound=500, utility_bound=0, gpus=(0, 1, 2, 3, 4,...
function expand_user (line 94) | def expand_user(path):
function model_snapshot (line 97) | def model_snapshot(model, new_file, old_file=None, verbose=False):
function load_lmdb (line 117) | def load_lmdb(lmdb_file, n_records=None):
function str2img (line 141) | def str2img(str_b):
function img2str (line 144) | def img2str(img):
function md5 (line 147) | def md5(s):
function eval_model (line 152) | def eval_model(model, ds, n_sample=None, ngpu=1, is_imagenet=False):
function load_state_dict (line 201) | def load_state_dict(model, model_urls, model_root):
FILE: utils/classifiers/pytorch_playground/utee/quant.py
function compute_integral_part (line 8) | def compute_integral_part(input, overflow_rate):
function linear_quantize (line 18) | def linear_quantize(input, sf, bits):
function log_minmax_quantize (line 31) | def log_minmax_quantize(input, bits):
function log_linear_quantize (line 42) | def log_linear_quantize(input, sf, bits):
function min_max_quantize (line 53) | def min_max_quantize(input, bits):
function tanh_quantize (line 71) | def tanh_quantize(input, bits):
class LinearQuant (line 85) | class LinearQuant(nn.Module):
method __init__ (line 86) | def __init__(self, name, bits, sf=None, overflow_rate=0.0, counter=10):
method counter (line 96) | def counter(self):
method forward (line 99) | def forward(self, input):
method __repr__ (line 109) | def __repr__(self):
class LogQuant (line 113) | class LogQuant(nn.Module):
method __init__ (line 114) | def __init__(self, name, bits, sf=None, overflow_rate=0.0, counter=10):
method counter (line 124) | def counter(self):
method forward (line 127) | def forward(self, input):
method __repr__ (line 138) | def __repr__(self):
class NormalQuant (line 142) | class NormalQuant(nn.Module):
method __init__ (line 143) | def __init__(self, name, bits, quant_func):
method counter (line 150) | def counter(self):
method forward (line 153) | def forward(self, input):
method __repr__ (line 157) | def __repr__(self):
function duplicate_model_with_quant (line 160) | def duplicate_model_with_quant(model, bits, overflow_rate=0.0, counter=1...
FILE: utils/classifiers/pytorch_playground/utee/selector.py
function mnist (line 18) | def mnist(cuda=True, model_root=None):
function svhn (line 26) | def svhn(cuda=True, model_root=None):
function cifar10 (line 34) | def cifar10(cuda=True, model_root=None):
function cifar100 (line 42) | def cifar100(cuda=True, model_root=None):
function stl10 (line 50) | def stl10(cuda=True, model_root=None):
function alexnet (line 58) | def alexnet(cuda=True, model_root=None):
function vgg16 (line 66) | def vgg16(cuda=True, model_root=None):
function vgg16_bn (line 74) | def vgg16_bn(cuda=True, model_root=None):
function vgg19 (line 82) | def vgg19(cuda=True, model_root=None):
function vgg19_bn (line 90) | def vgg19_bn(cuda=True, model_root=None):
function inception_v3 (line 98) | def inception_v3(cuda=True, model_root=None):
function resnet18 (line 106) | def resnet18(cuda=True, model_root=None):
function resnet34 (line 114) | def resnet34(cuda=True, model_root=None):
function resnet50 (line 122) | def resnet50(cuda=True, model_root=None):
function resnet101 (line 130) | def resnet101(cuda=True, model_root=None):
function resnet152 (line 138) | def resnet152(cuda=True, model_root=None):
function squeezenet_v0 (line 146) | def squeezenet_v0(cuda=True, model_root=None):
function squeezenet_v1 (line 154) | def squeezenet_v1(cuda=True, model_root=None):
function select (line 162) | def select(model_name, **kwargs):
FILE: utils/classifiers/stacked_mnist.py
class Classifier (line 11) | class Classifier():
method __init__ (line 12) | def __init__(self):
method get_predictions (line 22) | def get_predictions(self, x):
function get_mnist_dataloader (line 29) | def get_mnist_dataloader(batch_size=100):
class MNISTClassifier (line 46) | class MNISTClassifier(nn.Module):
method __init__ (line 47) | def __init__(self, input_dims=1024, n_hiddens=[256, 256], n_class=10):
method forward (line 63) | def forward(self, input):
method get_predictions (line 68) | def get_predictions(self, input):
method load (line 72) | def load(self, path):
method train (line 76) | def train(self):
FILE: utils/get_empirical_distribution.py
function get_empirical_distribution (line 12) | def get_empirical_distribution(path_to_samples):
function get_kl (line 34) | def get_kl(fake, nclasses):
FILE: utils/get_gt_imgs.py
function get_images (line 11) | def get_images(root, N):
function pt_to_np (line 29) | def pt_to_np(imgs):
function get_transform (line 34) | def get_transform(size):
function get_gt_samples (line 43) | def get_gt_samples(dataset, nimgs=50000):
FILE: utils/np_to_pt_img.py
function np_to_pt (line 4) | def np_to_pt(x):
FILE: visualize_clusters.py
function main (line 27) | def main():
Condensed preview — 119 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (409K chars).
[
{
"path": ".gitignore",
"chars": 34,
"preview": "*/**/*pyc*\n*/**/.DS_Store\n.vscode\n"
},
{
"path": "2d_mix/.gitignore",
"chars": 28,
"preview": "**.png\n**.pyc\n**.pt\noutput/\n"
},
{
"path": "2d_mix/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "2d_mix/config.py",
"chars": 1852,
"preview": "import torch\n\nfrom models import generator_dict, discriminator_dict\nfrom torch import optim\nimport torch.utils.data as u"
},
{
"path": "2d_mix/evaluation.py",
"chars": 1905,
"preview": "def warn(*args, **kwargs):\n pass\n\n\nimport warnings\nwarnings.warn = warn\n\nimport numpy as np\n\n\ndef percent_good_grid(x"
},
{
"path": "2d_mix/inputs.py",
"chars": 1199,
"preview": "import numpy as np\nimport random\n\nmapping = list(range(25))\n\ndef map_labels(labels):\n return np.array([mapping[label]"
},
{
"path": "2d_mix/models/__init__.py",
"chars": 116,
"preview": "from models import (cluster)\n\ngenerator_dict = {'standard': cluster.G}\ndiscriminator_dict = {'standard': cluster.D}\n"
},
{
"path": "2d_mix/models/cluster.py",
"chars": 3067,
"preview": "import sys\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\n\nsys.p"
},
{
"path": "2d_mix/train.py",
"chars": 7432,
"preview": "import argparse\nimport os\nimport sys\n\nimport torch\nfrom torch import optim\nfrom torch import distributions\nfrom torch im"
},
{
"path": "2d_mix/visualizations.py",
"chars": 1576,
"preview": "import matplotlib\nfrom matplotlib import pyplot\nimport os\n\nCOLORS = [\n 'purple',\n 'wheat',\n 'maroon',\n 'red'"
},
{
"path": "LICENSE",
"chars": 1066,
"preview": "MIT License\n\nCopyright (c) 2020 Steven Liu\n\nPermission is hereby granted, free of charge, to any person obtaining a copy"
},
{
"path": "README.md",
"chars": 5593,
"preview": "# Diverse Image Generation via Self-Conditioned GANs\n\n#### [Project](http://selfcondgan.csail.mit.edu/) | [Paper](http"
},
{
"path": "cluster_metrics.py",
"chars": 3084,
"preview": "import argparse\nimport os\nfrom tqdm import tqdm\n\nimport torch\nimport numpy as np\nfrom torch import nn\n\nfrom gan_training"
},
{
"path": "clusterers/__init__.py",
"chars": 265,
"preview": "from clusterers import (base_clusterer, selfcondgan, random_labels, online)\n\nclusterer_dict = {\n 'supervised': base_c"
},
{
"path": "clusterers/base_clusterer.py",
"chars": 2610,
"preview": "import copy\n\nimport torch\nimport numpy as np\n\nclass BaseClusterer():\n def __init__(self,\n discriminat"
},
{
"path": "clusterers/kmeans.py",
"chars": 935,
"preview": "import torch\nimport numpy as np\nfrom sklearn.cluster import KMeans\n\nfrom clusterers import base_clusterer\n\nclass Cluster"
},
{
"path": "clusterers/online.py",
"chars": 2435,
"preview": "import copy, random\n\nimport torch\nimport numpy as np\n\nfrom clusterers import kmeans\n\n\nclass Clusterer(kmeans.Clusterer):"
},
{
"path": "clusterers/random_labels.py",
"chars": 278,
"preview": "import torch\nfrom clusterers import base_clusterer\n\n\nclass Clusterer(base_clusterer.BaseClusterer):\n def __init__(sel"
},
{
"path": "clusterers/selfcondgan.py",
"chars": 3025,
"preview": "import copy, random\n\nimport torch\nimport numpy as np\nfrom sklearn.utils.linear_assignment_ import linear_assignment\n\nfro"
},
{
"path": "configs/cifar/conditional.yaml",
"chars": 187,
"preview": "generator:\n nlabels: 10\n conditioning: embedding\ndiscriminator:\n nlabels: 10\n conditioning: mask\ninherit_from: confi"
},
{
"path": "configs/cifar/default.yaml",
"chars": 572,
"preview": "data:\n type: cifar10\n train_dir: data/CIFAR\n img_size: 32\n nlabels: 10\ngenerator:\n name: dcgan_deep\n nlabels: 1\n "
},
{
"path": "configs/cifar/selfcondgan.yaml",
"chars": 273,
"preview": "generator:\n nlabels: 100\n conditioning: embedding\ndiscriminator:\n nlabels: 100\n conditioning: mask\nclusterer:\n name"
},
{
"path": "configs/cifar/unconditional.yaml",
"chars": 88,
"preview": "inherit_from: configs/cifar/default.yaml\ntraining:\n out_dir: output/cifar/unconditional"
},
{
"path": "configs/default.yaml",
"chars": 910,
"preview": "data:\n type: lsun\n train_dir: data/LSUN\n deterministic: False\n img_size: 128\n nlabels: 1\ngenerator:\n name: resnet\n"
},
{
"path": "configs/imagenet/conditional.yaml",
"chars": 200,
"preview": "generator:\n nlabels: 1000\n conditioning: embedding\ndiscriminator:\n nlabels: 1000\n conditioning: mask\ninherit_from: c"
},
{
"path": "configs/imagenet/default.yaml",
"chars": 542,
"preview": "data:\n type: image\n train_dir: data/ImageNet/train\n test_dir: data/ImageNet/val\n img_size: 128\n nlabels: 1000\ngener"
},
{
"path": "configs/imagenet/selfcondgan.yaml",
"chars": 311,
"preview": "generator:\n nlabels: 100\n conditioning: embedding\ndiscriminator:\n nlabels: 100\n conditioning: mask\nclusterer:\n name"
},
{
"path": "configs/imagenet/unconditional.yaml",
"chars": 206,
"preview": "generator:\n nlabels: 1\n conditioning: unconditional\ndiscriminator:\n nlabels: 1\n conditioning: unconditional\ninherit_"
},
{
"path": "configs/places/conditional.yaml",
"chars": 192,
"preview": "generator:\n nlabels: 365\n conditioning: embedding\ndiscriminator:\n nlabels: 365\n conditioning: mask\ntraining:\n out_d"
},
{
"path": "configs/places/default.yaml",
"chars": 546,
"preview": "data:\n type: image\n train_dir: data/places365/train\n test_dir: data/places365/val\n img_size: 128\n nlabels: 365\ngene"
},
{
"path": "configs/places/selfcondgan.yaml",
"chars": 307,
"preview": "generator:\n nlabels: 100\n conditioning: embedding\ndiscriminator:\n nlabels: 100\n conditioning: mask\nclusterer:\n name"
},
{
"path": "configs/places/unconditional.yaml",
"chars": 190,
"preview": "generator:\n nlabels: 1\n conditioning: embedding\ndiscriminator:\n nlabels: 1\n conditioning: mask\ninherit_from: configs"
},
{
"path": "configs/pretrained/imagenet/conditional.yaml",
"chars": 301,
"preview": "generator:\n nlabels: 1000\n conditioning: embedding\ndiscriminator:\n nlabels: 1000\n conditioning: mask\ninherit_from: c"
},
{
"path": "configs/pretrained/imagenet/selfcondgan.yaml",
"chars": 489,
"preview": "generator:\n nlabels: 100\n conditioning: embedding\ndiscriminator:\n nlabels: 100\n conditioning: mask\nclusterer:\n name"
},
{
"path": "configs/pretrained/imagenet/unconditional.yaml",
"chars": 301,
"preview": "generator:\n nlabels: 1\n conditioning: unconditional\ndiscriminator:\n nlabels: 1\n conditioning: unconditional\ninherit_"
},
{
"path": "configs/pretrained/places/conditional.yaml",
"chars": 295,
"preview": "generator:\n nlabels: 365\n conditioning: embedding\ndiscriminator:\n nlabels: 365\n conditioning: mask\ntraining:\n out_d"
},
{
"path": "configs/pretrained/places/selfcondgan.yaml",
"chars": 460,
"preview": "generator:\n nlabels: 100\n conditioning: embedding\ndiscriminator:\n nlabels: 100\n conditioning: mask\nclusterer:\n name"
},
{
"path": "configs/pretrained/places/unconditional.yaml",
"chars": 284,
"preview": "generator:\n nlabels: 1\n conditioning: embedding\ndiscriminator:\n nlabels: 1\n conditioning: mask\ninherit_from: configs"
},
{
"path": "configs/stacked_mnist/conditional.yaml",
"chars": 207,
"preview": "generator:\n nlabels: 1000\n conditioning: embedding\ndiscriminator:\n nlabels: 1000\n conditioning: mask\ninherit_from: c"
},
{
"path": "configs/stacked_mnist/default.yaml",
"chars": 607,
"preview": "data:\n type: stacked_mnist\n train_dir: data/MNIST\n img_size: 32\n nlabels: 1000\ngenerator:\n name: dcgan_shallow\n nl"
},
{
"path": "configs/stacked_mnist/selfcondgan.yaml",
"chars": 289,
"preview": "generator:\n nlabels: 100\n conditioning: embedding\ndiscriminator:\n nlabels: 100\n conditioning: mask\nclusterer:\n name"
},
{
"path": "configs/stacked_mnist/unconditional.yaml",
"chars": 104,
"preview": "inherit_from: configs/stacked_mnist/default.yaml\ntraining:\n out_dir: output/stacked_mnist/unconditional"
},
{
"path": "gan_training/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "gan_training/checkpoints.py",
"chars": 5740,
"preview": "import os, pickle\nimport urllib\nimport torch\nimport numpy as np\nfrom torch.utils import model_zoo\n\n\nclass CheckpointIO(o"
},
{
"path": "gan_training/config.py",
"chars": 3646,
"preview": "import yaml\nfrom torch import optim\nfrom os import path\nfrom gan_training.models import generator_dict, discriminator_di"
},
{
"path": "gan_training/distributions.py",
"chars": 1143,
"preview": "import torch\nfrom torch import distributions\n\n\ndef get_zdist(dist_name, dim, device=None):\n # Get distribution\n if"
},
{
"path": "gan_training/eval.py",
"chars": 2785,
"preview": "import numpy as np\nimport torch\nfrom torch.nn import functional as F\n\nfrom gan_training.metrics import inception_score\n\n"
},
{
"path": "gan_training/inputs.py",
"chars": 7127,
"preview": "import torch\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport numpy as np\n\nimp"
},
{
"path": "gan_training/logger.py",
"chars": 3174,
"preview": "import pickle\nimport os\nimport torchvision\nimport copy\n\n\nclass Logger(object):\n def __init__(self,\n l"
},
{
"path": "gan_training/metrics/__init__.py",
"chars": 100,
"preview": "from gan_training.metrics.inception_score import inception_score\n\n__all__ = [\n inception_score\n]\n"
},
{
"path": "gan_training/metrics/clustering_metrics.py",
"chars": 1091,
"preview": "def warn(*args, **kwargs):\n pass\n\n\nimport warnings\nwarnings.warn = warn\n\nfrom sklearn.metrics.cluster import normaliz"
},
{
"path": "gan_training/metrics/fid.py",
"chars": 11958,
"preview": "from __future__ import absolute_import, division, print_function\nimport numpy as np\nimport os\nos.environ['TF_CPP_MIN_LOG"
},
{
"path": "gan_training/metrics/inception_score.py",
"chars": 1927,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport torch.utils.data\n\nfrom torchvision.models."
},
{
"path": "gan_training/metrics/tf_is/LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "gan_training/metrics/tf_is/README.md",
"chars": 1776,
"preview": "Inception Score\n=====================================\n\nA new Tensorflow implementation of the \"Inception Score\" (IS) for"
},
{
"path": "gan_training/metrics/tf_is/inception_score.py",
"chars": 4751,
"preview": "'''\nFrom https://github.com/tsc2017/Inception-Score\nCode derived from https://github.com/openai/improved-gan/blob/master"
},
{
"path": "gan_training/models/__init__.py",
"chars": 367,
"preview": "from gan_training.models import (dcgan_deep, dcgan_shallow, resnet2)\n\ngenerator_dict = {\n 'resnet2': resnet2.Generato"
},
{
"path": "gan_training/models/blocks.py",
"chars": 6478,
"preview": "import torch\nfrom torch import nn\nfrom torch.autograd import Variable\nfrom torch.nn import functional as F\n\n\nclass Resne"
},
{
"path": "gan_training/models/dcgan_deep.py",
"chars": 4647,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport torch.utils.data\nimport torch.utils.data.d"
},
{
"path": "gan_training/models/dcgan_shallow.py",
"chars": 4458,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport torch.utils.data\nimport torch.utils.data.d"
},
{
"path": "gan_training/models/resnet2.py",
"chars": 6329,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.autograd import Variable\nimport torch."
},
{
"path": "gan_training/models/resnet2s.py",
"chars": 6583,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.autograd import Variable\nimport torch."
},
{
"path": "gan_training/models/resnet3.py",
"chars": 5561,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.autograd import Variable\nimport torch."
},
{
"path": "gan_training/train.py",
"chars": 4646,
"preview": "# coding: utf-8\nimport torch\nfrom torch.nn import functional as F\nimport torch.utils.data\nimport torch.utils.data.distri"
},
{
"path": "gan_training/utils.py",
"chars": 1290,
"preview": "import torch\nimport torch.utils.data\nimport torch.utils.data.distributed\nimport torchvision\n\nimport os\n\n\ndef save_images"
},
{
"path": "metrics.py",
"chars": 7240,
"preview": "import argparse\nimport os\nimport json\nfrom tqdm import tqdm\n\nimport numpy as np\nimport torch\n\nfrom gan_training.config i"
},
{
"path": "requirements.txt",
"chars": 119,
"preview": "pytorch-gpu==1.3.1\ntensorflow-gpu==1.14.0\nscikit-learn\nscikit-image\ntorchvision\ntqdm \npyyaml\ncloudpickle\nipython\nopencv"
},
{
"path": "seeded_sampler.py",
"chars": 6400,
"preview": "''' Samples from a (class-conditional) GAN, so that the samples can be reproduced '''\n\nimport os\nimport pickle\nimport ra"
},
{
"path": "seeing/frechet_distance.py",
"chars": 4478,
"preview": "#!/usr/bin/env python3\n\"\"\"Calculates the Frechet Distance (FD) between two samples.\n\nCode apapted from https://github.co"
},
{
"path": "seeing/fsd.py",
"chars": 10423,
"preview": "import torch, argparse, sys, os, numpy\nfrom .sampler import FixedRandomSubsetSampler, FixedSubsetSampler\nfrom torch.util"
},
{
"path": "seeing/lightbox.html",
"chars": 2609,
"preview": "<!DOCTYPE html>\n<html>\n<!--\n +lightbox.html, a page for automatically showing all images in a\n directory on an Apache "
},
{
"path": "seeing/parallelfolder.py",
"chars": 6591,
"preview": "'''\nVariants of pytorch's ImageFolder for loading image datasets with more\ninformation, such as parallel feature channel"
},
{
"path": "seeing/pbar.py",
"chars": 5911,
"preview": "'''\nUtilities for showing progress bars, controlling default verbosity, etc.\n'''\n\n# If the tqdm package is not available"
},
{
"path": "seeing/pidfile.py",
"chars": 3292,
"preview": "'''\nUtility for simple distribution of work on multiple processes, by\nmaking sure only one process is working on a job a"
},
{
"path": "seeing/sampler.py",
"chars": 7072,
"preview": "'''\nA sampler is just a list of integer listing the indexes of the\ninputs in a data set to sample. For reproducibility,"
},
{
"path": "seeing/segmenter.py",
"chars": 11750,
"preview": "# Usage as a simple differentiable segmenter base class\n\nimport os, torch, numpy, json, glob\nimport skimage.morphology\nf"
},
{
"path": "seeing/upsegmodel/__init__.py",
"chars": 53,
"preview": "from .models import ModelBuilder, SegmentationModule\n"
},
{
"path": "seeing/upsegmodel/models.py",
"chars": 17942,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision\nfrom . import resnet, resnext\ntry:"
},
{
"path": "seeing/upsegmodel/prroi_pool/.gitignore",
"chars": 20,
"preview": "*.o\n/_prroi_pooling\n"
},
{
"path": "seeing/upsegmodel/prroi_pool/README.md",
"chars": 3938,
"preview": "# PreciseRoIPooling\nThis repo implements the **Precise RoI Pooling** (PrRoI Pooling), proposed in the paper **Acquisitio"
},
{
"path": "seeing/upsegmodel/prroi_pool/__init__.py",
"chars": 350,
"preview": "#! /usr/bin/env python3\n# -*- coding: utf-8 -*-\n# File : __init__.py\n# Author : Jiayuan Mao, Tete Xiao\n# Email : maoj"
},
{
"path": "seeing/upsegmodel/prroi_pool/build.py",
"chars": 1343,
"preview": "#! /usr/bin/env python3\n# -*- coding: utf-8 -*-\n# File : build.py\n# Author : Jiayuan Mao, Tete Xiao\n# Email : maojiay"
},
{
"path": "seeing/upsegmodel/prroi_pool/functional.py",
"chars": 2510,
"preview": "#! /usr/bin/env python3\n# -*- coding: utf-8 -*-\n# File : functional.py\n# Author : Jiayuan Mao, Tete Xiao\n# Email : ma"
},
{
"path": "seeing/upsegmodel/prroi_pool/prroi_pool.py",
"chars": 827,
"preview": "#! /usr/bin/env python3\n# -*- coding: utf-8 -*-\n# File : prroi_pool.py\n# Author : Jiayuan Mao, Tete Xiao\n# Email : ma"
},
{
"path": "seeing/upsegmodel/prroi_pool/src/prroi_pooling_gpu.c",
"chars": 3854,
"preview": "/*\n * File : prroi_pooling_gpu.c\n * Author : Jiayuan Mao, Tete Xiao\n * Email : maojiayuan@gmail.com, jasonhsiao97@gma"
},
{
"path": "seeing/upsegmodel/prroi_pool/src/prroi_pooling_gpu.h",
"chars": 865,
"preview": "/*\n * File : prroi_pooling_gpu.h\n * Author : Jiayuan Mao, Tete Xiao\n * Email : maojiayuan@gmail.com, jasonhsiao97@gma"
},
{
"path": "seeing/upsegmodel/prroi_pool/src/prroi_pooling_gpu_impl.cu",
"chars": 17532,
"preview": "/*\n * File : prroi_pooling_gpu_impl.cu\n * Author : Tete Xiao, Jiayuan Mao\n * Email : jasonhsiao97@gmail.com\n *\n * Dis"
},
{
"path": "seeing/upsegmodel/prroi_pool/src/prroi_pooling_gpu_impl.cuh",
"chars": 1589,
"preview": "/*\n * File : prroi_pooling_gpu_impl.cuh\n * Author : Tete Xiao, Jiayuan Mao\n * Email : jasonhsiao97@gmail.com\n *\n * Di"
},
{
"path": "seeing/upsegmodel/prroi_pool/test_prroi_pooling2d.py",
"chars": 1473,
"preview": "# -*- coding: utf-8 -*-\n# File : test_prroi_pooling2d.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date"
},
{
"path": "seeing/upsegmodel/resnet.py",
"chars": 7363,
"preview": "import os\nimport sys\nimport torch\nimport torch.nn as nn\nimport math\ntry:\n from lib.nn import SynchronizedBatchNorm2d\n"
},
{
"path": "seeing/upsegmodel/resnext.py",
"chars": 6057,
"preview": "import os\nimport sys\nimport torch\nimport torch.nn as nn\nimport math\ntry:\n from lib.nn import SynchronizedBatchNorm2d\n"
},
{
"path": "seeing/yz_dataset.py",
"chars": 1057,
"preview": "import torch, numpy\n\n\nclass YZDataset():\n def __init__(self, zdim=256, nlabels=1, distribution=[1.], device='cpu'):\n "
},
{
"path": "seeing/zdataset.py",
"chars": 1644,
"preview": "import os, torch, numpy\nfrom torch.utils.data import TensorDataset\n\ndef z_dataset_for_model(model, size=100, seed=1):\n "
},
{
"path": "train.py",
"chars": 10343,
"preview": "import argparse\nimport os\nimport copy\nimport pprint\nfrom os import path\n\nimport torch\nimport numpy as np\nfrom torch impo"
},
{
"path": "utils/classifiers/__init__.py",
"chars": 233,
"preview": "from classifiers import stacked_mnist, cifar, places, imagenet\n\nclassifier_dict = {\n 'stacked_mnist': stacked_mnist.C"
},
{
"path": "utils/classifiers/cifar.py",
"chars": 302,
"preview": "import sys\nsys.path.append('utils/classifiers')\n\nfrom pytorch_playground.cifar.model import cifar10\n\nclass Classifier():"
},
{
"path": "utils/classifiers/imagenet.py",
"chars": 1106,
"preview": "import torch\nimport torchvision.models as models\nfrom torchvision import transforms as trn\nfrom torch.nn import function"
},
{
"path": "utils/classifiers/imagenet_class_index.json",
"chars": 35363,
"preview": "{\"0\": [\"n01440764\", \"tench\"], \"1\": [\"n01443537\", \"goldfish\"], \"2\": [\"n01484850\", \"great_white_shark\"], \"3\": [\"n01491361\""
},
{
"path": "utils/classifiers/places.py",
"chars": 2259,
"preview": "import torch\nimport torchvision.models as models\nfrom torchvision import transforms as trn\nfrom torch.nn import function"
},
{
"path": "utils/classifiers/pytorch_playground/.gitignore",
"chars": 100,
"preview": "__pycache__\n*.jpg\n*.png\nacc1_acc5.txt\nlog\npytorch_playground.egg-info\nscript/val224_compressed.pkl\n\n"
},
{
"path": "utils/classifiers/pytorch_playground/LICENSE",
"chars": 1067,
"preview": "MIT License\n\nCopyright (c) 2017 Aaron Chen\n\nPermission is hereby granted, free of charge, to any person obtaining a copy"
},
{
"path": "utils/classifiers/pytorch_playground/README.md",
"chars": 4971,
"preview": "This is a playground for pytorch beginners, which contains predefined models on popular dataset. Currently we support \n-"
},
{
"path": "utils/classifiers/pytorch_playground/cifar/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "utils/classifiers/pytorch_playground/cifar/dataset.py",
"chars": 2937,
"preview": "import torch\nfrom torchvision import datasets, transforms\nfrom torch.utils.data import DataLoader\nimport os\n\ndef get10(b"
},
{
"path": "utils/classifiers/pytorch_playground/cifar/model.py",
"chars": 2105,
"preview": "import torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\nfrom IPython import embed\nfrom collections import Ordere"
},
{
"path": "utils/classifiers/pytorch_playground/cifar/train.py",
"chars": 5777,
"preview": "import argparse\nimport os\nimport time\n\nfrom utee import misc\nimport torch\nimport torch.nn.functional as F\nimport torch.o"
},
{
"path": "utils/classifiers/pytorch_playground/quantize.py",
"chars": 4928,
"preview": "import argparse\nfrom utee import misc, quant, selector\nimport torch\nimport torch.backends.cudnn as cudnn\ncudnn.benchmark"
},
{
"path": "utils/classifiers/pytorch_playground/requirements.txt",
"chars": 83,
"preview": "Pillow==6.1\ntorchvision==0.4.2\ntqdm==4.41.1\nopencv-python==4.1.2.30\njoblib==0.14.1\n"
},
{
"path": "utils/classifiers/pytorch_playground/roadmap_zh.md",
"chars": 3788,
"preview": "# 定点化Roadmap\n首先定点化的setting分好几种,主要如下所示 (w代表weight,a代表activation,g代表gradient)\n\n最近两年的目前有13篇直接相关的论文,截止到2016年7月\n\n## float转化为定"
},
{
"path": "utils/classifiers/pytorch_playground/setup.py",
"chars": 447,
"preview": "from setuptools import setup, find_packages\n\nwith open(\"requirements.txt\") as requirements_file:\n REQUIREMENTS = requ"
},
{
"path": "utils/classifiers/pytorch_playground/utee/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "utils/classifiers/pytorch_playground/utee/misc.py",
"chars": 7772,
"preview": "import cv2\nimport os\nimport shutil\nimport pickle as pkl\nimport time\nimport numpy as np\nimport hashlib\n\nfrom IPython impo"
},
{
"path": "utils/classifiers/pytorch_playground/utee/quant.py",
"chars": 6302,
"preview": "from torch.autograd import Variable\nimport torch\nfrom torch import nn\nfrom collections import OrderedDict\nimport math\nfr"
},
{
"path": "utils/classifiers/pytorch_playground/utee/selector.py",
"chars": 5245,
"preview": "from utee import misc\nimport os\nfrom imagenet import dataset\nprint = misc.logger.info\nfrom IPython import embed\n\nknown_m"
},
{
"path": "utils/classifiers/stacked_mnist.py",
"chars": 3638,
"preview": "import torch\nfrom torch import nn\nimport torch.utils.model_zoo as model_zoo\nfrom collections import OrderedDict\nfrom tor"
},
{
"path": "utils/get_empirical_distribution.py",
"chars": 2905,
"preview": "import argparse\nimport os\nfrom tqdm import tqdm\n\nimport json\nimport numpy as np\n\nfrom classifiers import classifier_dict"
},
{
"path": "utils/get_gt_imgs.py",
"chars": 2884,
"preview": "import os\nimport argparse\nfrom tqdm import tqdm\nfrom PIL import Image\nimport torch\nfrom torchvision import transforms, d"
},
{
"path": "utils/np_to_pt_img.py",
"chars": 351,
"preview": "import torch\n\n\ndef np_to_pt(x):\n ''' permutes the appropriate channels to turn numpy formatted images to pt formatted"
},
{
"path": "visualize_clusters.py",
"chars": 4163,
"preview": "import argparse\nimport os\nimport shutil\nimport torch\nimport torchvision\n\nfrom torch import nn\nfrom gan_training import u"
}
]
About this extraction
This page contains the full source code of the stevliu/self-conditioned-gan GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 119 files (375.7 KB), approximately 104.1k tokens, and a symbol index with 462 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.