[
  {
    "path": ".gitignore",
    "content": "*/**/*pyc*\n*/**/.DS_Store\n.vscode\n"
  },
  {
    "path": "2d_mix/.gitignore",
    "content": "**.png\n**.pyc\n**.pt\noutput/\n"
  },
  {
    "path": "2d_mix/__init__.py",
    "content": ""
  },
  {
    "path": "2d_mix/config.py",
    "content": "import torch\n\nfrom models import generator_dict, discriminator_dict\nfrom torch import optim\nimport torch.utils.data as utils\n\n\ndef get_models(model_type, conditioning, k_value, d_act_dim, device):\n    G = generator_dict[model_type]\n    D = discriminator_dict[model_type]\n    generator = G(conditioning, k_value=k_value)\n    discriminator = D(conditioning, k_value=k_value, act_dim=d_act_dim)\n\n    generator.to(device)\n    discriminator.to(device)\n\n    return generator, discriminator\n\n\ndef get_optimizers(generator, discriminator, lr=1e-4, beta1=0.8, beta2=0.999):\n    g_optimizer = optim.Adam(generator.parameters(),\n                             lr=lr,\n                             betas=(beta1, beta2))\n    d_optimizer = optim.Adam(discriminator.parameters(),\n                             lr=lr,\n                             betas=(beta1, beta2))\n    return g_optimizer, d_optimizer\n\n\ndef get_test(get_data, batch_size, variance, k_value, device):\n    x_test, y_test = get_data(batch_size, var=variance)\n    x_test, y_test = torch.from_numpy(x_test).float().to(\n        device), torch.from_numpy(y_test).long().to(device)\n    return x_test, y_test\n\n\ndef get_dataset(get_data, batch_size, npts, variance, k_value):\n    samples, labels = get_data(npts, var=variance)\n    tensor_samples = torch.stack([torch.Tensor(x) for x in samples])\n    tensor_labels = torch.stack([torch.tensor(x) for x in labels])\n    dataset = utils.TensorDataset(tensor_samples, tensor_labels)\n    train_loader = utils.DataLoader(dataset,\n                                    batch_size=batch_size,\n                                    shuffle=True,\n                                    num_workers=0,\n                                    pin_memory=True,\n                                    sampler=None,\n                                    drop_last=True)\n    return train_loader\n"
  },
  {
    "path": "2d_mix/evaluation.py",
    "content": "def warn(*args, **kwargs):\n    pass\n\n\nimport warnings\nwarnings.warn = warn\n\nimport numpy as np\n\n\ndef percent_good_grid(x_fake, var=0.0025, nrows=5, ncols=5):\n    std = np.sqrt(var)\n    x = list(range(nrows))\n    y = list(range(ncols))\n\n    threshold = 3 * std\n    means = []\n    for i in x:\n        for j in y:\n            means.append(np.array([x[i] * 2 - 4, y[j] * 2 - 4]))\n    return percent_good_pts(x_fake, means, threshold)\n\n\ndef percent_good_ring(x_fake, var=0.0001, n_clusters=8, radius=2.0):\n    std = np.sqrt(var)\n    thetas = np.linspace(0, 2 * np.pi, n_clusters + 1)[:n_clusters]\n    x, y = radius * np.sin(thetas), radius * np.cos(thetas)\n    threshold = np.array([std * 3, std * 3])\n    means = []\n    for i in range(n_clusters):\n        means.append(np.array([x[i], y[i]]))\n    return percent_good_pts(x_fake, means, threshold)\n\n\ndef percent_good_pts(x_fake, means, threshold):\n    \"\"\"Calculate %good, #modes, kl\n\n    Keyword arguments:\n    x_fake -- detached generated samples\n    means -- true means\n    threshold -- good point if l_1 distance is within threshold\n    \"\"\"\n    count = 0\n    counts = np.zeros(len(means))\n    visited = set()\n    for point in x_fake:\n        minimum = 0\n        diff_minimum = [1e10, 1e10]\n        for i, mean in enumerate(means):\n            diff = np.abs(point - mean)\n            if np.all(diff < threshold):\n                visited.add(tuple(mean))\n                count += 1\n                break\n        for i, mean in enumerate(means):\n            diff = np.abs(point - mean)\n            if np.linalg.norm(diff) < np.linalg.norm(diff_minimum):\n                minimum = i\n                diff_minimum = diff\n        counts[minimum] += 1\n\n    kl = 0\n    counts = counts / len(x_fake)\n    for generated in counts:\n        if generated != 0:\n            kl += generated * np.log(len(means) * generated)\n\n    return count / len(x_fake), len(visited), kl\n"
  },
  {
    "path": "2d_mix/inputs.py",
    "content": "import numpy as np\nimport random\n\nmapping = list(range(25))\n\ndef map_labels(labels):\n    return np.array([mapping[label] for label in labels])\n\n\ndef get_data_ring(batch_size, radius=2.0, var=0.0001, n_clusters=8):\n    thetas = np.linspace(0, 2 * np.pi, n_clusters + 1)[:n_clusters]\n    xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)\n    classes = np.random.multinomial(batch_size,\n                                    [1.0 / n_clusters] * n_clusters)\n    labels = [i for i in range(n_clusters) for _ in range(classes[i])]\n    random.shuffle(labels)\n    labels = np.array(labels)\n    samples = np.array([\n        np.random.multivariate_normal([xs[i], ys[i]], [[var, 0], [0, var]])\n        for i in labels\n    ])\n    return samples, labels\n\n\ndef get_data_grid(batch_size, radius=2.0, var=0.0025, nrows=5, ncols=5):\n    samples = []\n    labels = []\n    for _ in range(batch_size):\n        i, j = random.randint(0, ncols - 1), random.randint(0, nrows - 1)\n        samples.append(\n            np.random.multivariate_normal([i * 2 - 4, j * 2 - 4],\n                                          [[var, 0], [0, var]]))\n        labels.append(5 * i + j)\n    return np.array(samples), map_labels(labels)\n"
  },
  {
    "path": "2d_mix/models/__init__.py",
    "content": "from models import (cluster)\n\ngenerator_dict = {'standard': cluster.G}\ndiscriminator_dict = {'standard': cluster.D}\n"
  },
  {
    "path": "2d_mix/models/cluster.py",
    "content": "import sys\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\n\nsys.path.append('../gan_training/models')\n\nfrom blocks import LatentEmbeddingConcat, Identity, LinearUnconditionalLogits, LinearConditionalMaskLogits\n\n\nclass G(nn.Module):\n    def __init__(self,\n                 conditioning,\n                 k_value,\n                 z_dim=2,\n                 embed_size=32,\n                 act_dim=400,\n                 x_dim=2):\n\n        super().__init__()\n        if conditioning == 'unconditional':\n            embed_size = 0\n            self.embedding = Identity()\n        elif conditioning == 'conditional':\n            self.embedding = LatentEmbeddingConcat(k_value, embed_size)\n        else:\n            raise NotImplementedError()\n\n        self.fc1 = nn.Sequential(nn.Linear(z_dim + embed_size, act_dim),\n                                 nn.BatchNorm1d(act_dim), nn.ReLU(True))\n        self.fc2 = nn.Sequential(nn.Linear(act_dim, act_dim),\n                                 nn.BatchNorm1d(act_dim), nn.ReLU(True))\n        self.fc3 = nn.Sequential(nn.Linear(act_dim, act_dim),\n                                 nn.BatchNorm1d(act_dim), nn.ReLU(True))\n        self.fc4 = nn.Sequential(nn.Linear(act_dim, act_dim),\n                                 nn.BatchNorm1d(act_dim), nn.ReLU(True))\n        self.fc_out = nn.Linear(act_dim, x_dim)\n\n    def forward(self, z, y=None):\n        out = self.fc1(self.embedding(z, y))\n        out = self.fc2(out)\n        out = self.fc3(out)\n        out = self.fc4(out)\n        out = self.fc_out(out)\n        return out\n\n\nclass D(nn.Module):\n    class Maxout(nn.Module):\n        # Taken from https://github.com/pytorch/pytorch/issues/805\n        def __init__(self, d_in, d_out, pool_size=5):\n            super().__init__()\n            self.d_in, self.d_out, self.pool_size = d_in, d_out, pool_size\n            self.lin = nn.Linear(d_in, d_out * pool_size)\n\n        def forward(self, inputs):\n            shape = list(inputs.size())\n            shape[-1] = self.d_out\n            shape.append(self.pool_size)\n            max_dim = len(shape) - 1\n            out = self.lin(inputs)\n            m, i = out.view(*shape).max(max_dim)\n            return m\n\n    def max(self, out, dim=5):\n        return out.view(out.size(0), -1, dim).max(2)[0]\n\n    def __init__(self, conditioning, k_value, act_dim=200, x_dim=2):\n        super().__init__()\n        self.fc1 = self.Maxout(x_dim, act_dim)\n        self.fc2 = self.Maxout(act_dim, act_dim)\n        self.fc3 = self.Maxout(act_dim, act_dim)\n\n        if conditioning == 'unconditional':\n            self.fc4 = LinearUnconditionalLogits(act_dim)\n        elif conditioning == 'conditional':\n            self.fc4 = LinearConditionalMaskLogits(act_dim, k_value)\n        else:\n            raise NotImplementedError()\n\n    def forward(self, x, y=None, get_features=False):\n        out = self.fc1(x)\n        out = self.fc2(out)\n        out = self.fc3(out)\n        if get_features: return out\n        return self.fc4(out, y, get_features=get_features)\n"
  },
  {
    "path": "2d_mix/train.py",
    "content": "import argparse\nimport os\nimport sys\n\nimport torch\nfrom torch import optim\nfrom torch import distributions\nfrom torch import nn\nimport torch.nn.functional as F\nimport numpy as np\n\nimport evaluation\nimport inputs\n\nfrom config import get_models, get_optimizers, get_test, get_dataset\nfrom visualizations import (visualize_generated, visualize_clusters)\n\nsys.path.append('../')\nfrom clusterers import clusterer_dict\nfrom gan_training.train import Trainer\n\nsys.path.append('../seeing/')\nimport pidfile\n\nparser = argparse.ArgumentParser(description='2d dataset experiments')\nparser.add_argument('--clusterer', help='type of clusterer to use. cluster specifies selfcondgan')\nparser.add_argument('--data_type', help='either grid or ring')\nparser.add_argument('--recluster_every', type=int, default=5000, help='how frequently to recluster')\nparser.add_argument('--nruns', type=int, default=1, help='number of trials to do')\nparser.add_argument('--burnin_time', type=int, default=0, help='wait this amount of iterations before clustering')\n\nparser.add_argument('--variance', type=float, default=None, help='variance of the gaussians')\nparser.add_argument('--model_type', type=str, default='standard', help='model architecture')\nparser.add_argument('--num_clusters', type=int, default=50, help='number of clusters to use for selfcondgan')\nparser.add_argument('--z_dim', type=int, default=2, help='G latent dim')\nparser.add_argument('--d_act_dim', type=int, default=200, help='hidden layer width')\nparser.add_argument('--npts', type=int, default=100000, help='number of points to use in dataset')\nparser.add_argument('--train_batch_size', type=int, default=100, help='training time batch size')\nparser.add_argument('--test_batch_size', type=int, default=50000, help='number of examples to get metrics with')\nparser.add_argument('--nepochs', type=int, default=100, help='number of epochs to run')\nparser.add_argument('--outdir', default='output')\nargs = parser.parse_args()\n\ndata_type = args.data_type\nk_value = 8 if data_type == 'ring' else 25\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\nnum_clusters = k_value if args.clusterer == 'supervised' else args.num_clusters\n\nexp_name = f'{args.data_type}_{args.clusterer}_{args.recluster_every}_{num_clusters}/'\nif args.model_type != 'standard':\n    exp_name = f'{args.model_type}_{exp_name}'\nif args.variance is not None:\n    exp_name = f'{args.variance}_{exp_name}'\n\nif args.variance is None:\n    variance = 0.0025 if data_type == 'grid' else 0.0001\nelse:\n    variance = args.variance\nnepochs = args.nepochs\nz_dim = args.z_dim\ntest_batch_size = args.test_batch_size\ntrain_batch_size = args.train_batch_size\nnpts = args.npts\n\n\ndef main(outdir):\n    for subdir in ['all', 'snapshots', 'clusters']:\n        if not os.path.exists(os.path.join(outdir, subdir)):\n            os.makedirs(os.path.join(outdir, subdir), exist_ok=True)\n\n    if data_type == 'grid':\n        get_data = inputs.get_data_grid\n        percent_good = evaluation.percent_good_grid\n    elif data_type == 'ring':\n        get_data = inputs.get_data_ring\n        percent_good = evaluation.percent_good_ring\n    else:\n        raise NotImplementedError()\n\n    zdist = distributions.Normal(torch.zeros(z_dim, device=device),\n                                 torch.ones(z_dim, device=device))\n    z_test = zdist.sample((test_batch_size, ))\n\n    x_test, y_test = get_test(get_data=get_data,\n                              batch_size=test_batch_size,\n                              variance=variance,\n                              k_value=k_value,\n                              device=device)\n\n    x_cluster, _ = get_test(get_data=get_data,\n                            batch_size=10000,\n                            variance=variance,\n                            k_value=k_value,\n                            device=device)\n\n    train_loader = get_dataset(get_data=get_data,\n                               batch_size=train_batch_size,\n                               npts=npts,\n                               variance=variance,\n                               k_value=k_value)\n\n    def train(trainer, g, d, clusterer, exp_dir):\n        it = 0\n        if os.path.exists(os.path.join(exp_dir, 'log.txt')):\n            os.remove(os.path.join(exp_dir, 'log.txt'))\n\n        for epoch in range(nepochs):\n            for x_real, y in train_loader:\n                z = zdist.sample((train_batch_size, ))\n                x_real, y = x_real.to(device), y.to(device)\n                y = clusterer.get_labels(x_real, y)\n\n                dloss, _ = trainer.discriminator_trainstep(x_real, y, z)\n                gloss = trainer.generator_trainstep(y, z)\n\n                if it % args.recluster_every == 0 and args.clusterer != 'supervised':\n                    if args.clusterer != 'burnin' or it >= args.burnin_time:\n                        clusterer.recluster(discriminator, x_batch=x_real)\n\n                if it % 1000 == 0:\n                    x_fake = g(z_test, clusterer.get_labels(x_test, y_test)).detach().cpu().numpy()\n\n                    visualize_generated(x_fake,\n                                        x_test.detach().cpu().numpy(), y, it,\n                                        exp_dir)\n\n                    visualize_clusters(x_test.detach().cpu().numpy(),\n                                       clusterer.get_labels(x_test, y_test),\n                                       it, exp_dir)\n\n                    torch.save(\n                        {\n                            'generator': g.state_dict(),\n                            'discriminator': d.state_dict(),\n                            'g_optimizer': g_optimizer.state_dict(),\n                            'd_optimizer': d_optimizer.state_dict()\n                        },\n                        os.path.join(exp_dir, 'snapshots', 'model_%d.pt' % it))\n\n                if it % 1000 == 0:\n                    g.eval()\n                    d.eval()\n\n                    x_fake = g(z_test, clusterer.get_labels(\n                        x_test, y_test)).detach().cpu().numpy()\n                    percent, modes, kl = percent_good(x_fake, var=variance)\n                    log_message = f'[epoch {epoch} it {it}] dloss = {dloss}, gloss = {gloss}, prop_real = {percent}, modes = {modes}, kl = {kl}'\n                    with open(os.path.join(exp_dir, 'log.txt'), 'a+') as f:\n                        f.write(log_message + '\\n')\n                    print(log_message)\n\n                it += 1\n\n    # train a G/D from scratch\n    generator, discriminator = get_models(args.model_type, 'conditional', num_clusters, args.d_act_dim, device)\n    g_optimizer, d_optimizer = get_optimizers(generator, discriminator)\n    trainer = Trainer(generator, discriminator, g_optimizer, d_optimizer, gan_type='standard', reg_type='none', reg_param=0)\n    clusterer = clusterer_dict[args.clusterer](discriminator=discriminator,\n                                               k_value=num_clusters,\n                                               x_cluster=x_cluster)\n    clusterer.recluster(discriminator=discriminator)\n    train(trainer, generator, discriminator, clusterer, os.path.join(outdir))\n\n\nif __name__ == '__main__':\n    outdir = os.path.join(args.outdir, exp_name)\n    pidfile.exit_if_job_done(outdir)\n    for run_number in range(args.nruns):\n        run_dir = f'{outdir}_run_{run_number}' if args.nruns > 1 else outdir\n        main(run_dir)\n    pidfile.mark_job_done(outdir)\n"
  },
  {
    "path": "2d_mix/visualizations.py",
    "content": "import matplotlib\nfrom matplotlib import pyplot\nimport os\n\nCOLORS = [\n    'purple',\n    'wheat',\n    'maroon',\n    'red',\n    'powderblue',\n    'dodgerblue',\n    'magenta',\n    'tan',\n    'aqua',\n    'yellow',\n    'slategray',\n    'blue',\n    'rosybrown',\n    'violet',\n    'lightseagreen',\n    'pink',\n    'darkorange',\n    'teal',\n    'royalblue',\n    'lawngreen',\n    'gold',\n    'navy',\n    'darkgreen',\n    'deeppink',\n    'palegreen',\n    'silver',\n    'saddlebrown',\n    'plum',\n    'peru',\n    'black',\n]\n\nassert (len(COLORS) == len(set(COLORS)))\n\ndef visualize_generated(fake, real, y, it, outdir):\n    pyplot.plot(real[:, 0], real[:, 1], 'r.')\n    pyplot.plot(fake[:, 0], fake[:, 1], 'b.')\n    pyplot.savefig(os.path.join(outdir, 'all', str(it) + '.png'))\n    pyplot.clf()\n\n    lim = 6\n    axes = pyplot.gca()\n    axes.set_aspect('equal', adjustable='box')\n    axes.set_xlim([-lim, lim])\n    axes.set_ylim([-lim, lim])\n\n    pyplot.locator_params(nbins=4)\n    pyplot.tight_layout()\n\n    pyplot.plot(fake[:, 0], fake[:, 1], 'b.', alpha=0.1)\n    pyplot.savefig(os.path.join(outdir, 'all',\n                                str(it) + 'square.png'),\n                   dpi=100,\n                   bbox_inches='tight')\n    pyplot.clf()\n\n\ndef visualize_clusters(x, y, it, outdir):\n    y = y.detach().cpu().numpy()\n    for i in range(y.max()):\n        pyplot.plot(x[y == i, 0],\n                    x[y == i, 1],\n                    '.',\n                    color=COLORS[i % len(COLORS)])\n    pyplot.savefig(os.path.join(outdir, 'clusters', str(it) + '.png'))\n    pyplot.clf()\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2020 Steven Liu\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "# Diverse Image Generation via Self-Conditioned GANs\n\n#### [Project](http://selfcondgan.csail.mit.edu/) |   [Paper](http://selfcondgan.csail.mit.edu/preprint.pdf)\n\n**Diverse Image Generation via Self-Conditioned GANs** <br>\n[Steven Liu](http://people.csail.mit.edu/stevenliu/),\n[Tongzhou Wang](https://ssnl.github.io/),\n[David Bau](http://people.csail.mit.edu/davidbau/home/),\n[Jun-Yan Zhu](http://people.csail.mit.edu/junyanz/),\n[Antonio Torralba](http://web.mit.edu/torralba/www/) <br>\nMIT, Adobe Research<br>\nin CVPR 2020.\n\n![Teaser](images/teaser.png)\n\nOur proposed self-conditioned GAN model learns to perform clustering and image synthesis simultaneously. The model training\nrequires no manual annotation of object classes. Here, we visualize several discovered clusters for both Places365 (top) and ImageNet\n(bottom). For each cluster, we show both real images and the generated samples conditioned on the cluster index.\n\n## Getting Started\n\n### Installation\n- Clone this repo:\n```bash\ngit clone https://github.com/stevliu/self-conditioned-gan.git\ncd self-conditioned-gan\n```\n\n- Install the dependencies\n```bash\nconda create --name selfcondgan python=3.6\nconda activate selfcondgan\nconda install --file requirements.txt\nconda install -c conda-forge tensorboardx\n```\n### Training and Evaluation\n- Train a model on CIFAR:\n```bash\npython train.py configs/cifar/selfcondgan.yaml\n```\n\n- Visualize samples and inferred clusters:\n```bash\npython visualize_clusters.py configs/cifar/selfcondgan.yaml --show_clusters\n```\nThe samples and clusters will be saved to `output/cifar/selfcondgan/clusters`. If this directory lies on an Apache server, you can open the URL to `output/cifar/selfcondgan/clusters/+lightbox.html` in the browser and visualize all samples and clusters in one webpage.\n\n- Evaluate the model's FID:\nYou will need to first gather a set of ground truth train set images to compute metrics against.\n```bash\npython utils/get_gt_imgs.py --cifar\npython metrics.py configs/cifar/selfcondgan.yaml --fid --every -1\n```\nYou can also evaluate with other metrics by appending additional flags, such as Inception Score (`--inception`), the number of covered modes + reverse-KL divergence (`--modes`), and cluster metrics (`--cluster_metrics`).\n\n## Pretrained Models\n\nYou can load and evaluate pretrained models on ImageNet and Places. If you have access to ImageNet or Places directories, first fill in paths to your ImageNet and/or Places dataset directories in `configs/imagenet/default.yaml` and `configs/places/default.yaml` respectively. You can use the following config files with the evaluation scripts, and the code will automatically download the appropriate models.\n\n```bash\nconfigs/pretrained/imagenet/selfcondgan.yaml\nconfigs/pretrained/places/selfcondgan.yaml\n\nconfigs/pretrained/imagenet/conditional.yaml\nconfigs/pretrained/places/conditional.yaml\n\nconfigs/pretrained/imagenet/baseline.yaml\nconfigs/pretrained/places/baseline.yaml\n```\n\n## Evaluation\n### Visualizations\n\nTo visualize generated samples and inferred clusters, run\n```bash\npython visualize_clusters.py config-file\n```\nYou can set the flag `--show_clusters` to also visualize the real inferred clusters, but this requires that you have a path to training set images.\n\n### Metrics\nTo obtain generation metrics, fill in paths to your ImageNet or Places dataset directories in `utils/get_gt_imgs.py` and then run\n```bash\npython utils/get_gt_imgs.py --imagenet --places\n```\nto precompute batches of GT images for FID/FSD evaluation.\n\nThen, you can use\n```bash\npython metrics.py config-file\n```\nwith the appropriate flags compute the FID (`--fid`), FSD (`--fsd`), IS (`--inception`), number of modes covered/ reverse-KL divergence (`--modes`) and clustering metrics (`--cluster_metrics`) for each of the checkpoints.\n\n## Training models\nTo train a model, set up a configuration file (examples in `/configs`), and run\n```bash\npython train.py config-file\n```\n\nAn example config of self-conditioned GAN on ImageNet is `config/imagenet/selfcondgan.yaml` and on Places is `config/places/selfcondgan.yaml`.\n\nSome models may be too large to fit on one GPU, so you may want to add `--devices DEVICE_NUMBERS` as an additional flag to do multi GPU training.\n\n## 2D-experiments\nFor synthetic dataset experiments, first go into the `2d_mix` directory.\n\nTo train a self-conditioned GAN on the 2D-ring and 2D-grid dataset, run\n```bash\npython train.py --clusterer selfcondgan --data_type ring\npython train.py --clusterer selfcondgan --data_type grid\n```\nYou can test several other configurations via the command line arguments.\n\n\n## Acknowledgments\nThis code is heavily based on the [GAN-stability](https://github.com/LMescheder/GAN_stability) code base.\nOur FSD code is taken from the [GANseeing](https://github.com/davidbau/ganseeing) work.\nTo compute inception score, we use the code provided from [Shichang Tang](https://github.com/tsc2017/Inception-Score.git).\nTo compute FID, we use the code provided from [TTUR](https://github.com/bioinf-jku/TTUR).\nWe also use pretrained classifiers given by the [pytorch-playground](https://github.com/aaron-xichen/pytorch-playground).\n\nWe thank all the authors for their useful code.\n\n## Citation\nIf you use this code for your research, please cite the following work.\n```\n@inproceedings{liu2020selfconditioned,\n title={Diverse Image Generation via Self-Conditioned GANs},\n author={Liu, Steven and Wang, Tongzhou and Bau, David and Zhu, Jun-Yan and Torralba, Antonio},\n booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\n year={2020}\n}\n```\n"
  },
  {
    "path": "cluster_metrics.py",
    "content": "import argparse\nimport os\nfrom tqdm import tqdm\n\nimport torch\nimport numpy as np\nfrom torch import nn\n\nfrom gan_training import utils\nfrom gan_training.inputs import get_dataset\nfrom gan_training.checkpoints import CheckpointIO\nfrom gan_training.config import load_config\nfrom gan_training.metrics.clustering_metrics import (nmi, purity_score)\n\ntorch.backends.cudnn.benchmark = True\n\n# Arguments\nparser = argparse.ArgumentParser(description='Evaluate the clustering inferred by our method')\nparser.add_argument('config', type=str, help='Path to config file.')\nparser.add_argument('--model_it', type=str)\nparser.add_argument('--random', action='store_true', help='Figure out if the clusters were randomly assigned')\n\nargs = parser.parse_args()\nconfig = load_config(args.config, 'configs/default.yaml')\nout_dir = config['training']['out_dir']\n\n\ndef main():\n    checkpoint_dir = os.path.join(out_dir, 'chkpts')\n    batch_size = config['training']['batch_size']\n\n    if 'cifar' in config['data']['train_dir'].lower():\n        name = 'cifar10'\n    elif 'stacked_mnist' == config['data']['type']:\n        name = 'stacked_mnist'\n    else:\n        name = 'image'\n\n    if os.path.exists(os.path.join(out_dir, 'cluster_preds.npz')):\n        # if we've already computed assignments, load them and move on\n        with np.load(os.path.join(out_dir, 'cluster_preds.npz')) as f:\n            y_reals = f['y_reals']\n            y_preds = f['y_preds']\n    else:\n        train_dataset, _ = get_dataset(\n            name=name,\n            data_dir=config['data']['train_dir'],\n            size=config['data']['img_size'])\n\n        train_loader = torch.utils.data.DataLoader(\n            train_dataset,\n            batch_size=batch_size,\n            num_workers=config['training']['nworkers'],\n            shuffle=True,\n            pin_memory=True,\n            sampler=None,\n            drop_last=True)\n\n        checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)\n\n        print('Loading clusterer:')\n        most_recent = utils.get_most_recent(checkpoint_dir, 'model') if args.model_it is None else args.model_it\n        clusterer = checkpoint_io.load_clusterer(most_recent, load_samples=False, pretrained=config['pretrained'])\n\n        if isinstance(clusterer.discriminator, nn.DataParallel):\n            clusterer.discriminator = clusterer.discriminator.module\n\n        y_preds = []\n        y_reals = []\n\n        for batch_num, (x_real, y_real) in enumerate(tqdm(train_loader, total=len(train_loader))):\n            y_pred = clusterer.get_labels(x_real.cuda(), None)\n            y_preds.append(y_pred.detach().cpu())\n            y_reals.append(y_real)\n\n        y_reals = torch.cat(y_reals).numpy()\n        y_preds = torch.cat(y_preds).numpy()\n\n        np.savez(os.path.join(out_dir, 'cluster_preds.npz'), y_reals=y_reals, y_preds=y_preds)\n\n    if args.random:\n        y_preds = np.random.randint(0, 100, size=y_reals.shape)\n\n    nmi_score = nmi(y_preds, y_reals)\n    purity = purity_score(y_preds, y_reals)\n    print('nmi', nmi_score, 'purity', purity)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "clusterers/__init__.py",
    "content": "from clusterers import (base_clusterer, selfcondgan, random_labels, online)\n\nclusterer_dict = {\n    'supervised': base_clusterer.BaseClusterer,\n    'selfcondgan': selfcondgan.Clusterer,\n    'online': online.Clusterer,\n    'random_labels': random_labels.Clusterer\n}\n"
  },
  {
    "path": "clusterers/base_clusterer.py",
    "content": "import copy\n\nimport torch\nimport numpy as np\n\nclass BaseClusterer():\n    def __init__(self,\n                 discriminator,\n                 k_value=-1,\n                 x_cluster=None,\n                 batch_size=100,\n                 **kwargs):\n        ''' requires that self.x is not on the gpu, or else it hogs too much gpu memory ''' \n        self.cluster_counts = [0] * k_value\n        self.discriminator = copy.deepcopy(discriminator)\n        self.discriminator.eval()\n        self.k = k_value\n        self.kmeans = None\n        self.x = x_cluster\n        self.x_labels = None\n        self.batch_size = batch_size\n\n    def get_labels(self, x, y):\n        return y\n\n    def recluster(self, discriminator, **kwargs):\n        return\n\n    def get_features(self, x):\n        ''' by default gets discriminator, but you can use other things '''\n        return self.get_discriminator_output(x)\n\n    def get_cluster_batch_features(self):\n        ''' returns the discriminator features for the batch self.x as a numpy array '''\n        with torch.no_grad():\n            outputs = []\n            x = self.x\n            for batch in range(x.size(0) // self.batch_size):\n                x_batch = x[batch * self.batch_size:(batch + 1) * self.batch_size].cuda()\n                outputs.append(self.get_features(x_batch).detach().cpu())\n            if (x.size(0) % self.batch_size != 0):\n                x_batch = x[x.size(0) // self.batch_size * self.batch_size:].cuda()\n                outputs.append(self.get_features(x_batch).detach().cpu())\n            result = torch.cat(outputs, dim=0).numpy()\n            return result\n\n    def get_discriminator_output(self, x):\n        '''returns discriminator features'''\n        self.discriminator.eval()\n        with torch.no_grad():\n            return self.discriminator(x, get_features=True)\n\n    def get_label_distribution(self, x=None):\n        '''returns the empirical distributon of clustering'''\n        y = self.x_labels if x is None else self.get_labels(x, None)\n        counts = [0] * self.k\n        for yi in y:\n            counts[yi] += 1\n        return counts\n\n    def sample_y(self, batch_size):\n        '''samples y according to the empirical distribution (not sure if used anymore)'''\n        distribution = self.get_label_distribution()\n        distribution = [i / sum(distribution) for i in distribution]\n        m = torch.distributions.Multinomial(batch_size,\n                                            torch.tensor(distribution))\n        return m.sample()\n\n    def print_label_distribution(self, x=None):\n        print(self.get_label_distribution(x))\n"
  },
  {
    "path": "clusterers/kmeans.py",
    "content": "import torch\nimport numpy as np\nfrom sklearn.cluster import KMeans\n\nfrom clusterers import base_clusterer\n\nclass Clusterer(base_clusterer.BaseClusterer):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.mapping = list(range(self.k))\n\n    def kmeans_fit_predict(self, features, init='k-means++', n_init=10):\n        '''fits kmeans, and returns the predictions of the kmeans'''\n        print('Fitting k-means w data shape', features.shape)\n        self.kmeans = KMeans(init=init, n_clusters=self.k,\n                             n_init=n_init).fit(features)\n        return self.kmeans.predict(features)\n\n    def get_labels(self, x, y):\n        d_features = self.get_features(x).detach().cpu().numpy()\n        np_prediction = self.kmeans.predict(d_features)\n        permuted_prediction = np.array([self.mapping[x] for x in np_prediction])\n        return torch.from_numpy(permuted_prediction).long().cuda()\n"
  },
  {
    "path": "clusterers/online.py",
    "content": "import copy, random\n\nimport torch\nimport numpy as np\n\nfrom clusterers import kmeans\n\n\nclass Clusterer(kmeans.Clusterer):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.burned_in = False\n\n    def get_initialization(self, features, labels):\n        '''given points (from new discriminator) and their old assignments as np arrays, compute the induced means as a np array'''\n        means = []\n        for i in range(self.k):\n            mask = (labels == i)\n            mean = np.zeros(features[0].shape)\n            numels = mask.astype(int).sum()\n            if numels > 0:\n                for index, equal in enumerate(mask):\n                    if equal: mean += features[index]\n                means.append(mean / numels)\n            else:\n                # use kmeans++ init if cluster is starved\n                rand_point = random.randint(0, features.size(0) - 1)\n                means.append(features[rand_point])\n        result = np.array(means)\n        return result\n\n    def recluster(self, discriminator, x_batch=None, **kwargs):\n        if self.kmeans is None:\n            print('kmeans clustering as initialization')\n            self.discriminator = copy.deepcopy(discriminator)\n            features = self.get_cluster_batch_features()\n            self.x_labels = self.kmeans_fit_predict(features)\n        else:\n            self.discriminator = discriminator\n            if not self.burned_in:\n                print('Burned in: computing initialization for kmeans')\n                features = self.get_cluster_batch_features()\n                initialization = self.get_initialization(\n                    features, self.x_labels)\n                self.kmeans_fit_predict(features, init=initialization)\n                self.burned_in = True\n            else:\n                assert x_batch is not None\n                self.discriminator = discriminator\n                features = self.get_features(x_batch).detach().cpu().numpy()\n                y_pred = self.kmeans.predict(features)\n\n                for xi, yi in zip(features, y_pred):\n                    self.cluster_counts[yi] += 1\n                    difference = xi - self.kmeans.cluster_centers_[yi]\n                    step_size = 1.0 / self.cluster_counts[yi]\n                    self.kmeans.cluster_centers_[\n                        yi] = self.kmeans.cluster_centers_[yi] + step_size * (\n                            difference)\n"
  },
  {
    "path": "clusterers/random_labels.py",
    "content": "import torch\nfrom clusterers import base_clusterer\n\n\nclass Clusterer(base_clusterer.BaseClusterer):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n    def get_labels(self, x, y):\n        return torch.randint(low=0, high=self.k, size=y.shape).long().cuda()"
  },
  {
    "path": "clusterers/selfcondgan.py",
    "content": "import copy, random\n\nimport torch\nimport numpy as np\nfrom sklearn.utils.linear_assignment_ import linear_assignment\n\nfrom clusterers import kmeans\n\n\nclass Clusterer(kmeans.Clusterer):\n    def __init__(self, initialization=True, matching=True, **kwargs):\n        self.initialization = initialization\n        self.matching = matching\n\n        super().__init__(**kwargs)\n\n    def get_initialization(self, features, labels):\n        '''given points (from new discriminator) and their old assignments as np arrays, compute the induced means as a np array'''\n        means = []\n        for i in range(self.k):\n            mask = (labels == i)\n            mean = np.zeros(features[0].shape)\n            numels = mask.astype(int).sum()\n            if numels > 0:\n                for index, equal in enumerate(mask):\n                    if equal: mean += features[index]\n                means.append(mean / numels)\n            else:\n                # use kmeans++ init if cluster is starved\n                rand_point = random.randint(0, features.size(0) - 1)\n                means.append(features[rand_point])\n        result = np.array(means)\n        return result\n\n    def fit_means(self):\n        features = self.get_cluster_batch_features()\n\n        # if clustered already, use old assignments for the cluster mean\n        if self.x_labels is not None and self.initialization:\n            print('Initializing k-means with previous cluster assignments')\n            initialization = self.get_initialization(features, self.x_labels)\n        else:\n            initialization = 'k-means++'\n\n        new_classes = self.kmeans_fit_predict(features, init=initialization)\n\n        # we've clustered already, so compute the permutation\n        if self.x_labels is not None and self.matching:\n            print('Doing cluster matching')\n            matching = self.hungarian_match(new_classes, self.x_labels, self.k,\n                                            self.k)\n            self.mapping = [int(j) for i, j in sorted(matching)]\n\n        # recompute the fixed labels\n        self.x_labels = np.array([self.mapping[x] for x in new_classes])\n\n    def recluster(self, discriminator, **kwargs):\n        self.discriminator = copy.deepcopy(discriminator)\n        self.fit_means()\n\n    def hungarian_match(self, flat_preds, flat_targets, preds_k, targets_k):\n        '''takes in np arrays flat_preds, flat_targets of integers'''\n        num_samples = flat_targets.shape[0]\n\n        assert (preds_k == targets_k)  # one to one\n        num_k = preds_k\n        num_correct = np.zeros((num_k, num_k))\n\n        for c1 in range(num_k):\n            for c2 in range(num_k):\n                votes = int(((flat_preds == c1) * (flat_targets == c2)).sum())\n                num_correct[c1, c2] = votes\n\n        # num_correct is small\n        match = linear_assignment(num_samples - num_correct)\n\n        # return as list of tuples, out_c to gt_c\n        res = []\n        for out_c, gt_c in match:\n            res.append((out_c, gt_c))\n\n        return res\n"
  },
  {
    "path": "configs/cifar/conditional.yaml",
    "content": "generator:\n  nlabels: 10\n  conditioning: embedding\ndiscriminator:\n  nlabels: 10\n  conditioning: mask\ninherit_from: configs/cifar/default.yaml\ntraining:\n  out_dir: output/cifar/conditional"
  },
  {
    "path": "configs/cifar/default.yaml",
    "content": "data:\n  type: cifar10\n  train_dir: data/CIFAR\n  img_size: 32\n  nlabels: 10\ngenerator:\n  name: dcgan_deep\n  nlabels: 1\n  conditioning: unconditional\n  kwargs:\n    placeholder: None\ndiscriminator:\n  name: dcgan_deep\n  nlabels: 1\n  conditioning: unconditional\n  kwargs:\n    placeholder: None\nz_dist:\n  type: gauss\n  dim: 128\nclusterer:\n  name: supervised\n  nimgs: 25000\n  kwargs: \n    placeholder: None\ntraining:\n  gan_type: standard\n  reg_type: none\n  reg_param: 0.\n  take_model_average: false\n  sample_nlabels: 20\n  log_every: 1000\n  inception_every: 10000\n  batch_size: 64"
  },
  {
    "path": "configs/cifar/selfcondgan.yaml",
    "content": "generator:\n  nlabels: 100\n  conditioning: embedding\ndiscriminator:\n  nlabels: 100\n  conditioning: mask\nclusterer:\n  name: selfcondgan\n  kwargs: \n    k_value: 100\ninherit_from: configs/cifar/default.yaml\ntraining:\n  out_dir: output/cifar/selfcondgan\n  recluster_every: 25000"
  },
  {
    "path": "configs/cifar/unconditional.yaml",
    "content": "inherit_from: configs/cifar/default.yaml\ntraining:\n  out_dir: output/cifar/unconditional"
  },
  {
    "path": "configs/default.yaml",
    "content": "data:\n  type: lsun\n  train_dir: data/LSUN\n  deterministic: False\n  img_size: 128\n  nlabels: 1\ngenerator:\n  name: resnet\n  nlabels: 1\n  conditioning: unconditional\n  kwargs: \n    placeholder: None\ndiscriminator:\n  name: resnet\n  nlabels: 1\n  conditioning: unconditional\n  kwargs: \n    pack_size: 1\n    placeholder: None\nclusterer:\n  name: supervised\n  nimgs: 100\n  kwargs: \n    num_components: -1\nz_dist:\n  type: gauss\n  dim: 256\ntraining:\n  out_dir: output/default\n  gan_type: standard\n  reg_type: real\n  reg_param: 10.\n  log_every: 1\n  batch_size: 128\n  ntest: 128\n  nworkers: 72\n  burnin_time: 0\n  take_model_average: true\n  model_average_beta: 0.999\n  monitoring: tensorboard\n  sample_every: 5000\n  sample_nlabels: 20\n  inception_every: 10000\n  inception_nsamples: 50000\n  backup_every: 10000\n  recluster_every: 10000\n  optimizer: adam\n  lr_g: 0.0001\n  lr_d: 0.0001\n  beta1: 0.0\n  beta2: 0.99\npretrained: {}"
  },
  {
    "path": "configs/imagenet/conditional.yaml",
    "content": "generator:\n  nlabels: 1000\n  conditioning: embedding\ndiscriminator:\n  nlabels: 1000\n  conditioning: mask\ninherit_from: configs/imagenet/default.yaml\ntraining:\n  out_dir: output/imagenet/conditional\n  "
  },
  {
    "path": "configs/imagenet/default.yaml",
    "content": "data:\n  type: image\n  train_dir: data/ImageNet/train\n  test_dir: data/ImageNet/val\n  img_size: 128\n  nlabels: 1000\ngenerator:\n  name: resnet2\n  nlabels: 1\n  conditioning: unconditional\ndiscriminator:\n  name: resnet2\n  nlabels: 1\n  conditioning: unconditional\nz_dist:\n  type: gauss\n  dim: 256\nclusterer:\n  name: supervised\ntraining:\n  gan_type: standard\n  reg_type: real\n  reg_param: 10.\n  take_model_average: true\n  model_average_beta: 0.999\n  sample_nlabels: 20\n  log_every: 10\n  inception_every: 10000\n  backup_every: 5000\n  batch_size: 128"
  },
  {
    "path": "configs/imagenet/selfcondgan.yaml",
    "content": "generator:\n  nlabels: 100\n  conditioning: embedding\ndiscriminator:\n  nlabels: 100\n  conditioning: mask\nclusterer:\n  name: selfcondgan\n  nimgs: 50000\n  kwargs: \n    k_value: 100\ninherit_from: configs/imagenet/default.yaml\ntraining:\n  out_dir: output/imagenet/selfcondgan\n  recluster_every: 75000\n  reg_param: 0.1"
  },
  {
    "path": "configs/imagenet/unconditional.yaml",
    "content": "generator:\n  nlabels: 1\n  conditioning: unconditional\ndiscriminator:\n  nlabels: 1\n  conditioning: unconditional\ninherit_from: configs/imagenet/default.yaml\ntraining:\n  out_dir: output/imagenet/unconditional"
  },
  {
    "path": "configs/places/conditional.yaml",
    "content": "generator:\n  nlabels: 365\n  conditioning: embedding\ndiscriminator:\n  nlabels: 365\n  conditioning: mask\ntraining:\n  out_dir: output/places/conditional\ninherit_from: configs/places/default.yaml\n"
  },
  {
    "path": "configs/places/default.yaml",
    "content": "data:\n  type: image\n  train_dir: data/places365/train\n  test_dir: data/places365/val\n  img_size: 128\n  nlabels: 365\ngenerator:\n  name: resnet2\n  nlabels: 1\n  conditioning: unconditional\ndiscriminator:\n  name: resnet2\n  nlabels: 1\n  conditioning: unconditional\nz_dist:\n  type: gauss\n  dim: 256\nclusterer:\n  name: supervised\ntraining:\n  gan_type: standard\n  reg_type: real\n  reg_param: 10.\n  take_model_average: true\n  model_average_beta: 0.999\n  sample_nlabels: 20\n  log_every: 10\n  inception_every: 10000\n  backup_every: 5000\n  batch_size: 128\n  "
  },
  {
    "path": "configs/places/selfcondgan.yaml",
    "content": "generator:\n  nlabels: 100\n  conditioning: embedding\ndiscriminator:\n  nlabels: 100\n  conditioning: mask\nclusterer:\n  name: selfcondgan\n  nimgs: 50000\n  kwargs: \n    k_value: 100\ninherit_from: configs/places/default.yaml\ntraining:\n  out_dir: output/places/selfcondgan\n  recluster_every: 75000\n  reg_param: 0.1"
  },
  {
    "path": "configs/places/unconditional.yaml",
    "content": "generator:\n  nlabels: 1\n  conditioning: embedding\ndiscriminator:\n  nlabels: 1\n  conditioning: mask\ninherit_from: configs/places/default.yaml\ntraining:\n  out_dir: output/places/unconditional\n"
  },
  {
    "path": "configs/pretrained/imagenet/conditional.yaml",
    "content": "generator:\n  nlabels: 1000\n  conditioning: embedding\ndiscriminator:\n  nlabels: 1000\n  conditioning: mask\ninherit_from: configs/imagenet/default.yaml\ntraining:\n  out_dir: output/pretrained/imagenet/class_conditional\npretrained:\n  model: http://selfcondgan.csail.mit.edu/weights/classcondgan_i_model.pt\n"
  },
  {
    "path": "configs/pretrained/imagenet/selfcondgan.yaml",
    "content": "generator:\n  nlabels: 100\n  conditioning: embedding\ndiscriminator:\n  nlabels: 100\n  conditioning: mask\nclusterer:\n  name: selfcondgan\n  nimgs: 50000\n  kwargs: \n    k_value: 100\ninherit_from: configs/imagenet/default.yaml\ntraining:\n  out_dir: output/pretrained/imagenet/selfcondgan\n  recluster_every: 75000\n  reg_param: 0.1\npretrained:\n  model: http://selfcondgan.csail.mit.edu/weights/selfcondgan_i_model.pt\n  clusterer: http://selfcondgan.csail.mit.edu/weights/selfcondgan_i_clusterer.pkl"
  },
  {
    "path": "configs/pretrained/imagenet/unconditional.yaml",
    "content": "generator:\n  nlabels: 1\n  conditioning: unconditional\ndiscriminator:\n  nlabels: 1\n  conditioning: unconditional\ninherit_from: configs/imagenet/default.yaml\ntraining:\n  out_dir: output/pretrained/imagenet/unconditional\npretrained:\n  model: http://selfcondgan.csail.mit.edu/weights/uncondgan_i_model.pt\n"
  },
  {
    "path": "configs/pretrained/places/conditional.yaml",
    "content": "generator:\n  nlabels: 365\n  conditioning: embedding\ndiscriminator:\n  nlabels: 365\n  conditioning: mask\ntraining:\n  out_dir: output/pretrained/places/class_conditional\ninherit_from: configs/places/default.yaml\npretrained:\n  model: http://selfcondgan.csail.mit.edu/weights/classcondgan_p_model.pt\n"
  },
  {
    "path": "configs/pretrained/places/selfcondgan.yaml",
    "content": "generator:\n  nlabels: 100\n  conditioning: embedding\ndiscriminator:\n  nlabels: 100\n  conditioning: mask\nclusterer:\n  name: selfcondgan\n  nimgs: 50000\n  kwargs: \n    k_value: 100\ninherit_from: configs/places/default.yaml\ntraining:\n  out_dir: output/pretrained/places/selfcondgan\n  reg_param: 0.1\npretrained:\n  model: http://selfcondgan.csail.mit.edu/weights/selfcondgan_p_model.pt\n  clusterer: http://selfcondgan.csail.mit.edu/weights/selfcondgan_p_clusterer.pkl"
  },
  {
    "path": "configs/pretrained/places/unconditional.yaml",
    "content": "generator:\n  nlabels: 1\n  conditioning: embedding\ndiscriminator:\n  nlabels: 1\n  conditioning: mask\ninherit_from: configs/places/default.yaml\ntraining:\n  out_dir: output/pretrained/places/unconditional\npretrained:\n  model: http://selfcondgan.csail.mit.edu/weights/uncondgan_p_model.pt\n"
  },
  {
    "path": "configs/stacked_mnist/conditional.yaml",
    "content": "generator:\n  nlabels: 1000\n  conditioning: embedding\ndiscriminator:\n  nlabels: 1000\n  conditioning: mask\ninherit_from: configs/stacked_mnist/default.yaml\ntraining:\n  out_dir: output/stacked_mnist/conditional"
  },
  {
    "path": "configs/stacked_mnist/default.yaml",
    "content": "data:\n  type: stacked_mnist\n  train_dir: data/MNIST\n  img_size: 32\n  nlabels: 1000\ngenerator:\n  name: dcgan_shallow\n  nlabels: 1\n  conditioning: unconditional\n  kwargs:\n    placeholder: None\ndiscriminator:\n  name: dcgan_shallow\n  nlabels: 1\n  conditioning: unconditional\n  kwargs:\n    placeholder: None\nz_dist:\n  type: gauss\n  dim: 128\nclusterer:\n  name: supervised\n  nimgs: 25000\n  kwargs: \n    placeholder: None\ntraining:\n  gan_type: standard\n  reg_type: none\n  reg_param: 0.\n  take_model_average: false\n  sample_nlabels: 20\n  log_every: 1000\n  backup_every: 5000\n  inception_every: 10000\n  batch_size: 64"
  },
  {
    "path": "configs/stacked_mnist/selfcondgan.yaml",
    "content": "generator:\n  nlabels: 100\n  conditioning: embedding\ndiscriminator:\n  nlabels: 100\n  conditioning: mask\nclusterer:\n  name: selfcondgan\n  kwargs: \n    k_value: 100\ninherit_from: configs/stacked_mnist/default.yaml\ntraining:\n  out_dir: output/stacked_mnist/selfcondgan\n  recluster_every: 25000"
  },
  {
    "path": "configs/stacked_mnist/unconditional.yaml",
    "content": "inherit_from: configs/stacked_mnist/default.yaml\ntraining:\n  out_dir: output/stacked_mnist/unconditional"
  },
  {
    "path": "gan_training/__init__.py",
    "content": ""
  },
  {
    "path": "gan_training/checkpoints.py",
    "content": "import os, pickle\nimport urllib\nimport torch\nimport numpy as np\nfrom torch.utils import model_zoo\n\n\nclass CheckpointIO(object):\n    ''' CheckpointIO class.\n\n    It handles saving and loading checkpoints.\n\n    Args:\n        checkpoint_dir (str): path where checkpoints are saved\n    '''\n\n    def __init__(self, checkpoint_dir='./chkpts', **kwargs):\n        self.module_dict = kwargs\n        self.checkpoint_dir = checkpoint_dir\n        if not os.path.exists(checkpoint_dir):\n            os.makedirs(checkpoint_dir)\n\n    def register_modules(self, **kwargs):\n        ''' Registers modules in current module dictionary.\n        '''\n        self.module_dict.update(kwargs)\n\n    def save(self, filename, **kwargs):\n        ''' Saves the current module dictionary.\n\n        Args:\n            filename (str): name of output file\n        '''\n        if not os.path.isabs(filename):\n            filename = os.path.join(self.checkpoint_dir, filename)\n\n        outdict = kwargs\n        for k, v in self.module_dict.items():\n            outdict[k] = v.state_dict()\n        torch.save(outdict, filename)\n\n    def load(self, filename, pretrained={}):\n        '''Loads a module dictionary from local file or url.\n\n        Args:\n            filename (str): name of saved module dictionary\n        '''\n        if 'model' in pretrained:\n            filename = pretrained['model']\n        if is_url(filename):\n            return self.load_url(filename)\n        else:\n            return self.load_file(filename)\n\n    def load_file(self, filename):\n        '''Loads a module dictionary from file.\n\n        Args:\n            filename (str): name of saved module dictionary\n        '''\n\n        if not os.path.isabs(filename):\n            filename = os.path.join(self.checkpoint_dir, filename)\n\n        if os.path.exists(filename):\n            print('=> Loading checkpoint from local file...', filename)\n            state_dict = torch.load(filename)\n            scalars = self.parse_state_dict(state_dict)\n            return scalars\n        else:\n            print('File not found', filename)\n            raise FileNotFoundError\n\n    def load_url(self, url):\n        '''Load a module dictionary from url.\n\n        Args:\n            url (str): url to saved model\n        '''\n        print('=> Loading checkpoint from url...', url)\n        state_dict = model_zoo.load_url(url, model_dir=self.checkpoint_dir, progress=True)\n        scalars = self.parse_state_dict(state_dict)\n        return scalars\n\n    def parse_state_dict(self, state_dict):\n        '''Parse state_dict of model and return scalars.\n\n        Args:\n            state_dict (dict): State dict of model\n    '''\n        for k, v in self.module_dict.items():\n            if k in state_dict:\n                v.load_state_dict(state_dict[k])\n            else:\n                print('Warning: Could not find %s in checkpoint!' % k)\n        scalars = {\n            k: v\n            for k, v in state_dict.items() if k not in self.module_dict\n        }\n        return scalars\n\n    def load_clusterer(self, it, load_samples, pretrained={}):\n        if 'clusterer' in pretrained:\n            pretrained_file = os.path.join(self.checkpoint_dir, 'pretrained_clusterer.pkl')\n            if not os.path.exists(pretrained_file):\n                import cloudpickle as cp\n                from urllib.request import urlopen\n                print('Loading pretrained clusterer from', pretrained['clusterer'])\n                clusterer = cp.load(urlopen(pretrained['clusterer'])) \n                print('Saving pretrained clusterer to', pretrained_file)\n                with open(pretrained_file, 'wb') as f:\n                    f.write(pickle.dumps(clusterer))\n            else:\n                with open(pretrained_file, 'rb') as f:\n                    clusterer = pickle.load(f)\n            return clusterer\n        else:\n            print('Loading clusterer:')\n            with open(os.path.join(self.checkpoint_dir, f'clusterer{it}.pkl'), 'rb') as f:\n                clusterer = pickle.load(f)\n\n            if load_samples:\n                print('Loading cluster samples:')\n                with np.load(os.path.join(self.checkpoint_dir, 'cluster_samples.npz')) as f:\n                    x = f['x']\n                clusterer.x = torch.from_numpy(x)\n            return clusterer\n\n    def load_models(self, it, pretrained={}, load_samples=False):\n        try:\n            load_dict = self.load('model_%08d.pt' % it, pretrained)\n            epoch_idx = load_dict.get('epoch_idx', -1)\n        except Exception as e:  #models are not dataparallel modules\n            print('Trying again to load w/o data parallel modules')\n            try:\n                for name, module in self.module_dict.items():\n                    if isinstance(module, torch.nn.DataParallel):\n                        self.module_dict[name] = module.module\n                load_dict = self.load('model_%08d.pt' % it, pretrained)\n                epoch_idx = load_dict.get('epoch_idx', -1)\n            except FileNotFoundError as e:\n                print(e)\n                print(\"Models not found\")\n                it = epoch_idx = -1\n        \n        try:\n            clusterer = self.load_clusterer(it, load_samples, pretrained)\n        except FileNotFoundError as e:\n            clusterer = None\n\n        return it, epoch_idx, clusterer\n    \n    def save_clusterer(self, clusterer, it):\n        with open(os.path.join(self.checkpoint_dir, f'clusterer{it}.pkl'), 'wb') as f:\n            #hack: only save changing data\n            x = clusterer.x\n            clusterer.x = None\n            pickle.dump(clusterer, f)\n            clusterer.x = x\n\ndef is_url(url):\n    scheme = urllib.parse.urlparse(url).scheme\n    return scheme in ('http', 'https')\n"
  },
  {
    "path": "gan_training/config.py",
    "content": "import yaml\nfrom torch import optim\nfrom os import path\nfrom gan_training.models import generator_dict, discriminator_dict\nfrom gan_training.train import toggle_grad\nfrom clusterers import clusterer_dict\n\n\n# General config\ndef load_config(path, default_path):\n    ''' Loads config file.\n\n    Args:  \n        path (str): path to config file\n        default_path (bool): whether to use default path\n    '''\n    # Load configuration from file itself\n    with open(path, 'r') as f:\n        cfg_special = yaml.load(f)\n\n    # Check if we should inherit from a config\n    inherit_from = cfg_special.get('inherit_from')\n\n    # If yes, load this config first as default\n    # If no, use the default_path\n    if inherit_from is not None:\n        cfg = load_config(inherit_from, default_path)\n    elif default_path is not None:\n        with open(default_path, 'r') as f:\n            cfg = yaml.load(f)\n    else:\n        cfg = dict()\n\n    # Include main configuration\n    update_recursive(cfg, cfg_special)\n\n    return cfg\n\n\ndef update_recursive(dict1, dict2):\n    ''' Update two config dictionaries recursively.\n\n    Args:\n        dict1 (dict): first dictionary to be updated\n        dict2 (dict): second dictionary which entries should be used\n\n    '''\n    for k, v in dict2.items():\n        # Add item if not yet in dict1\n        if k not in dict1:\n            dict1[k] = None\n        # Update\n        if isinstance(dict1[k], dict):\n            update_recursive(dict1[k], v)\n        else:\n            dict1[k] = v\n\n\ndef get_clusterer(config):\n    return clusterer_dict[config['clusterer']['name']]\n\n\ndef build_models(config):\n    # Get classes\n    Generator = generator_dict[config['generator']['name']]\n    Discriminator = discriminator_dict[config['discriminator']['name']]\n\n    # Build models\n    generator = Generator(z_dim=config['z_dist']['dim'],\n                          nlabels=config['generator']['nlabels'],\n                          size=config['data']['img_size'],\n                          conditioning=config['generator']['conditioning'],\n                          **config['generator']['kwargs'])\n    discriminator = Discriminator(\n        nlabels=config['discriminator']['nlabels'],\n        conditioning=config['discriminator']['conditioning'],\n        size=config['data']['img_size'],\n        **config['discriminator']['kwargs'])\n\n    return generator, discriminator\n\n\ndef build_optimizers(generator, discriminator, config):\n    optimizer = config['training']['optimizer']\n    lr_g = config['training']['lr_g']\n    lr_d = config['training']['lr_d']\n    \n\n    toggle_grad(generator, True)\n    toggle_grad(discriminator, True)\n\n    g_params = generator.parameters()\n    d_params = discriminator.parameters()\n\n    if optimizer == 'rmsprop':\n        g_optimizer = optim.RMSprop(g_params, lr=lr_g, alpha=0.99, eps=1e-8)\n        d_optimizer = optim.RMSprop(d_params, lr=lr_d, alpha=0.99, eps=1e-8)\n    elif optimizer == 'adam':\n        beta1 = config['training']['beta1']\n        beta2 = config['training']['beta2']\n        g_optimizer = optim.Adam(g_params, lr=lr_g, betas=(beta1, beta2), eps=1e-8)\n        d_optimizer = optim.Adam(d_params, lr=lr_d, betas=(beta1, beta2), eps=1e-8)\n    elif optimizer == 'sgd':\n        g_optimizer = optim.SGD(g_params, lr=lr_g, momentum=0.)\n        d_optimizer = optim.SGD(d_params, lr=lr_d, momentum=0.)\n\n    return g_optimizer, d_optimizer\n\n\n# Some utility functions\ndef get_parameter_groups(parameters, gradient_scales, base_lr):\n    param_groups = []\n    for p in parameters:\n        c = gradient_scales.get(p, 1.)\n        param_groups.append({'params': [p], 'lr': c * base_lr})\n    return param_groups\n"
  },
  {
    "path": "gan_training/distributions.py",
    "content": "import torch\nfrom torch import distributions\n\n\ndef get_zdist(dist_name, dim, device=None):\n    # Get distribution\n    if dist_name == 'uniform':\n        low = -torch.ones(dim, device=device)\n        high = torch.ones(dim, device=device)\n        zdist = distributions.Uniform(low, high)\n    elif dist_name == 'gauss':\n        mu = torch.zeros(dim, device=device)\n        scale = torch.ones(dim, device=device)\n        zdist = distributions.Normal(mu, scale)\n    else:\n        raise NotImplementedError\n\n    # Add dim attribute\n    zdist.dim = dim\n\n    return zdist\n\n\ndef get_ydist(nlabels, device=None):\n    logits = torch.zeros(nlabels, device=device)\n    ydist = distributions.categorical.Categorical(logits=logits)\n\n    # Add nlabels attribute\n    ydist.nlabels = nlabels\n\n    return ydist\n\n\ndef interpolate_sphere(z1, z2, t):\n    p = (z1 * z2).sum(dim=-1, keepdim=True)\n    p = p / z1.pow(2).sum(dim=-1, keepdim=True).sqrt()\n    p = p / z2.pow(2).sum(dim=-1, keepdim=True).sqrt()\n    omega = torch.acos(p)\n    s1 = torch.sin((1-t)*omega)/torch.sin(omega)\n    s2 = torch.sin(t*omega)/torch.sin(omega)\n    z = s1 * z1 + s2 * z2\n\n    return z\n"
  },
  {
    "path": "gan_training/eval.py",
    "content": "import numpy as np\nimport torch\nfrom torch.nn import functional as F\n\nfrom gan_training.metrics import inception_score\n\nclass Evaluator(object):\n    def __init__(self,\n                 generator,\n                 zdist,\n                 ydist,\n                 train_loader,\n                 clusterer,\n                 batch_size=64,\n                 inception_nsamples=10000,\n                 device=None):\n        self.generator = generator\n        self.clusterer = clusterer\n        self.train_loader = train_loader\n        self.zdist = zdist\n        self.ydist = ydist\n        self.inception_nsamples = inception_nsamples\n        self.batch_size = batch_size\n        self.device = device\n\n    def sample_z(self, batch_size):\n        return self.zdist.sample((batch_size, )).to(self.device)\n\n    def get_y(self, x, y):\n        return self.clusterer.get_labels(x, y).to(self.device)\n\n    def get_fake_real_samples(self, N):\n        ''' returns N fake images and N real images in pytorch form'''\n        with torch.no_grad():\n            self.generator.eval()\n            fake_imgs = []\n            real_imgs = []\n            while len(fake_imgs) < N:\n                for x_real, y_gt in self.train_loader:\n                    x_real = x_real.cuda()\n                    z = self.sample_z(x_real.size(0))\n                    y = self.get_y(x_real, y_gt)\n                    samples = self.generator(z, y)\n                    samples = [s.data.cpu() for s in samples]\n                    fake_imgs.extend(samples)\n                    real_batch = [img.data.cpu() for img in x_real]\n                    real_imgs.extend(real_batch)\n                    assert (len(real_imgs) == len(fake_imgs))\n                    if len(fake_imgs) >= N:\n                        fake_imgs = fake_imgs[:N]\n                        real_imgs = real_imgs[:N]\n                        return fake_imgs, real_imgs\n\n    def compute_inception_score(self):\n        imgs, _ = self.get_fake_real_samples(self.inception_nsamples)\n        imgs = [img.numpy() for img in imgs]\n        score, score_std = inception_score(imgs,\n                                           device=self.device,\n                                           resize=True,\n                                           splits=1)\n\n        return score, score_std\n\n    def create_samples(self, z, y=None):\n        self.generator.eval()\n        batch_size = z.size(0)\n        # Parse y\n        if y is None:\n            raise NotImplementedError()\n        elif isinstance(y, int):\n            y = torch.full((batch_size, ),\n                           y,\n                           device=self.device,\n                           dtype=torch.int64)\n        # Sample x\n        with torch.no_grad():\n            x = self.generator(z, y)\n        return x\n\n\n"
  },
  {
    "path": "gan_training/inputs.py",
    "content": "import torch\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\nimport numpy as np\n\nimport os\nimport torch.utils.data as data\nfrom torchvision.datasets.folder import default_loader\nfrom PIL import Image\nimport random\n\nfrom PIL import ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\ndef get_dataset(name,\n                data_dir,\n                size=64,\n                lsun_categories=None,\n                deterministic=False,\n                transform=None):\n                \n    transform = transforms.Compose([\n        t for t in [\n            transforms.Resize(size),\n            transforms.CenterCrop(size),\n            (not deterministic) and transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n            (not deterministic) and\n            transforms.Lambda(lambda x: x + 1. / 128 * torch.rand(x.size())),\n        ] if t is not False\n    ]) if transform == None else transform\n\n    if name == 'image':\n        print('Using image labels')\n        dataset = datasets.ImageFolder(data_dir, transform)\n        nlabels = len(dataset.classes)\n    elif name == 'webp':\n        print('Using no labels from webp')\n        dataset = CachedImageFolder(data_dir, transform)\n        nlabels = len(dataset.classes)\n    elif name == 'npy':\n        # Only support normalization for now\n        dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])\n        nlabels = len(dataset.classes)\n    elif name == 'cifar10':\n        dataset = datasets.CIFAR10(root=data_dir,\n                                   train=True,\n                                   download=True,\n                                   transform=transform)\n        nlabels = 10\n    elif name == 'stacked_mnist':\n        dataset = StackedMNIST(data_dir,\n                               transform=transforms.Compose([\n                                   transforms.Resize(size),\n                                   transforms.CenterCrop(size),\n                                   transforms.ToTensor(),\n                                   transforms.Normalize((0.5, ), (0.5, ))\n                               ]))\n        nlabels = 1000\n    elif name == 'lsun':\n        if lsun_categories is None:\n            lsun_categories = 'train'\n        dataset = datasets.LSUN(data_dir, lsun_categories, transform)\n        nlabels = len(dataset.classes)\n    elif name == 'lsun_class':\n        dataset = datasets.LSUNClass(data_dir,\n                                     transform,\n                                     target_transform=(lambda t: 0))\n        nlabels = 1\n    else:\n        raise NotImplemented\n    return dataset, nlabels\n\nclass CachedImageFolder(data.Dataset):\n    \"\"\"\n    A version of torchvision.dataset.ImageFolder that takes advantage\n    of cached filename lists.\n    photo/park/004234.jpg\n    photo/park/004236.jpg\n    photo/park/004237.jpg\n    \"\"\"\n\n    def __init__(self, root, transform=None, loader=default_loader):\n        classes, class_to_idx = find_classes(root)\n        self.imgs = make_class_dataset(root, class_to_idx)\n        if len(self.imgs) == 0:\n            raise RuntimeError(\"Found 0 images within: %s\" % root)\n        self.root = root\n        self.classes = classes\n        self.class_to_idx = class_to_idx\n        self.transform = transform\n        self.loader = loader\n\n    def __getitem__(self, index):\n        path, classidx = self.imgs[index]\n        source = self.loader(path)\n        if self.transform is not None:\n            source = self.transform(source)\n        return source, classidx\n\n    def __len__(self):\n        return len(self.imgs)\n\nclass StackedMNIST(data.Dataset):\n    def __init__(self, data_dir, transform, batch_size=100000):\n        super().__init__()\n        self.channel1 = datasets.MNIST(data_dir,\n                                       transform=transform,\n                                       train=True,\n                                       download=True)\n        self.channel2 = datasets.MNIST(data_dir,\n                                       transform=transform,\n                                       train=True,\n                                       download=True)\n        self.channel3 = datasets.MNIST(data_dir,\n                                       transform=transform,\n                                       train=True,\n                                       download=True)\n        self.indices = {\n            k: (random.randint(0,\n                               len(self.channel1) - 1),\n                random.randint(0,\n                               len(self.channel1) - 1),\n                random.randint(0,\n                               len(self.channel1) - 1))\n            for k in range(batch_size)\n        }\n\n    def __getitem__(self, index):\n        index1, index2, index3 = self.indices[index]\n        x1, y1 = self.channel1[index1]\n        x2, y2 = self.channel2[index2]\n        x3, y3 = self.channel3[index3]\n        return torch.cat([x1, x2, x3], dim=0), y1 * 100 + y2 * 10 + y3\n\n    def __len__(self):\n        return len(self.indices)\n        \n\ndef is_npy_file(path):\n    return path.endswith('.npy') or path.endswith('.NPY')\n\n\ndef walk_image_files(rootdir):\n    print(rootdir)\n    if os.path.isfile('%s.txt' % rootdir):\n        print('Loading file list from %s.txt instead of scanning dir' %\n              rootdir)\n        basedir = os.path.dirname(rootdir)\n        with open('%s.txt' % rootdir) as f:\n            result = sorted([\n                os.path.join(basedir, line.strip()) for line in f.readlines()\n            ])\n            import random\n            random.Random(1).shuffle(result)\n            return result\n    result = []\n\n    IMG_EXTENSIONS = [\n        '.jpg',\n        '.JPG',\n        '.jpeg',\n        '.JPEG',\n        '.png',\n        '.PNG',\n        '.ppm',\n        '.PPM',\n        '.bmp',\n        '.BMP',\n    ]\n\n    for dirname, _, fnames in sorted(os.walk(rootdir)):\n        for fname in sorted(fnames):\n            if any(fname.endswith(extension)\n                   for extension in IMG_EXTENSIONS) or is_npy_file(fname):\n                result.append(os.path.join(dirname, fname))\n    return result\n\n\ndef find_classes(dir):\n    classes = [\n        d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))\n    ]\n    classes.sort()\n    class_to_idx = {classes[i]: i for i in range(len(classes))}\n    return classes, class_to_idx\n\n\ndef make_class_dataset(source_root, class_to_idx):\n    \"\"\"\n    Returns (source, classnum, feature)\n    \"\"\"\n    imagepairs = []\n    source_root = os.path.expanduser(source_root)\n    for path in walk_image_files(source_root):\n        classname = os.path.basename(os.path.dirname(path))\n        imagepairs.append((path, 0))\n    return imagepairs\n\n\ndef npy_loader(path):\n    img = np.load(path)\n\n    if img.dtype == np.uint8:\n        img = img.astype(np.float32)\n        img = img / 127.5 - 1.\n    elif img.dtype == np.float32:\n        img = img * 2 - 1.\n    else:\n        raise NotImplementedError\n\n    img = torch.Tensor(img)\n    if len(img.size()) == 4:\n        img.squeeze_(0)\n\n    return img\n"
  },
  {
    "path": "gan_training/logger.py",
    "content": "import pickle\nimport os\nimport torchvision\nimport copy\n\n\nclass Logger(object):\n    def __init__(self,\n                 log_dir='./logs',\n                 img_dir='./imgs',\n                 monitoring=None,\n                 monitoring_dir=None):\n        self.stats = dict()\n        self.log_dir = log_dir\n        self.img_dir = img_dir\n\n        if not os.path.exists(log_dir):\n            os.makedirs(log_dir)\n\n        if not os.path.exists(img_dir):\n            os.makedirs(img_dir)\n\n        if not (monitoring is None or monitoring == 'none'):\n            self.setup_monitoring(monitoring, monitoring_dir)\n        else:\n            self.monitoring = None\n            self.monitoring_dir = None\n\n    def setup_monitoring(self, monitoring, monitoring_dir=None):\n        self.monitoring = monitoring\n        self.monitoring_dir = monitoring_dir\n\n        if monitoring == 'telemetry':\n            import telemetry\n            self.tm = telemetry.ApplicationTelemetry()\n            if self.tm.get_status() == 0:\n                print('Telemetry successfully connected.')\n        elif monitoring == 'tensorboard':\n            import tensorboardX\n            self.tb = tensorboardX.SummaryWriter(monitoring_dir)\n        else:\n            raise NotImplementedError('Monitoring tool \"%s\" not supported!' %\n                                      monitoring)\n\n    def add(self, category, k, v, it):\n        if category not in self.stats:\n            self.stats[category] = {}\n\n        if k not in self.stats[category]:\n            self.stats[category][k] = []\n\n        self.stats[category][k].append((it, v))\n\n        k_name = '%s/%s' % (category, k)\n        if self.monitoring == 'telemetry':\n            self.tm.metric_push_async({'metric': k_name, 'value': v, 'it': it})\n        elif self.monitoring == 'tensorboard':\n            self.tb.add_scalar(k_name, v, it)\n\n    def add_imgs(self, imgs, class_name, it):\n        outdir = os.path.join(self.img_dir, class_name)\n        if not os.path.exists(outdir):\n            os.makedirs(outdir)\n        outfile = os.path.join(outdir, '%08d.png' % it)\n\n        imgs = imgs / 2 + 0.5\n        imgs = torchvision.utils.make_grid(imgs)\n        torchvision.utils.save_image(copy.deepcopy(imgs), outfile, nrow=8)\n\n        if self.monitoring == 'tensorboard':\n            self.tb.add_image(class_name, copy.deepcopy(imgs), it)\n\n    def get_last(self, category, k, default=0.):\n        if category not in self.stats:\n            return default\n        elif k not in self.stats[category]:\n            return default\n        else:\n            return self.stats[category][k][-1][1]\n\n    def save_stats(self, filename):\n        filename = os.path.join(self.log_dir, filename)\n        with open(filename, 'wb') as f:\n            pickle.dump(self.stats, f)\n\n    def load_stats(self, filename):\n        filename = os.path.join(self.log_dir, filename)\n        if not os.path.exists(filename):\n            print('Warning: file \"%s\" does not exist!' % filename)\n            return\n\n        try:\n            with open(filename, 'rb') as f:\n                self.stats = pickle.load(f)\n        except EOFError:\n            print('Warning: log file corrupted!')\n"
  },
  {
    "path": "gan_training/metrics/__init__.py",
    "content": "from gan_training.metrics.inception_score import inception_score\n\n__all__ = [\n    inception_score\n]\n"
  },
  {
    "path": "gan_training/metrics/clustering_metrics.py",
    "content": "def warn(*args, **kwargs):\n    pass\n\n\nimport warnings\nwarnings.warn = warn\n\nfrom sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score, homogeneity_score\nfrom sklearn import metrics\n\nimport numpy as np\n\n\ndef nmi(inferred, gt):\n    return normalized_mutual_info_score(inferred, gt)\n\n\ndef acc(inferred, gt):\n    gt = gt.astype(np.int64)\n    assert inferred.size == gt.size\n    D = max(inferred.max(), gt.max()) + 1\n    w = np.zeros((D, D), dtype=np.int64)\n    for i in range(inferred.size):\n        w[inferred[i], gt[i]] += 1\n    from sklearn.utils.linear_assignment_ import linear_assignment\n    ind = linear_assignment(w.max() - w)\n    return sum([w[i, j] for i, j in ind]) * 1.0 / inferred.size\n\n\ndef purity_score(y_true, y_pred):\n    contingency_matrix = metrics.cluster.contingency_matrix(y_true, y_pred)\n    return np.sum(np.amax(contingency_matrix,\n                          axis=0)) / np.sum(contingency_matrix)\n\n\ndef ari(inferred, gt):\n    return adjusted_rand_score(gt, inferred)\n\n\ndef homogeneity(inferred, gt):\n    return homogeneity_score(gt, inferred)\n"
  },
  {
    "path": "gan_training/metrics/fid.py",
    "content": "from __future__ import absolute_import, division, print_function\nimport numpy as np\nimport os\nos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\nimport tensorflow as tf\nfrom scipy import linalg\nimport pathlib\nimport urllib\nfrom tqdm import tqdm\nimport warnings\n\n\ndef check_or_download_inception(inception_path):\n    ''' Checks if the path to the inception file is valid, or downloads\n        the file if it is not present. '''\n    INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'\n    if inception_path is None:\n        inception_path = '/tmp'\n    inception_path = pathlib.Path(inception_path)\n    model_file = inception_path / 'classify_image_graph_def.pb'\n    if not model_file.exists():\n        print(\"Downloading Inception model\")\n        from urllib import request\n        import tarfile\n        fn, _ = request.urlretrieve(INCEPTION_URL)\n        with tarfile.open(fn, mode='r') as f:\n            f.extract('classify_image_graph_def.pb', str(model_file.parent))\n    return str(model_file)\n\n\ndef create_inception_graph(pth):\n    \"\"\"Creates a graph from saved GraphDef file.\"\"\"\n    # Creates graph from saved graph_def.pb.\n    with tf.io.gfile.GFile(pth, 'rb') as f:\n        graph_def = tf.compat.v1.GraphDef()\n        graph_def.ParseFromString(f.read())\n        _ = tf.import_graph_def(graph_def, name='FID_Inception_Net')\n\n\ndef calculate_activation_statistics(images,\n                                    sess,\n                                    batch_size=200,\n                                    verbose=False):\n    \"\"\"Calculation of the statistics used by the FID.\n    Params:\n    -- images      : Numpy array of dimension (n_images, hi, wi, 3). The values\n                     must lie between 0 and 255.\n    -- sess        : current session\n    -- batch_size  : the images numpy array is split into batches with batch size\n                     batch_size. A reasonable batch size depends on the available hardware.\n    -- verbose     : If set to True and parameter out_step is given, the number of calculated\n                     batches is reported.\n    Returns:\n    -- mu    : The mean over samples of the activations of the pool_3 layer of\n               the incption model.\n    -- sigma : The covariance matrix of the activations of the pool_3 layer of\n               the incption model.\n    \"\"\"\n    act = get_activations(images, sess, batch_size, verbose)\n    mu = np.mean(act, axis=0)\n    sigma = np.cov(act, rowvar=False)\n    return mu, sigma\n\n\n# code for handling inception net derived from\n#   https://github.com/openai/improved-gan/blob/master/inception_score/model.py\ndef _get_inception_layer(sess):\n    \"\"\"Prepares inception net for batched usage and returns pool_3 layer. \"\"\"\n    layername = 'FID_Inception_Net/pool_3:0'\n    pool3 = sess.graph.get_tensor_by_name(layername)\n    ops = pool3.graph.get_operations()\n    for op_idx, op in enumerate(ops):\n        for o in op.outputs:\n            shape = o.get_shape()\n            if shape._dims != []:\n                shape = [s.value for s in shape]\n                new_shape = []\n                for j, s in enumerate(shape):\n                    if s == 1 and j == 0:\n                        new_shape.append(None)\n                    else:\n                        new_shape.append(s)\n                o.__dict__['_shape_val'] = tf.TensorShape(new_shape)\n    return pool3\n\n\n#-------------------------------------------------------------------------------\n\n\ndef get_activations(images, sess, batch_size=200, verbose=False):\n    \"\"\"Calculates the activations of the pool_3 layer for all images.\n    Params:\n    -- images      : Numpy array of dimension (n_images, hi, wi, 3). The values\n                     must lie between 0 and 256.\n    -- sess        : current session\n    -- batch_size  : the images numpy array is split into batches with batch size\n                     batch_size. A reasonable batch size depends on the disposable hardware.\n    -- verbose    : If set to True and parameter out_step is given, the number of calculated\n                     batches is reported.\n    Returns:\n    -- A numpy array of dimension (num images, 2048) that contains the\n       activations of the given tensor when feeding inception with the query tensor.\n    \"\"\"\n    inception_layer = _get_inception_layer(sess)\n    n_images = images.shape[0]\n    if batch_size > n_images:\n        print(\n            \"warning: batch size is bigger than the data size. setting batch size to data size\"\n        )\n        batch_size = n_images\n    n_batches = n_images // batch_size\n    pred_arr = np.empty((n_images, 2048))\n    for i in tqdm(range(n_batches)):\n        if verbose:\n            print(\"\\rPropagating batch %d/%d\" % (i + 1, n_batches),\n                  end=\"\",\n                  flush=True)\n        start = i * batch_size\n\n        if start + batch_size < n_images:\n            end = start + batch_size\n        else:\n            end = n_images\n\n        batch = images[start:end]\n        pred = sess.run(inception_layer,\n                        {'FID_Inception_Net/ExpandDims:0': batch})\n        pred_arr[start:end] = pred.reshape(batch_size, -1)\n    if verbose:\n        print(\" done\")\n    return pred_arr\n\n\n#-------------------------------------------------------------------------------\n\n\ndef calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):\n    \"\"\"Numpy implementation of the Frechet Distance.\n    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)\n    and X_2 ~ N(mu_2, C_2) is\n            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).\n            \n    Stable version by Dougal J. Sutherland.\n    Params:\n    -- mu1 : Numpy array containing the activations of the pool_3 layer of the\n             inception net ( like returned by the function 'get_predictions')\n             for generated samples.\n    -- mu2   : The sample mean over activations of the pool_3 layer, precalcualted\n               on an representive data set.\n    -- sigma1: The covariance matrix over activations of the pool_3 layer for\n               generated samples.\n    -- sigma2: The covariance matrix over activations of the pool_3 layer,\n               precalcualted on an representive data set.\n    Returns:\n    --   : The Frechet Distance.\n    \"\"\"\n\n    mu1 = np.atleast_1d(mu1)\n    mu2 = np.atleast_1d(mu2)\n\n    sigma1 = np.atleast_2d(sigma1)\n    sigma2 = np.atleast_2d(sigma2)\n\n    assert mu1.shape == mu2.shape, \"Training and test mean vectors have different lengths\"\n    assert sigma1.shape == sigma2.shape, \"Training and test covariances have different dimensions\"\n\n    diff = mu1 - mu2\n\n    # product might be almost singular\n    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n    if not np.isfinite(covmean).all():\n        msg = \"fid calculation produces singular product; adding %s to diagonal of cov estimates\" % eps\n        warnings.warn(msg)\n        offset = np.eye(sigma1.shape[0]) * eps\n        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n\n    # numerical error might give slight imaginary component\n    if np.iscomplexobj(covmean):\n        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n            m = np.max(np.abs(covmean.imag))\n            raise ValueError(\"Imaginary component {}\".format(m))\n        covmean = covmean.real\n\n    tr_covmean = np.trace(covmean)\n\n    return diff.dot(diff) + np.trace(sigma1) + np.trace(\n        sigma2) - 2 * tr_covmean\n\n\ndef compute_fid_from_npz(path):\n    print(path)\n    with np.load(path) as data:\n        fake_imgs = data['fake']\n\n        name = None\n        for name in ['imagenet', 'cifar', 'places']:\n            if name in path: \n                real_imgs = name\n                break\n        print('Inferred name', name)\n        if name is None:\n            real_imgs = data['real']\n            \n        if fake_imgs.shape[0] < 1000: return 0\n\n    inception_path = check_or_download_inception(None)\n    create_inception_graph(inception_path)\n    with tf.Session() as sess:\n        sess.run(tf.global_variables_initializer())\n        m1, s1 = calculate_activation_statistics(fake_imgs, sess)\n        if isinstance(real_imgs, str):\n            print(f'using cached image stats for {real_imgs}')\n            with np.load(precomputed_stats[real_imgs]) as data:\n                m2, s2 = data['m'], data['s']\n        else:\n            print('computing real images stats from scratch')\n            m2, s2 = calculate_activation_statistics(real_imgs, sess)\n\n    return calculate_frechet_distance(m1, s1, m2, s2)\n\nprecomputed_stats = {\n    'places':\n    'output/places_gt_stats.npz',\n    'imagenet':\n    'output/imagenet_gt_stats.npz',\n    'cifar':\n    'output/cifar_gt_stats.npz'\n}\n\n\ndef compute_fid_from_imgs(fake_imgs, real_imgs):\n    inception_path = check_or_download_inception(None)\n    create_inception_graph(inception_path)\n    with tf.Session() as sess:\n        sess.run(tf.global_variables_initializer())\n        m1, s1 = calculate_activation_statistics(fake_imgs, sess)\n        if isinstance(real_imgs, str):\n            with np.load(precomputed_stats[real_imgs]) as data:\n                m2, s2 = data['m'], data['s']\n        else:\n            m2, s2 = calculate_activation_statistics(real_imgs, sess)\n    return calculate_frechet_distance(m1, s1, m2, s2)\n\ndef compute_stats(exp_path):\n    #TODO: a bit hacky\n    if 'places' in exp_path and not os.path.exists(precomputed_stats['places']):\n        with np.load('output/places_gt_imgs.npz') as data_real:\n            real_imgs = data_real['real']\n            print('loaded real places images', real_imgs.shape)\n        inception_path = check_or_download_inception(None)\n        create_inception_graph(inception_path)\n        with tf.Session() as sess:\n            sess.run(tf.global_variables_initializer())\n            m, s = calculate_activation_statistics(real_imgs, sess)\n        np.savez(precomputed_stats['places'], m=m, s=s)\n    \n    if 'imagenet' in exp_path and not os.path.exists(precomputed_stats['imagenet']):\n        with np.load('output/imagenet_gt_imgs.npz') as data_real:\n            real_imgs = data_real['real']\n            print('loaded real imagenet images', real_imgs.shape)\n        inception_path = check_or_download_inception(None)\n        create_inception_graph(inception_path)\n        with tf.Session() as sess:\n            sess.run(tf.global_variables_initializer())\n            m, s = calculate_activation_statistics(real_imgs, sess)\n        np.savez(precomputed_stats['imagenet'], m=m, s=s)\n\n    if 'cifar' in exp_path and not os.path.exists(precomputed_stats['cifar']):\n        with np.load('output/cifar_gt_imgs.npz') as data_real:\n            real_imgs = data_real['real']\n            print('loaded real cifar images', real_imgs.shape)\n        inception_path = check_or_download_inception(None)\n        create_inception_graph(inception_path)\n        with tf.Session() as sess:\n            sess.run(tf.global_variables_initializer())\n            m, s = calculate_activation_statistics(real_imgs, sess)\n        np.savez(precomputed_stats['cifar'], m=m, s=s)\n\nif __name__ == '__main__':\n    import argparse\n    import json\n\n    parser = argparse.ArgumentParser('compute TF FID')\n    parser.add_argument('--samples', help='path to samples')\n    parser.add_argument('--it', type=str, help='path to samples')\n    parser.add_argument('--results_dir', help='path to results_dir')\n    args = parser.parse_args()\n    \n    it = args.it\n    results_dir = args.results_dir\n\n    compute_stats(args.samples)\n    mean = compute_fid_from_npz(args.samples)\n    print(f'FID: {mean}')\n    \n    if args.results_dir is not None:\n        with open(os.path.join(args.results_dir, 'fid_results.json')) as f:\n            fid_results = json.load(f)\n\n        fid_results[it] = mean\n        print(f'{results_dir} iteration {it} FID: {mean}')\n        \n        with open(os.path.join(args.results_dir, 'fid_results.json'), 'w') as f:\n            f.write(json.dumps(fid_results))"
  },
  {
    "path": "gan_training/metrics/inception_score.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport torch.utils.data\n\nfrom torchvision.models.inception import inception_v3\n\nimport numpy as np\nfrom scipy.stats import entropy\n\n\ndef inception_score(imgs, device=None, batch_size=32, resize=False, splits=1):\n    \"\"\"Computes the inception score of the generated images imgs\n\n    Args:\n        imgs: Torch dataset of (3xHxW) numpy images normalized in the\n              range [-1, 1]\n        cuda: whether or not to run on GPU\n        batch_size: batch size for feeding into Inception v3\n        splits: number of splits\n    \"\"\"\n    N = len(imgs)\n\n    assert batch_size > 0\n    assert N > batch_size\n\n    # Set up dataloader\n    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)\n\n    # Load inception model\n    inception_model = inception_v3(pretrained=True, transform_input=False)\n    inception_model = inception_model.to(device)\n    inception_model.eval()\n    up = nn.Upsample(size=(299, 299), mode='bilinear').to(device)\n\n    def get_pred(x):\n        with torch.no_grad():\n            if resize:\n                x = up(x)\n            x = inception_model(x)\n            out = F.softmax(x, dim=-1)\n        out = out.cpu().numpy()\n        return out\n\n    # Get predictions\n    preds = np.zeros((N, 1000))\n\n    for i, batch in enumerate(dataloader, 0):\n        batchv = batch.to(device)\n        batch_size_i = batch.size()[0]\n\n        preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)\n\n    # Now compute the mean kl-div\n    split_scores = []\n\n    for k in range(splits):\n        part = preds[k * (N // splits):(k + 1) * (N // splits), :]\n        py = np.mean(part, axis=0)\n        scores = []\n        for i in range(part.shape[0]):\n            pyx = part[i, :]\n            scores.append(entropy(pyx, py))\n        split_scores.append(np.exp(np.mean(scores)))\n\n    return np.mean(split_scores), np.std(split_scores)\n"
  },
  {
    "path": "gan_training/metrics/tf_is/LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "gan_training/metrics/tf_is/README.md",
    "content": "Inception Score\n=====================================\n\nA new Tensorflow implementation of the \"Inception Score\" (IS) for the evaluation of generative models, with a bug raised in [https://github.com/openai/improved-gan/issues/29](https://github.com/openai/improved-gan/issues/29) fixed. \n\n## Major Dependency\n- `tensorflow >= 1.14`\n\n## Features\n- Fast, easy-to-use and memory-efficient, written in a way that is similar to the original implementation\n- No prior knowledge about Tensorflow is necessary if your are using CPU or GPU\n- Makes use of [TF-GAN](https://github.com/tensorflow/gan)\n- Downloads InceptionV1 automatically\n- Compatible with both Python 2 and Python 3\n\n## Usage\n- If you are working with GPU, use `inception_score.py`; if you are working with TPU, use `inception_score_tpu.py` and pass a Tensorflow Session and a [TPUStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/TPUStrategy) as additional arguments.\n- Call `get_inception_score(images, splits=10)`, where `images` is a numpy array with values ranging from 0 to 255 and shape in the form `[N, 3, HEIGHT, WIDTH]` where `N`, `HEIGHT` and `WIDTH` can be arbitrary. `dtype` of the images is recommended to be `np.uint8` to save CPU memory.\n- A smaller `BATCH_SIZE` reduces GPU/TPU memory usage, but at the cost of a slight slowdown.\n- If you want to compute a general \"Classifier Score\" with probabilities `preds` from another classifier, call `preds2score(preds, splits=10)`. `preds` can be a numpy array of arbitrary shape `[N, num_classes]`.\n## Links\n- The Inception Score was proposed in the paper [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498)\n- Code for the [Fréchet Inception Distance](https://github.com/tsc2017/Frechet-Inception-Distance)\n"
  },
  {
    "path": "gan_training/metrics/tf_is/inception_score.py",
    "content": "'''\nFrom https://github.com/tsc2017/Inception-Score\nCode derived from https://github.com/openai/improved-gan/blob/master/inception_score/model.py and https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py\n\nUsage:\n    Call get_inception_score(images, splits=10)\nArgs:\n    images: A numpy array with values ranging from 0 to 255 and shape in the form [N, 3, HEIGHT, WIDTH] where N, HEIGHT and WIDTH can be arbitrary. A dtype of np.uint8 is recommended to save CPU memory.\n    splits: The number of splits of the images, default is 10.\nReturns:\n    Mean and standard deviation of the Inception Score across the splits.\n'''\n\nimport os\nos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\nimport tensorflow as tf\nimport functools\nimport numpy as np\nimport time\nfrom tqdm import tqdm\nfrom tensorflow.python.ops import array_ops\ntfgan = tf.contrib.gan\n\nsession=tf.compat.v1.InteractiveSession()\n\n# A smaller BATCH_SIZE reduces GPU memory usage, but at the cost of a slight slowdown\nBATCH_SIZE = 64\nINCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz'\nINCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb'\n\n# Run images through Inception.\ninception_images = tf.compat.v1.placeholder(tf.float32, [None, 3, None, None])\ndef inception_logits(images = inception_images, num_splits = 1):\n    images = tf.transpose(images, [0, 2, 3, 1])\n    size = 299\n    images = tf.compat.v1.image.resize_bilinear(images, [size, size])\n    generated_images_list = array_ops.split(images, num_or_size_splits = num_splits)\n    logits = tf.map_fn(\n        fn = functools.partial(\n             tfgan.eval.run_inception, \n             default_graph_def_fn = functools.partial(\n             tfgan.eval.get_graph_def_from_url_tarball, \n             INCEPTION_URL, \n             INCEPTION_FROZEN_GRAPH, \n             os.path.basename(INCEPTION_URL)), \n             output_tensor = 'logits:0'),\n        elems = array_ops.stack(generated_images_list),\n        parallel_iterations = 8,\n        back_prop = False,\n        swap_memory = True,\n        name = 'RunClassifier')\n    logits = array_ops.concat(array_ops.unstack(logits), 0)\n    return logits\n\nlogits=inception_logits()\n\ndef get_inception_probs(inps):\n    n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE))\n    preds = np.zeros([inps.shape[0], 1000], dtype = np.float32)\n    for i in tqdm(range(n_batches)):\n        inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] / 255. * 2 - 1\n        preds[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(logits,{inception_images: inp})[:, :1000]\n    preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True)\n    return preds\n\ndef preds2score(preds, splits=10):\n    scores = []\n    for i in range(splits):\n        part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]\n        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))\n        kl = np.mean(np.sum(kl, 1))\n        scores.append(np.exp(kl))\n    return np.mean(scores), np.std(scores)\n\ndef get_inception_score(images, splits=10):\n    assert(type(images) == np.ndarray)\n    assert(len(images.shape) == 4)\n    assert(images.shape[1] == 3)\n    assert(np.min(images[0]) >= 0 and np.max(images[0]) > 10), 'Image values should be in the range [0, 255]'\n    print('Calculating Inception Score with %i images in %i splits' % (images.shape[0], splits))\n    start_time=time.time()\n    preds = get_inception_probs(images)\n    mean, std = preds2score(preds, splits)\n    print('Inception Score calculation time: %f s' % (time.time() - start_time))\n    return mean, std  # Reference values: 11.38 for 50000 CIFAR-10 training set images, or mean=11.31, std=0.10 if in 10 splits.\n\ndef compute_is_from_npz(path):\n    with np.load(path) as data:\n        fake_imgs = data['fake']\n    fake_imgs = fake_imgs.transpose(0, 3, 1, 2)\n    print(fake_imgs.shape)\n    return get_inception_score(fake_imgs)\n\n\nif __name__ == '__main__':\n    import argparse\n    import json\n\n    parser = argparse.ArgumentParser('compute TF IS')\n    parser.add_argument('--samples', help='path to samples')\n    parser.add_argument('--it', type=str, help='path to samples')\n    parser.add_argument('--results_dir', help='path to results_dir')\n    args = parser.parse_args()\n\n    it = args.it\n    results_dir = args.results_dir\n    mean, std = compute_is_from_npz(args.samples)\n\n    with open(os.path.join(args.results_dir, 'is_results.json')) as f:\n        is_results = json.load(f)\n\n    is_results[it] = float(mean)\n    print(f'{results_dir} iteration {it} IS: {mean}')\n\n    with open(os.path.join(args.results_dir, 'is_results.json'), 'w') as f:\n        f.write(json.dumps(is_results))"
  },
  {
    "path": "gan_training/models/__init__.py",
    "content": "from gan_training.models import (dcgan_deep, dcgan_shallow, resnet2)\n\ngenerator_dict = {\n    'resnet2': resnet2.Generator,\n    'dcgan_deep': dcgan_deep.Generator,\n    'dcgan_shallow': dcgan_shallow.Generator\n}\n\ndiscriminator_dict = {\n    'resnet2': resnet2.Discriminator,\n    'dcgan_deep': dcgan_deep.Discriminator,\n    'dcgan_shallow': dcgan_shallow.Discriminator\n}\n"
  },
  {
    "path": "gan_training/models/blocks.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.autograd import Variable\nfrom torch.nn import functional as F\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(self,\n                 fin,\n                 fout,\n                 bn,\n                 nclasses,\n                 fhidden=None,\n                 is_bias=True):\n        super().__init__()\n        # Attributes\n        self.is_bias = is_bias\n        self.learned_shortcut = (fin != fout)\n        self.fin = fin\n        self.fout = fout\n        if fhidden is None:\n            self.fhidden = min(fin, fout)\n        else:\n            self.fhidden = fhidden\n        # Submodules\n        self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)\n        self.conv_1 = nn.Conv2d(self.fhidden,\n                      self.fout,\n                      3,\n                      stride=1,\n                      padding=1,\n                      bias=is_bias)\n        if self.learned_shortcut:\n            self.conv_s = nn.Conv2d(self.fin,\n                          self.fout,\n                          1,\n                          stride=1,\n                          padding=0,\n                          bias=False)\n        self.bn0 = bn(self.fin, nclasses)\n        self.bn1 = bn(self.fhidden, nclasses)\n\n    def forward(self, x, y):\n        x_s = self._shortcut(x)\n        dx = self.conv_0(actvn(self.bn0(x, y)))\n        dx = self.conv_1(actvn(self.bn1(dx, y)))\n        out = x_s + 0.1 * dx\n\n        return out\n\n    def _shortcut(self, x):\n        if self.learned_shortcut:\n            x_s = self.conv_s(x)\n        else:\n            x_s = x\n        return x_s\n\n\ndef actvn(x):\n    out = F.leaky_relu(x, 2e-1)\n    return out\n\n\nclass LatentEmbeddingConcat(nn.Module):\n    ''' projects class embedding onto hypersphere and returns the concat of the latent and the class embedding '''\n\n    def __init__(self, nlabels, embed_dim):\n        super().__init__()\n        self.embedding = nn.Embedding(nlabels, embed_dim)\n\n    def forward(self, z, y):\n        assert (y.size(0) == z.size(0))\n        yembed = self.embedding(y)\n        yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True)\n        yz = torch.cat([z, yembed], dim=1)\n        return yz\n\n\nclass NormalizeLinear(nn.Module):\n    def __init__(self, act_dim, k_value):\n        super().__init__()\n        self.lin = nn.Linear(act_dim, k_value)\n\n    def normalize(self):\n        self.lin.weight.data = F.normalize(self.lin.weight.data, p=2, dim=1)\n\n    def forward(self, x):\n        self.normalize()\n        return self.lin(x)\n\n\nclass Identity(nn.Module):\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n\n    def forward(self, inp, *args, **kwargs):\n        return inp\n\n\nclass LinearConditionalMaskLogits(nn.Module):\n    ''' runs activated logits through fc and masks out the appropriate discriminator score according to class number'''\n\n    def __init__(self, nc, nlabels):\n        super().__init__()\n        self.fc = nn.Linear(nc, nlabels)\n\n    def forward(self, inp, y=None, take_best=False, get_features=False):\n        out = self.fc(inp)\n        if get_features: return out\n\n        if not take_best:\n            y = y.view(-1)\n            index = Variable(torch.LongTensor(range(out.size(0))))\n            if y.is_cuda:\n                index = index.cuda()\n            return out[index, y]\n        else:\n            # high activation means real, so take the highest activations\n            best_logits, _ = out.max(dim=1)\n            return best_logits\n\n\nclass ProjectionDiscriminatorLogits(nn.Module):\n    ''' takes in activated flattened logits before last linear layer and implements https://arxiv.org/pdf/1802.05637.pdf '''\n\n    def __init__(self, nc, nlabels):\n        super().__init__()\n        self.fc = nn.Linear(nc, 1)\n        self.embedding = nn.Embedding(nlabels, nc)\n        self.nlabels = nlabels\n\n    def forward(self, x, y, take_best=False):\n        output = self.fc(x)\n\n        if not take_best:\n            label_info = torch.sum(self.embedding(y) * x, dim=1, keepdim=True)\n            return (output + label_info).view(x.size(0))\n        else:\n            #TODO: this may be computationally expensive, maybe we want to do the global pooling first to reduce x's size\n            index = torch.LongTensor(range(self.nlabels)).cuda()\n            labels = index.repeat((x.size(0), ))\n            x = x.repeat_interleave(self.nlabels, dim=0)\n            label_info = torch.sum(self.embedding(labels) * x,\n                                   dim=1,\n                                   keepdim=True).view(output.size(0),\n                                                      self.nlabels)\n            # high activation means real, so take the highest activations\n            best_logits, _ = label_info.max(dim=1)\n            return output.view(output.size(0)) + best_logits\n\n\nclass LinearUnconditionalLogits(nn.Module):\n    ''' standard discriminator logit layer '''\n\n    def __init__(self, nc):\n        super().__init__()\n        self.fc = nn.Linear(nc, 1)\n\n    def forward(self, inp, y, take_best=False):\n        assert (take_best == False)\n\n        out = self.fc(inp)\n        return out.view(out.size(0))\n\n\nclass Reshape(nn.Module):\n    def __init__(self, *shape):\n        super().__init__()\n        self.shape = shape\n\n    def forward(self, x):\n        batch_size = x.shape[0]\n        return x.view(*((batch_size, ) + self.shape))\n\n\nclass ConditionalBatchNorm2d(nn.Module):\n    ''' from https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775 '''\n\n    def __init__(self, num_features, num_classes):\n        super().__init__()\n        self.num_features = num_features\n        self.bn = nn.BatchNorm2d(num_features, affine=False)\n        self.embed = nn.Embedding(num_classes, num_features * 2)\n        self.embed.weight.data[:, :num_features].normal_(\n            1, 0.02)  # Initialize scale at N(1, 0.02)\n        self.embed.weight.data[:, num_features:].zero_(\n        )  # Initialize bias at 0\n\n    def forward(self, x, y):\n        out = self.bn(x)\n        gamma, beta = self.embed(y).chunk(2, 1)\n        out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(\n            -1, self.num_features, 1, 1)\n        return out\n\n\nclass BatchNorm2d(nn.Module):\n    ''' identical to nn.BatchNorm2d but takes in y input that is ignored '''\n\n    def __init__(self, nc, nchannels, **kwargs):\n        super().__init__()\n        self.bn = nn.BatchNorm2d(nc)\n\n    def forward(self, x, y):\n        return self.bn(x)\n"
  },
  {
    "path": "gan_training/models/dcgan_deep.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport torch.utils.data\nimport torch.utils.data.distributed\nfrom gan_training.models import blocks\n\n\nclass Generator(nn.Module):\n    def __init__(self,\n                 nlabels,\n                 conditioning,\n                 z_dim=128,\n                 nc=3,\n                 ngf=64,\n                 embed_dim=256,\n                 **kwargs):\n        super(Generator, self).__init__()\n\n        assert conditioning != 'unconditional' or nlabels == 1\n\n        if conditioning == 'embedding':\n            self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_dim)\n            self.fc = nn.Linear(z_dim + embed_dim, 4 * 4 * ngf * 8)\n        elif conditioning == 'unconditional':\n            self.get_latent = blocks.Identity()\n            self.fc = nn.Linear(z_dim, 4 * 4 * ngf * 8)\n        else:\n            raise NotImplementedError(\n                f\"{conditioning} not implemented for generator\")\n\n        bn = blocks.BatchNorm2d\n\n        self.nlabels = nlabels\n\n        self.conv1 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1)\n        self.bn1 = bn(ngf * 4, nlabels)\n\n        self.conv2 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1)\n        self.bn2 = bn(ngf * 2, nlabels)\n\n        self.conv3 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1)\n        self.bn3 = bn(ngf, nlabels)\n\n        self.conv_out = nn.Sequential(nn.Conv2d(ngf, nc, 3, 1, 1), nn.Tanh())\n\n    def forward(self, input, y):\n        y = y.clamp(None, self.nlabels - 1)\n        out = self.get_latent(input, y)\n\n        out = self.fc(out)\n        out = out.view(out.size(0), -1, 4, 4)\n        out = F.relu(self.bn1(self.conv1(out), y))\n        out = F.relu(self.bn2(self.conv2(out), y))\n        out = F.relu(self.bn3(self.conv3(out), y))\n        return self.conv_out(out)\n\n\nclass Discriminator(nn.Module):\n    def __init__(self,\n                 nlabels,\n                 conditioning,\n                 nc=3,\n                 ndf=64,\n                 pack_size=1,\n                 features='penultimate',\n                 **kwargs):\n\n        super(Discriminator, self).__init__()\n\n        assert conditioning != 'unconditional' or nlabels == 1\n\n        self.nlabels = nlabels\n\n        self.conv1 = nn.Sequential(nn.Conv2d(nc * pack_size, ndf, 3, 1, 1), nn.LeakyReLU(0.1))\n        self.conv2 = nn.Sequential(nn.Conv2d(ndf, ndf, 4, 2, 1), nn.LeakyReLU(0.1))\n        self.conv3 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, 3, 1, 1), nn.LeakyReLU(0.1))\n        self.conv4 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 2, 4, 2, 1), nn.LeakyReLU(0.1))\n        self.conv5 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, 3, 1, 1), nn.LeakyReLU(0.1))\n        self.conv6 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1), nn.LeakyReLU(0.1))\n        self.conv7 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, 3, 1, 1), nn.LeakyReLU(0.1))\n\n        if conditioning == 'mask':\n            self.fc_out = blocks.LinearConditionalMaskLogits(\n                ndf * 8 * 4 * 4, nlabels)\n        elif conditioning == 'unconditional':\n            self.fc_out = blocks.LinearUnconditionalLogits(\n                ndf * 8 * 4 * 4)\n        else:\n            raise NotImplementedError(\n                f\"{conditioning} not implemented for discriminator\")\n\n        self.features = features\n        self.pack_size = pack_size\n        print(f'Getting features from {self.features}')\n\n    def stack(self, x):\n        #pacgan\n        nc = self.pack_size\n        assert (x.size(0) % nc == 0)\n        if nc == 1:\n            return x\n        x_new = []\n        for i in range(x.size(0) // nc):\n            imgs_to_stack = x[i * nc:(i + 1) * nc]\n            x_new.append(torch.cat([t for t in imgs_to_stack], dim=0))\n        return torch.stack(x_new)\n\n    def forward(self, input, y=None, get_features=False):\n        input = self.stack(input)\n        out = self.conv1(input)\n        out = self.conv2(out)\n        out = self.conv3(out)\n        out = self.conv4(out)\n        out = self.conv5(out)\n        out = self.conv6(out)\n        out = self.conv7(out)\n\n        if get_features and self.features == \"penultimate\":\n            return out.view(out.size(0), -1)\n        if get_features and self.features == \"summed\":\n            return out.view(out.size(0), out.size(1), -1).sum(dim=2)\n            \n        out = out.view(out.size(0), -1)\n        y = y.clamp(None, self.nlabels - 1)\n        result = self.fc_out(out, y)\n        assert (len(result.shape) == 1)\n        return result\n\n\nif __name__ == '__main__':\n    z = torch.zeros((1, 128))\n    g = Generator()\n    x = torch.zeros((1, 3, 32, 32))\n    d = Discriminator()\n\n    g(z)\n    d(g(z))\n    d(x)\n"
  },
  {
    "path": "gan_training/models/dcgan_shallow.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport torch.utils.data\nimport torch.utils.data.distributed\nfrom gan_training.models import blocks\n\n\nclass Generator(nn.Module):\n    def __init__(self,\n                 nlabels,\n                 conditioning,\n                 z_dim=128,\n                 nc=3,\n                 ngf=64,\n                 embed_dim=256,\n                 **kwargs):\n        super(Generator, self).__init__()\n\n        assert conditioning != 'unconditional' or nlabels == 1\n\n        if conditioning == 'embedding':\n            self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_dim)\n            self.fc = nn.Linear(z_dim + embed_dim, 4 * 4 * ngf * 8)\n        elif conditioning == 'unconditional':\n            self.get_latent = blocks.Identity()\n            self.fc = nn.Linear(z_dim, 4 * 4 * ngf * 8)\n        else:\n            raise NotImplementedError(\n                f\"{conditioning} not implemented for generator\")\n\n        bn = blocks.BatchNorm2d\n\n        self.nlabels = nlabels\n\n        self.conv1 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1)\n        self.bn1 = bn(ngf * 4, nlabels)\n\n        self.conv2 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1)\n        self.bn2 = bn(ngf * 2, nlabels)\n\n        self.conv3 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1)\n        self.bn3 = bn(ngf, nlabels)\n\n        self.conv_out = nn.Sequential(nn.Conv2d(ngf, nc, 3, 1, 1), nn.Tanh())\n\n    def forward(self, input, y):\n        y = y.clamp(None, self.nlabels - 1)\n\n        out = self.get_latent(input, y)\n        out = self.fc(out)\n\n        out = out.view(out.size(0), -1, 4, 4)\n        out = F.relu(self.bn1(self.conv1(out), y))\n        out = F.relu(self.bn2(self.conv2(out), y))\n        out = F.relu(self.bn3(self.conv3(out), y))\n        return self.conv_out(out)\n\n\nclass Discriminator(nn.Module):\n    def __init__(self,\n                 nlabels,\n                 conditioning,\n                 features='penultimate',\n                 pack_size=1,\n                 nc=3,\n                 ndf=64,\n                 **kwargs):\n        super(Discriminator, self).__init__()\n\n        assert conditioning != 'unconditional' or nlabels == 1\n\n        self.nlabels = nlabels\n\n        self.conv1 = nn.Sequential(nn.Conv2d(nc * pack_size, ndf, 4, 2, 1),\n                                   nn.BatchNorm2d(ndf),\n                                   nn.LeakyReLU(0.2, inplace=True))\n        self.conv2 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, 4, 2, 1),\n                                   nn.BatchNorm2d(ndf * 2),\n                                   nn.LeakyReLU(0.2, inplace=True))\n        self.conv3 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),\n                                   nn.BatchNorm2d(ndf * 4),\n                                   nn.LeakyReLU(0.2, inplace=True))\n        self.conv4 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),\n                                   nn.BatchNorm2d(ndf * 8),\n                                   nn.LeakyReLU(0.2, inplace=True))\n\n        if conditioning == 'mask':\n            self.fc_out = blocks.LinearConditionalMaskLogits(ndf * 8 * 4 , nlabels)\n        elif conditioning == 'unconditional':\n            self.fc_out = blocks.LinearUnconditionalLogits(ndf * 8 * 4)\n        else:\n            raise NotImplementedError(\n                f\"{conditioning} not implemented for discriminator\")\n\n        self.pack_size = pack_size\n        self.features = features\n        print(f'Getting features from {self.features}')\n\n    def stack(self, x):\n        #pacgan\n        nc = self.pack_size\n        if nc == 1:\n            return x\n        x_new = []\n        for i in range(x.size(0) // nc):\n            imgs_to_stack = x[i * nc:(i + 1) * nc]\n            x_new.append(torch.cat([t for t in imgs_to_stack], dim=0))\n        return torch.stack(x_new)\n\n    def forward(self, input, y=None, get_features=False):\n        input = self.stack(input)\n        out = self.conv1(input)\n        out = self.conv2(out)\n        out = self.conv3(out)\n        out = self.conv4(out)\n        out = out.view(out.size(0), -1)\n        if get_features: return out.view(out.size(0), -1)\n        y = y.clamp(None, self.nlabels - 1)\n        result = self.fc_out(out, y)\n        assert (len(result.shape) == 1)\n        return result\n\n\nif __name__ == '__main__':\n    z = torch.zeros((1, 128))\n    g = Generator()\n    x = torch.zeros((1, 3, 32, 32))\n    d = Discriminator()\n\n    g(z)\n    d(g(z))\n    d(x)\n"
  },
  {
    "path": "gan_training/models/resnet2.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.autograd import Variable\nimport torch.utils.data\nimport torch.utils.data.distributed\n\nfrom gan_training.models import blocks\nfrom gan_training.models.blocks import ResnetBlock\nfrom torch.nn.utils.spectral_norm import spectral_norm\n\n\nclass Generator(nn.Module):\n    def __init__(self,\n                 z_dim,\n                 nlabels,\n                 size,\n                 conditioning,\n                 embed_size=256,\n                 nfilter=64,\n                 **kwargs):\n        super().__init__()\n        s0 = self.s0 = size // 32\n        nf = self.nf = nfilter\n        self.nlabels = nlabels\n        self.z_dim = z_dim\n\n        assert conditioning != 'unconditional' or nlabels == 1\n\n        if conditioning == 'embedding':\n            self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_size)\n            self.fc = nn.Linear(z_dim + embed_size, 16 * nf * s0 * s0)\n        elif conditioning == 'unconditional':\n            self.get_latent = blocks.Identity()\n            self.fc = nn.Linear(z_dim, 16 * nf * s0 * s0)\n        else:\n            raise NotImplementedError(\n                f\"{conditioning} not implemented for generator\")\n\n        #either use conditional batch norm, or use no batch norm\n        bn = blocks.Identity\n\n        self.resnet_0_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)\n        self.resnet_0_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)\n\n        self.resnet_1_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)\n        self.resnet_1_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)\n\n        self.resnet_2_0 = ResnetBlock(16 * nf, 8 * nf, bn, nlabels)\n        self.resnet_2_1 = ResnetBlock(8 * nf, 8 * nf, bn, nlabels)\n\n        self.resnet_3_0 = ResnetBlock(8 * nf, 4 * nf, bn, nlabels)\n        self.resnet_3_1 = ResnetBlock(4 * nf, 4 * nf, bn, nlabels)\n\n        self.resnet_4_0 = ResnetBlock(4 * nf, 2 * nf, bn, nlabels)\n        self.resnet_4_1 = ResnetBlock(2 * nf, 2 * nf, bn, nlabels)\n\n        self.resnet_5_0 = ResnetBlock(2 * nf, 1 * nf, bn, nlabels)\n        self.resnet_5_1 = ResnetBlock(1 * nf, 1 * nf, bn, nlabels)\n\n        self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)\n\n    def forward(self, z, y):\n        y = y.clamp(None, self.nlabels - 1)\n        out = self.get_latent(z, y)\n\n        out = self.fc(out)\n\n        out = out.view(z.size(0), 16 * self.nf, self.s0, self.s0)\n\n        out = self.resnet_0_0(out, y)\n        out = self.resnet_0_1(out, y)\n\n        out = F.interpolate(out, scale_factor=2)\n        out = self.resnet_1_0(out, y)\n        out = self.resnet_1_1(out, y)\n\n        out = F.interpolate(out, scale_factor=2)\n        out = self.resnet_2_0(out, y)\n        out = self.resnet_2_1(out, y)\n\n        out = F.interpolate(out, scale_factor=2)\n        out = self.resnet_3_0(out, y)\n        out = self.resnet_3_1(out, y)\n\n        out = F.interpolate(out, scale_factor=2)\n        out = self.resnet_4_0(out, y)\n        out = self.resnet_4_1(out, y)\n\n        out = F.interpolate(out, scale_factor=2)\n        out = self.resnet_5_0(out, y)\n        out = self.resnet_5_1(out, y)\n\n        out = self.conv_img(actvn(out))\n        out = torch.tanh(out)\n\n        return out\n\n\nclass Discriminator(nn.Module):\n    def __init__(self,\n                 nlabels,\n                 size,\n                 conditioning,\n                 nfilter=64,\n                 features='penultimate',\n                 **kwargs):\n        super().__init__()\n        s0 = self.s0 = size // 32\n        nf = self.nf = nfilter\n        self.nlabels = nlabels\n\n        assert conditioning != 'unconditional' or nlabels == 1\n        bn = blocks.Identity\n\n        self.conv_img = nn.Conv2d(3, 1 * nf, 3, padding=1)\n\n        self.resnet_0_0 = ResnetBlock(1 * nf, 1 * nf, bn, nlabels)\n        self.resnet_0_1 = ResnetBlock(1 * nf, 2 * nf, bn, nlabels)\n\n        self.resnet_1_0 = ResnetBlock(2 * nf, 2 * nf, bn, nlabels)\n        self.resnet_1_1 = ResnetBlock(2 * nf, 4 * nf, bn, nlabels)\n\n        self.resnet_2_0 = ResnetBlock(4 * nf, 4 * nf, bn, nlabels)\n        self.resnet_2_1 = ResnetBlock(4 * nf, 8 * nf, bn, nlabels)\n\n        self.resnet_3_0 = ResnetBlock(8 * nf, 8 * nf, bn, nlabels)\n        self.resnet_3_1 = ResnetBlock(8 * nf, 16 * nf, bn, nlabels)\n\n        self.resnet_4_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)\n        self.resnet_4_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)\n\n        self.resnet_5_0 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)\n        self.resnet_5_1 = ResnetBlock(16 * nf, 16 * nf, bn, nlabels)\n\n        if conditioning == 'mask':\n            self.fc_out = blocks.LinearConditionalMaskLogits(\n                16 * nf * s0 * s0, nlabels)\n        elif conditioning == 'unconditional':\n            self.fc_out = blocks.LinearUnconditionalLogits(16 * nf * s0 * s0)\n        else:\n            raise NotImplementedError(\n                f\"{conditioning} not implemented for discriminator\")\n\n        self.features = features\n\n    def forward(self, x, y=None, get_features=False):\n        batch_size = x.size(0)\n        if y is not None:\n            y = y.clamp(None, self.nlabels - 1)\n\n        out = self.conv_img(x)\n\n        out = self.resnet_0_0(out, y)\n        out = self.resnet_0_1(out, y)\n        out = F.avg_pool2d(out, 3, stride=2, padding=1)\n        out = self.resnet_1_0(out, y)\n        out = self.resnet_1_1(out, y)\n        out = F.avg_pool2d(out, 3, stride=2, padding=1)\n        out = self.resnet_2_0(out, y)\n        out = self.resnet_2_1(out, y)\n        out = F.avg_pool2d(out, 3, stride=2, padding=1)\n        out = self.resnet_3_0(out, y)\n        out = self.resnet_3_1(out, y)\n        out = F.avg_pool2d(out, 3, stride=2, padding=1)\n        out = self.resnet_4_0(out, y)\n        out = self.resnet_4_1(out, y)\n        out = F.avg_pool2d(out, 3, stride=2, padding=1)\n        out = self.resnet_5_0(out, y)\n        out = self.resnet_5_1(out, y)\n        out = actvn(out)\n\n        if get_features and self.features == 'summed':\n            return out.view(out.size(0), out.size(1), -1).sum(dim=2)\n\n        out = out.view(batch_size, 16 * self.nf * self.s0 * self.s0)\n\n        if get_features: return out.view(batch_size, -1)\n        result = self.fc_out(out, y)\n        assert (len(result.shape) == 1)\n        return result\n\n\ndef actvn(x):\n    out = F.leaky_relu(x, 2e-1)\n    return out"
  },
  {
    "path": "gan_training/models/resnet2s.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.autograd import Variable\nimport torch.utils.data\nimport torch.utils.data.distributed\nfrom collections import OrderedDict\n\n\nclass Reshape(nn.Module):\n    def __init__(self, *shape):\n        super().__init__()\n        self.shape = shape\n\n    def forward(self, x):\n        batch_size = x.shape[0]\n        return x.view(*((batch_size, ) + self.shape))\n\n\nclass Generator(nn.Module):\n    '''\n    Perfectly equivalent to resnet2.Generator (can load state dicts\n    from that class), but organizes layers as a sequence for more\n    automatic inversion.\n    '''\n\n    def __init__(self,\n                 z_dim,\n                 nlabels,\n                 size,\n                 embed_size=256,\n                 nfilter=64,\n                 use_class_labels=False,\n                 **kwargs):\n        super().__init__()\n        s0 = self.s0 = size // 32\n        nf = self.nf = nfilter\n        self.z_dim = z_dim\n        self.use_class_labels = use_class_labels\n        # Submodules\n        if use_class_labels:\n            self.condition = ConditionGen(z_dim, nlabels, embed_size)\n            latent_dim = self.condition.latent_dim\n        else:\n            latent_dim = z_dim\n\n        self.layers = nn.Sequential(\n            OrderedDict([('fc', nn.Linear(latent_dim, 16 * nf * s0 * s0)),\n                         ('reshape', Reshape(16 * self.nf, self.s0, self.s0)),\n                         ('resnet_0_0', ResnetBlock(16 * nf, 16 * nf)),\n                         ('resnet_0_1', ResnetBlock(16 * nf, 16 * nf)),\n                         ('upsample_1', nn.Upsample(scale_factor=2)),\n                         ('resnet_1_0', ResnetBlock(16 * nf, 16 * nf)),\n                         ('resnet_1_1', ResnetBlock(16 * nf, 16 * nf)),\n                         ('upsample_2', nn.Upsample(scale_factor=2)),\n                         ('resnet_2_0', ResnetBlock(16 * nf, 8 * nf)),\n                         ('resnet_2_1', ResnetBlock(8 * nf, 8 * nf)),\n                         ('upsample_3', nn.Upsample(scale_factor=2)),\n                         ('resnet_3_0', ResnetBlock(8 * nf, 4 * nf)),\n                         ('resnet_3_1', ResnetBlock(4 * nf, 4 * nf)),\n                         ('upsample_4', nn.Upsample(scale_factor=2)),\n                         ('resnet_4_0', ResnetBlock(4 * nf, 2 * nf)),\n                         ('resnet_4_1', ResnetBlock(2 * nf, 2 * nf)),\n                         ('upsample_5', nn.Upsample(scale_factor=2)),\n                         ('resnet_5_0', ResnetBlock(2 * nf, 1 * nf)),\n                         ('resnet_5_1', ResnetBlock(1 * nf, 1 * nf)),\n                         ('img_relu', nn.LeakyReLU(2e-1)),\n                         ('conv_img', nn.Conv2d(nf, 3, 3, padding=1)),\n                         ('tanh', nn.Tanh())]))\n\n    def forward(self, z, y=None):\n        assert (y is None or z.size(0) == y.size(0))\n        assert (not self.use_class_labels or y is not None)\n        batch_size = z.size(0)\n        if self.use_class_labels:\n            z = self.condition(z, y)\n        return self.layers(z)\n\n    def load_v2_state_dict(self, state_dict):\n        converted = {}\n        for k, v in state_dict.items():\n            if 'module.' in k: k = k.split('module.')[1]\n            if k.startswith('embedding'):\n                k = 'condition.' + k\n            elif k == 'get_latent.embedding.weight':\n                k = 'condition.embedding.weight'\n            else:\n                k = 'layers.' + k\n            converted[k] = v\n        self.load_state_dict(converted)\n\n\nclass ConditionGen(nn.Module):\n    def __init__(self, z_dim, nlabels, embed_size=256):\n        super().__init__()\n        self.embedding = nn.Embedding(nlabels, embed_size)\n        self.latent_dim = z_dim + embed_size\n        self.z_dim = z_dim\n        self.nlabels = nlabels\n        self.embed_size = embed_size\n\n    def forward(self, z, y):\n        assert (z.size(0) == y.size(0))\n        batch_size = z.size(0)\n        if y.dtype is torch.int64:\n            yembed = self.embedding(y)\n        else:\n            yembed = y\n        yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True)\n        return torch.cat([z, yembed], dim=1)\n\n\ndef convert_from_resnet2_generator(gen):\n    nlabels, embed_size = 0, 0\n    use_class_labels = False\n    if hasattr(gen, 'embedding'):\n        # new version does not have gen.use_class_labels..\n        nlabels = gen.embedding.num_embeddings\n        embed_size = gen.embedding.embedding_dim\n        use_class_labels = True\n    if hasattr(gen, 'get_latent'):\n        # new version does not have gen.use_class_labels..\n        nlabels = gen.get_latent.embedding.num_embeddings\n        embed_size = gen.get_latent.embedding.embedding_dim\n        use_class_labels = True\n    size = gen.s0 * 32\n    newgen = Generator(gen.z_dim, nlabels, size, embed_size, gen.nf,\n                       use_class_labels)\n    newgen.load_v2_state_dict(gen.state_dict())\n    return newgen\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(self, fin, fout, fhidden=None, is_bias=True):\n        super().__init__()\n        # Attributes\n        self.is_bias = is_bias\n        self.learned_shortcut = (fin != fout)\n        self.fin = fin\n        self.fout = fout\n        if fhidden is None:\n            self.fhidden = min(fin, fout)\n        else:\n            self.fhidden = fhidden\n\n        # Submodules\n        self.conv_0 = nn.Conv2d(self.fin,\n                                self.fhidden,\n                                kernel_size=3,\n                                stride=1,\n                                padding=1)\n        self.conv_1 = nn.Conv2d(self.fhidden,\n                                self.fout,\n                                kernel_size=3,\n                                stride=1,\n                                padding=1,\n                                bias=is_bias)\n        if self.learned_shortcut:\n            self.conv_s = nn.Conv2d(self.fin,\n                                    self.fout,\n                                    kernel_size=1,\n                                    stride=1,\n                                    padding=0,\n                                    bias=False)\n\n    def forward(self, x):\n        x_s = self._shortcut(x)\n        dx = self.conv_0(actvn(x))\n        dx = self.conv_1(actvn(dx))\n        out = x_s + 0.1 * dx\n\n        return out\n\n    def _shortcut(self, x):\n        if self.learned_shortcut:\n            x_s = self.conv_s(x)\n        else:\n            x_s = x\n        return x_s\n\n\ndef actvn(x):\n    out = F.leaky_relu(x, 2e-1)\n    return out\n\n\n"
  },
  {
    "path": "gan_training/models/resnet3.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.autograd import Variable\nimport torch.utils.data\nimport torch.utils.data.distributed\nfrom collections import OrderedDict\n\nclass Generator(nn.Module):\n    '''\n    Perfectly equivalent to resnet2.Generator (can load state dicts\n    from that class), but organizes layers as a sequence for more\n    automatic inversion.\n    '''\n    def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64,\n            use_class_labels=False, **kwargs):\n        super().__init__()\n        s0 = self.s0 = size // 32\n        nf = self.nf = nfilter\n        self.z_dim = z_dim\n        self.use_class_labels = use_class_labels\n\n        # Submodules\n        if use_class_labels:\n            self.condition = ConditionGen(z_dim, nlabels, embed_size)\n            latent_dim = self.condition.latent_dim\n        else:\n            latent_dim = z_dim\n\n        self.layers = nn.Sequential(OrderedDict([\n            ('fc', nn.Linear(latent_dim, 16*nf*s0*s0)),\n            ('reshape', Reshape(16*self.nf, self.s0, self.s0)),\n            ('resnet_0_0', ResnetBlock(16*nf, 16*nf)),\n            ('resnet_0_1', ResnetBlock(16*nf, 16*nf)),\n            ('upsample_1', nn.Upsample(scale_factor=2)),\n            ('resnet_1_0', ResnetBlock(16*nf, 16*nf)),\n            ('resnet_1_1', ResnetBlock(16*nf, 16*nf)),\n            ('upsample_2', nn.Upsample(scale_factor=2)),\n            ('resnet_2_0', ResnetBlock(16*nf, 8*nf)),\n            ('resnet_2_1', ResnetBlock(8*nf, 8*nf)),\n            ('upsample_3', nn.Upsample(scale_factor=2)),\n            ('resnet_3_0', ResnetBlock(8*nf, 4*nf)),\n            ('resnet_3_1', ResnetBlock(4*nf, 4*nf)),\n            ('upsample_4', nn.Upsample(scale_factor=2)),\n            ('resnet_4_0', ResnetBlock(4*nf, 2*nf)),\n            ('resnet_4_1', ResnetBlock(2*nf, 2*nf)),\n            ('upsample_5', nn.Upsample(scale_factor=2)),\n            ('resnet_5_0', ResnetBlock(2*nf, 1*nf)),\n            ('resnet_5_1', ResnetBlock(1*nf, 1*nf)),\n            ('img_relu', nn.LeakyReLU(2e-1)),\n            ('conv_img', nn.Conv2d(nf, 3, 3, padding=1)),\n            ('tanh', nn.Tanh())\n        ]))\n\n    def forward(self, z, y=None):\n        assert(y is None or z.size(0) == y.size(0))\n        assert(not self.use_class_labels or y is not None)\n        batch_size = z.size(0)\n        if self.use_class_labels:\n            z = self.condition(z, y)\n        return self.layers(z)\n\n    def load_v2_state_dict(self, state_dict):\n        converted = {}\n        for k, v in state_dict.items():\n            if k.startswith('embedding'):\n                k = 'condition.' + k\n            elif k == 'get_latent.embedding.weight':\n                k = 'condition.embedding.weight'\n            else:\n                k = 'layers.' + k\n            converted[k] = v\n        self.load_state_dict(converted)\n\nclass Reshape(nn.Module):\n    def __init__(self, *shape):\n        super().__init__()\n        self.shape = shape\n    def forward(self, x):\n        batch_size = x.shape[0]\n        return x.view(*((batch_size,) + self.shape))\n\nclass ConditionGen(nn.Module):\n    def __init__(self, z_dim, nlabels, embed_size=256):\n        super().__init__()\n        self.embedding = nn.Embedding(nlabels, embed_size)\n        self.latent_dim = z_dim + embed_size\n        self.z_dim = z_dim\n        self.nlabels = nlabels\n        self.embed_size = embed_size\n\n    def forward(self, z, y):\n        assert(z.size(0) == y.size(0))\n        batch_size = z.size(0)\n        if y.dtype is torch.int64:\n            yembed = self.embedding(y)\n        else:\n            yembed = y\n        yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True)\n        return torch.cat([z, yembed], dim=1)\n\ndef convert_from_resnet2_generator(gen):\n    nlabels, embed_size = 0, 0\n    \n    if hasattr(gen, 'get_latent'):\n        # new version does not have gen.use_class_labels..\n        nlabels = gen.get_latent.embedding.num_embeddings\n        embed_size = gen.get_latent.embedding.embedding_dim\n        use_class_labels = True\n    elif gen.use_class_labels:\n        nlabels = gen.embedding.num_embeddings\n        embed_size = gen.embedding.embedding_dim\n        use_class_labels = True\n\n    size = gen.s0 * 32\n    newgen = Generator(gen.z_dim, nlabels, size, embed_size, gen.nf, use_class_labels)\n    newgen.load_v2_state_dict(gen.state_dict())\n    return newgen\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(self, fin, fout, fhidden=None, is_bias=True):\n        super().__init__()\n        # Attributes\n        self.is_bias = is_bias\n        self.learned_shortcut = (fin != fout)\n        self.fin = fin\n        self.fout = fout\n        if fhidden is None:\n            self.fhidden = min(fin, fout)\n        else:\n            self.fhidden = fhidden\n\n        # Submodules\n        self.conv_0 = nn.Conv2d(self.fin, self.fhidden,\n                kernel_size=3, stride=1, padding=1)\n        self.conv_1 = nn.Conv2d(self.fhidden, self.fout,\n                kernel_size=3, stride=1, padding=1, bias=is_bias)\n        if self.learned_shortcut:\n            self.conv_s = nn.Conv2d(self.fin, self.fout,\n                    kernel_size=1, stride=1, padding=0, bias=False)\n\n    def forward(self, x):\n        x_s = self._shortcut(x)\n        dx = self.conv_0(actvn(x))\n        dx = self.conv_1(actvn(dx))\n        out = x_s + 0.1*dx\n\n        return out\n\n    def _shortcut(self, x):\n        if self.learned_shortcut:\n            x_s = self.conv_s(x)\n        else:\n            x_s = x\n        return x_s\n\n\ndef actvn(x):\n    out = F.leaky_relu(x, 2e-1)\n    return out"
  },
  {
    "path": "gan_training/train.py",
    "content": "# coding: utf-8\nimport torch\nfrom torch.nn import functional as F\nimport torch.utils.data\nimport torch.utils.data.distributed\nfrom torch import autograd\nimport numpy as np\n\n\nclass Trainer(object):\n    def __init__(self,\n                 generator,\n                 discriminator,\n                 g_optimizer,\n                 d_optimizer,\n                 gan_type,\n                 reg_type,\n                 reg_param):\n\n        self.generator = generator\n        self.discriminator = discriminator\n        self.g_optimizer = g_optimizer\n        self.d_optimizer = d_optimizer\n        self.gan_type = gan_type\n        self.reg_type = reg_type\n        self.reg_param = reg_param\n\n        print('D reg gamma', self.reg_param)\n\n    def generator_trainstep(self, y, z):\n        assert (y.size(0) == z.size(0))\n        toggle_grad(self.generator, True)\n        toggle_grad(self.discriminator, False)\n\n        self.generator.train()\n        self.discriminator.train()\n        self.g_optimizer.zero_grad()\n\n        x_fake = self.generator(z, y)\n        d_fake = self.discriminator(x_fake, y)\n        gloss = self.compute_loss(d_fake, 1)\n        gloss.backward()\n\n        self.g_optimizer.step()\n\n        return gloss.item()\n\n    def discriminator_trainstep(self, x_real, y, z):\n        toggle_grad(self.generator, False)\n        toggle_grad(self.discriminator, True)\n        self.generator.train()\n        self.discriminator.train()\n        self.d_optimizer.zero_grad()\n\n        # On real data\n        x_real.requires_grad_()\n\n        d_real = self.discriminator(x_real, y)\n        dloss_real = self.compute_loss(d_real, 1)\n\n        if self.reg_type == 'real' or self.reg_type == 'real_fake':\n            dloss_real.backward(retain_graph=True)\n            reg = self.reg_param * compute_grad2(d_real, x_real).mean()\n            reg.backward()\n        else:\n            dloss_real.backward()\n\n        # On fake data\n        with torch.no_grad():\n            x_fake = self.generator(z, y)\n\n        x_fake.requires_grad_()\n        d_fake = self.discriminator(x_fake, y)\n        dloss_fake = self.compute_loss(d_fake, 0)\n\n        if self.reg_type == 'fake' or self.reg_type == 'real_fake':\n            dloss_fake.backward(retain_graph=True)\n            reg = self.reg_param * compute_grad2(d_fake, x_fake).mean()\n            reg.backward()\n        else:\n            dloss_fake.backward()\n\n        if self.reg_type == 'wgangp':\n            reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake, y)\n            reg.backward()\n        elif self.reg_type == 'wgangp0':\n            reg = self.reg_param * self.wgan_gp_reg(\n                x_real, x_fake, y, center=0.)\n            reg.backward()\n\n        self.d_optimizer.step()\n\n        dloss = (dloss_real + dloss_fake)\n        if self.reg_type == 'none':\n            reg = torch.tensor(0.)\n\n        return dloss.item(), reg.item()\n\n    def compute_loss(self, d_out, target):\n        targets = d_out.new_full(size=d_out.size(), fill_value=target)\n\n        if self.gan_type == 'standard':\n            loss = F.binary_cross_entropy_with_logits(d_out, targets)\n        elif self.gan_type == 'wgan':\n            loss = (2 * target - 1) * d_out.mean()\n        else:\n            raise NotImplementedError\n\n        return loss\n\n    def wgan_gp_reg(self, x_real, x_fake, y, center=1.):\n        batch_size = y.size(0)\n        eps = torch.rand(batch_size, device=y.device).view(batch_size, 1, 1, 1)\n        x_interp = (1 - eps) * x_real + eps * x_fake\n        x_interp = x_interp.detach()\n        x_interp.requires_grad_()\n        d_out = self.discriminator(x_interp, y)\n\n        reg = (compute_grad2(d_out, x_interp).sqrt() - center).pow(2).mean()\n\n        return reg\n\n\n# Utility functions\ndef toggle_grad(model, requires_grad):\n    for p in model.parameters():\n        p.requires_grad_(requires_grad)\n\n\ndef compute_grad2(d_out, x_in):\n    batch_size = x_in.size(0)\n    grad_dout = autograd.grad(outputs=d_out.sum(),\n                              inputs=x_in,\n                              create_graph=True,\n                              retain_graph=True,\n                              only_inputs=True)[0]\n    grad_dout2 = grad_dout.pow(2)\n    assert (grad_dout2.size() == x_in.size())\n    reg = grad_dout2.view(batch_size, -1).sum(1)\n    return reg\n\n\ndef update_average(model_tgt, model_src, beta):\n    toggle_grad(model_src, False)\n    toggle_grad(model_tgt, False)\n\n    param_dict_src = dict(model_src.named_parameters())\n\n    for p_name, p_tgt in model_tgt.named_parameters():\n        p_src = param_dict_src[p_name]\n        assert (p_src is not p_tgt)\n        p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src)\n"
  },
  {
    "path": "gan_training/utils.py",
    "content": "import torch\nimport torch.utils.data\nimport torch.utils.data.distributed\nimport torchvision\n\nimport os\n\n\ndef save_images(imgs, outfile, nrow=8):\n    imgs = imgs / 2 + 0.5  # unnormalize\n    torchvision.utils.save_image(imgs, outfile, nrow=nrow)\n\n\ndef get_nsamples(data_loader, N):\n    x = []\n    y = []\n    n = 0\n    for x_next, y_next in data_loader:\n        x.append(x_next)\n        y.append(y_next)\n        n += x_next.size(0)\n        if n > N:\n            break\n    x = torch.cat(x, dim=0)[:N]\n    y = torch.cat(y, dim=0)[:N]\n    return x, y\n\n\ndef update_average(model_tgt, model_src, beta):\n    param_dict_src = dict(model_src.named_parameters())\n\n    for p_name, p_tgt in model_tgt.named_parameters():\n        p_src = param_dict_src[p_name]\n        assert (p_src is not p_tgt)\n        p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src)\n\n\ndef get_most_recent(d, ext):\n    if not os.path.exists(d):\n        print('Directory', d, 'does not exist')\n        return -1 \n    its = []\n    for f in os.listdir(d):\n        try:\n            it = int(f.split(ext + \"_\")[1].split('.pt')[0])\n            its.append(it)\n        except Exception as e:\n            pass\n    if len(its) == 0:\n        print('Found no files with extension \\\"%s\\\" under %s' % (ext, d))\n        return -1\n    return max(its)\n"
  },
  {
    "path": "metrics.py",
    "content": "import argparse\nimport os\nimport json\nfrom tqdm import tqdm\n\nimport numpy as np\nimport torch\n\nfrom gan_training.config import load_config\nfrom seeded_sampler import SeededSampler\n\nparser = argparse.ArgumentParser('Computes numbers used in paper and caches them to a result files. Examples include FID, IS, reverse-KL, # modes, FSD, cluster NMI, Purity.')\nparser.add_argument('paths', nargs='+', type=str, help='list of configs for each experiment')\nparser.add_argument('--it', type=int, default=-1, help='If set, computes numbers only for that iteration')\nparser.add_argument('--every', type=int, default=-1, help='skips some checkpoints and only computes those whose iteration number are divisible by every')\nparser.add_argument('--fid', action='store_true', help='compute FID metric')\nparser.add_argument('--inception', action='store_true', help='compute IS metric')\nparser.add_argument('--modes', action='store_true', help='compute # modes and reverse-KL metric')\nparser.add_argument('--fsd', action='store_true', help='compute FSD metric')\nparser.add_argument('--cluster_metrics', action='store_true', help='compute clustering metrics (NMI, purity)')\nparser.add_argument('--device', type=int, default=1, help='device to run the metrics on (can run into OOM issues if same as main device)')\nargs = parser.parse_args()\n\ndevice = args.device\ndirs = list(args.paths)\n\nN = 50000\nBS = 100\n\ndatasets = ['imagenet', 'cifar', 'stacked_mnist', 'places']\n\ndataset_to_img = {\n    'places': 'output/places_gt_imgs.npz',\n    'imagenet': 'output/imagenet_gt_imgs.npz'}\n\n\ndef load_results(results_dir):\n    results = []\n    for results_file in ['fid_results.json', 'is_results.json', 'kl_results.json', 'nmodes_results.json', 'fsd_results.json', 'cluster_metrics.json']:\n        results_file = os.path.join(results_dir, results_file)\n        if not os.path.exists(results_file):\n            with open(results_file, 'w') as f:\n                f.write(json.dumps({}))\n        with open(results_file) as f:\n            results.append(json.load(f))\n    return results\n\n\ndef get_dataset_from_path(path):\n    for name in datasets:\n        if name in path:\n            print('Inferred dataset:', name)\n            return name\n\n\ndef pt_to_np(imgs):\n    '''normalizes pytorch image in [-1, 1] to [0, 255]'''\n    return (imgs.permute(0, 2, 3, 1).mul_(0.5).add_(0.5).mul_(255)).clamp_(0, 255).numpy()\n\n\ndef sample(sampler):\n    with torch.no_grad():\n        samples = []\n        for _ in tqdm(range(N // BS + 1)):\n            x_real = sampler.sample(BS)[0].detach().cpu()\n            x_real = [x.detach().cpu() for x in x_real]\n            samples.extend(x_real)\n        samples = torch.stack(samples[:N], dim=0)\n        return pt_to_np(samples)\n\n\nroot = './'\n\nwhile len(dirs) > 0:\n    path = dirs.pop()\n    if os.path.isdir(path):     # search down tree for config files\n        for d1 in os.listdir(path):\n            dirs.append(os.path.join(path, d1))\n    else:\n        if path.endswith('.yaml'):\n            config = load_config(path, default_path='configs/default.yaml')\n            outdir = config['training']['out_dir']\n\n            if not os.path.exists(outdir) and config['pretrained'] == {}:\n                print('Skipping', path, 'outdir', outdir)\n                continue\n\n            results_dir = os.path.join(outdir, 'results')\n            checkpoint_dir = os.path.join(outdir, 'chkpts')\n            os.makedirs(results_dir, exist_ok=True)\n\n            fid_results, is_results, kl_results, nmodes_results, fsd_results, cluster_results = load_results(results_dir)\n\n            checkpoint_files = os.listdir(checkpoint_dir) if os.path.exists(checkpoint_dir) else []\n            if config['pretrained'] != {}:\n                checkpoint_files = checkpoint_files + ['pretrained']\n\n            for checkpoint in checkpoint_files:\n                if (checkpoint.endswith('.pt') and checkpoint != 'model.pt') or checkpoint == 'pretrained':\n                    print('Computing for', checkpoint)\n                    if 'model' in checkpoint:\n                        # infer iteration number from checkpoint file w/o loading it\n                        if 'model_' in checkpoint:\n                            it = int(checkpoint.split('model_')[1].split('.pt')[0])\n                        else:\n                            continue\n                        if args.every != 0 and it % args.every != 0:\n                            continue\n                        # iteration 0 is often useless, skip it\n                        if it == 0 or args.it != -1 and it != args.it:\n                            continue\n                    elif checkpoint == 'pretrained':\n                        it = 'pretrained'\n                    it = str(it)\n\n                    clusterer_path = os.path.join(root, checkpoint_dir, f'clusterer{it}.pkl')\n                    #  don't save samples for each iteration for disk space\n                    samples_path = os.path.join(outdir, 'results', 'samples.npz')\n\n                    targets = []\n                    if args.inception:\n                        targets = targets + [is_results]\n                    if args.fid:\n                        targets = targets + [fid_results]\n                    if args.modes:\n                        targets = targets + [kl_results, nmodes_results]\n                    if args.fsd:\n                        targets = targets + [fsd_results]\n\n                    if all([it in result for result in targets]):\n                        print('Already generated', it, path)\n                    else:\n                        sampler = SeededSampler(path,\n                                                model_path=os.path.join(root, checkpoint_dir, checkpoint),\n                                                clusterer_path=clusterer_path,\n                                                pretrained=config['pretrained'])\n                        samples = sample(sampler)\n                        dataset_name = get_dataset_from_path(path)\n                        np.savez(samples_path, fake=samples, real=dataset_name)\n\n                    arguments = f'--samples {samples_path} --it {it} --results_dir {results_dir}'\n                    if args.fid and it not in fid_results:\n                        os.system(f'CUDA_VISIBLE_DEVICES={device} python gan_training/metrics/fid.py {arguments}')\n                    if args.inception and it not in is_results:\n                        os.system(f'CUDA_VISIBLE_DEVICES={device} python gan_training/metrics/tf_is/inception_score.py {arguments}')\n                    if args.modes and (it not in kl_results or it not in nmodes_results):\n                        os.system(f'CUDA_VISIBLE_DEVICES={device} python utils/get_empirical_distribution.py {arguments} --dataset {dataset_name}')\n                    if args.cluster_metrics and it not in cluster_results:\n                        os.system(f'CUDA_VISIBLE_DEVICES={device} python cluster_metrics.py {path} --model_it {it}')\n                    if args.fsd and it not in fsd_results:\n                        gt_path = dataset_to_img[dataset_name]\n                        os.system(f'CUDA_VISIBLE_DEVICES={device} python -m seeing.fsd {gt_path} {samples_path} --it {it} --results_dir {results_dir}')\n"
  },
  {
    "path": "requirements.txt",
    "content": "pytorch-gpu==1.3.1\ntensorflow-gpu==1.14.0\nscikit-learn\nscikit-image\ntorchvision\ntqdm \npyyaml\ncloudpickle\nipython\nopencv"
  },
  {
    "path": "seeded_sampler.py",
    "content": "''' Samples from a (class-conditional) GAN, so that the samples can be reproduced '''\n\nimport os\nimport pickle\nimport random\nimport copy\n\nimport torch\nfrom torch import nn\n\nfrom gan_training.checkpoints import CheckpointIO\nfrom gan_training.config import (load_config, build_models)\nfrom seeing.yz_dataset import YZDataset\n\n\ndef get_most_recent(models):\n    model_numbers = [\n        int(model.split(\"model.pt\")[0]) if model != \"model.pt\" else 0\n        for model in models\n    ]\n    return str(max(model_numbers)) + \"model.pt\"\n\n\nclass SeededSampler():\n    def __init__(\n            self,\n            config_name,        # name of experiment's config file\n            model_path=\"\",      # path to the model. empty string infers the most recent checkpoint\n            clusterer_path=\"\",  # path to the clusterer, ignored if gan type doesn't require a clusterer\n            pretrained={},      # urls to the pretrained models\n            rootdir='./',\n            device='cuda:0'):\n        self.config = load_config(os.path.join(rootdir, config_name), 'configs/default.yaml')\n        self.model_path = model_path\n        self.clusterer_path = clusterer_path\n        self.rootdir = rootdir\n        self.nlabels = self.config['generator']['nlabels']\n        self.device = device\n        self.pretrained = pretrained\n\n        self.generator = self.get_generator()\n        self.generator.eval()\n        self.yz_dist = self.get_yz_dist()\n\n    def sample(self, nimgs):\n        '''\n        samples an image using the generator, with z drawn from isotropic gaussian, and y drawn from self.yz_dist.\n        For baseline methods, y doesn't matter because y is ignored in the input\n        yz_dist is the empirical label distribution for the clustered gans.\n\n        returns the image, and the integer seed used to generate it. generated sample is in [-1, 1]\n        '''\n        self.generator.eval()\n        with torch.no_grad():\n            seeds = [random.randint(0, 1e8) for _ in range(nimgs)]\n            z, y = self.yz_dist(seeds)\n            return self.generator(z, y), seeds\n\n    def conditional_sample(self, yi, seed=None):\n        ''' returns a generated sample, which is in [-1, 1], seed is an int'''\n        self.generator.eval()\n        with torch.no_grad():\n            if seed is None:\n                seed = [random.randint(0, 1e8)]\n            else:\n                seed = [seed]\n            z, _ = self.yz_dist(seed)\n            y = torch.LongTensor([yi]).to(self.device)\n            return self.generator(z, y)\n\n    def sample_with_seed(self, seeds):\n        ''' returns a generated sample, which is in [-1, 1] '''\n        self.generator.eval()\n        z, y = self.yz_dist(seeds)\n        return self.generator(z, y)\n\n    def get_zy(self, seeds):\n        '''returns the batch of z, y corresponding to the seeds'''\n        return self.yz_dist(seeds)\n\n    def sample_with_zy(self, z, y):\n        ''' returns a generated sample given z and y, which is in [-1, 1].'''\n        self.generator.eval()\n        return self.generator(z, y)\n\n    def get_generator(self):\n        ''' loads a generator according to self.model_path '''\n\n        exp_out_dir = os.path.join(self.rootdir, self.config['training']['out_dir'])\n        # infer checkpoint if neeeded\n        checkpoint_dir = os.path.join(exp_out_dir, 'chkpts') if self.model_path == \"\" or 'model' in self.pretrained else \"./\"\n        model_name = get_most_recent(os.listdir(checkpoint_dir)) if self.model_path == \"\" else self.model_path\n\n        checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)\n        self.checkpoint_io = checkpoint_io\n\n        generator, _ = build_models(self.config)\n        generator = generator.to(self.device)\n        generator = nn.DataParallel(generator)\n\n        if self.config['training']['take_model_average']:\n            generator_test = copy.deepcopy(generator)\n            checkpoint_io.register_modules(generator_test=generator_test)\n        else:\n            generator_test = generator\n\n        checkpoint_io.register_modules(generator=generator)\n\n        try:\n            it = checkpoint_io.load(model_name, pretrained=self.pretrained)\n            assert (it != -1)\n        except Exception as e:\n            # try again without data parallel\n            print(e)\n            checkpoint_io.register_modules(generator=generator.module)\n            checkpoint_io.register_modules(generator_test=generator_test.module)\n            it = checkpoint_io.load(model_name, pretrained=self.pretrained)\n            assert (it != -1)\n\n        print('Loaded iteration:', it['it'])\n        return generator_test\n\n    def get_yz_dist(self):\n        '''loads the z and y dists used to sample from the generator.'''\n\n        if self.config['clusterer']['name'] != 'supervised':\n            if 'clusterer' in self.pretrained:\n                clusterer = self.checkpoint_io.load_clusterer('pretrained', load_samples=False, pretrained=self.pretrained)\n            elif os.path.exists(self.clusterer_path):\n                with open(self.clusterer_path, 'rb') as f:\n                    clusterer = pickle.load(f)\n\n            if isinstance(clusterer.discriminator, nn.DataParallel):\n                clusterer.discriminator = clusterer.discriminator.module\n\n            if clusterer.kmeans is not None:\n                # use clusterer empirical distribution as sampling\n                print('Using k-means empirical distribution')\n                distribution = clusterer.get_label_distribution()\n                probs = [f / sum(distribution) for f in distribution]\n            else:\n                # otherwise, use a uniform distribution. this is not desired, unless it's a random label or unconditional GAN\n                print(\"Sampling with uniform distribution over\", clusterer.k, \"labels\")\n                probs = [1. / clusterer.k for _ in range(clusterer.k)]\n        else:\n            # if it's supervised, then sample uniformly over all classes.\n            # this might not be the right thing to do, since datasets are usually imbalanced.\n            print(\"Sampling with uniform distribution over\", self.nlabels,\n                  \"labels\")\n            probs = [1. / self.nlabels for _ in range(self.nlabels)]\n\n        return YZDataset(zdim=self.config['z_dist']['dim'],\n                         nlabels=len(probs),\n                         distribution=probs,\n                         device=self.device)\n"
  },
  {
    "path": "seeing/frechet_distance.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Calculates the Frechet Distance (FD) between two samples.\n\nCode apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead\nof Tensorflow\n\nCopyright 2018 Institute of Bioinformatics, JKU Linz\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n   http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\nimport numpy as np\nimport torch\nfrom scipy import linalg\n\ndef sample_frechet_distance(sample1, sample2, eps=1e-6,\n        return_components=False):\n    '''\n    Both samples should be numpy arrays.\n    Returns the Frechet distance.\n    '''\n    (mu1, sigma1), (mu2, sigma2) = [calculate_activation_statistics(s)\n            for s in [sample1, sample2]]\n    return calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=eps,\n            return_components=return_components)\n\ndef calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6,\n        return_components=False):\n    \"\"\"Numpy implementation of the Frechet Distance.\n    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)\n    and X_2 ~ N(mu_2, C_2) is\n            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).\n\n    Stable version by Dougal J. Sutherland.\n\n    Params:\n    -- mu1   : Numpy array containing the activations of a layer of the\n               inception net (like returned by the function 'get_predictions')\n               for generated samples.\n    -- mu2   : The sample mean over activations, precalculated on an\n               representative data set.\n    -- sigma1: The covariance matrix over activations for generated samples.\n    -- sigma2: The covariance matrix over activations, precalculated on an\n               representative data set.\n\n    Returns:\n    --   : The Frechet Distance.\n    \"\"\"\n\n    mu1 = np.atleast_1d(mu1)\n    mu2 = np.atleast_1d(mu2)\n\n    sigma1 = np.atleast_2d(sigma1)\n    sigma2 = np.atleast_2d(sigma2)\n\n    assert mu1.shape == mu2.shape, \\\n        'Training and test mean vectors have different lengths'\n    assert sigma1.shape == sigma2.shape, \\\n        'Training and test covariances have different dimensions'\n\n    diff = mu1 - mu2\n\n    # Product might be almost singular\n    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n    if not np.isfinite(covmean).all():\n        msg = ('fid calculation produces singular product; '\n               'adding %s to diagonal of cov estimates') % eps\n        print(msg)\n        offset = np.eye(sigma1.shape[0]) * eps\n        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n\n    # Numerical error might give slight imaginary component\n    if np.iscomplexobj(covmean):\n        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n            m = np.max(np.abs(covmean.imag))\n            raise ValueError('Imaginary component {}'.format(m))\n        covmean = covmean.real\n\n    tr_covmean = np.trace(covmean)\n\n    meandiff = diff.dot(diff)\n    covdiff = np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean\n    if return_components:\n        return (meandiff + covdiff, meandiff, covdiff)\n    else:\n        return meandiff + covdiff\n\n\ndef calculate_activation_statistics(act):\n    \"\"\"Calculation of the statistics used by the FID.\n    Params:\n    -- files       : List of image files paths\n    -- model       : Instance of inception model\n    -- batch_size  : The images numpy array is split into batches with\n                     batch size batch_size. A reasonable batch size\n                     depends on the hardware.\n    -- dims        : Dimensionality of features returned by Inception\n    -- cuda        : If set to True, use GPU\n    -- verbose     : If set to True and parameter out_step is given, the\n                     number of calculated batches is reported.\n    Returns:\n    -- mu    : The mean over samples of the activations of the pool_3 layer of\n               the inception model.\n    -- sigma : The covariance matrix of the activations of the pool_3 layer of\n               the inception model.\n    \"\"\"\n    mu = np.mean(act, axis=0)\n    sigma = np.cov(act, rowvar=False)\n    return mu, sigma\n"
  },
  {
    "path": "seeing/fsd.py",
    "content": "import torch, argparse, sys, os, numpy\nfrom .sampler import FixedRandomSubsetSampler, FixedSubsetSampler\nfrom torch.utils.data import DataLoader, TensorDataset\nimport numpy as np\nfrom torchvision import transforms, utils\nfrom . import pbar, zdataset, segmenter, frechet_distance, parallelfolder\n\nNUM_OBJECTS = 336\n\n\ndef main():\n    parser = argparse.ArgumentParser(description='Net dissect utility')\n    parser.add_argument('true_dir')\n    parser.add_argument('gen_dir')\n    parser.add_argument('--size', type=int, default=10000)\n    parser.add_argument('--cachedir', default='results/fsd/cache')\n    parser.add_argument('--histout', default=None)\n    parser.add_argument('--maxscale', type=float, default=50)\n    parser.add_argument('--labelcount', type=int, default=30)\n    parser.add_argument('--dpi', type=float, default=100)\n    parser.add_argument('--it', type=str, default=\"-1\")\n    parser.add_argument('--results_dir', default=None, help='path to results_dir')\n    args = parser.parse_args()\n    if len(sys.argv) == 1:\n        parser.print_usage(sys.stderr)\n        sys.exit(1)\n    args = parser.parse_args()\n    print(args.true_dir, args.gen_dir)\n    true_dir, gen_dir = args.true_dir, args.gen_dir\n    seed1, seed2 = [1, 1 if true_dir != gen_dir else 2]\n    true_tally, gen_tally = [\n        cached_tally_directory(d,\n                               size=args.size,\n                               cachedir=args.cachedir,\n                               seed=seed)\n        for d, seed in [(true_dir, seed1), (gen_dir, seed2)]\n    ]\n    fsd, meandiff, covdiff = frechet_distance.sample_frechet_distance(\n        true_tally * 100, gen_tally * 100, return_components=True)\n    print('fsd: %f; meandiff: %f; covdiff: %f' % (fsd, meandiff, covdiff))\n    if args.histout is not None:\n        diff_figure(true_tally * 100,\n                    gen_tally * 100,\n                    labelcount=args.labelcount,\n                    maxscale=args.maxscale,\n                    dpi=args.dpi).savefig(args.histout)\n\n    if args.results_dir is not None:\n        import json\n\n        it = args.it\n        results_dir = args.results_dir\n\n        with open(os.path.join(args.results_dir, 'fsd_results.json')) as f:\n            fsd_results = json.load(f)\n\n        fsd_results[it] = (fsd, meandiff, covdiff)\n        \n        with open(os.path.join(args.results_dir, 'fsd_results.json'), 'w') as f:\n            f.write(json.dumps(fsd_results))\n\n        diff_figure(true_tally * 100,\n                    gen_tally * 100,\n                    labelcount=args.labelcount,\n                    maxscale=args.maxscale,\n                    dpi=args.dpi).savefig(os.path.join(args.results_dir, f'fsd_{it}.png'))\n    \ndef cached_tally_directory(directory, size=10000, cachedir=None, seed=1):\n    filename = '%s_segtally_%d.npy' % (directory, size)\n    if seed != 1:\n        filename = '%d_%s' % (seed, filename)\n    if cachedir is not None:\n        filename = os.path.join(cachedir, filename.replace('/', '_'))\n    #load only if gt stats, or image directory\n    if os.path.isfile(filename) and (not directory.endswith('.npz') or 'gt' in directory):\n        return numpy.load(filename)\n    os.makedirs(cachedir, exist_ok=True)\n    result = tally_directory(directory, size, seed=seed)\n    numpy.save(filename, result)\n    return result\n\n\ndef tally_directory(directory, size=10000, seed=1):\n    if directory.endswith('.npz'):\n        with np.load(directory) as f:\n            images = torch.from_numpy(f['fake'])\n            images = images.permute(0, 3, 1, 2) #BHWC -> BCHW\n            images = (images/127.5) - 1 #normalize in [-1, 1]\n            images = torch.nn.functional.interpolate(images, size=(256, 256))\n            print(images.shape, images.max(), images.min())\n        dataset = TensorDataset(images)\n    else:  \n        dataset = parallelfolder.ParallelImageFolders(\n            [directory],\n            transform=transforms.Compose([\n                transforms.Resize(256),\n                transforms.CenterCrop(256),\n                transforms.ToTensor(),\n                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n            ]))\n    loader = DataLoader(dataset,\n                        sampler=FixedRandomSubsetSampler(dataset,\n                                                         end=size,\n                                                         seed=seed),\n                        batch_size=10,\n                        pin_memory=True)\n    upp = segmenter.UnifiedParsingSegmenter()\n    labelnames, catnames = upp.get_label_and_category_names()\n    result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)\n    batch_result = torch.zeros(loader.batch_size,\n                               NUM_OBJECTS,\n                               dtype=torch.float).cuda()\n    with torch.no_grad():\n        batch_index = 0\n        for [batch] in pbar(loader):\n            seg_result = upp.segment_batch(batch.cuda())\n            for i in range(len(batch)):\n                batch_result[i] = (seg_result[i, 0].view(-1).bincount(\n                    minlength=NUM_OBJECTS).float() /\n                                   (seg_result.shape[2] * seg_result.shape[3]))\n            result[batch_index:batch_index +\n                   len(batch)] = (batch_result.cpu().numpy())\n            batch_index += len(batch)\n    return result\n\n\ndef tally_dataset_objects(dataset, size=10000):\n    loader = DataLoader(dataset,\n                        sampler=FixedRandomSubsetSampler(dataset, end=size),\n                        batch_size=10,\n                        pin_memory=True)\n    upp = segmenter.UnifiedParsingSegmenter()\n    labelnames, catnames = upp.get_label_and_category_names()\n    result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)\n    batch_result = torch.zeros(loader.batch_size,\n                               NUM_OBJECTS,\n                               dtype=torch.float).cuda()\n    with torch.no_grad():\n        batch_index = 0\n        for [batch] in pbar(loader):\n            seg_result = upp.segment_batch(batch.cuda())\n            for i in range(len(batch)):\n                batch_result[i] = (seg_result[i, 0].view(-1).bincount(\n                    minlength=NUM_OBJECTS).float() /\n                                   (seg_result.shape[2] * seg_result.shape[3]))\n            result[batch_index:batch_index +\n                   len(batch)] = (batch_result.cpu().numpy())\n            batch_index += len(batch)\n    return result\n\n\ndef tally_generated_objects(model, size=10000):\n    zds = zdataset.z_dataset_for_model(model, size)\n    loader = DataLoader(zds, batch_size=10, pin_memory=True)\n    upp = segmenter.UnifiedParsingSegmenter()\n    labelnames, catnames = upp.get_label_and_category_names()\n    result = numpy.zeros((size, NUM_OBJECTS), dtype=numpy.float)\n    batch_result = torch.zeros(loader.batch_size,\n                               NUM_OBJECTS,\n                               dtype=torch.float).cuda()\n    with torch.no_grad():\n        batch_index = 0\n        for [zbatch] in pbar(loader):\n            img = model(zbatch.cuda())\n            seg_result = upp.segment_batch(img)\n            for i in range(len(zbatch)):\n                batch_result[i] = (seg_result[i, 0].view(-1).bincount(\n                    minlength=NUM_OBJECTS).float() /\n                                   (seg_result.shape[2] * seg_result.shape[3]))\n            result[batch_index:batch_index +\n                   len(zbatch)] = (batch_result.cpu().numpy())\n            batch_index += len(zbatch)\n    return result\n\n\ndef diff_figure(ttally,\n                gtally,\n                labelcount=30,\n                labelleft=True,\n                dpi=100,\n                maxscale=50.0,\n                legend=False):\n    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas\n    from matplotlib.figure import Figure\n    tresult, gresult = [t.mean(0) for t in [ttally, gtally]]\n    upp = segmenter.UnifiedParsingSegmenter()\n    labelnames, catnames = upp.get_label_and_category_names()\n    x = []\n    labels = []\n    gen_amount = []\n    change_frac = []\n    true_amount = []\n    for label in numpy.argsort(-tresult):\n        if label == 0 or labelnames[label][1] == 'material':\n            continue\n        if tresult[label] == 0:\n            break\n        x.append(len(x))\n        labels.append(labelnames[label][0].split()[0])\n        true_amount.append(tresult[label].item())\n        gen_amount.append(gresult[label].item())\n        change_frac.append(\n            (float(gresult[label] - tresult[label]) / tresult[label]))\n        if len(x) >= labelcount:\n            break\n    fig = Figure(dpi=dpi, figsize=(1.4 + 5.0 * labelcount / 30, 4.0))\n    FigureCanvas(fig)\n    a1, a0 = fig.subplots(2, 1, gridspec_kw={'height_ratios': [1, 2]})\n    a0.bar(x, change_frac, label='relative delta')\n    a0.set_xticks(x)\n    a0.set_xticklabels(labels, rotation='vertical')\n    if labelleft:\n        a0.set_ylabel('relative delta\\n(gen - train) / train')\n    a0.set_xlim(-1.0, len(x))\n    a0.set_ylim([-1, 1.1])\n    a0.grid(axis='y', antialiased=False, alpha=0.25)\n    if legend:\n        a0.legend(loc=2)\n    prev_high = None\n    for ix, cf in enumerate(change_frac):\n        if cf > 1.15:\n            if prev_high == (ix - 1):\n                offset = 0.1\n            else:\n                offset = 0.0\n                prev_high = ix\n            a0.text(ix,\n                    1.15 + offset,\n                    '%.1f' % cf,\n                    horizontalalignment='center',\n                    size=6)\n\n    a1.bar(x, true_amount, label='training')\n    a1.plot(x, gen_amount, linewidth=3, color='red', label='generated')\n    a1.set_yscale('log')\n    a1.set_xlim(-1.0, len(x))\n    a1.set_ylim(maxscale / 5000, maxscale)\n    from matplotlib.ticker import LogLocator\n    # a1.yaxis.set_major_locator(LogLocator(subs=(1,)))\n    # a1.yaxis.set_minor_locator(LogLocator(subs=(1,), numdecs=10))\n    # a1.yaxis.set_minor_locator(LogLocator(subs=(1,2,3,4,5,6,7,8,9)))\n    # a1.yaxis.set_minor_locator(yminor_locator)\n    if labelleft:\n        a1.set_ylabel('mean area\\nlog scale')\n    if legend:\n        a1.legend()\n    a1.set_yticks([1e-2, 1e-1, 1.0, 1e+1])\n    a1.set_yticks([\n        a * b for a in [1e-2, 1e-1, 1.0, 1e+1]\n        for b in range(1, 10) if maxscale / 5000 <= a * b <= maxscale\n    ], True)  # minor ticks.\n    a1.set_xticks([])\n    fig.tight_layout()\n    return fig\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "seeing/lightbox.html",
    "content": "<!DOCTYPE html>\n<html>\n<!--\n  +lightbox.html, a page for automatically showing all images in a\n  directory on an Apache server. Just copy it into the directory.\n  Works by scraping the default directory HTML at \"./\" - David Bau.\n-->\n\n<head>\n  <script src=\"https://cdn.jsdelivr.net/npm/vue@2.5.16/dist/vue.js\"\n    integrity=\"sha256-CMMTrj5gGwOAXBeFi7kNokqowkzbeL8ydAJy39ewjkQ=\" crossorigin=\"anonymous\"></script>\n  <script src=\"https://cdn.jsdelivr.net/npm/lodash@4.17.10/lodash.js\"\n    integrity=\"sha256-qwbDmNVLiCqkqRBpF46q5bjYH11j5cd+K+Y6D3/ja28=\" crossorigin=\"anonymous\"></script>\n  <script src=\"https://code.jquery.com/jquery-3.3.1.js\" integrity=\"sha256-2Kok7MbOyxpgUVvAk/HJ2jigOSYS2auK4Pfzbm7uH60=\"\n    crossorigin=\"anonymous\"></script>\n  <script src=\"https://cdnjs.cloudflare.com/ajax/libs/lity/2.3.1/lity.js\"\n    integrity=\"sha256-28JiZvE/RethQIYCwkMdtSMHgI//KoTLeB2tSm10trs=\" crossorigin=\"anonymous\"></script>\n  <link rel=\"stylesheet\" href=\"https://cdnjs.cloudflare.com/ajax/libs/lity/2.3.1/lity.css\"\n    integrity=\"sha256-76wKiAXVBs5Kyj7j0T43nlBCbvR6pqdeeZmXI4ATnY0=\" crossorigin=\"anonymous\" />\n  <style>\n    h3 {\n      font-family: sans-serif;\n      font-size: 18px;\n    }\n\n    .thumb,\n    .filter {\n      font-family: sans-serif;\n      font-size: 12px;\n    }\n\n    .filter {\n      padding-bottom: 10px;\n    }\n\n    .thumb {\n      display: inline-block;\n      margin: 1px;\n      text-align: center;\n    }\n\n    .thumb img,\n    .thumb div {\n      max-width: 150px;\n      word-break: break-all;\n    }\n  </style>\n</head>\n\n<body>\n  <div id=\"app\" v-if=\"images\">\n    <h3>Images in <a :href=\"directory\">{{ directory }}</a></h3>\n    <div class=\"filter\">\n      Filter: <input v-model=\"pattern\" placeholder=\"regexp\">\n    </div>\n    <div v-for=\"r in images\" class=\"thumb\" v-if=\"patternRe.test(r)\">\n      <div>{{ r }}</div>\n      <a :href=\"r\" data-lity><img :src=\"r\"></a>\n    </div>\n  </div>\n  <!--app-->\n</body>\n<script>\n  var theapp = new Vue({\n    el: '#app',\n    data: {\n      directory: window.location.pathname.replace(/[^\\/]*$/, ''),\n      images: null,\n      pattern: '',\n    },\n    created: function () {\n      var self = this;\n      $.get('./?' + Math.random(), function (d) {\n        var imgurls = $.map($(d).find('a'),\n          x => x.href).filter(\n            x => x.match(/\\.(jpg|jpeg|png|gif|svg)$/i)).map(\n              x => x.replace(/.*\\//, ''));\n        self.images = imgurls;\n      }, 'html');\n    },\n    computed: {\n      patternRe: function () {\n        try {\n          return RegExp(this.pattern);\n        } catch (e) {\n          return /.*/;\n        }\n      }\n    },\n  })\n</script>\n\n</html>"
  },
  {
    "path": "seeing/parallelfolder.py",
    "content": "'''\nVariants of pytorch's ImageFolder for loading image datasets with more\ninformation, such as parallel feature channels in separate files,\ncached files with lists of filenames, etc.\n'''\n\nimport os, torch, re, random, numpy, itertools\nimport torch.utils.data as data\nfrom torchvision.datasets.folder import default_loader as tv_default_loader\nfrom PIL import Image\nfrom collections import OrderedDict\nfrom . import pbar\n\ndef grayscale_loader(path):\n    with open(path, 'rb') as f:\n        return Image.open(f).convert('L')\n\nclass ndarray(numpy.ndarray):\n    '''\n    Wrapper to make ndarrays into heap objects so that shared_state can\n    be attached as an attribute.\n    '''\n    pass\n\ndef default_loader(filename):\n    '''\n    Handles both numpy files and image formats.\n    '''\n    if filename.endswith('.npy'):\n        return numpy.load(filename).view(ndarray)\n    elif filename.endswith('.npz'):\n        return numpy.load(filename)\n    else:\n        return tv_default_loader(filename)\n\nclass ParallelImageFolders(data.Dataset):\n    \"\"\"\n    A data loader that looks for parallel image filenames, for example\n\n    photo1/park/004234.jpg\n    photo1/park/004236.jpg\n    photo1/park/004237.jpg\n\n    photo2/park/004234.png\n    photo2/park/004236.png\n    photo2/park/004237.png\n    \"\"\"\n    def __init__(self, image_roots,\n            transform=None,\n            loader=default_loader,\n            stacker=None,\n            classification=False,\n            intersection=False,\n            filter_tuples=None,\n            verbose=None,\n            size=None,\n            shuffle=None,\n            lazy_init=True):\n        self.image_roots = image_roots\n        if transform is not None and not hasattr(transform, '__iter__'):\n            transform = [transform for _ in image_roots]\n        self.transforms = transform\n        self.stacker = stacker\n        self.loader = loader\n        def do_lazy_init():\n            self.images, self.classes, self.class_to_idx = (\n                    make_parallel_dataset(image_roots,\n                        classification=classification,\n                        intersection=intersection,\n                        filter_tuples=filter_tuples,\n                        verbose=verbose))\n            if len(self.images) == 0:\n                raise RuntimeError(\"Found 0 images within: %s\" % image_roots)\n            if shuffle is not None:\n                random.Random(shuffle).shuffle(self.images)\n            if size is not None:\n                self.image = self.images[:size]\n            self._do_lazy_init = None\n        # Do slow initialization lazily.\n        if lazy_init:\n            self._do_lazy_init = do_lazy_init\n        else:\n            do_lazy_init()\n\n    def __getattr__(self, attr):\n        if self._do_lazy_init is not None:\n            self._do_lazy_init()\n            return getattr(self, attr)\n        raise AttributeError()\n\n    def __getitem__(self, index):\n        if self._do_lazy_init is not None:\n            self._do_lazy_init()\n        paths = self.images[index]\n        if self.classes is not None:\n            classidx = paths[-1]\n            paths = paths[:-1]\n        sources = [self.loader(path) for path in paths]\n        # Add a common shared state dict to allow random crops/flips to be\n        # coordinated.\n        shared_state = {}\n        for s in sources:\n            try:\n                s.shared_state = shared_state\n            except:\n                pass\n        if self.transforms is not None:\n            sources = [transform(source) if transform is not None else source\n                    for source, transform\n                    in itertools.zip_longest(sources, self.transforms)]\n        if self.stacker is not None:\n            sources = self.stacker(sources)\n            if self.classes is not None:\n                sources = (sources, classidx)\n        else:\n            if self.classes is not None:\n                sources.append(classidx)\n            sources = tuple(sources)\n        return sources\n\n    def __len__(self):\n        if self._do_lazy_init is not None:\n            self._do_lazy_init()\n        return len(self.images)\n\ndef is_npy_file(path):\n    return path.endswith('.npy') or path.endswith('.NPY')\n\ndef is_image_file(path):\n    return None != re.search(r'\\.(jpe?g|png)$', path, re.IGNORECASE)\n\ndef walk_image_files(rootdir, verbose=None):\n    indexfile = '%s.txt' % rootdir\n    if os.path.isfile(indexfile):\n        basedir = os.path.dirname(rootdir)\n        with open(indexfile) as f:\n            result = sorted([os.path.join(basedir, line.strip())\n                for line in f.readlines()])\n            return result\n    result = []\n    for dirname, _, fnames in sorted(pbar(os.walk(rootdir),\n            desc='Walking %s' % os.path.basename(rootdir))):\n        for fname in sorted(fnames):\n            if is_image_file(fname) or is_npy_file(fname):\n                result.append(os.path.join(dirname, fname))\n    return result\n\ndef make_parallel_dataset(image_roots, classification=False,\n        intersection=False, filter_tuples=None, verbose=None):\n    \"\"\"\n    Returns ([(img1, img2, clsid), (img1, img2, clsid)..],\n             classes, class_to_idx)\n    \"\"\"\n    image_roots = [os.path.expanduser(d) for d in image_roots]\n    image_sets = OrderedDict()\n    for j, root in enumerate(image_roots):\n        for path in walk_image_files(root, verbose=verbose):\n            key = os.path.splitext(os.path.relpath(path, root))[0]\n            if key not in image_sets:\n                image_sets[key] = []\n            if not intersection and len(image_sets[key]) != j:\n                raise RuntimeError(\n                    'Images not parallel: %s missing from one dir' % (key))\n            image_sets[key].append(path)\n    if classification:\n        classes = sorted(set([os.path.basename(os.path.dirname(k))\n            for k in image_sets.keys()]))\n        class_to_idx = dict({k: v for v, k in enumerate(classes)})\n        for k, v in image_sets.items():\n            v.append(class_to_idx[os.path.basename(os.path.dirname(k))])\n    else:\n        classes, class_to_idx = None, None\n    tuples = []\n    for key, value in image_sets.items():\n        if len(value) != len(image_roots) + (1 if classification else 0):\n            if intersection:\n                continue\n            else:\n                raise RuntimeError(\n                    'Images not parallel: %s missing from one dir' % (key))\n        value = tuple(value)\n        if filter_tuples and not filter_tuples(value):\n            continue\n        tuples.append(value)\n    return tuples, classes, class_to_idx\n\n"
  },
  {
    "path": "seeing/pbar.py",
    "content": "'''\nUtilities for showing progress bars, controlling default verbosity, etc.\n'''\n\n# If the tqdm package is not available, then do not show progress bars;\n# just connect print_progress to print.\nimport sys, types, builtins\ntry:\n    from tqdm import tqdm, tqdm_notebook\nexcept:\n    tqdm = None\n\ndefault_verbosity = True\nnext_description = None\npython_print = builtins.print\n\ndef post(**kwargs):\n    '''\n    When within a progress loop, pbar.post(k=str) will display\n    the given k=str status on the right-hand-side of the progress\n    status bar.  If not within a visible progress bar, does nothing.\n    '''\n    innermost = innermost_tqdm()\n    if innermost is not None:\n        innermost.set_postfix(**kwargs)\n\ndef desc(desc):\n    '''\n    When within a progress loop, pbar.desc(str) changes the\n    left-hand-side description of the loop toe the given description.\n    '''\n    innermost = innermost_tqdm()\n    if innermost is not None:\n        innermost.set_description(str(desc))\n\ndef descnext(desc):\n    '''\n    Called before starting a progress loop, pbar.descnext(str)\n    sets the description text that will be used in the following loop.\n    '''\n    global next_description\n    if not default_verbosity or tqdm is None:\n        return\n    next_description = desc\n\ndef print(*args):\n    '''\n    When within a progress loop, will print above the progress loop.\n    '''\n    global next_description\n    next_description = None\n    if default_verbosity:\n        msg = ' '.join(str(s) for s in args)\n        if tqdm is None:\n            python_print(msg)\n        else:\n            tqdm.write(msg)\n\ndef tqdm_terminal(it, *args, **kwargs):\n    '''\n    Some settings for tqdm that make it run better in resizable terminals.\n    '''\n    return tqdm(it, *args, dynamic_ncols=True, ascii=True,\n            leave=(innermost_tqdm() is not None), **kwargs)\n\ndef in_notebook():\n    '''\n    True if running inside a Jupyter notebook.\n    '''\n    # From https://stackoverflow.com/a/39662359/265298\n    try:\n        shell = get_ipython().__class__.__name__\n        if shell == 'ZMQInteractiveShell':\n            return True   # Jupyter notebook or qtconsole\n        elif shell == 'TerminalInteractiveShell':\n            return False  # Terminal running IPython\n        else:\n            return False  # Other type (?)\n    except NameError:\n        return False      # Probably standard Python interpreter\n\ndef innermost_tqdm():\n    '''\n    Returns the innermost active tqdm progress loop on the stack.\n    '''\n    if hasattr(tqdm, '_instances') and len(tqdm._instances) > 0:\n        return max(tqdm._instances, key=lambda x: x.pos)\n    else:\n        return None\n\ndef reporthook(*args, **kwargs):\n    '''\n    For use with urllib.request.urlretrieve.\n\n    with pbar.reporthook() as hook:\n        urllib.request.urlretrieve(url, filename, reporthook=hook)\n    '''\n    kwargs2 = dict(unit_scale=True, miniters=1)\n    kwargs2.update(kwargs)\n    bar = __call__(None, *args, **kwargs2)\n    class ReportHook(object):\n        def __init__(self, t):\n            self.t = t\n        def __call__(self, b=1, bsize=1, tsize=None):\n            if hasattr(self.t, 'total'):\n                if tsize is not None:\n                    self.t.total = tsize\n            if hasattr(self.t, 'update'):\n                self.t.update(b * bsize - self.t.n)\n        def __enter__(self):\n            return self\n        def __exit__(self, *exc):\n            if hasattr(self.t, '__exit__'):\n                self.t.__exit__(*exc)\n    return ReportHook(bar)\n\ndef __call__(x, *args, **kwargs):\n    '''\n    Invokes a progress function that can wrap iterators to print\n    progress messages, if verbose is True.\n   \n    If verbose is False or tqdm is unavailable, then a quiet\n    non-printing identity function is used.\n\n    verbose can also be set to a spefific progress function rather\n    than True, and that function will be used.\n    '''\n    global default_verbosity, next_description\n    if not default_verbosity or tqdm is None:\n        return x\n    if default_verbosity == True:\n        fn = tqdm_notebook if in_notebook() else tqdm_terminal\n    else:\n        fn = default_verbosity\n    if next_description is not None:\n        kwargs = dict(kwargs)\n        kwargs['desc'] = next_description\n        next_description = None\n    return fn(x, *args, **kwargs)\n\nclass VerboseContextManager():\n    def __init__(self, v, entered=False):\n        self.v, self.entered, self.saved = v, False, []\n        if entered:\n            self.__enter__()\n            self.entered = True\n    def __enter__(self):\n        global default_verbosity\n        if self.entered:\n            self.entered = False\n        else:\n            self.saved.append(default_verbosity)\n            default_verbosity = self.v\n        return self\n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        global default_verbosity\n        default_verbosity = self.saved.pop()\n    def __call__(self, v=True):\n        '''\n        Calling the context manager makes a new context that is\n        pre-entered, so it works as both a plain function and as a\n        factory for a context manager.\n        '''\n        new_v = v if self.v else not v\n        cm = VerboseContextManager(new_v, entered=True)\n        default_verbosity = new_v\n        return cm\n\n# Use as either \"with pbar.verbose:\" or \"pbar.verbose(False)\", or also\n# \"with pbar.verbose(False):\"\nverbose = VerboseContextManager(True)\n\n# Use as either \"with @pbar.quiet\" or \"pbar.quiet(True)\". or also\n# \"with pbar.quiet(True):\"\nquiet = VerboseContextManager(False)\n\nclass CallableModule(types.ModuleType):\n    def __init__(self):\n        # or super().__init__(__name__) for Python 3\n        types.ModuleType.__init__(self, __name__)\n        self.__dict__.update(sys.modules[__name__].__dict__)\n    def __call__(self, x, *args, **kwargs):\n        return __call__(x, *args, **kwargs)\n\nsys.modules[__name__] = CallableModule()\n\n"
  },
  {
    "path": "seeing/pidfile.py",
    "content": "'''\nUtility for simple distribution of work on multiple processes, by\nmaking sure only one process is working on a job at once.\n'''\n\nimport os, errno, socket, atexit, time, sys\n\ndef exit_if_job_done(directory, redo=False, force=False, verbose=True):\n    if pidfile_taken(os.path.join(directory, 'lockfile.pid'),\n            force=force, verbose=verbose):\n        sys.exit(0)\n    donefile = os.path.join(directory, 'done.txt')\n    if os.path.isfile(donefile):\n        with open(donefile) as f:\n            msg = f.read()\n        if redo or force:\n            if verbose:\n                print('Removing %s %s' % (donefile, msg))\n            os.remove(donefile)\n        else:\n            if verbose:\n                print('%s %s' % (donefile, msg))\n            sys.exit(0)\n\ndef mark_job_done(directory):\n    with open(os.path.join(directory, 'done.txt'), 'w') as f:\n        f.write('done by %d@%s %s at %s' %\n                (os.getpid(), socket.gethostname(),\n                 os.getenv('STY', ''),\n                 time.strftime('%c')))\n\ndef pidfile_taken(path, verbose=False, force=False):\n    '''\n    Usage.  To grab an exclusive lock for the remaining duration of the\n    current process (and exit if another process already has the lock),\n    do this:\n\n    if pidfile_taken('job_423/lockfile.pid', verbose=True):\n        sys.exit(0)\n\n    To do a batch of jobs, just run a script that does them all on\n    each available machine, sharing a network filesystem.  When each\n    job grabs a lock, then this will automatically distribute the\n    jobs so that each one is done just once on one machine.\n    '''\n\n    # Try to create the file exclusively and write my pid into it.\n    try:\n        os.makedirs(os.path.dirname(path), exist_ok=True)\n        fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR)\n    except OSError as e:\n        if e.errno == errno.EEXIST:\n            # If we cannot because there was a race, yield the conflicter.\n            conflicter = 'race'\n            try:\n                with open(path, 'r') as lockfile:\n                    conflicter = lockfile.read().strip() or 'empty'\n            except:\n                pass\n            # Force is for manual one-time use, for deleting stale lockfiles.\n            if force:\n                if verbose:\n                    print('Removing %s from %s' % (path, conflicter))\n                os.remove(path)\n                return pidfile_taken(path, verbose=verbose, force=False)\n            if verbose:\n                print('%s held by %s' % (path, conflicter))\n            return conflicter\n        else:\n            # Other problems get an exception.\n            raise\n    # Register to delete this file on exit.\n    lockfile = os.fdopen(fd, 'r+')\n    atexit.register(delete_pidfile, lockfile, path)\n    # Write my pid into the open file.\n    lockfile.write('%d@%s %s\\n' % (os.getpid(), socket.gethostname(),\n        os.getenv('STY', '')))\n    lockfile.flush()\n    os.fsync(lockfile)\n    # Return 'None' to say there was not a conflict.\n    return None\n\ndef delete_pidfile(lockfile, path):\n    '''\n    Runs at exit after pidfile_taken succeeds.\n    '''\n    if lockfile is not None:\n        try:\n            lockfile.close()\n        except:\n            pass\n    try:\n        os.unlink(path)\n    except:\n        pass\n"
  },
  {
    "path": "seeing/sampler.py",
    "content": "'''\nA sampler is just a list of integer listing the indexes of the\ninputs in a data set to sample.  For reproducibility, the\nFixedRandomSubsetSampler uses a seeded prng to produce the same\nsequence always.  FixedSubsetSampler is just a wrapper for an\nexplicit list of integers.\n\ncoordinate_sample solves another sampling problem: when testing\nconvolutional outputs, we can reduce data explosing by sampling\nrandom points of the feature map rather than the entire feature map.\ncoordinate_sample does this in a deterministic way that is also\nresolution-independent.\n'''\n\nimport numpy\nimport random\nfrom torch.utils.data.sampler import Sampler\n\nclass FixedSubsetSampler(Sampler):\n    \"\"\"Represents a fixed sequence of data set indices.\n    Subsets can be created by specifying a subset of output indexes.\n    \"\"\"\n    def __init__(self, samples):\n        self.samples = samples\n\n    def __iter__(self):\n        return iter(self.samples)\n\n    def __len__(self):\n        return len(self.samples)\n\n    def __getitem__(self, key):\n        return self.samples[key]\n\n    def subset(self, new_subset):\n        return FixedSubsetSampler(self.dereference(new_subset))\n\n    def dereference(self, indices):\n        '''\n        Translate output sample indices (small numbers indexing the sample)\n        to input sample indices (larger number indexing the original full set)\n        '''\n        return [self.samples[i] for i in indices]\n\n\nclass FixedRandomSubsetSampler(FixedSubsetSampler):\n    \"\"\"Samples a fixed number of samples from the dataset, deterministically.\n    Arguments:\n        data_source,\n        sample_size,\n        seed (optional)\n    \"\"\"\n    def __init__(self, data_source, start=None, end=None, seed=1):\n        rng = random.Random(seed)\n        shuffled = list(range(len(data_source)))\n        rng.shuffle(shuffled)\n        self.data_source = data_source\n        super(FixedRandomSubsetSampler, self).__init__(shuffled[start:end])\n\n    def class_subset(self, class_filter):\n        '''\n        Returns only the subset matching the given rule.\n        '''\n        if isinstance(class_filter, int):\n            rule = lambda d: d[1] == class_filter\n        else:\n            rule = class_filter\n        return self.subset([i for i, j in enumerate(self.samples)\n                if rule(self.data_source[j])])\n\ndef coordinate_sample(shape, sample_size, seeds, grid=13, seed=1, flat=False):\n    '''\n    Returns a (end-start) sets of sample_size grid points within\n    the shape given.  If the shape dimensions are a multiple of 'grid',\n    then sampled points within the same row will never be duplicated.\n    '''\n    if flat:\n        sampind = numpy.zeros((len(seeds), sample_size), dtype=int)\n    else:\n        sampind = numpy.zeros((len(seeds), 2, sample_size), dtype=int)\n    assert sample_size <= grid\n    for j, seed in enumerate(seeds):\n        rng = numpy.random.RandomState(seed)\n        # Shuffle the 169 random grid squares, and pick :sample_size.\n        square_count = grid ** len(shape)\n        square = numpy.stack(numpy.unravel_index(\n            rng.choice(square_count, square_count)[:sample_size],\n            (grid,) * len(shape)))\n        # Then add a random offset to each x, y and put in the range [0...1)\n        # Notice this selects the same locations regardless of resolution.\n        uniform = (square + rng.uniform(size=square.shape)) / grid\n        # TODO: support affine scaling so that we can align receptive field\n        # centers exactly when sampling neurons in different layers.\n        coords = (uniform * numpy.array(shape)[:,None]).astype(int)\n        # Now take sample_size without replacement.  We do this in a way\n        # such that if sample_size is decreased or increased up to 'grid',\n        # the selected points become a subset, not totally different points.\n        if flat:\n            sampind[j] = numpy.ravel_multi_index(coords, dims=shape)\n        else:\n            sampind[j] = coords\n    return sampind\n\ndef main():\n    from . import parallelfolder\n    import argparse, os, shutil\n\n    parser = argparse.ArgumentParser(description='Net dissect utility',\n            prog='python -m %s.sampler' % (__package__))\n    parser.add_argument('indir')\n    parser.add_argument('outdir')\n    parser.add_argument('--size', type=int, default=100)\n    parser.add_argument('--test', action='store_true', default=False)\n    args = parser.parse_args()\n    if os.path.exists(args.outdir):\n        print('%s already exists' % args.outdir)\n        sys.exit(1)\n    os.makedirs(args.outdir)\n    dataset = parallelfolder.ParallelImageFolders([args.indir])\n    sampler = FixedRandomSubsetSampler(dataset, end=args.size)\n    seen_filenames = set()\n    def number_filename(filename, number):\n        if '.' in filename:\n            a, b = filename.rsplit('.', 1)\n            return a + '_%d.' % number + b\n        return filename + '_%d' % number\n    for i in sampler.dereference(range(args.size)):\n        sourcefile = dataset.images[i][0]\n        filename = os.path.basename(sourcefile)\n        template = filename\n        num = 0\n        while filename in seen_filenames:\n            num += 1\n            filename = number_filename(template, num)\n        seen_filenames.add(filename)\n        shutil.copy(os.path.join(args.indir, sourcefile),\n                os.path.join(args.outdir, filename))\n\ndef test():\n    from numpy.testing import assert_almost_equal\n    # Test that coordinate_sample is deterministic, in-range, and scalable.\n    assert_almost_equal(coordinate_sample((26, 26), 10, range(101, 102)),\n            [[[14,  0, 12, 11,  8, 13, 11, 20,  7, 20],\n              [ 9, 22,  7, 11, 23, 18, 21, 15,  2,  5]]])\n    assert_almost_equal(coordinate_sample((13, 13), 10, range(101, 102)),\n            [[[ 7,  0,  6,  5,  4,  6,  5, 10,  3, 20 // 2],\n              [ 4, 11,  3,  5, 11,  9, 10,  7,  1,  5 // 2]]])\n    assert_almost_equal(coordinate_sample((13, 13), 10, range(100, 102),\n        flat=True),\n            [[  8,  24,  67, 103,  87,  79, 138,  94,  98,  53],\n             [ 95,  11,  81,  70,  63,  87,  75, 137,  40, 2+10*13]])\n    assert_almost_equal(coordinate_sample((13, 13), 10, range(101, 103),\n        flat=True),\n            [[ 95,  11,  81,  70,  63,  87,  75, 137,  40, 132],\n             [  0,  78, 114, 111,  66,  45,  72,  73,  79, 135]])\n    assert_almost_equal(coordinate_sample((26, 26), 10, range(101, 102),\n        flat=True),\n            [[373,  22, 319, 297, 231, 356, 307, 535, 184, 5+20*26]])\n    # Test FixedRandomSubsetSampler\n    fss = FixedRandomSubsetSampler(range(10))\n    assert len(fss) == 10\n    assert_almost_equal(list(fss), [6, 8, 9, 7, 5, 3, 0, 4, 1, 2])\n    fss = FixedRandomSubsetSampler(range(10), 3, 8)\n    assert len(fss) == 5\n    assert_almost_equal(list(fss), [7, 5, 3, 0, 4])\n    fss = FixedRandomSubsetSampler([(i, i % 3) for i in range(10)]\n            ).class_subset(class_filter=1)\n    assert len(fss) == 3\n    assert_almost_equal(list(fss), [7, 4, 1])\n\nif __name__ == '__main__':\n    import sys\n    if '--test' in sys.argv[1:]:\n        test()\n    else:\n        main()\n"
  },
  {
    "path": "seeing/segmenter.py",
    "content": "# Usage as a simple differentiable segmenter base class\n\nimport os, torch, numpy, json, glob\nimport skimage.morphology\nfrom collections import OrderedDict\nfrom . import upsegmodel\nfrom urllib.request import urlretrieve\n\nclass BaseSegmenter:\n    def get_label_and_category_names(self):\n        '''\n        Returns two lists: first, a list of tuples [(label, category), ...]\n        where the label and category are human-readable strings indicating\n        the meaning of a segmentation class.  The 0th segmentation class\n        should be reserved for a label ('-') that means \"no prediction.\"\n        The second list should just be a list of [category,...] listing\n        all categories in a canonical order.\n        '''\n        raise NotImplemented()\n\n    def segment_batch(self, tensor_images, downsample=1):\n        '''\n        Returns a multilabel segmentation for the given batch of (RGB [-1...1])\n        images.  Each pixel of the result is a torch.long indicating a\n        predicted class number.  Multiple classes can be predicted for\n        the same pixel: output shape is (n, multipred, y, x), where\n        multipred is 3, 5, or 6, for how many different predicted labels can\n        be given for each pixel (depending on whether subdivision is being\n        used).  If downsample is specified, then the output y and x dimensions\n        are downsampled from the original image.\n        '''\n        raise NotImplemented()\n\nclass UnifiedParsingSegmenter(BaseSegmenter):\n    '''\n    This is a wrapper for a more complicated multi-class segmenter,\n    as described in https://arxiv.org/pdf/1807.10221.pdf, and as\n    released in https://github.com/CSAILVision/unifiedparsing.\n    For our purposes and to simplify processing, we do not use\n    whole-scene predictions, and we only consume part segmentations\n    for the three largest object classes (sky, building, person).\n    '''\n\n    def __init__(self, segsizes=None):\n        # Create a segmentation model\n        if segsizes is None:\n            segsizes = [256]\n        segvocab = 'upp'\n        segarch = ('resnet50', 'upernet')\n        epoch = 40\n        ensure_upp_segmenter_downloaded('datasets/segmodel')\n        segmodel = load_unified_parsing_segmentation_model(\n                segarch, segvocab, epoch)\n        segmodel.cuda()\n        self.segmodel = segmodel\n        self.segsizes = segsizes\n        # Assign class numbers for parts.\n        first_partnumber = (1 +\n                (len(segmodel.labeldata['object']) - 1) +\n                (len(segmodel.labeldata['material']) - 1))\n        partobjects = segmodel.labeldata['object_part'].keys()\n        partnumbers = {}\n        partnames = []\n        objectnumbers = {k: v\n                for v, k in enumerate(segmodel.labeldata['object'])}\n        part_index_translation = []\n        # We merge some classes.  For example \"door\" is both an object\n        # and a part of a building.  To avoid confusion, we just count\n        # such classes as objects, and add part scores to the same index.\n        for owner in partobjects:\n            part_list = segmodel.labeldata['object_part'][owner]\n            numeric_part_list = []\n            for part in part_list:\n                if part in objectnumbers:\n                    numeric_part_list.append(objectnumbers[part])\n                elif part in partnumbers:\n                    numeric_part_list.append(partnumbers[part])\n                else:\n                    partnumbers[part] = len(partnames) + first_partnumber\n                    partnames.append(part)\n                    numeric_part_list.append(partnumbers[part])\n            part_index_translation.append(torch.tensor(numeric_part_list))\n        self.objects_with_parts = [objectnumbers[obj] for obj in partobjects]\n        self.part_index = part_index_translation\n        self.part_names = partnames\n        # For now we'll just do object and material labels.\n        self.num_classes = 1 + (\n                len(segmodel.labeldata['object']) - 1) + (\n                len(segmodel.labeldata['material']) - 1) + len(partnames)\n        self.num_object_classes = len(self.segmodel.labeldata['object']) - 1\n\n    def get_label_and_category_names(self, dataset=None):\n        '''\n        Lists label and category names.\n        '''\n        # Labels are ordered as follows:\n        # 0, [object labels] [divided object labels] [materials] [parts]\n        # The zero label is reserved to mean 'no prediction'.\n        suffixes = []\n        divided_labels = []\n        for suffix in suffixes:\n            divided_labels.extend([('%s-%s' % (label, suffix), 'part')\n                for label in self.segmodel.labeldata['object'][1:]])\n        # Create the whole list of labels\n        labelcats = (\n                [(label, 'object')\n                    for label in self.segmodel.labeldata['object']] +\n                divided_labels +\n                [(label, 'material')\n                    for label in self.segmodel.labeldata['material'][1:]] +\n                [(label, 'part') for label in self.part_names])\n        return labelcats, ['object', 'part', 'material']\n\n    def raw_seg_prediction(self, tensor_images, downsample=1):\n        '''\n        Generates a segmentation by applying multiresolution voting on\n        the segmentation model, using (rounded to 32 pixels) a set of\n        resolutions in the example benchmark code.\n        '''\n        y, x = tensor_images.shape[2:]\n        b = len(tensor_images)\n        tensor_images = (tensor_images + 1) / 2 * 255\n        tensor_images = torch.flip(tensor_images, (1,)) # BGR!!!?\n        tensor_images -= torch.tensor([102.9801, 115.9465, 122.7717]).to(\n                   dtype=tensor_images.dtype, device=tensor_images.device\n                   )[None,:,None,None]\n        seg_shape = (y // downsample, x // downsample)\n        # We want these to be multiples of 32 for the model.\n        sizes = [(s, s) for s in self.segsizes]\n        pred = {category: torch.zeros(\n            len(tensor_images), len(self.segmodel.labeldata[category]),\n            seg_shape[0], seg_shape[1]).cuda()\n            for category in ['object', 'material']}\n        part_pred = {partobj_index: torch.zeros(\n            len(tensor_images), len(partindex),\n            seg_shape[0], seg_shape[1]).cuda()\n            for partobj_index, partindex in enumerate(self.part_index)}\n        for size in sizes:\n            if size == tensor_images.shape[2:]:\n                resized = tensor_images\n            else:\n                resized = torch.nn.AdaptiveAvgPool2d(size)(tensor_images)\n            r_pred = self.segmodel(\n                dict(img=resized), seg_size=seg_shape)\n            for k in pred:\n                pred[k] += r_pred[k]\n            for k in part_pred:\n                part_pred[k] += r_pred['part'][k]\n        return pred, part_pred\n\n    def segment_batch(self, tensor_images, downsample=1):\n        '''\n        Returns a multilabel segmentation for the given batch of (RGB [-1...1])\n        images.  Each pixel of the result is a torch.long indicating a\n        predicted class number.  Multiple classes can be predicted for\n        the same pixel: output shape is (n, multipred, y, x), where\n        multipred is 3, 5, or 6, for how many different predicted labels can\n        be given for each pixel (depending on whether subdivision is being\n        used).  If downsample is specified, then the output y and x dimensions\n        are downsampled from the original image.\n        '''\n        pred, part_pred = self.raw_seg_prediction(tensor_images,\n                downsample=downsample)\n        y, x = tensor_images.shape[2:]\n        seg_shape = (y // downsample, x // downsample)\n        segs = torch.zeros(len(tensor_images), 3, # objects, materials, parts\n                seg_shape[0], seg_shape[1],\n                dtype=torch.long, device=tensor_images.device)\n        _, segs[:,0] = torch.max(pred['object'], dim=1)\n        # Get materials and translate to shared numbering scheme\n        _, segs[:,1] = torch.max(pred['material'], dim=1)\n        maskout = (segs[:,1] == 0)\n        segs[:,1] += (len(self.segmodel.labeldata['object']) - 1)\n        segs[:,1][maskout] = 0\n        # Now deal with subparts of sky, buildings, people\n        for i, object_index in enumerate(self.objects_with_parts):\n            trans = self.part_index[i].to(segs.device)\n            # Get the argmax, and then translate to shared numbering scheme\n            seg = trans[torch.max(part_pred[i], dim=1)[1]]\n            # Only trust the parts where the prediction also predicts the\n            # owning object.\n            mask = (segs[:,0] == object_index)\n            segs[:,2][mask] = seg[mask]\n        return segs\n\ndef load_unified_parsing_segmentation_model(segmodel_arch, segvocab, epoch):\n    segmodel_dir = 'datasets/segmodel/%s-%s-%s' % ((segvocab,) + segmodel_arch)\n    # Load json of class names and part/object structure\n    with open(os.path.join(segmodel_dir, 'labels.json')) as f:\n        labeldata = json.load(f)\n    nr_classes={k: len(labeldata[k])\n                for k in ['object', 'scene', 'material']}\n    nr_classes['part'] = sum(len(p) for p in labeldata['object_part'].values())\n    # Create a segmentation model\n    segbuilder = upsegmodel.ModelBuilder()\n    # example segmodel_arch = ('resnet101', 'upernet')\n    seg_encoder = segbuilder.build_encoder(\n            arch=segmodel_arch[0],\n            fc_dim=2048,\n            weights=os.path.join(segmodel_dir, 'encoder_epoch_%d.pth' % epoch))\n    seg_decoder = segbuilder.build_decoder(\n            arch=segmodel_arch[1],\n            fc_dim=2048, use_softmax=True,\n            nr_classes=nr_classes,\n            weights=os.path.join(segmodel_dir, 'decoder_epoch_%d.pth' % epoch))\n    segmodel = upsegmodel.SegmentationModule(\n            seg_encoder, seg_decoder, labeldata)\n    segmodel.categories = ['object', 'part', 'material']\n    segmodel.eval()\n    return segmodel\n\ndef ensure_upp_segmenter_downloaded(directory):\n    baseurl = 'http://netdissect.csail.mit.edu/data/segmodel'\n    dirname = 'upp-resnet50-upernet'\n    files = ['decoder_epoch_40.pth', 'encoder_epoch_40.pth', 'labels.json']\n    download_dir = os.path.join(directory, dirname)\n    os.makedirs(download_dir, exist_ok=True)\n    for fn in files:\n        if os.path.isfile(os.path.join(download_dir, fn)):\n            continue # Skip files already downloaded\n        url = '%s/%s/%s' % (baseurl, dirname, fn)\n        print('Downloading %s' % url)\n        urlretrieve(url, os.path.join(download_dir, fn))\n    assert os.path.isfile(os.path.join(directory, dirname, 'labels.json'))\n\ndef test_main():\n    '''\n    Test the unified segmenter.\n    '''\n    from PIL import Image\n    testim = Image.open('script/testdata/test_church_242.jpg')\n    tensor_im = (torch.from_numpy(numpy.asarray(testim)).permute(2, 0, 1)\n            .float() / 255 * 2 - 1)[None, :, :, :].cuda()\n    segmenter = UnifiedParsingSegmenter()\n    seg = segmenter.segment_batch(tensor_im)\n    bc = torch.bincount(seg.view(-1))\n    labels, cats = segmenter.get_label_and_category_names()\n    for label in bc.nonzero()[:,0]:\n        if label.item():\n            # What is the prediction for this class?\n            pred, mask = segmenter.predict_single_class(tensor_im, label.item())\n            assert mask.sum().item() == bc[label].item()\n            assert len(((seg == label).max(1)[0] - mask).nonzero()) == 0\n            inside_pred = pred[mask].mean().item()\n            outside_pred = pred[~mask].mean().item()\n            print('%s (%s, #%d): %d pixels, pred %.2g inside %.2g outside' %\n                (labels[label.item()] + (label.item(), bc[label].item(),\n                    inside_pred, outside_pred)))\n\nif __name__ == '__main__':\n    test_main()\n"
  },
  {
    "path": "seeing/upsegmodel/__init__.py",
    "content": "from .models import ModelBuilder, SegmentationModule\n"
  },
  {
    "path": "seeing/upsegmodel/models.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision\nfrom . import resnet, resnext\ntry:\n    from lib.nn import SynchronizedBatchNorm2d\nexcept ImportError:\n    from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d\n\n\nclass SegmentationModuleBase(nn.Module):\n    def __init__(self):\n        super(SegmentationModuleBase, self).__init__()\n\n    @staticmethod\n    def pixel_acc(pred, label, ignore_index=-1):\n        _, preds = torch.max(pred, dim=1)\n        valid = (label != ignore_index).long()\n        acc_sum = torch.sum(valid * (preds == label).long())\n        pixel_sum = torch.sum(valid)\n        acc = acc_sum.float() / (pixel_sum.float() + 1e-10)\n        return acc\n\n    @staticmethod\n    def part_pixel_acc(pred_part, gt_seg_part, gt_seg_object, object_label, valid):\n        mask_object = (gt_seg_object == object_label)\n        _, pred = torch.max(pred_part, dim=1)\n        acc_sum = mask_object * (pred == gt_seg_part)\n        acc_sum = torch.sum(acc_sum.view(acc_sum.size(0), -1), dim=1)\n        acc_sum = torch.sum(acc_sum * valid)\n        pixel_sum = torch.sum(mask_object.view(mask_object.size(0), -1), dim=1)\n        pixel_sum = torch.sum(pixel_sum * valid)\n        return acc_sum, pixel_sum \n\n    @staticmethod\n    def part_loss(pred_part, gt_seg_part, gt_seg_object, object_label, valid):\n        mask_object = (gt_seg_object == object_label)\n        loss = F.nll_loss(pred_part, gt_seg_part * mask_object.long(), reduction='none')\n        loss = loss * mask_object.float()\n        loss = torch.sum(loss.view(loss.size(0), -1), dim=1)\n        nr_pixel = torch.sum(mask_object.view(mask_object.shape[0], -1), dim=1)\n        sum_pixel = (nr_pixel * valid).sum()\n        loss = (loss * valid.float()).sum() / torch.clamp(sum_pixel, 1).float()\n        return loss\n\n\nclass SegmentationModule(SegmentationModuleBase):\n    def __init__(self, net_enc, net_dec, labeldata, loss_scale=None):\n        super(SegmentationModule, self).__init__()\n        self.encoder = net_enc\n        self.decoder = net_dec\n        self.crit_dict = nn.ModuleDict()\n        if loss_scale is None:\n            self.loss_scale = {\"object\": 1, \"part\": 0.5, \"scene\": 0.25, \"material\": 1}\n        else:\n            self.loss_scale = loss_scale\n\n        # criterion\n        self.crit_dict[\"object\"] = nn.NLLLoss(ignore_index=0)  # ignore background 0\n        self.crit_dict[\"material\"] = nn.NLLLoss(ignore_index=0)  # ignore background 0\n        self.crit_dict[\"scene\"] = nn.NLLLoss(ignore_index=-1)  # ignore unlabelled -1\n\n        # Label data - read from json\n        self.labeldata = labeldata\n        object_to_num = {k: v for v, k in enumerate(labeldata['object'])}\n        part_to_num = {k: v for v, k in enumerate(labeldata['part'])}\n        self.object_part = {object_to_num[k]:\n                [part_to_num[p] for p in v]\n                for k, v in labeldata['object_part'].items()}\n        self.object_with_part = sorted(self.object_part.keys())\n        self.decoder.object_part = self.object_part\n        self.decoder.object_with_part = self.object_with_part\n\n    def forward(self, feed_dict, *, seg_size=None):\n        if seg_size is None: # training\n\n            if feed_dict['source_idx'] == 0:\n                output_switch = {\"object\": True, \"part\": True, \"scene\": True, \"material\": False}\n            elif feed_dict['source_idx'] == 1:\n                output_switch = {\"object\": False, \"part\": False, \"scene\": False, \"material\": True}\n            else:\n                raise ValueError\n\n            pred = self.decoder(\n                self.encoder(feed_dict['img'], return_feature_maps=True),\n                output_switch=output_switch\n            )\n\n            # loss\n            loss_dict = {}\n            if pred['object'] is not None:  # object\n                loss_dict['object'] = self.crit_dict['object'](pred['object'], feed_dict['seg_object'])\n            if pred['part'] is not None:  # part\n                part_loss = 0\n                for idx_part, object_label in enumerate(self.object_with_part):\n                    part_loss += self.part_loss(\n                        pred['part'][idx_part], feed_dict['seg_part'],\n                        feed_dict['seg_object'], object_label, feed_dict['valid_part'][:, idx_part])\n                loss_dict['part'] = part_loss\n            if pred['scene'] is not None:  # scene\n                loss_dict['scene'] = self.crit_dict['scene'](pred['scene'], feed_dict['scene_label'])\n            if pred['material'] is not None:  # material\n                loss_dict['material'] = self.crit_dict['material'](pred['material'], feed_dict['seg_material'])\n            loss_dict['total'] = sum([loss_dict[k] * self.loss_scale[k] for k in loss_dict.keys()])\n\n            # metric \n            metric_dict= {}\n            if pred['object'] is not None:\n                metric_dict['object'] = self.pixel_acc(\n                    pred['object'], feed_dict['seg_object'], ignore_index=0)\n            if pred['material'] is not None:\n                metric_dict['material'] = self.pixel_acc(\n                    pred['material'], feed_dict['seg_material'], ignore_index=0)\n            if pred['part'] is not None:\n                acc_sum, pixel_sum = 0, 0\n                for idx_part, object_label in enumerate(self.object_with_part):\n                    acc, pixel = self.part_pixel_acc(\n                        pred['part'][idx_part], feed_dict['seg_part'], feed_dict['seg_object'],\n                        object_label, feed_dict['valid_part'][:, idx_part])\n                    acc_sum += acc\n                    pixel_sum += pixel\n                metric_dict['part'] = acc_sum.float() / (pixel_sum.float() + 1e-10)\n            if pred['scene'] is not None:\n                metric_dict['scene'] = self.pixel_acc(\n                    pred['scene'], feed_dict['scene_label'], ignore_index=-1)\n\n            return {'metric': metric_dict, 'loss': loss_dict}\n        else: # inference\n            output_switch = {\"object\": True, \"part\": True, \"scene\": True, \"material\": True}\n            pred = self.decoder(self.encoder(feed_dict['img'], return_feature_maps=True),\n                                output_switch=output_switch, seg_size=seg_size)\n            return pred\n\n\ndef conv3x3(in_planes, out_planes, stride=1, has_bias=False):\n    \"3x3 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=has_bias)\n\n\ndef conv3x3_bn_relu(in_planes, out_planes, stride=1):\n    return nn.Sequential(\n            conv3x3(in_planes, out_planes, stride),\n            SynchronizedBatchNorm2d(out_planes),\n            nn.ReLU(inplace=True),\n            )\n\n\nclass ModelBuilder:\n    def __init__(self):\n        pass\n\n    # custom weights initialization\n    @staticmethod\n    def weights_init(m):\n        classname = m.__class__.__name__\n        if classname.find('Conv') != -1:\n            nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu')\n        elif classname.find('BatchNorm') != -1:\n            m.weight.data.fill_(1.)\n            m.bias.data.fill_(1e-4)\n        #elif classname.find('Linear') != -1:\n        #    m.weight.data.normal_(0.0, 0.0001)\n\n    def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''):\n        pretrained = True if len(weights) == 0 else False\n        if arch == 'resnet50':\n            orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)\n            net_encoder = Resnet(orig_resnet)\n        elif arch == 'resnet101':\n            orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)\n            net_encoder = Resnet(orig_resnet)\n        elif arch == 'resnext101':\n            orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained)\n            net_encoder = Resnet(orig_resnext) # we can still use class Resnet\n        else:\n            raise Exception('Architecture undefined!')\n\n        # net_encoder.apply(self.weights_init)\n        if len(weights) > 0:\n            # print('Loading weights for net_encoder')\n            net_encoder.load_state_dict(\n                torch.load(weights, map_location=lambda storage, loc: storage), strict=False)\n        return net_encoder\n\n    def build_decoder(self, nr_classes,\n                      arch='ppm_bilinear_deepsup', fc_dim=512,\n                      weights='', use_softmax=False):\n        if arch == 'upernet_lite':\n            net_decoder = UPerNet(\n                nr_classes=nr_classes,\n                fc_dim=fc_dim,\n                use_softmax=use_softmax,\n                fpn_dim=256)\n        elif arch == 'upernet':\n            net_decoder = UPerNet(\n                nr_classes=nr_classes,\n                fc_dim=fc_dim,\n                use_softmax=use_softmax,\n                fpn_dim=512)\n        else:\n            raise Exception('Architecture undefined!')\n\n        net_decoder.apply(self.weights_init)\n        if len(weights) > 0:\n            # print('Loading weights for net_decoder')\n            net_decoder.load_state_dict(\n                torch.load(weights, map_location=lambda storage, loc: storage), strict=False)\n        return net_decoder\n\n\nclass Resnet(nn.Module):\n    def __init__(self, orig_resnet):\n        super(Resnet, self).__init__()\n\n        # take pretrained resnet, except AvgPool and FC\n        self.conv1 = orig_resnet.conv1\n        self.bn1 = orig_resnet.bn1\n        self.relu1 = orig_resnet.relu1\n        self.conv2 = orig_resnet.conv2\n        self.bn2 = orig_resnet.bn2\n        self.relu2 = orig_resnet.relu2\n        self.conv3 = orig_resnet.conv3\n        self.bn3 = orig_resnet.bn3\n        self.relu3 = orig_resnet.relu3\n        self.maxpool = orig_resnet.maxpool\n        self.layer1 = orig_resnet.layer1\n        self.layer2 = orig_resnet.layer2\n        self.layer3 = orig_resnet.layer3\n        self.layer4 = orig_resnet.layer4\n\n    def forward(self, x, return_feature_maps=False):\n        conv_out = []\n\n        x = self.relu1(self.bn1(self.conv1(x)))\n        x = self.relu2(self.bn2(self.conv2(x)))\n        x = self.relu3(self.bn3(self.conv3(x)))\n        x = self.maxpool(x)\n\n        x = self.layer1(x); conv_out.append(x);\n        x = self.layer2(x); conv_out.append(x);\n        x = self.layer3(x); conv_out.append(x);\n        x = self.layer4(x); conv_out.append(x);\n\n        if return_feature_maps:\n            return conv_out\n        return [x]\n\n\n# upernet\nclass UPerNet(nn.Module):\n    def __init__(self, nr_classes, fc_dim=4096,\n                 use_softmax=False, pool_scales=(1, 2, 3, 6),\n                 fpn_inplanes=(256,512,1024,2048), fpn_dim=256):\n        # Lazy import so that compilation isn't needed if not being used.\n        from .prroi_pool import PrRoIPool2D\n        super(UPerNet, self).__init__()\n        self.use_softmax = use_softmax\n\n        # PPM Module\n        self.ppm_pooling = []\n        self.ppm_conv = []\n\n        for scale in pool_scales:\n            # we use the feature map size instead of input image size, so down_scale = 1.0\n            self.ppm_pooling.append(PrRoIPool2D(scale, scale, 1.))\n            self.ppm_conv.append(nn.Sequential(\n                nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),\n                SynchronizedBatchNorm2d(512),\n                nn.ReLU(inplace=True)\n            ))\n        self.ppm_pooling = nn.ModuleList(self.ppm_pooling)\n        self.ppm_conv = nn.ModuleList(self.ppm_conv)\n        self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1)\n\n        # FPN Module\n        self.fpn_in = []\n        for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer\n            self.fpn_in.append(nn.Sequential(\n                nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False),\n                SynchronizedBatchNorm2d(fpn_dim),\n                nn.ReLU(inplace=True)\n            ))\n        self.fpn_in = nn.ModuleList(self.fpn_in)\n\n        self.fpn_out = []\n        for i in range(len(fpn_inplanes) - 1): # skip the top layer\n            self.fpn_out.append(nn.Sequential(\n                conv3x3_bn_relu(fpn_dim, fpn_dim, 1),\n            ))\n        self.fpn_out = nn.ModuleList(self.fpn_out)\n\n        self.conv_fusion = conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1)\n\n        # background included. if ignore in loss, output channel 0 will not be trained.\n        self.nr_scene_class, self.nr_object_class, self.nr_part_class, self.nr_material_class = \\\n            nr_classes['scene'], nr_classes['object'], nr_classes['part'], nr_classes['material']\n\n        # input: PPM out, input_dim: fpn_dim\n        self.scene_head = nn.Sequential(\n            conv3x3_bn_relu(fpn_dim, fpn_dim, 1),\n            nn.AdaptiveAvgPool2d(1),\n            nn.Conv2d(fpn_dim, self.nr_scene_class, kernel_size=1, bias=True)\n        )\n\n        # input: Fusion out, input_dim: fpn_dim\n        self.object_head = nn.Sequential(\n            conv3x3_bn_relu(fpn_dim, fpn_dim, 1),\n            nn.Conv2d(fpn_dim, self.nr_object_class, kernel_size=1, bias=True)\n        )\n\n        # input: Fusion out, input_dim: fpn_dim\n        self.part_head = nn.Sequential(\n            conv3x3_bn_relu(fpn_dim, fpn_dim, 1),\n            nn.Conv2d(fpn_dim, self.nr_part_class, kernel_size=1, bias=True)\n        )\n\n        # input: FPN_2 (P2), input_dim: fpn_dim\n        self.material_head = nn.Sequential(\n            conv3x3_bn_relu(fpn_dim, fpn_dim, 1),\n            nn.Conv2d(fpn_dim, self.nr_material_class, kernel_size=1, bias=True)\n        )\n\n    def forward(self, conv_out, output_switch=None, seg_size=None):\n\n        output_dict = {k: None for k in output_switch.keys()}\n\n        conv5 = conv_out[-1]\n        input_size = conv5.size()\n        ppm_out = [conv5]\n        roi = [] # fake rois, just used for pooling\n        for i in range(input_size[0]): # batch size\n            roi.append(torch.Tensor([i, 0, 0, input_size[3], input_size[2]]).view(1, -1)) # b, x0, y0, x1, y1\n        roi = torch.cat(roi, dim=0).type_as(conv5)\n        ppm_out = [conv5]\n        for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):\n            ppm_out.append(pool_conv(F.interpolate(\n                pool_scale(conv5, roi.detach()),\n                (input_size[2], input_size[3]),\n                mode='bilinear', align_corners=False)))\n        ppm_out = torch.cat(ppm_out, 1)\n        f = self.ppm_last_conv(ppm_out)\n\n        if output_switch['scene']: # scene\n            output_dict['scene'] = self.scene_head(f)\n\n        if output_switch['object'] or output_switch['part'] or output_switch['material']:\n            fpn_feature_list = [f]\n            for i in reversed(range(len(conv_out) - 1)):\n                conv_x = conv_out[i]\n                conv_x = self.fpn_in[i](conv_x) # lateral branch\n\n                f = F.interpolate(\n                    f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch\n                f = conv_x + f\n\n                fpn_feature_list.append(self.fpn_out[i](f))\n            fpn_feature_list.reverse() # [P2 - P5]\n\n            # material\n            if output_switch['material']:\n                output_dict['material'] = self.material_head(fpn_feature_list[0])\n\n            if output_switch['object'] or output_switch['part']:\n                output_size = fpn_feature_list[0].size()[2:]\n                fusion_list = [fpn_feature_list[0]]\n                for i in range(1, len(fpn_feature_list)):\n                    fusion_list.append(F.interpolate(\n                        fpn_feature_list[i],\n                        output_size,\n                        mode='bilinear', align_corners=False))\n                fusion_out = torch.cat(fusion_list, 1)\n                x = self.conv_fusion(fusion_out)\n\n                if output_switch['object']: # object\n                    output_dict['object'] = self.object_head(x)\n                if output_switch['part']:\n                    output_dict['part'] = self.part_head(x)\n\n        if self.use_softmax:  # is True during inference\n            # inference scene\n            x = output_dict['scene']\n            x = x.squeeze(3).squeeze(2)\n            x = F.softmax(x, dim=1)\n            output_dict['scene'] = x\n\n            # inference object, material\n            for k in ['object', 'material']:\n                x = output_dict[k]\n                x = F.interpolate(x, size=seg_size, mode='bilinear', align_corners=False)\n                x = F.softmax(x, dim=1)\n                output_dict[k] = x\n\n            # inference part\n            x = output_dict['part']\n            x = F.interpolate(x, size=seg_size, mode='bilinear', align_corners=False)\n            part_pred_list, head = [], 0\n            for idx_part, object_label in enumerate(self.object_with_part):\n                n_part = len(self.object_part[object_label])\n                _x = F.interpolate(x[:, head: head + n_part], size=seg_size, mode='bilinear', align_corners=False)\n                _x = F.softmax(_x, dim=1)\n                part_pred_list.append(_x)\n                head += n_part\n            output_dict['part'] = part_pred_list\n\n        else:   # Training\n            # object, scene, material\n            for k in ['object', 'scene', 'material']:\n                if output_dict[k] is None:\n                    continue\n                x = output_dict[k]\n                x = F.log_softmax(x, dim=1)\n                if k == \"scene\":  # for scene\n                    x = x.squeeze(3).squeeze(2)\n                output_dict[k] = x\n            if output_dict['part'] is not None:\n                part_pred_list, head = [], 0\n                for idx_part, object_label in enumerate(self.object_with_part):\n                    n_part = len(self.object_part[object_label])\n                    x = output_dict['part'][:, head: head + n_part]\n                    x = F.log_softmax(x, dim=1)\n                    part_pred_list.append(x)\n                    head += n_part\n                output_dict['part'] = part_pred_list\n\n        return output_dict\n"
  },
  {
    "path": "seeing/upsegmodel/prroi_pool/.gitignore",
    "content": "*.o\n/_prroi_pooling\n"
  },
  {
    "path": "seeing/upsegmodel/prroi_pool/README.md",
    "content": "# PreciseRoIPooling\nThis repo implements the **Precise RoI Pooling** (PrRoI Pooling), proposed in the paper **Acquisition of Localization Confidence for Accurate Object Detection** published at ECCV 2018 (Oral Presentation).\n\n**Acquisition of Localization Confidence for Accurate Object Detection**\n\n_Borui Jiang*, Ruixuan Luo*, Jiayuan Mao*, Tete Xiao, Yuning Jiang_ (* indicates equal contribution.)\n\nhttps://arxiv.org/abs/1807.11590\n\n## Brief\n\nIn short, Precise RoI Pooling is an integration-based (bilinear interpolation) average pooling method for RoI Pooling. It avoids any quantization and has a continuous gradient on bounding box coordinates. It is:\n\n- different from the original RoI Pooling proposed in [Fast R-CNN](https://arxiv.org/abs/1504.08083). PrRoI Pooling uses average pooling instead of max pooling for each bin and has a continuous gradient on bounding box coordinates. That is, one can take the derivatives of some loss function w.r.t the coordinates of each RoI and optimize the RoI coordinates.\n- different from the RoI Align proposed in [Mask R-CNN](https://arxiv.org/abs/1703.06870). PrRoI Pooling uses a full integration-based average pooling instead of sampling a constant number of points. This makes the gradient w.r.t. the coordinates continuous.\n\nFor a better illustration, we illustrate RoI Pooling, RoI Align and PrRoI Pooing in the following figure. More details including the gradient computation can be found in our paper.\n\n<center><img src=\"./_assets/prroi_visualization.png\" width=\"80%\"></center>\n\n## Implementation\n\nPrRoI Pooling was originally implemented by [Tete Xiao](http://tetexiao.com/) based on MegBrain, an (internal) deep learning framework built by Megvii Inc. It was later adapted into open-source deep learning frameworks. Currently, we only support PyTorch. Unfortunately, we don't have any specific plan for the adaptation into other frameworks such as TensorFlow, but any contributions (pull requests) will be more than welcome.\n\n## Usage (PyTorch 1.0)\n\nIn the directory `pytorch/`, we provide a PyTorch-based implementation of PrRoI Pooling. It requires PyTorch 1.0+ and only supports CUDA (CPU mode is not implemented).\nSince we use PyTorch JIT for cxx/cuda code compilation, to use the module in your code, simply do:\n\n```\nfrom prroi_pool import PrRoIPool2D\n\navg_pool = PrRoIPool2D(window_height, window_width, spatial_scale)\nroi_features = avg_pool(features, rois)\n\n# for those who want to use the \"functional\"\n\nfrom prroi_pool.functional import prroi_pool2d\nroi_features = prroi_pool2d(features, rois, window_height, window_width, spatial_scale)\n```\n\n\n## Usage (PyTorch 0.4)\n\n**!!! Please first checkout to the branch pytorch0.4.**\n\nIn the directory `pytorch/`, we provide a PyTorch-based implementation of PrRoI Pooling. It requires PyTorch 0.4 and only supports CUDA (CPU mode is not implemented).\nTo use the PrRoI Pooling module, first goto `pytorch/prroi_pool` and execute `./travis.sh` to compile the essential components (you may need `nvcc` for this step). To use the module in your code, simply do:\n\n```\nfrom prroi_pool import PrRoIPool2D\n\navg_pool = PrRoIPool2D(window_height, window_width, spatial_scale)\nroi_features = avg_pool(features, rois)\n\n# for those who want to use the \"functional\"\n\nfrom prroi_pool.functional import prroi_pool2d\nroi_features = prroi_pool2d(features, rois, window_height, window_width, spatial_scale)\n```\n\nHere,\n\n- RoI is an `m * 5` float tensor of format `(batch_index, x0, y0, x1, y1)`, following the convention in the original Caffe implementation of RoI Pooling, although in some frameworks the batch indices are provided by an integer tensor.\n- `spatial_scale` is multiplied to the RoIs. For example, if your feature maps are down-sampled by a factor of 16 (w.r.t. the input image), you should use a spatial scale of `1/16`.\n- The coordinates for RoI follows the [L, R) convension. That is, `(0, 0, 4, 4)` denotes a box of size `4x4`.\n"
  },
  {
    "path": "seeing/upsegmodel/prroi_pool/__init__.py",
    "content": "#! /usr/bin/env python3\n# -*- coding: utf-8 -*-\n# File   : __init__.py\n# Author : Jiayuan Mao, Tete Xiao\n# Email  : maojiayuan@gmail.com, jasonhsiao97@gmail.com\n# Date   : 07/13/2018\n# \n# This file is part of PreciseRoIPooling.\n# Distributed under terms of the MIT license.\n# Copyright (c) 2017 Megvii Technology Limited.\n\nfrom .prroi_pool import *\n\n"
  },
  {
    "path": "seeing/upsegmodel/prroi_pool/build.py",
    "content": "#! /usr/bin/env python3\n# -*- coding: utf-8 -*-\n# File   : build.py\n# Author : Jiayuan Mao, Tete Xiao\n# Email  : maojiayuan@gmail.com, jasonhsiao97@gmail.com\n# Date   : 07/13/2018\n# \n# This file is part of PreciseRoIPooling.\n# Distributed under terms of the MIT license.\n# Copyright (c) 2017 Megvii Technology Limited.\n\nimport os\nimport torch\n\nfrom torch.utils.ffi import create_extension\n\nheaders = []\nsources = []\ndefines = []\nextra_objects = []\nwith_cuda = False\n\nif torch.cuda.is_available():\n    with_cuda = True\n\n    headers+= ['src/prroi_pooling_gpu.h']\n    sources += ['src/prroi_pooling_gpu.c']\n    defines += [('WITH_CUDA', None)]\n\n    this_file = os.path.dirname(os.path.realpath(__file__))\n    extra_objects_cuda = ['src/prroi_pooling_gpu_impl.cu.o']\n    extra_objects_cuda = [os.path.join(this_file, fname) for fname in extra_objects_cuda]\n    extra_objects.extend(extra_objects_cuda)\nelse:\n    # TODO(Jiayuan Mao @ 07/13): remove this restriction after we support the cpu implementation.\n    raise NotImplementedError('Precise RoI Pooling only supports GPU (cuda) implememtations.')\n\nffi = create_extension(\n    '_prroi_pooling',\n    headers=headers,\n    sources=sources,\n    define_macros=defines,\n    relative_to=__file__,\n    with_cuda=with_cuda,\n    extra_objects=extra_objects\n)\n\nif __name__ == '__main__':\n    ffi.build()\n\n"
  },
  {
    "path": "seeing/upsegmodel/prroi_pool/functional.py",
    "content": "#! /usr/bin/env python3\n# -*- coding: utf-8 -*-\n# File   : functional.py\n# Author : Jiayuan Mao, Tete Xiao\n# Email  : maojiayuan@gmail.com, jasonhsiao97@gmail.com\n# Date   : 07/13/2018\n#\n# This file is part of PreciseRoIPooling.\n# Distributed under terms of the MIT license.\n# Copyright (c) 2017 Megvii Technology Limited.\n\nimport torch\nimport torch.autograd as ag\n\ntry:\n    from os.path import join as pjoin, dirname\n    from torch.utils.cpp_extension import load as load_extension\n    root_dir = pjoin(dirname(__file__), 'src')\n    _prroi_pooling = load_extension(\n        '_prroi_pooling',\n        [pjoin(root_dir, 'prroi_pooling_gpu.c'), pjoin(root_dir, 'prroi_pooling_gpu_impl.cu')],\n        verbose=False\n    )\nexcept ImportError:\n    raise ImportError('Can not compile Precise RoI Pooling library.')\n\n__all__ = ['prroi_pool2d']\n\n\nclass PrRoIPool2DFunction(ag.Function):\n    @staticmethod\n    def forward(ctx, features, rois, pooled_height, pooled_width, spatial_scale):\n        assert 'FloatTensor' in features.type() and 'FloatTensor' in rois.type(), \\\n                'Precise RoI Pooling only takes float input, got {} for features and {} for rois.'.format(features.type(), rois.type())\n\n        pooled_height = int(pooled_height)\n        pooled_width = int(pooled_width)\n        spatial_scale = float(spatial_scale)\n\n        features = features.contiguous()\n        rois = rois.contiguous()\n        params = (pooled_height, pooled_width, spatial_scale)\n\n        if features.is_cuda:\n            output = _prroi_pooling.prroi_pooling_forward_cuda(features, rois, *params)\n            ctx.params = params\n            # everything here is contiguous.\n            ctx.save_for_backward(features, rois, output)\n        else:\n            raise NotImplementedError('Precise RoI Pooling only supports GPU (cuda) implememtations.')\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        features, rois, output = ctx.saved_tensors\n        grad_input = grad_coor = None\n\n        if features.requires_grad:\n            grad_output = grad_output.contiguous()\n            grad_input = _prroi_pooling.prroi_pooling_backward_cuda(features, rois, output, grad_output, *ctx.params)\n        if rois.requires_grad:\n            grad_output = grad_output.contiguous()\n            grad_coor = _prroi_pooling.prroi_pooling_coor_backward_cuda(features, rois, output, grad_output, *ctx.params)\n\n        return grad_input, grad_coor, None, None, None\n\n\nprroi_pool2d = PrRoIPool2DFunction.apply\n\n"
  },
  {
    "path": "seeing/upsegmodel/prroi_pool/prroi_pool.py",
    "content": "#! /usr/bin/env python3\n# -*- coding: utf-8 -*-\n# File   : prroi_pool.py\n# Author : Jiayuan Mao, Tete Xiao\n# Email  : maojiayuan@gmail.com, jasonhsiao97@gmail.com\n# Date   : 07/13/2018\n# \n# This file is part of PreciseRoIPooling.\n# Distributed under terms of the MIT license.\n# Copyright (c) 2017 Megvii Technology Limited.\n\nimport torch.nn as nn\n\nfrom .functional import prroi_pool2d\n\n__all__ = ['PrRoIPool2D']\n\n\nclass PrRoIPool2D(nn.Module):\n    def __init__(self, pooled_height, pooled_width, spatial_scale):\n        super().__init__()\n\n        self.pooled_height = int(pooled_height)\n        self.pooled_width = int(pooled_width)\n        self.spatial_scale = float(spatial_scale)\n\n    def forward(self, features, rois):\n        return prroi_pool2d(features, rois, self.pooled_height, self.pooled_width, self.spatial_scale)\n"
  },
  {
    "path": "seeing/upsegmodel/prroi_pool/src/prroi_pooling_gpu.c",
    "content": "/*\n * File   : prroi_pooling_gpu.c\n * Author : Jiayuan Mao, Tete Xiao\n * Email  : maojiayuan@gmail.com, jasonhsiao97@gmail.com\n * Date   : 07/13/2018\n *\n * Distributed under terms of the MIT license.\n * Copyright (c) 2017 Megvii Technology Limited.\n */\n\n#include <math.h>\n#include <torch/extension.h>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <THC/THC.h>\n\n#include \"prroi_pooling_gpu_impl.cuh\"\n\n\nat::Tensor prroi_pooling_forward_cuda(const at::Tensor &features, const at::Tensor &rois, int pooled_height, int pooled_width, float spatial_scale) {\n    int nr_rois = rois.size(0);\n    int nr_channels = features.size(1);\n    int height = features.size(2);\n    int width = features.size(3);\n    int top_count = nr_rois * nr_channels * pooled_height * pooled_width;\n    auto output = at::zeros({nr_rois, nr_channels, pooled_height, pooled_width}, features.options());\n\n    if (output.numel() == 0) {\n        THCudaCheck(cudaGetLastError());\n        return output;\n    }\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n    PrRoIPoolingForwardGpu(\n        stream, features.data<float>(), rois.data<float>(), output.data<float>(),\n        nr_channels, height, width, pooled_height, pooled_width, spatial_scale,\n        top_count\n    );\n\n    THCudaCheck(cudaGetLastError());\n    return output;\n}\n\nat::Tensor prroi_pooling_backward_cuda(\n    const at::Tensor &features, const at::Tensor &rois, const at::Tensor &output, const at::Tensor &output_diff,\n    int pooled_height, int pooled_width, float spatial_scale) {\n\n    auto features_diff = at::zeros_like(features);\n\n    int nr_rois = rois.size(0);\n    int batch_size = features.size(0);\n    int nr_channels = features.size(1);\n    int height = features.size(2);\n    int width = features.size(3);\n    int top_count = nr_rois * nr_channels * pooled_height * pooled_width;\n    int bottom_count = batch_size * nr_channels * height * width;\n\n    if (output.numel() == 0) {\n        THCudaCheck(cudaGetLastError());\n        return features_diff;\n    }\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n    PrRoIPoolingBackwardGpu(\n        stream,\n        features.data<float>(), rois.data<float>(), output.data<float>(), output_diff.data<float>(),\n        features_diff.data<float>(),\n        nr_channels, height, width, pooled_height, pooled_width, spatial_scale,\n        top_count, bottom_count\n    );\n\n    THCudaCheck(cudaGetLastError());\n    return features_diff;\n}\n\nat::Tensor prroi_pooling_coor_backward_cuda(\n    const at::Tensor &features, const at::Tensor &rois, const at::Tensor &output, const at::Tensor &output_diff,\n    int pooled_height, int pooled_width, float spatial_scale) {\n\n    auto coor_diff = at::zeros_like(rois);\n\n    int nr_rois = rois.size(0);\n    int nr_channels = features.size(1);\n    int height = features.size(2);\n    int width = features.size(3);\n    int top_count = nr_rois * nr_channels * pooled_height * pooled_width;\n    int bottom_count = nr_rois * 5;\n\n    if (output.numel() == 0) {\n        THCudaCheck(cudaGetLastError());\n        return coor_diff;\n    }\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n    PrRoIPoolingCoorBackwardGpu(\n        stream,\n        features.data<float>(), rois.data<float>(), output.data<float>(), output_diff.data<float>(),\n        coor_diff.data<float>(),\n        nr_channels, height, width, pooled_height, pooled_width, spatial_scale,\n        top_count, bottom_count\n    );\n\n    THCudaCheck(cudaGetLastError());\n    return coor_diff;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"prroi_pooling_forward_cuda\", &prroi_pooling_forward_cuda, \"PRRoIPooling_forward\");\n    m.def(\"prroi_pooling_backward_cuda\", &prroi_pooling_backward_cuda, \"PRRoIPooling_backward\");\n    m.def(\"prroi_pooling_coor_backward_cuda\", &prroi_pooling_coor_backward_cuda, \"PRRoIPooling_backward_coor\");\n}\n"
  },
  {
    "path": "seeing/upsegmodel/prroi_pool/src/prroi_pooling_gpu.h",
    "content": "/*\n * File   : prroi_pooling_gpu.h\n * Author : Jiayuan Mao, Tete Xiao\n * Email  : maojiayuan@gmail.com, jasonhsiao97@gmail.com \n * Date   : 07/13/2018\n * \n * Distributed under terms of the MIT license.\n * Copyright (c) 2017 Megvii Technology Limited.\n */\n\nint prroi_pooling_forward_cuda(THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, int pooled_height, int pooled_width, float spatial_scale);\n\nint prroi_pooling_backward_cuda(\n    THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff,\n    int pooled_height, int pooled_width, float spatial_scale\n);\n\nint prroi_pooling_coor_backward_cuda(\n    THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff,\n    int pooled_height, int pooled_width, float spatial_scal\n);\n\n"
  },
  {
    "path": "seeing/upsegmodel/prroi_pool/src/prroi_pooling_gpu_impl.cu",
    "content": "/*\n * File   : prroi_pooling_gpu_impl.cu\n * Author : Tete Xiao, Jiayuan Mao\n * Email  : jasonhsiao97@gmail.com\n *\n * Distributed under terms of the MIT license.\n * Copyright (c) 2017 Megvii Technology Limited.\n */\n\n#include \"prroi_pooling_gpu_impl.cuh\"\n\n#include <cstdio>\n#include <cfloat>\n\n#define CUDA_KERNEL_LOOP(i, n) \\\n    for (int i = blockIdx.x * blockDim.x + threadIdx.x; \\\n        i < (n); \\\n        i += blockDim.x * gridDim.x)\n\n#define CUDA_POST_KERNEL_CHECK \\\n    do { \\\n        cudaError_t err = cudaGetLastError(); \\\n        if (cudaSuccess != err) { \\\n            fprintf(stderr, \"cudaCheckError() failed : %s\\n\", cudaGetErrorString(err)); \\\n            exit(-1); \\\n        } \\\n    } while(0)\n\n#define CUDA_NUM_THREADS 512\n\nnamespace {\n\nstatic int CUDA_NUM_BLOCKS(const int N) {\n  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;\n}\n\n__device__ static float PrRoIPoolingGetData(F_DEVPTR_IN data, const int h, const int w, const int height, const int width)\n{\n    bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);\n    float retVal = overflow ? 0.0f : data[h * width + w];\n    return retVal;\n}\n\n__device__ static float PrRoIPoolingGetCoeff(float dh, float dw){\n    dw = dw > 0 ? dw : -dw;\n    dh = dh > 0 ? dh : -dh;\n    return (1.0f - dh) * (1.0f - dw);\n}\n\n__device__ static float PrRoIPoolingSingleCoorIntegral(float s, float t, float c1, float c2) {\n    return 0.5 * (t * t - s * s) * c2 + (t - 0.5 * t * t - s + 0.5 * s * s) * c1;\n}\n\n__device__ static float PrRoIPoolingInterpolation(F_DEVPTR_IN data, const float h, const float w, const int height, const int width){\n    float retVal = 0.0f;\n    int h1 = floorf(h);\n    int w1 = floorf(w);\n    retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1));\n    h1 = floorf(h)+1;\n    w1 = floorf(w);\n    retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1));\n    h1 = floorf(h);\n    w1 = floorf(w)+1;\n    retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1));\n    h1 = floorf(h)+1;\n    w1 = floorf(w)+1;\n    retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1));\n    return retVal;\n}\n\n__device__ static float PrRoIPoolingMatCalculation(F_DEVPTR_IN this_data, const int s_h, const int s_w, const int e_h, const int e_w,\n        const float y0, const float x0, const float y1, const float x1, const int h0, const int w0)\n{\n    float alpha, beta, lim_alpha, lim_beta, tmp;\n    float sum_out = 0;\n\n    alpha = x0 - float(s_w);\n    beta = y0 - float(s_h);\n    lim_alpha = x1 - float(s_w);\n    lim_beta = y1 - float(s_h);\n    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)\n        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);\n    sum_out += PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp;\n\n    alpha = float(e_w) - x1;\n    lim_alpha = float(e_w) - x0;\n    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)\n        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);\n    sum_out += PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp;\n\n    alpha = x0 - float(s_w);\n    beta = float(e_h) - y1;\n    lim_alpha = x1 - float(s_w);\n    lim_beta = float(e_h) - y0;\n    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)\n        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);\n    sum_out += PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp;\n\n    alpha = float(e_w) - x1;\n    lim_alpha = float(e_w) - x0;\n    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)\n        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);\n    sum_out += PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp;\n\n    return sum_out;\n}\n\n__device__ static void PrRoIPoolingDistributeDiff(F_DEVPTR_OUT diff, const float top_diff, const int h, const int w, const int height, const int width, const float coeff)\n{\n    bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);\n    if (!overflow)\n        atomicAdd(diff + h * width + w, top_diff * coeff);\n}\n\n__device__ static void PrRoIPoolingMatDistributeDiff(F_DEVPTR_OUT diff, const float top_diff, const int s_h, const int s_w, const int e_h, const int e_w,\n        const float y0, const float x0, const float y1, const float x1, const int h0, const int w0)\n{\n    float alpha, beta, lim_alpha, lim_beta, tmp;\n\n    alpha = x0 - float(s_w);\n    beta = y0 - float(s_h);\n    lim_alpha = x1 - float(s_w);\n    lim_beta = y1 - float(s_h);\n    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)\n        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);\n    PrRoIPoolingDistributeDiff(diff, top_diff, s_h, s_w, h0, w0, tmp);\n\n    alpha = float(e_w) - x1;\n    lim_alpha = float(e_w) - x0;\n    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)\n        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);\n    PrRoIPoolingDistributeDiff(diff, top_diff, s_h, e_w, h0, w0, tmp);\n\n    alpha = x0 - float(s_w);\n    beta = float(e_h) - y1;\n    lim_alpha = x1 - float(s_w);\n    lim_beta = float(e_h) - y0;\n    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)\n        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);\n    PrRoIPoolingDistributeDiff(diff, top_diff, e_h, s_w, h0, w0, tmp);\n\n    alpha = float(e_w) - x1;\n    lim_alpha = float(e_w) - x0;\n    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)\n        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);\n    PrRoIPoolingDistributeDiff(diff, top_diff, e_h, e_w, h0, w0, tmp);\n}\n\n__global__ void PrRoIPoolingForward(\n        const int nthreads,\n        F_DEVPTR_IN bottom_data,\n        F_DEVPTR_IN bottom_rois,\n        F_DEVPTR_OUT top_data,\n        const int channels,\n        const int height,\n        const int width,\n        const int pooled_height,\n        const int pooled_width,\n        const float spatial_scale) {\n\n  CUDA_KERNEL_LOOP(index, nthreads) {\n    // (n, c, ph, pw) is an element in the pooled output\n    int pw = index % pooled_width;\n    int ph = (index / pooled_width) % pooled_height;\n    int c = (index / pooled_width / pooled_height) % channels;\n    int n = index / pooled_width / pooled_height / channels;\n\n    bottom_rois += n * 5;\n    int roi_batch_ind = bottom_rois[0];\n\n    float roi_start_w = bottom_rois[1] * spatial_scale;\n    float roi_start_h = bottom_rois[2] * spatial_scale;\n    float roi_end_w = bottom_rois[3] * spatial_scale;\n    float roi_end_h = bottom_rois[4] * spatial_scale;\n\n    float roi_width = max(roi_end_w - roi_start_w, ((float)0.0));\n    float roi_height = max(roi_end_h - roi_start_h, ((float)0.0));\n    float bin_size_h = roi_height / static_cast<float>(pooled_height);\n    float bin_size_w = roi_width / static_cast<float>(pooled_width);\n\n    const float *this_data = bottom_data + (roi_batch_ind * channels + c) * height * width;\n    float *this_out = top_data + index;\n\n    float win_start_w = roi_start_w + bin_size_w * pw;\n    float win_start_h = roi_start_h + bin_size_h * ph;\n    float win_end_w = win_start_w + bin_size_w;\n    float win_end_h = win_start_h + bin_size_h;\n\n    float win_size = max(float(0.0), bin_size_w * bin_size_h);\n    if (win_size == 0) {\n        *this_out = 0;\n        return;\n    }\n\n    float sum_out = 0;\n\n    int s_w, s_h, e_w, e_h;\n\n    s_w = floorf(win_start_w);\n    e_w = ceilf(win_end_w);\n    s_h = floorf(win_start_h);\n    e_h = ceilf(win_end_h);\n\n    for (int w_iter = s_w; w_iter < e_w; ++w_iter)\n        for (int h_iter = s_h; h_iter < e_h; ++h_iter)\n            sum_out += PrRoIPoolingMatCalculation(this_data, h_iter, w_iter, h_iter + 1, w_iter + 1,\n                max(win_start_h, float(h_iter)), max(win_start_w, float(w_iter)),\n                min(win_end_h, float(h_iter) + 1.0), min(win_end_w, float(w_iter + 1.0)),\n                height, width);\n    *this_out = sum_out / win_size;\n  }\n}\n\n__global__ void PrRoIPoolingBackward(\n        const int nthreads,\n        F_DEVPTR_IN bottom_rois,\n        F_DEVPTR_IN top_diff,\n        F_DEVPTR_OUT bottom_diff,\n        const int channels,\n        const int height,\n        const int width,\n        const int pooled_height,\n        const int pooled_width,\n        const float spatial_scale) {\n\n  CUDA_KERNEL_LOOP(index, nthreads) {\n    // (n, c, ph, pw) is an element in the pooled output\n    int pw = index % pooled_width;\n    int ph = (index / pooled_width) % pooled_height;\n    int c = (index / pooled_width / pooled_height) % channels;\n    int n = index / pooled_width / pooled_height / channels;\n    bottom_rois += n * 5;\n\n    int roi_batch_ind = bottom_rois[0];\n    float roi_start_w = bottom_rois[1] * spatial_scale;\n    float roi_start_h = bottom_rois[2] * spatial_scale;\n    float roi_end_w = bottom_rois[3] * spatial_scale;\n    float roi_end_h = bottom_rois[4] * spatial_scale;\n\n    float roi_width = max(roi_end_w - roi_start_w, (float)0);\n    float roi_height = max(roi_end_h - roi_start_h, (float)0);\n    float bin_size_h = roi_height / static_cast<float>(pooled_height);\n    float bin_size_w = roi_width / static_cast<float>(pooled_width);\n\n    const float *this_out_grad = top_diff + index;\n    float *this_data_grad = bottom_diff + (roi_batch_ind * channels + c) * height * width;\n\n    float win_start_w = roi_start_w + bin_size_w * pw;\n    float win_start_h = roi_start_h + bin_size_h * ph;\n    float win_end_w = win_start_w + bin_size_w;\n    float win_end_h = win_start_h + bin_size_h;\n\n    float win_size = max(float(0.0), bin_size_w * bin_size_h);\n\n    float sum_out = win_size == float(0) ? float(0) : *this_out_grad / win_size;\n\n    int s_w, s_h, e_w, e_h;\n\n    s_w = floorf(win_start_w);\n    e_w = ceilf(win_end_w);\n    s_h = floorf(win_start_h);\n    e_h = ceilf(win_end_h);\n\n    for (int w_iter = s_w; w_iter < e_w; ++w_iter)\n        for (int h_iter = s_h; h_iter < e_h; ++h_iter)\n            PrRoIPoolingMatDistributeDiff(this_data_grad, sum_out, h_iter, w_iter, h_iter + 1, w_iter + 1,\n                max(win_start_h, float(h_iter)), max(win_start_w, float(w_iter)),\n                min(win_end_h, float(h_iter) + 1.0), min(win_end_w, float(w_iter + 1.0)),\n                height, width);\n\n  }\n}\n\n__global__ void PrRoIPoolingCoorBackward(\n        const int nthreads,\n        F_DEVPTR_IN bottom_data,\n        F_DEVPTR_IN bottom_rois,\n        F_DEVPTR_IN top_data,\n        F_DEVPTR_IN top_diff,\n        F_DEVPTR_OUT bottom_diff,\n        const int channels,\n        const int height,\n        const int width,\n        const int pooled_height,\n        const int pooled_width,\n        const float spatial_scale) {\n\n  CUDA_KERNEL_LOOP(index, nthreads) {\n    // (n, c, ph, pw) is an element in the pooled output\n    int pw = index % pooled_width;\n    int ph = (index / pooled_width) % pooled_height;\n    int c = (index / pooled_width / pooled_height) % channels;\n    int n = index / pooled_width / pooled_height / channels;\n    bottom_rois += n * 5;\n\n    int roi_batch_ind = bottom_rois[0];\n    float roi_start_w = bottom_rois[1] * spatial_scale;\n    float roi_start_h = bottom_rois[2] * spatial_scale;\n    float roi_end_w = bottom_rois[3] * spatial_scale;\n    float roi_end_h = bottom_rois[4] * spatial_scale;\n\n    float roi_width = max(roi_end_w - roi_start_w, (float)0);\n    float roi_height = max(roi_end_h - roi_start_h, (float)0);\n    float bin_size_h = roi_height / static_cast<float>(pooled_height);\n    float bin_size_w = roi_width / static_cast<float>(pooled_width);\n\n    const float *this_out_grad = top_diff + index;\n    const float *this_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;\n    const float *this_top_data = top_data + index;\n    float *this_data_grad = bottom_diff + n * 5;\n\n    float win_start_w = roi_start_w + bin_size_w * pw;\n    float win_start_h = roi_start_h + bin_size_h * ph;\n    float win_end_w = win_start_w + bin_size_w;\n    float win_end_h = win_start_h + bin_size_h;\n\n    float win_size = max(float(0.0), bin_size_w * bin_size_h);\n\n    float sum_out = win_size == float(0) ? float(0) : *this_out_grad / win_size;\n\n    // WARNING: to be discussed\n    if (sum_out == 0)\n        return;\n\n    int s_w, s_h, e_w, e_h;\n\n    s_w = floorf(win_start_w);\n    e_w = ceilf(win_end_w);\n    s_h = floorf(win_start_h);\n    e_h = ceilf(win_end_h);\n\n    float g_x1_y = 0, g_x2_y = 0, g_x_y1 = 0, g_x_y2 = 0;\n    for (int h_iter = s_h; h_iter < e_h; ++h_iter) {\n        g_x1_y += PrRoIPoolingSingleCoorIntegral(max(win_start_h, float(h_iter)) - h_iter,\n                min(win_end_h, float(h_iter + 1)) - h_iter,\n                PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_start_w, height, width),\n                PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_start_w, height, width));\n\n        g_x2_y += PrRoIPoolingSingleCoorIntegral(max(win_start_h, float(h_iter)) - h_iter,\n                min(win_end_h, float(h_iter + 1)) - h_iter,\n                PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_end_w, height, width),\n                PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_end_w, height, width));\n    }\n\n    for (int w_iter = s_w; w_iter < e_w; ++w_iter) {\n        g_x_y1 += PrRoIPoolingSingleCoorIntegral(max(win_start_w, float(w_iter)) - w_iter,\n                min(win_end_w, float(w_iter + 1)) - w_iter,\n                PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter, height, width),\n                PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter + 1, height, width));\n\n        g_x_y2 += PrRoIPoolingSingleCoorIntegral(max(win_start_w, float(w_iter)) - w_iter,\n                min(win_end_w, float(w_iter + 1)) - w_iter,\n                PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter, height, width),\n                PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter + 1, height, width));\n    }\n\n    float partial_x1 = -g_x1_y + (win_end_h - win_start_h) * (*this_top_data);\n    float partial_y1 = -g_x_y1 + (win_end_w - win_start_w) * (*this_top_data);\n    float partial_x2 = g_x2_y - (win_end_h - win_start_h) * (*this_top_data);\n    float partial_y2 = g_x_y2 - (win_end_w - win_start_w) * (*this_top_data);\n\n    partial_x1 = partial_x1 / win_size * spatial_scale;\n    partial_x2 = partial_x2 / win_size * spatial_scale;\n    partial_y1 = partial_y1 / win_size * spatial_scale;\n    partial_y2 = partial_y2 / win_size * spatial_scale;\n\n    // (b, x1, y1, x2, y2)\n\n    this_data_grad[0] = 0;\n    atomicAdd(this_data_grad + 1, (partial_x1 * (1.0 - float(pw) / pooled_width) + partial_x2 * (1.0 - float(pw + 1) / pooled_width))\n            * (*this_out_grad));\n    atomicAdd(this_data_grad + 2, (partial_y1 * (1.0 - float(ph) / pooled_height) + partial_y2 * (1.0 - float(ph + 1) / pooled_height))\n            * (*this_out_grad));\n    atomicAdd(this_data_grad + 3, (partial_x2 * float(pw + 1) / pooled_width + partial_x1 * float(pw) / pooled_width)\n            * (*this_out_grad));\n    atomicAdd(this_data_grad + 4, (partial_y2 * float(ph + 1) / pooled_height + partial_y1 * float(ph) / pooled_height)\n            * (*this_out_grad));\n  }\n}\n\n} /* !anonymous namespace */\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nvoid PrRoIPoolingForwardGpu(\n    cudaStream_t stream,\n    F_DEVPTR_IN bottom_data,\n    F_DEVPTR_IN bottom_rois,\n    F_DEVPTR_OUT top_data,\n    const int channels_, const int height_, const int width_,\n    const int pooled_height_, const int pooled_width_,\n    const float spatial_scale_,\n    const int top_count) {\n\n    PrRoIPoolingForward<<<CUDA_NUM_BLOCKS(top_count), CUDA_NUM_THREADS, 0, stream>>>(\n        top_count, bottom_data, bottom_rois, top_data,\n        channels_, height_, width_, pooled_height_, pooled_width_, spatial_scale_);\n\n    CUDA_POST_KERNEL_CHECK;\n}\n\nvoid PrRoIPoolingBackwardGpu(\n    cudaStream_t stream,\n    F_DEVPTR_IN bottom_data,\n    F_DEVPTR_IN bottom_rois,\n    F_DEVPTR_IN top_data,\n    F_DEVPTR_IN top_diff,\n    F_DEVPTR_OUT bottom_diff,\n    const int channels_, const int height_, const int width_,\n    const int pooled_height_, const int pooled_width_,\n    const float spatial_scale_,\n    const int top_count, const int bottom_count) {\n\n    cudaMemsetAsync(bottom_diff, 0, sizeof(float) * bottom_count, stream);\n    PrRoIPoolingBackward<<<CUDA_NUM_BLOCKS(top_count), CUDA_NUM_THREADS, 0, stream>>>(\n        top_count, bottom_rois, top_diff, bottom_diff,\n        channels_, height_, width_, pooled_height_, pooled_width_, spatial_scale_);\n    CUDA_POST_KERNEL_CHECK;\n}\n\nvoid PrRoIPoolingCoorBackwardGpu(\n    cudaStream_t stream,\n    F_DEVPTR_IN bottom_data,\n    F_DEVPTR_IN bottom_rois,\n    F_DEVPTR_IN top_data,\n    F_DEVPTR_IN top_diff,\n    F_DEVPTR_OUT bottom_diff,\n    const int channels_, const int height_, const int width_,\n    const int pooled_height_, const int pooled_width_,\n    const float spatial_scale_,\n    const int top_count, const int bottom_count) {\n\n    cudaMemsetAsync(bottom_diff, 0, sizeof(float) * bottom_count, stream);\n    PrRoIPoolingCoorBackward<<<CUDA_NUM_BLOCKS(top_count), CUDA_NUM_THREADS, 0, stream>>>(\n        top_count, bottom_data, bottom_rois, top_data, top_diff, bottom_diff,\n        channels_, height_, width_, pooled_height_, pooled_width_, spatial_scale_);\n    CUDA_POST_KERNEL_CHECK;\n}\n\n} /* !extern \"C\" */\n\n"
  },
  {
    "path": "seeing/upsegmodel/prroi_pool/src/prroi_pooling_gpu_impl.cuh",
    "content": "/*\n * File   : prroi_pooling_gpu_impl.cuh\n * Author : Tete Xiao, Jiayuan Mao\n * Email  : jasonhsiao97@gmail.com\n *\n * Distributed under terms of the MIT license.\n * Copyright (c) 2017 Megvii Technology Limited.\n */\n\n#ifndef PRROI_POOLING_GPU_IMPL_CUH\n#define PRROI_POOLING_GPU_IMPL_CUH\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n#define F_DEVPTR_IN const float *\n#define F_DEVPTR_OUT float *\n\nvoid PrRoIPoolingForwardGpu(\n    cudaStream_t stream,\n    F_DEVPTR_IN bottom_data,\n    F_DEVPTR_IN bottom_rois,\n    F_DEVPTR_OUT top_data,\n    const int channels_, const int height_, const int width_,\n    const int pooled_height_, const int pooled_width_,\n    const float spatial_scale_,\n    const int top_count);\n\nvoid PrRoIPoolingBackwardGpu(\n    cudaStream_t stream,\n    F_DEVPTR_IN bottom_data,\n    F_DEVPTR_IN bottom_rois,\n    F_DEVPTR_IN top_data,\n    F_DEVPTR_IN top_diff,\n    F_DEVPTR_OUT bottom_diff,\n    const int channels_, const int height_, const int width_,\n    const int pooled_height_, const int pooled_width_,\n    const float spatial_scale_,\n    const int top_count, const int bottom_count);\n\nvoid PrRoIPoolingCoorBackwardGpu(\n    cudaStream_t stream,\n    F_DEVPTR_IN bottom_data,\n    F_DEVPTR_IN bottom_rois,\n    F_DEVPTR_IN top_data,\n    F_DEVPTR_IN top_diff,\n    F_DEVPTR_OUT bottom_diff,\n    const int channels_, const int height_, const int width_,\n    const int pooled_height_, const int pooled_width_,\n    const float spatial_scale_,\n    const int top_count, const int bottom_count);\n\n#ifdef __cplusplus\n} /* !extern \"C\" */\n#endif\n\n#endif /* !PRROI_POOLING_GPU_IMPL_CUH */\n\n"
  },
  {
    "path": "seeing/upsegmodel/prroi_pool/test_prroi_pooling2d.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : test_prroi_pooling2d.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 18/02/2018\n#\n# This file is part of Jacinle.\n\nimport unittest\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom jactorch.utils.unittest import TorchTestCase\n\nfrom prroi_pool import PrRoIPool2D\n\n\nclass TestPrRoIPool2D(TorchTestCase):\n    def test_forward(self):\n        pool = PrRoIPool2D(7, 7, spatial_scale=0.5)\n        features = torch.rand((4, 16, 24, 32)).cuda()\n        rois = torch.tensor([\n            [0, 0, 0, 14, 14],\n            [1, 14, 14, 28, 28],\n        ]).float().cuda()\n\n        out = pool(features, rois)\n        out_gold = F.avg_pool2d(features, kernel_size=2, stride=1)\n\n        self.assertTensorClose(out, torch.stack((\n            out_gold[0, :, :7, :7],\n            out_gold[1, :, 7:14, 7:14],\n        ), dim=0))\n\n    def test_backward_shapeonly(self):\n        pool = PrRoIPool2D(2, 2, spatial_scale=0.5)\n\n        features = torch.rand((4, 2, 24, 32)).cuda()\n        rois = torch.tensor([\n            [0, 0, 0, 4, 4],\n            [1, 14, 14, 18, 18],\n        ]).float().cuda()\n        features.requires_grad = rois.requires_grad = True\n        out = pool(features, rois)\n\n        loss = out.sum()\n        loss.backward()\n\n        self.assertTupleEqual(features.size(), features.grad.size())\n        self.assertTupleEqual(rois.size(), rois.grad.size())\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "seeing/upsegmodel/resnet.py",
    "content": "import os\nimport sys\nimport torch\nimport torch.nn as nn\nimport math\ntry:\n    from lib.nn import SynchronizedBatchNorm2d\nexcept ImportError:\n    from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d\n\ntry:\n    from urllib import urlretrieve\nexcept ImportError:\n    from urllib.request import urlretrieve\n\n\n__all__ = ['ResNet', 'resnet50', 'resnet101'] # resnet101 is coming soon!\n\n\nmodel_urls = {\n    'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',\n    'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth'\n}\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"3x3 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = SynchronizedBatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = SynchronizedBatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = SynchronizedBatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = SynchronizedBatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n        self.bn3 = SynchronizedBatchNorm2d(planes * 4)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n\n    def __init__(self, block, layers, num_classes=1000):\n        self.inplanes = 128\n        super(ResNet, self).__init__()\n        self.conv1 = conv3x3(3, 64, stride=2)\n        self.bn1 = SynchronizedBatchNorm2d(64)\n        self.relu1 = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(64, 64)\n        self.bn2 = SynchronizedBatchNorm2d(64)\n        self.relu2 = nn.ReLU(inplace=True)\n        self.conv3 = conv3x3(64, 128)\n        self.bn3 = SynchronizedBatchNorm2d(128)\n        self.relu3 = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n        self.avgpool = nn.AvgPool2d(7, stride=1)\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, SynchronizedBatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                SynchronizedBatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.relu1(self.bn1(self.conv1(x)))\n        x = self.relu2(self.bn2(self.conv2(x)))\n        x = self.relu3(self.bn3(self.conv3(x)))\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n\n        return x\n\n'''\ndef resnet18(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-18 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Places\n    \"\"\"\n    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)\n    if pretrained:\n        model.load_state_dict(load_url(model_urls['resnet18']))\n    return model\n\n\ndef resnet34(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-34 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Places\n    \"\"\"\n    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)\n    if pretrained:\n        model.load_state_dict(load_url(model_urls['resnet34']))\n    return model\n'''\n\ndef resnet50(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-50 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Places\n    \"\"\"\n    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)\n    if pretrained:\n        model.load_state_dict(load_url(model_urls['resnet50']), strict=False)\n    return model\n\n\ndef resnet101(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-101 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Places\n    \"\"\"\n    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)\n    if pretrained:\n        model.load_state_dict(load_url(model_urls['resnet101']), strict=False)\n    return model\n\n# def resnet152(pretrained=False, **kwargs):\n#     \"\"\"Constructs a ResNet-152 model.\n#\n#     Args:\n#         pretrained (bool): If True, returns a model pre-trained on Places\n#     \"\"\"\n#     model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)\n#     if pretrained:\n#         model.load_state_dict(load_url(model_urls['resnet152']))\n#     return model\n\ndef load_url(url, model_dir='./pretrained', map_location=None):\n    if not os.path.exists(model_dir):\n        os.makedirs(model_dir)\n    filename = url.split('/')[-1]\n    cached_file = os.path.join(model_dir, filename)\n    if not os.path.exists(cached_file):\n        sys.stderr.write('Downloading: \"{}\" to {}\\n'.format(url, cached_file))\n        urlretrieve(url, cached_file)\n    return torch.load(cached_file, map_location=map_location)\n"
  },
  {
    "path": "seeing/upsegmodel/resnext.py",
    "content": "import os\nimport sys\nimport torch\nimport torch.nn as nn\nimport math\ntry:\n    from lib.nn import SynchronizedBatchNorm2d\nexcept ImportError:\n    from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d\n\ntry:\n    from urllib import urlretrieve\nexcept ImportError:\n    from urllib.request import urlretrieve\n\n\n__all__ = ['ResNeXt', 'resnext101'] # support resnext 101\n\n\nmodel_urls = {\n    #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth',\n    'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth'\n}\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"3x3 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\n\nclass GroupBottleneck(nn.Module):\n    expansion = 2\n\n    def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None):\n        super(GroupBottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = SynchronizedBatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, groups=groups, bias=False)\n        self.bn2 = SynchronizedBatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False)\n        self.bn3 = SynchronizedBatchNorm2d(planes * 2)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNeXt(nn.Module):\n\n    def __init__(self, block, layers, groups=32, num_classes=1000):\n        self.inplanes = 128\n        super(ResNeXt, self).__init__()\n        self.conv1 = conv3x3(3, 64, stride=2)\n        self.bn1 = SynchronizedBatchNorm2d(64)\n        self.relu1 = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(64, 64)\n        self.bn2 = SynchronizedBatchNorm2d(64)\n        self.relu2 = nn.ReLU(inplace=True)\n        self.conv3 = conv3x3(64, 128)\n        self.bn3 = SynchronizedBatchNorm2d(128)\n        self.relu3 = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\n        self.layer1 = self._make_layer(block, 128, layers[0], groups=groups)\n        self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups)\n        self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups)\n        self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups)\n        self.avgpool = nn.AvgPool2d(7, stride=1)\n        self.fc = nn.Linear(1024 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, SynchronizedBatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def _make_layer(self, block, planes, blocks, stride=1, groups=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                SynchronizedBatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, groups, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=groups))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.relu1(self.bn1(self.conv1(x)))\n        x = self.relu2(self.bn2(self.conv2(x)))\n        x = self.relu3(self.bn3(self.conv3(x)))\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n\n        return x\n\n\n'''\ndef resnext50(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-50 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Places\n    \"\"\"\n    model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs)\n    if pretrained:\n        model.load_state_dict(load_url(model_urls['resnext50']), strict=False)\n    return model\n'''\n\n\ndef resnext101(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-101 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Places\n    \"\"\"\n    model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs)\n    if pretrained:\n        model.load_state_dict(load_url(model_urls['resnext101']), strict=False)\n    return model\n\n\n# def resnext152(pretrained=False, **kwargs):\n#     \"\"\"Constructs a ResNeXt-152 model.\n#\n#     Args:\n#         pretrained (bool): If True, returns a model pre-trained on Places\n#     \"\"\"\n#     model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs)\n#     if pretrained:\n#         model.load_state_dict(load_url(model_urls['resnext152']))\n#     return model\n\n\ndef load_url(url, model_dir='./pretrained', map_location=None):\n    if not os.path.exists(model_dir):\n        os.makedirs(model_dir)\n    filename = url.split('/')[-1]\n    cached_file = os.path.join(model_dir, filename)\n    if not os.path.exists(cached_file):\n        sys.stderr.write('Downloading: \"{}\" to {}\\n'.format(url, cached_file))\n        urlretrieve(url, cached_file)\n    return torch.load(cached_file, map_location=map_location)\n"
  },
  {
    "path": "seeing/yz_dataset.py",
    "content": "import torch, numpy\n\n\nclass YZDataset():\n    def __init__(self, zdim=256, nlabels=1, distribution=[1.], device='cpu'):\n        self.zdim = zdim\n        self.nlabels = nlabels\n        self.device = device\n        self.distribution = distribution\n        assert (len(distribution) == nlabels)\n\n    def __call__(self, seeds):\n        zs, ys = [], []\n        for seed in seeds:\n            rng = numpy.random.RandomState(seed)\n            z = torch.from_numpy(\n                rng.standard_normal(self.zdim).reshape(\n                    1, self.zdim)).float().to(self.device)\n            y = torch.from_numpy(\n                rng.choice(self.nlabels, 1, replace=False,\n                           p=self.distribution)).long().to(self.device)\n            zs.append(z)\n            ys.append(y)\n        return torch.cat(zs, dim=0), torch.cat(ys, dim=0)\n\n\nif __name__ == '__main__':\n    sampler = YZDataset()\n    a, d = sampler([10, 11])\n    b, e = sampler([12, 13])\n    assert ((a - b).mean() > 1e-3)\n    c, f = sampler([10, 11])\n    assert ((a - c).mean() < 1e-3)\n"
  },
  {
    "path": "seeing/zdataset.py",
    "content": "import os, torch, numpy\nfrom torch.utils.data import TensorDataset\n\ndef z_dataset_for_model(model, size=100, seed=1):\n    return TensorDataset(z_sample_for_model(model, size, seed))\n\ndef z_sample_for_model(model, size=100, seed=1):\n    # If the model is marked with an input shape, use it.\n    if hasattr(model, 'input_shape'):\n        sample = standard_z_sample(size, model.input_shape[1], seed=seed).view(\n                (size,) + model.input_shape[1:])\n        return sample\n    # Examine first conv in model to determine input feature size.\n    first_layer = [c for c in model.modules()\n            if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,\n                torch.nn.Linear))][0]\n    # 4d input if convolutional, 2d input if first layer is linear.\n    if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):\n        sample = standard_z_sample(\n                size, first_layer.in_channels, seed=seed)[:,:,None,None]\n    else:\n        sample = standard_z_sample(\n                size, first_layer.in_features, seed=seed)\n    return sample\n\ndef standard_z_sample(size, depth, seed=1, device=None):\n\t'''\n\tGenerate a standard set of random Z as a (size, z_dimension) tensor.\n\tWith the same random seed, it always returns the same z (e.g.,\n\tthe first one is always the same regardless of the size.)\n\t'''\n\t# Use numpy RandomState since it can be done deterministically\n\t# without affecting global state\n\trng = numpy.random.RandomState(seed)\n\tresult = torch.from_numpy(\n\t\t\trng.standard_normal(size * depth)\n\t\t\t.reshape(size, depth)).float()\n\tif device is not None:\n\t\tresult = result.to(device)\n\treturn result\n\n"
  },
  {
    "path": "train.py",
    "content": "import argparse\nimport os\nimport copy\nimport pprint\nfrom os import path\n\nimport torch\nimport numpy as np\nfrom torch import nn\n\nfrom gan_training import utils\nfrom gan_training.train import Trainer, update_average\nfrom gan_training.logger import Logger\nfrom gan_training.checkpoints import CheckpointIO\nfrom gan_training.inputs import get_dataset\nfrom gan_training.distributions import get_ydist, get_zdist\nfrom gan_training.eval import Evaluator\nfrom gan_training.config import (load_config, get_clusterer, build_models, build_optimizers)\nfrom seeing.pidfile import exit_if_job_done, mark_job_done\n\ntorch.backends.cudnn.benchmark = True\n\n# Arguments\nparser = argparse.ArgumentParser(\n    description='Train a GAN with different regularization strategies.')\nparser.add_argument('config', type=str, help='Path to config file.')\nparser.add_argument('--outdir', type=str, help='used to override outdir (useful for multiple runs)')\nparser.add_argument('--nepochs', type=int, default=250, help='number of epochs to run before terminating')\nparser.add_argument('--model_it', type=int, default=-1, help='which model iteration to load from, -1 loads the most recent model')\nparser.add_argument('--devices', nargs='+', type=str, default=['0'], help='devices to use')\n\nargs = parser.parse_args()\nconfig = load_config(args.config, 'configs/default.yaml')\nout_dir = config['training']['out_dir'] if args.outdir is None else args.outdir\n\n\ndef main():\n    pp = pprint.PrettyPrinter(indent=1)\n    pp.pprint({\n        'data': config['data'],\n        'generator': config['generator'],\n        'discriminator': config['discriminator'],\n        'clusterer': config['clusterer'],\n        'training': config['training']\n    })\n    is_cuda = torch.cuda.is_available()\n\n    # Short hands\n    batch_size = config['training']['batch_size']\n    log_every = config['training']['log_every']\n    inception_every = config['training']['inception_every']\n    backup_every = config['training']['backup_every']\n    sample_nlabels = config['training']['sample_nlabels']\n    nlabels = config['data']['nlabels']\n    sample_nlabels = min(nlabels, sample_nlabels)\n\n    checkpoint_dir = path.join(out_dir, 'chkpts')\n\n    # Create missing directories\n    if not path.exists(out_dir):\n        os.makedirs(out_dir)\n    if not path.exists(checkpoint_dir):\n        os.makedirs(checkpoint_dir)\n\n    # Logger\n    checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)\n\n    device = torch.device(\"cuda:0\" if is_cuda else \"cpu\")\n\n    train_dataset, _ = get_dataset(\n        name=config['data']['type'],\n        data_dir=config['data']['train_dir'],\n        size=config['data']['img_size'],\n        deterministic=config['data']['deterministic'])\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=batch_size,\n        num_workers=config['training']['nworkers'],\n        shuffle=True,\n        pin_memory=True,\n        sampler=None,\n        drop_last=True)\n\n    # Create models\n    generator, discriminator = build_models(config)\n\n    # Put models on gpu if needed\n    generator = generator.to(device)\n    discriminator = discriminator.to(device)\n\n    for name, module in discriminator.named_modules():\n        if isinstance(module, nn.Sigmoid):\n            print('Found sigmoid layer in discriminator; not compatible with BCE with logits')\n            exit()\n\n    g_optimizer, d_optimizer = build_optimizers(generator, discriminator, config)\n\n    devices = [int(x) for x in args.devices]\n    generator = nn.DataParallel(generator, device_ids=devices)\n    discriminator = nn.DataParallel(discriminator, device_ids=devices)\n\n    # Register modules to checkpoint\n    checkpoint_io.register_modules(generator=generator,\n                                   discriminator=discriminator,\n                                   g_optimizer=g_optimizer,\n                                   d_optimizer=d_optimizer)\n\n    # Logger\n    logger = Logger(log_dir=path.join(out_dir, 'logs'),\n                    img_dir=path.join(out_dir, 'imgs'),\n                    monitoring=config['training']['monitoring'],\n                    monitoring_dir=path.join(out_dir, 'monitoring'))\n\n    # Distributions\n    ydist = get_ydist(nlabels, device=device)\n    zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], device=device)\n\n    ntest = config['training']['ntest']\n    x_test, y_test = utils.get_nsamples(train_loader, ntest)\n    x_cluster, y_cluster = utils.get_nsamples(train_loader, config['clusterer']['nimgs'])\n    x_test, y_test = x_test.to(device), y_test.to(device)\n    z_test = zdist.sample((ntest, ))\n    utils.save_images(x_test, path.join(out_dir, 'real.png'))\n    logger.add_imgs(x_test, 'gt', 0)\n\n    # Test generator\n    if config['training']['take_model_average']:\n        print('Taking model average')\n        bad_modules = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]\n        for model in [generator, discriminator]:\n            for name, module in model.named_modules():\n                for bad_module in bad_modules:\n                    if isinstance(module, bad_module):\n                        print('Batch norm in discriminator not compatible with exponential moving average')\n                        exit()\n        generator_test = copy.deepcopy(generator)\n        checkpoint_io.register_modules(generator_test=generator_test)\n    else:\n        generator_test = generator\n\n    clusterer = get_clusterer(config)(discriminator=discriminator,\n                                      x_cluster=x_cluster,\n                                      x_labels=y_cluster,\n                                      gt_nlabels=config['data']['nlabels'],\n                                      **config['clusterer']['kwargs'])\n\n    # Load checkpoint if it exists\n    it = utils.get_most_recent(checkpoint_dir, 'model') if args.model_it == -1 else args.model_it\n    it, epoch_idx, loaded_clusterer = checkpoint_io.load_models(it=it, load_samples='supervised' != config['clusterer']['name'])\n\n    if loaded_clusterer is None:\n        print('Initializing new clusterer. The first clustering can be quite slow.')\n        clusterer.recluster(discriminator=discriminator)\n        checkpoint_io.save_clusterer(clusterer, it=0)\n        np.savez(os.path.join(checkpoint_dir, 'cluster_samples.npz'), x=x_cluster)\n    else:\n        print('Using loaded clusterer')\n        clusterer = loaded_clusterer\n\n    # Evaluator\n    evaluator = Evaluator(\n        generator_test,\n        zdist,\n        ydist,\n        train_loader=train_loader,\n        clusterer=clusterer,\n        batch_size=batch_size,\n        device=device,\n        inception_nsamples=config['training']['inception_nsamples'])\n\n    # Trainer\n    trainer = Trainer(generator,\n                      discriminator,\n                      g_optimizer,\n                      d_optimizer,\n                      gan_type=config['training']['gan_type'],\n                      reg_type=config['training']['reg_type'],\n                      reg_param=config['training']['reg_param'])\n\n    # Training loop\n    print('Start training...')\n    while it < args.nepochs * len(train_loader):\n        epoch_idx += 1\n\n        for x_real, y in train_loader:\n            it += 1\n\n            x_real, y = x_real.to(device), y.to(device)\n            z = zdist.sample((batch_size, ))\n            y = clusterer.get_labels(x_real, y).to(device)\n\n            # Discriminator updates\n            dloss, reg = trainer.discriminator_trainstep(x_real, y, z)\n            logger.add('losses', 'discriminator', dloss, it=it)\n            logger.add('losses', 'regularizer', reg, it=it)\n\n            # Generators updates\n            gloss = trainer.generator_trainstep(y, z)\n            logger.add('losses', 'generator', gloss, it=it)\n\n            if config['training']['take_model_average']:\n                update_average(generator_test, generator, beta=config['training']['model_average_beta'])\n\n            # Print stats\n            if it % log_every == 0:\n                g_loss_last = logger.get_last('losses', 'generator')\n                d_loss_last = logger.get_last('losses', 'discriminator')\n                d_reg_last = logger.get_last('losses', 'regularizer')\n                print('[epoch %0d, it %4d] g_loss = %.4f, d_loss = %.4f, reg=%.4f'\n                      % (epoch_idx, it, g_loss_last, d_loss_last, d_reg_last))\n\n            if it % config['training']['recluster_every'] == 0 and it > config['training']['burnin_time']:\n                # print cluster distribution for online methods\n                if it % 100 == 0 and config['training']['recluster_every'] <= 100:\n                    print(f'[epoch {epoch_idx}, it {it}], distribution: {clusterer.get_label_distribution(x_real)}')\n                clusterer.recluster(discriminator=discriminator, x_batch=x_real)\n\n            # (i) Sample if necessary\n            if it % config['training']['sample_every'] == 0:\n                print('Creating samples...')\n                x = evaluator.create_samples(z_test, y_test)\n                x = evaluator.create_samples(z_test, clusterer.get_labels(x_test, y_test).to(device))\n                logger.add_imgs(x, 'all', it)\n\n                for y_inst in range(sample_nlabels):\n                    x = evaluator.create_samples(z_test, y_inst)\n                    logger.add_imgs(x, '%04d' % y_inst, it)\n\n            # (ii) Compute inception if necessary\n            if it % inception_every == 0 and it > 0:\n                print('PyTorch Inception score...')\n                inception_mean, inception_std = evaluator.compute_inception_score()\n                logger.add('metrics', 'pt_inception_mean', inception_mean, it=it)\n                logger.add('metrics', 'pt_inception_stddev', inception_std, it=it)\n                print(f'[epoch {epoch_idx}, it {it}] pt_inception_mean: {inception_mean}, pt_inception_stddev: {inception_std}')\n\n            # (iii) Backup if necessary\n            if it % backup_every == 0:\n                print('Saving backup...')\n                checkpoint_io.save('model_%08d.pt' % it, it=it)\n                checkpoint_io.save_clusterer(clusterer, int(it))\n                logger.save_stats('stats_%08d.p' % it)\n\n                if it > 0:\n                    checkpoint_io.save('model.pt', it=it)\n\n\nif __name__ == '__main__':\n    exit_if_job_done(out_dir)\n    main()\n    mark_job_done(out_dir)\n"
  },
  {
    "path": "utils/classifiers/__init__.py",
    "content": "from classifiers import stacked_mnist, cifar, places, imagenet\n\nclassifier_dict = {\n    'stacked_mnist': stacked_mnist.Classifier,\n    'cifar': cifar.Classifier, \n    'places': places.Classifier,\n    'imagenet': imagenet.Classifier\n}"
  },
  {
    "path": "utils/classifiers/cifar.py",
    "content": "import sys\nsys.path.append('utils/classifiers')\n\nfrom pytorch_playground.cifar.model import cifar10\n\nclass Classifier():\n    def __init__(self):\n        self.classifier = cifar10().cuda()\n\n    def get_predictions(self, x):\n        assert(x.size(1) == 3)\n        return self.classifier(x).argmax(dim=1)\n"
  },
  {
    "path": "utils/classifiers/imagenet.py",
    "content": "import torch\nimport torchvision.models as models\nfrom torchvision import transforms as trn\nfrom torch.nn import functional as F\nimport os\n\n\nclass Classifier():\n    def __init__(self):\n        self.model = models.resnet50(pretrained=True).cuda()\n        self.model.eval()\n\n        self.mean = [0.485, 0.456, 0.406]\n        self.std = [0.229, 0.224, 0.225]\n        self.trn = trn.Normalize(self.mean, self.std)\n\n        import json\n        with open(\"utils/classifiers/imagenet_class_index.json\") as f:\n            self.class_idx = json.load(f)\n\n    def transform(self, x):\n        x = F.interpolate(x, size=(224, 224)) / 255.\n        x = torch.stack([self.trn(xi) for xi in x]).cuda()\n        return x\n\n    def get_name(self, class_id):\n        return self.class_idx[str(class_id)][1]\n\n    def get_predictions_and_confidence(self, x):\n        x = self.transform(x)\n        logit = self.model.forward(x)\n        values, ind = logit.max(dim=1)\n        return ind, values\n\n    def get_predictions(self, x):\n        x = self.transform(x)\n        logit = self.model.forward(x)\n        return logit.argmax(dim=1)\n"
  },
  {
    "path": "utils/classifiers/imagenet_class_index.json",
    "content": "{\"0\": [\"n01440764\", \"tench\"], \"1\": [\"n01443537\", \"goldfish\"], \"2\": [\"n01484850\", \"great_white_shark\"], \"3\": [\"n01491361\", \"tiger_shark\"], \"4\": [\"n01494475\", \"hammerhead\"], \"5\": [\"n01496331\", \"electric_ray\"], \"6\": [\"n01498041\", \"stingray\"], \"7\": [\"n01514668\", \"cock\"], \"8\": [\"n01514859\", \"hen\"], \"9\": [\"n01518878\", \"ostrich\"], \"10\": [\"n01530575\", \"brambling\"], \"11\": [\"n01531178\", \"goldfinch\"], \"12\": [\"n01532829\", \"house_finch\"], \"13\": [\"n01534433\", \"junco\"], \"14\": [\"n01537544\", \"indigo_bunting\"], \"15\": [\"n01558993\", \"robin\"], \"16\": [\"n01560419\", \"bulbul\"], \"17\": [\"n01580077\", \"jay\"], \"18\": [\"n01582220\", \"magpie\"], \"19\": [\"n01592084\", \"chickadee\"], \"20\": [\"n01601694\", \"water_ouzel\"], \"21\": [\"n01608432\", \"kite\"], \"22\": [\"n01614925\", \"bald_eagle\"], \"23\": [\"n01616318\", \"vulture\"], \"24\": [\"n01622779\", \"great_grey_owl\"], \"25\": [\"n01629819\", \"European_fire_salamander\"], \"26\": [\"n01630670\", \"common_newt\"], \"27\": [\"n01631663\", \"eft\"], \"28\": [\"n01632458\", \"spotted_salamander\"], \"29\": [\"n01632777\", \"axolotl\"], \"30\": [\"n01641577\", \"bullfrog\"], \"31\": [\"n01644373\", \"tree_frog\"], \"32\": [\"n01644900\", \"tailed_frog\"], \"33\": [\"n01664065\", \"loggerhead\"], \"34\": [\"n01665541\", \"leatherback_turtle\"], \"35\": [\"n01667114\", \"mud_turtle\"], \"36\": [\"n01667778\", \"terrapin\"], \"37\": [\"n01669191\", \"box_turtle\"], \"38\": [\"n01675722\", \"banded_gecko\"], \"39\": [\"n01677366\", \"common_iguana\"], \"40\": [\"n01682714\", \"American_chameleon\"], \"41\": [\"n01685808\", \"whiptail\"], \"42\": [\"n01687978\", \"agama\"], \"43\": [\"n01688243\", \"frilled_lizard\"], \"44\": [\"n01689811\", \"alligator_lizard\"], \"45\": [\"n01692333\", \"Gila_monster\"], \"46\": [\"n01693334\", \"green_lizard\"], \"47\": [\"n01694178\", \"African_chameleon\"], \"48\": [\"n01695060\", \"Komodo_dragon\"], \"49\": [\"n01697457\", \"African_crocodile\"], \"50\": [\"n01698640\", \"American_alligator\"], \"51\": [\"n01704323\", \"triceratops\"], \"52\": [\"n01728572\", \"thunder_snake\"], \"53\": [\"n01728920\", \"ringneck_snake\"], \"54\": [\"n01729322\", \"hognose_snake\"], \"55\": [\"n01729977\", \"green_snake\"], \"56\": [\"n01734418\", \"king_snake\"], \"57\": [\"n01735189\", \"garter_snake\"], \"58\": [\"n01737021\", \"water_snake\"], \"59\": [\"n01739381\", \"vine_snake\"], \"60\": [\"n01740131\", \"night_snake\"], \"61\": [\"n01742172\", \"boa_constrictor\"], \"62\": [\"n01744401\", \"rock_python\"], \"63\": [\"n01748264\", \"Indian_cobra\"], \"64\": [\"n01749939\", \"green_mamba\"], \"65\": [\"n01751748\", \"sea_snake\"], \"66\": [\"n01753488\", \"horned_viper\"], \"67\": [\"n01755581\", \"diamondback\"], \"68\": [\"n01756291\", \"sidewinder\"], \"69\": [\"n01768244\", \"trilobite\"], \"70\": [\"n01770081\", \"harvestman\"], \"71\": [\"n01770393\", \"scorpion\"], \"72\": [\"n01773157\", \"black_and_gold_garden_spider\"], \"73\": [\"n01773549\", \"barn_spider\"], \"74\": [\"n01773797\", \"garden_spider\"], \"75\": [\"n01774384\", \"black_widow\"], \"76\": [\"n01774750\", \"tarantula\"], \"77\": [\"n01775062\", \"wolf_spider\"], \"78\": [\"n01776313\", \"tick\"], \"79\": [\"n01784675\", \"centipede\"], \"80\": [\"n01795545\", \"black_grouse\"], \"81\": [\"n01796340\", \"ptarmigan\"], \"82\": [\"n01797886\", \"ruffed_grouse\"], \"83\": [\"n01798484\", \"prairie_chicken\"], \"84\": [\"n01806143\", \"peacock\"], \"85\": [\"n01806567\", \"quail\"], \"86\": [\"n01807496\", \"partridge\"], \"87\": [\"n01817953\", \"African_grey\"], \"88\": [\"n01818515\", \"macaw\"], \"89\": [\"n01819313\", \"sulphur-crested_cockatoo\"], \"90\": [\"n01820546\", \"lorikeet\"], \"91\": [\"n01824575\", \"coucal\"], \"92\": [\"n01828970\", \"bee_eater\"], \"93\": [\"n01829413\", \"hornbill\"], \"94\": [\"n01833805\", \"hummingbird\"], \"95\": [\"n01843065\", \"jacamar\"], \"96\": [\"n01843383\", \"toucan\"], \"97\": [\"n01847000\", \"drake\"], \"98\": [\"n01855032\", \"red-breasted_merganser\"], \"99\": [\"n01855672\", \"goose\"], \"100\": [\"n01860187\", \"black_swan\"], \"101\": [\"n01871265\", \"tusker\"], \"102\": [\"n01872401\", \"echidna\"], \"103\": [\"n01873310\", \"platypus\"], \"104\": [\"n01877812\", \"wallaby\"], \"105\": [\"n01882714\", \"koala\"], \"106\": [\"n01883070\", \"wombat\"], \"107\": [\"n01910747\", \"jellyfish\"], \"108\": [\"n01914609\", \"sea_anemone\"], \"109\": [\"n01917289\", \"brain_coral\"], \"110\": [\"n01924916\", \"flatworm\"], \"111\": [\"n01930112\", \"nematode\"], \"112\": [\"n01943899\", \"conch\"], \"113\": [\"n01944390\", \"snail\"], \"114\": [\"n01945685\", \"slug\"], \"115\": [\"n01950731\", \"sea_slug\"], \"116\": [\"n01955084\", \"chiton\"], \"117\": [\"n01968897\", \"chambered_nautilus\"], \"118\": [\"n01978287\", \"Dungeness_crab\"], \"119\": [\"n01978455\", \"rock_crab\"], \"120\": [\"n01980166\", \"fiddler_crab\"], \"121\": [\"n01981276\", \"king_crab\"], \"122\": [\"n01983481\", \"American_lobster\"], \"123\": [\"n01984695\", \"spiny_lobster\"], \"124\": [\"n01985128\", \"crayfish\"], \"125\": [\"n01986214\", \"hermit_crab\"], \"126\": [\"n01990800\", \"isopod\"], \"127\": [\"n02002556\", \"white_stork\"], \"128\": [\"n02002724\", \"black_stork\"], \"129\": [\"n02006656\", \"spoonbill\"], \"130\": [\"n02007558\", \"flamingo\"], \"131\": [\"n02009229\", \"little_blue_heron\"], \"132\": [\"n02009912\", \"American_egret\"], \"133\": [\"n02011460\", \"bittern\"], \"134\": [\"n02012849\", \"crane\"], \"135\": [\"n02013706\", \"limpkin\"], \"136\": [\"n02017213\", \"European_gallinule\"], \"137\": [\"n02018207\", \"American_coot\"], \"138\": [\"n02018795\", \"bustard\"], \"139\": [\"n02025239\", \"ruddy_turnstone\"], \"140\": [\"n02027492\", \"red-backed_sandpiper\"], \"141\": [\"n02028035\", \"redshank\"], \"142\": [\"n02033041\", \"dowitcher\"], \"143\": [\"n02037110\", \"oystercatcher\"], \"144\": [\"n02051845\", \"pelican\"], \"145\": [\"n02056570\", \"king_penguin\"], \"146\": [\"n02058221\", \"albatross\"], \"147\": [\"n02066245\", \"grey_whale\"], \"148\": [\"n02071294\", \"killer_whale\"], \"149\": [\"n02074367\", \"dugong\"], \"150\": [\"n02077923\", \"sea_lion\"], \"151\": [\"n02085620\", \"Chihuahua\"], \"152\": [\"n02085782\", \"Japanese_spaniel\"], \"153\": [\"n02085936\", \"Maltese_dog\"], \"154\": [\"n02086079\", \"Pekinese\"], \"155\": [\"n02086240\", \"Shih-Tzu\"], \"156\": [\"n02086646\", \"Blenheim_spaniel\"], \"157\": [\"n02086910\", \"papillon\"], \"158\": [\"n02087046\", \"toy_terrier\"], \"159\": [\"n02087394\", \"Rhodesian_ridgeback\"], \"160\": [\"n02088094\", \"Afghan_hound\"], \"161\": [\"n02088238\", \"basset\"], \"162\": [\"n02088364\", \"beagle\"], \"163\": [\"n02088466\", \"bloodhound\"], \"164\": [\"n02088632\", \"bluetick\"], \"165\": [\"n02089078\", \"black-and-tan_coonhound\"], \"166\": [\"n02089867\", \"Walker_hound\"], \"167\": [\"n02089973\", \"English_foxhound\"], \"168\": [\"n02090379\", \"redbone\"], \"169\": [\"n02090622\", \"borzoi\"], \"170\": [\"n02090721\", \"Irish_wolfhound\"], \"171\": [\"n02091032\", \"Italian_greyhound\"], \"172\": [\"n02091134\", \"whippet\"], \"173\": [\"n02091244\", \"Ibizan_hound\"], \"174\": [\"n02091467\", \"Norwegian_elkhound\"], \"175\": [\"n02091635\", \"otterhound\"], \"176\": [\"n02091831\", \"Saluki\"], \"177\": [\"n02092002\", \"Scottish_deerhound\"], \"178\": [\"n02092339\", \"Weimaraner\"], \"179\": [\"n02093256\", \"Staffordshire_bullterrier\"], \"180\": [\"n02093428\", \"American_Staffordshire_terrier\"], \"181\": [\"n02093647\", \"Bedlington_terrier\"], \"182\": [\"n02093754\", \"Border_terrier\"], \"183\": [\"n02093859\", \"Kerry_blue_terrier\"], \"184\": [\"n02093991\", \"Irish_terrier\"], \"185\": [\"n02094114\", \"Norfolk_terrier\"], \"186\": [\"n02094258\", \"Norwich_terrier\"], \"187\": [\"n02094433\", \"Yorkshire_terrier\"], \"188\": [\"n02095314\", \"wire-haired_fox_terrier\"], \"189\": [\"n02095570\", \"Lakeland_terrier\"], \"190\": [\"n02095889\", \"Sealyham_terrier\"], \"191\": [\"n02096051\", \"Airedale\"], \"192\": [\"n02096177\", \"cairn\"], \"193\": [\"n02096294\", \"Australian_terrier\"], \"194\": [\"n02096437\", \"Dandie_Dinmont\"], \"195\": [\"n02096585\", \"Boston_bull\"], \"196\": [\"n02097047\", \"miniature_schnauzer\"], \"197\": [\"n02097130\", \"giant_schnauzer\"], \"198\": [\"n02097209\", \"standard_schnauzer\"], \"199\": [\"n02097298\", \"Scotch_terrier\"], \"200\": [\"n02097474\", \"Tibetan_terrier\"], \"201\": [\"n02097658\", \"silky_terrier\"], \"202\": [\"n02098105\", \"soft-coated_wheaten_terrier\"], \"203\": [\"n02098286\", \"West_Highland_white_terrier\"], \"204\": [\"n02098413\", \"Lhasa\"], \"205\": [\"n02099267\", \"flat-coated_retriever\"], \"206\": [\"n02099429\", \"curly-coated_retriever\"], \"207\": [\"n02099601\", \"golden_retriever\"], \"208\": [\"n02099712\", \"Labrador_retriever\"], \"209\": [\"n02099849\", \"Chesapeake_Bay_retriever\"], \"210\": [\"n02100236\", \"German_short-haired_pointer\"], \"211\": [\"n02100583\", \"vizsla\"], \"212\": [\"n02100735\", \"English_setter\"], \"213\": [\"n02100877\", \"Irish_setter\"], \"214\": [\"n02101006\", \"Gordon_setter\"], \"215\": [\"n02101388\", \"Brittany_spaniel\"], \"216\": [\"n02101556\", \"clumber\"], \"217\": [\"n02102040\", \"English_springer\"], \"218\": [\"n02102177\", \"Welsh_springer_spaniel\"], \"219\": [\"n02102318\", \"cocker_spaniel\"], \"220\": [\"n02102480\", \"Sussex_spaniel\"], \"221\": [\"n02102973\", \"Irish_water_spaniel\"], \"222\": [\"n02104029\", \"kuvasz\"], \"223\": [\"n02104365\", \"schipperke\"], \"224\": [\"n02105056\", \"groenendael\"], \"225\": [\"n02105162\", \"malinois\"], \"226\": [\"n02105251\", \"briard\"], \"227\": [\"n02105412\", \"kelpie\"], \"228\": [\"n02105505\", \"komondor\"], \"229\": [\"n02105641\", \"Old_English_sheepdog\"], \"230\": [\"n02105855\", \"Shetland_sheepdog\"], \"231\": [\"n02106030\", \"collie\"], \"232\": [\"n02106166\", \"Border_collie\"], \"233\": [\"n02106382\", \"Bouvier_des_Flandres\"], \"234\": [\"n02106550\", \"Rottweiler\"], \"235\": [\"n02106662\", \"German_shepherd\"], \"236\": [\"n02107142\", \"Doberman\"], \"237\": [\"n02107312\", \"miniature_pinscher\"], \"238\": [\"n02107574\", \"Greater_Swiss_Mountain_dog\"], \"239\": [\"n02107683\", \"Bernese_mountain_dog\"], \"240\": [\"n02107908\", \"Appenzeller\"], \"241\": [\"n02108000\", \"EntleBucher\"], \"242\": [\"n02108089\", \"boxer\"], \"243\": [\"n02108422\", \"bull_mastiff\"], \"244\": [\"n02108551\", \"Tibetan_mastiff\"], \"245\": [\"n02108915\", \"French_bulldog\"], \"246\": [\"n02109047\", \"Great_Dane\"], \"247\": [\"n02109525\", \"Saint_Bernard\"], \"248\": [\"n02109961\", \"Eskimo_dog\"], \"249\": [\"n02110063\", \"malamute\"], \"250\": [\"n02110185\", \"Siberian_husky\"], \"251\": [\"n02110341\", \"dalmatian\"], \"252\": [\"n02110627\", \"affenpinscher\"], \"253\": [\"n02110806\", \"basenji\"], \"254\": [\"n02110958\", \"pug\"], \"255\": [\"n02111129\", \"Leonberg\"], \"256\": [\"n02111277\", \"Newfoundland\"], \"257\": [\"n02111500\", \"Great_Pyrenees\"], \"258\": [\"n02111889\", \"Samoyed\"], \"259\": [\"n02112018\", \"Pomeranian\"], \"260\": [\"n02112137\", \"chow\"], \"261\": [\"n02112350\", \"keeshond\"], \"262\": [\"n02112706\", \"Brabancon_griffon\"], \"263\": [\"n02113023\", \"Pembroke\"], \"264\": [\"n02113186\", \"Cardigan\"], \"265\": [\"n02113624\", \"toy_poodle\"], \"266\": [\"n02113712\", \"miniature_poodle\"], \"267\": [\"n02113799\", \"standard_poodle\"], \"268\": [\"n02113978\", \"Mexican_hairless\"], \"269\": [\"n02114367\", \"timber_wolf\"], \"270\": [\"n02114548\", \"white_wolf\"], \"271\": [\"n02114712\", \"red_wolf\"], \"272\": [\"n02114855\", \"coyote\"], \"273\": [\"n02115641\", \"dingo\"], \"274\": [\"n02115913\", \"dhole\"], \"275\": [\"n02116738\", \"African_hunting_dog\"], \"276\": [\"n02117135\", \"hyena\"], \"277\": [\"n02119022\", \"red_fox\"], \"278\": [\"n02119789\", \"kit_fox\"], \"279\": [\"n02120079\", \"Arctic_fox\"], \"280\": [\"n02120505\", \"grey_fox\"], \"281\": [\"n02123045\", \"tabby\"], \"282\": [\"n02123159\", \"tiger_cat\"], \"283\": [\"n02123394\", \"Persian_cat\"], \"284\": [\"n02123597\", \"Siamese_cat\"], \"285\": [\"n02124075\", \"Egyptian_cat\"], \"286\": [\"n02125311\", \"cougar\"], \"287\": [\"n02127052\", \"lynx\"], \"288\": [\"n02128385\", \"leopard\"], \"289\": [\"n02128757\", \"snow_leopard\"], \"290\": [\"n02128925\", \"jaguar\"], \"291\": [\"n02129165\", \"lion\"], \"292\": [\"n02129604\", \"tiger\"], \"293\": [\"n02130308\", \"cheetah\"], \"294\": [\"n02132136\", \"brown_bear\"], \"295\": [\"n02133161\", \"American_black_bear\"], \"296\": [\"n02134084\", \"ice_bear\"], \"297\": [\"n02134418\", \"sloth_bear\"], \"298\": [\"n02137549\", \"mongoose\"], \"299\": [\"n02138441\", \"meerkat\"], \"300\": [\"n02165105\", \"tiger_beetle\"], \"301\": [\"n02165456\", \"ladybug\"], \"302\": [\"n02167151\", \"ground_beetle\"], \"303\": [\"n02168699\", \"long-horned_beetle\"], \"304\": [\"n02169497\", \"leaf_beetle\"], \"305\": [\"n02172182\", \"dung_beetle\"], \"306\": [\"n02174001\", \"rhinoceros_beetle\"], \"307\": [\"n02177972\", \"weevil\"], \"308\": [\"n02190166\", \"fly\"], \"309\": [\"n02206856\", \"bee\"], \"310\": [\"n02219486\", \"ant\"], \"311\": [\"n02226429\", \"grasshopper\"], \"312\": [\"n02229544\", \"cricket\"], \"313\": [\"n02231487\", \"walking_stick\"], \"314\": [\"n02233338\", \"cockroach\"], \"315\": [\"n02236044\", \"mantis\"], \"316\": [\"n02256656\", \"cicada\"], \"317\": [\"n02259212\", \"leafhopper\"], \"318\": [\"n02264363\", \"lacewing\"], \"319\": [\"n02268443\", \"dragonfly\"], \"320\": [\"n02268853\", \"damselfly\"], \"321\": [\"n02276258\", \"admiral\"], \"322\": [\"n02277742\", \"ringlet\"], \"323\": [\"n02279972\", \"monarch\"], \"324\": [\"n02280649\", \"cabbage_butterfly\"], \"325\": [\"n02281406\", \"sulphur_butterfly\"], \"326\": [\"n02281787\", \"lycaenid\"], \"327\": [\"n02317335\", \"starfish\"], \"328\": [\"n02319095\", \"sea_urchin\"], \"329\": [\"n02321529\", \"sea_cucumber\"], \"330\": [\"n02325366\", \"wood_rabbit\"], \"331\": [\"n02326432\", \"hare\"], \"332\": [\"n02328150\", \"Angora\"], \"333\": [\"n02342885\", \"hamster\"], \"334\": [\"n02346627\", \"porcupine\"], \"335\": [\"n02356798\", \"fox_squirrel\"], \"336\": [\"n02361337\", \"marmot\"], \"337\": [\"n02363005\", \"beaver\"], \"338\": [\"n02364673\", \"guinea_pig\"], \"339\": [\"n02389026\", \"sorrel\"], \"340\": [\"n02391049\", \"zebra\"], \"341\": [\"n02395406\", \"hog\"], \"342\": [\"n02396427\", \"wild_boar\"], \"343\": [\"n02397096\", \"warthog\"], \"344\": [\"n02398521\", \"hippopotamus\"], \"345\": [\"n02403003\", \"ox\"], \"346\": [\"n02408429\", \"water_buffalo\"], \"347\": [\"n02410509\", \"bison\"], \"348\": [\"n02412080\", \"ram\"], \"349\": [\"n02415577\", \"bighorn\"], \"350\": [\"n02417914\", \"ibex\"], \"351\": [\"n02422106\", \"hartebeest\"], \"352\": [\"n02422699\", \"impala\"], \"353\": [\"n02423022\", \"gazelle\"], \"354\": [\"n02437312\", \"Arabian_camel\"], \"355\": [\"n02437616\", \"llama\"], \"356\": [\"n02441942\", \"weasel\"], \"357\": [\"n02442845\", \"mink\"], \"358\": [\"n02443114\", \"polecat\"], \"359\": [\"n02443484\", \"black-footed_ferret\"], \"360\": [\"n02444819\", \"otter\"], \"361\": [\"n02445715\", \"skunk\"], \"362\": [\"n02447366\", \"badger\"], \"363\": [\"n02454379\", \"armadillo\"], \"364\": [\"n02457408\", \"three-toed_sloth\"], \"365\": [\"n02480495\", \"orangutan\"], \"366\": [\"n02480855\", \"gorilla\"], \"367\": [\"n02481823\", \"chimpanzee\"], \"368\": [\"n02483362\", \"gibbon\"], \"369\": [\"n02483708\", \"siamang\"], \"370\": [\"n02484975\", \"guenon\"], \"371\": [\"n02486261\", \"patas\"], \"372\": [\"n02486410\", \"baboon\"], \"373\": [\"n02487347\", \"macaque\"], \"374\": [\"n02488291\", \"langur\"], \"375\": [\"n02488702\", \"colobus\"], \"376\": [\"n02489166\", \"proboscis_monkey\"], \"377\": [\"n02490219\", \"marmoset\"], \"378\": [\"n02492035\", \"capuchin\"], \"379\": [\"n02492660\", \"howler_monkey\"], \"380\": [\"n02493509\", \"titi\"], \"381\": [\"n02493793\", \"spider_monkey\"], \"382\": [\"n02494079\", \"squirrel_monkey\"], \"383\": [\"n02497673\", \"Madagascar_cat\"], \"384\": [\"n02500267\", \"indri\"], \"385\": [\"n02504013\", \"Indian_elephant\"], \"386\": [\"n02504458\", \"African_elephant\"], \"387\": [\"n02509815\", \"lesser_panda\"], \"388\": [\"n02510455\", \"giant_panda\"], \"389\": [\"n02514041\", \"barracouta\"], \"390\": [\"n02526121\", \"eel\"], \"391\": [\"n02536864\", \"coho\"], \"392\": [\"n02606052\", \"rock_beauty\"], \"393\": [\"n02607072\", \"anemone_fish\"], \"394\": [\"n02640242\", \"sturgeon\"], \"395\": [\"n02641379\", \"gar\"], \"396\": [\"n02643566\", \"lionfish\"], \"397\": [\"n02655020\", \"puffer\"], \"398\": [\"n02666196\", \"abacus\"], \"399\": [\"n02667093\", \"abaya\"], \"400\": [\"n02669723\", \"academic_gown\"], \"401\": [\"n02672831\", \"accordion\"], \"402\": [\"n02676566\", \"acoustic_guitar\"], \"403\": [\"n02687172\", \"aircraft_carrier\"], \"404\": [\"n02690373\", \"airliner\"], \"405\": [\"n02692877\", \"airship\"], \"406\": [\"n02699494\", \"altar\"], \"407\": [\"n02701002\", \"ambulance\"], \"408\": [\"n02704792\", \"amphibian\"], \"409\": [\"n02708093\", \"analog_clock\"], \"410\": [\"n02727426\", \"apiary\"], \"411\": [\"n02730930\", \"apron\"], \"412\": [\"n02747177\", \"ashcan\"], \"413\": [\"n02749479\", \"assault_rifle\"], \"414\": [\"n02769748\", \"backpack\"], \"415\": [\"n02776631\", \"bakery\"], \"416\": [\"n02777292\", \"balance_beam\"], \"417\": [\"n02782093\", \"balloon\"], \"418\": [\"n02783161\", \"ballpoint\"], \"419\": [\"n02786058\", \"Band_Aid\"], \"420\": [\"n02787622\", \"banjo\"], \"421\": [\"n02788148\", \"bannister\"], \"422\": [\"n02790996\", \"barbell\"], \"423\": [\"n02791124\", \"barber_chair\"], \"424\": [\"n02791270\", \"barbershop\"], \"425\": [\"n02793495\", \"barn\"], \"426\": [\"n02794156\", \"barometer\"], \"427\": [\"n02795169\", \"barrel\"], \"428\": [\"n02797295\", \"barrow\"], \"429\": [\"n02799071\", \"baseball\"], \"430\": [\"n02802426\", \"basketball\"], \"431\": [\"n02804414\", \"bassinet\"], \"432\": [\"n02804610\", \"bassoon\"], \"433\": [\"n02807133\", \"bathing_cap\"], \"434\": [\"n02808304\", \"bath_towel\"], \"435\": [\"n02808440\", \"bathtub\"], \"436\": [\"n02814533\", \"beach_wagon\"], \"437\": [\"n02814860\", \"beacon\"], \"438\": [\"n02815834\", \"beaker\"], \"439\": [\"n02817516\", \"bearskin\"], \"440\": [\"n02823428\", \"beer_bottle\"], \"441\": [\"n02823750\", \"beer_glass\"], \"442\": [\"n02825657\", \"bell_cote\"], \"443\": [\"n02834397\", \"bib\"], \"444\": [\"n02835271\", \"bicycle-built-for-two\"], \"445\": [\"n02837789\", \"bikini\"], \"446\": [\"n02840245\", \"binder\"], \"447\": [\"n02841315\", \"binoculars\"], \"448\": [\"n02843684\", \"birdhouse\"], \"449\": [\"n02859443\", \"boathouse\"], \"450\": [\"n02860847\", \"bobsled\"], \"451\": [\"n02865351\", \"bolo_tie\"], \"452\": [\"n02869837\", \"bonnet\"], \"453\": [\"n02870880\", \"bookcase\"], \"454\": [\"n02871525\", \"bookshop\"], \"455\": [\"n02877765\", \"bottlecap\"], \"456\": [\"n02879718\", \"bow\"], \"457\": [\"n02883205\", \"bow_tie\"], \"458\": [\"n02892201\", \"brass\"], \"459\": [\"n02892767\", \"brassiere\"], \"460\": [\"n02894605\", \"breakwater\"], \"461\": [\"n02895154\", \"breastplate\"], \"462\": [\"n02906734\", \"broom\"], \"463\": [\"n02909870\", \"bucket\"], \"464\": [\"n02910353\", \"buckle\"], \"465\": [\"n02916936\", \"bulletproof_vest\"], \"466\": [\"n02917067\", \"bullet_train\"], \"467\": [\"n02927161\", \"butcher_shop\"], \"468\": [\"n02930766\", \"cab\"], \"469\": [\"n02939185\", \"caldron\"], \"470\": [\"n02948072\", \"candle\"], \"471\": [\"n02950826\", \"cannon\"], \"472\": [\"n02951358\", \"canoe\"], \"473\": [\"n02951585\", \"can_opener\"], \"474\": [\"n02963159\", \"cardigan\"], \"475\": [\"n02965783\", \"car_mirror\"], \"476\": [\"n02966193\", \"carousel\"], \"477\": [\"n02966687\", \"carpenter's_kit\"], \"478\": [\"n02971356\", \"carton\"], \"479\": [\"n02974003\", \"car_wheel\"], \"480\": [\"n02977058\", \"cash_machine\"], \"481\": [\"n02978881\", \"cassette\"], \"482\": [\"n02979186\", \"cassette_player\"], \"483\": [\"n02980441\", \"castle\"], \"484\": [\"n02981792\", \"catamaran\"], \"485\": [\"n02988304\", \"CD_player\"], \"486\": [\"n02992211\", \"cello\"], \"487\": [\"n02992529\", \"cellular_telephone\"], \"488\": [\"n02999410\", \"chain\"], \"489\": [\"n03000134\", \"chainlink_fence\"], \"490\": [\"n03000247\", \"chain_mail\"], \"491\": [\"n03000684\", \"chain_saw\"], \"492\": [\"n03014705\", \"chest\"], \"493\": [\"n03016953\", \"chiffonier\"], \"494\": [\"n03017168\", \"chime\"], \"495\": [\"n03018349\", \"china_cabinet\"], \"496\": [\"n03026506\", \"Christmas_stocking\"], \"497\": [\"n03028079\", \"church\"], \"498\": [\"n03032252\", \"cinema\"], \"499\": [\"n03041632\", \"cleaver\"], \"500\": [\"n03042490\", \"cliff_dwelling\"], \"501\": [\"n03045698\", \"cloak\"], \"502\": [\"n03047690\", \"clog\"], \"503\": [\"n03062245\", \"cocktail_shaker\"], \"504\": [\"n03063599\", \"coffee_mug\"], \"505\": [\"n03063689\", \"coffeepot\"], \"506\": [\"n03065424\", \"coil\"], \"507\": [\"n03075370\", \"combination_lock\"], \"508\": [\"n03085013\", \"computer_keyboard\"], \"509\": [\"n03089624\", \"confectionery\"], \"510\": [\"n03095699\", \"container_ship\"], \"511\": [\"n03100240\", \"convertible\"], \"512\": [\"n03109150\", \"corkscrew\"], \"513\": [\"n03110669\", \"cornet\"], \"514\": [\"n03124043\", \"cowboy_boot\"], \"515\": [\"n03124170\", \"cowboy_hat\"], \"516\": [\"n03125729\", \"cradle\"], \"517\": [\"n03126707\", \"crane\"], \"518\": [\"n03127747\", \"crash_helmet\"], \"519\": [\"n03127925\", \"crate\"], \"520\": [\"n03131574\", \"crib\"], \"521\": [\"n03133878\", \"Crock_Pot\"], \"522\": [\"n03134739\", \"croquet_ball\"], \"523\": [\"n03141823\", \"crutch\"], \"524\": [\"n03146219\", \"cuirass\"], \"525\": [\"n03160309\", \"dam\"], \"526\": [\"n03179701\", \"desk\"], \"527\": [\"n03180011\", \"desktop_computer\"], \"528\": [\"n03187595\", \"dial_telephone\"], \"529\": [\"n03188531\", \"diaper\"], \"530\": [\"n03196217\", \"digital_clock\"], \"531\": [\"n03197337\", \"digital_watch\"], \"532\": [\"n03201208\", \"dining_table\"], \"533\": [\"n03207743\", \"dishrag\"], \"534\": [\"n03207941\", \"dishwasher\"], \"535\": [\"n03208938\", \"disk_brake\"], \"536\": [\"n03216828\", \"dock\"], \"537\": [\"n03218198\", \"dogsled\"], \"538\": [\"n03220513\", \"dome\"], \"539\": [\"n03223299\", \"doormat\"], \"540\": [\"n03240683\", \"drilling_platform\"], \"541\": [\"n03249569\", \"drum\"], \"542\": [\"n03250847\", \"drumstick\"], \"543\": [\"n03255030\", \"dumbbell\"], \"544\": [\"n03259280\", \"Dutch_oven\"], \"545\": [\"n03271574\", \"electric_fan\"], \"546\": [\"n03272010\", \"electric_guitar\"], \"547\": [\"n03272562\", \"electric_locomotive\"], \"548\": [\"n03290653\", \"entertainment_center\"], \"549\": [\"n03291819\", \"envelope\"], \"550\": [\"n03297495\", \"espresso_maker\"], \"551\": [\"n03314780\", \"face_powder\"], \"552\": [\"n03325584\", \"feather_boa\"], \"553\": [\"n03337140\", \"file\"], \"554\": [\"n03344393\", \"fireboat\"], \"555\": [\"n03345487\", \"fire_engine\"], \"556\": [\"n03347037\", \"fire_screen\"], \"557\": [\"n03355925\", \"flagpole\"], \"558\": [\"n03372029\", \"flute\"], \"559\": [\"n03376595\", \"folding_chair\"], \"560\": [\"n03379051\", \"football_helmet\"], \"561\": [\"n03384352\", \"forklift\"], \"562\": [\"n03388043\", \"fountain\"], \"563\": [\"n03388183\", \"fountain_pen\"], \"564\": [\"n03388549\", \"four-poster\"], \"565\": [\"n03393912\", \"freight_car\"], \"566\": [\"n03394916\", \"French_horn\"], \"567\": [\"n03400231\", \"frying_pan\"], \"568\": [\"n03404251\", \"fur_coat\"], \"569\": [\"n03417042\", \"garbage_truck\"], \"570\": [\"n03424325\", \"gasmask\"], \"571\": [\"n03425413\", \"gas_pump\"], \"572\": [\"n03443371\", \"goblet\"], \"573\": [\"n03444034\", \"go-kart\"], \"574\": [\"n03445777\", \"golf_ball\"], \"575\": [\"n03445924\", \"golfcart\"], \"576\": [\"n03447447\", \"gondola\"], \"577\": [\"n03447721\", \"gong\"], \"578\": [\"n03450230\", \"gown\"], \"579\": [\"n03452741\", \"grand_piano\"], \"580\": [\"n03457902\", \"greenhouse\"], \"581\": [\"n03459775\", \"grille\"], \"582\": [\"n03461385\", \"grocery_store\"], \"583\": [\"n03467068\", \"guillotine\"], \"584\": [\"n03476684\", \"hair_slide\"], \"585\": [\"n03476991\", \"hair_spray\"], \"586\": [\"n03478589\", \"half_track\"], \"587\": [\"n03481172\", \"hammer\"], \"588\": [\"n03482405\", \"hamper\"], \"589\": [\"n03483316\", \"hand_blower\"], \"590\": [\"n03485407\", \"hand-held_computer\"], \"591\": [\"n03485794\", \"handkerchief\"], \"592\": [\"n03492542\", \"hard_disc\"], \"593\": [\"n03494278\", \"harmonica\"], \"594\": [\"n03495258\", \"harp\"], \"595\": [\"n03496892\", \"harvester\"], \"596\": [\"n03498962\", \"hatchet\"], \"597\": [\"n03527444\", \"holster\"], \"598\": [\"n03529860\", \"home_theater\"], \"599\": [\"n03530642\", \"honeycomb\"], \"600\": [\"n03532672\", \"hook\"], \"601\": [\"n03534580\", \"hoopskirt\"], \"602\": [\"n03535780\", \"horizontal_bar\"], \"603\": [\"n03538406\", \"horse_cart\"], \"604\": [\"n03544143\", \"hourglass\"], \"605\": [\"n03584254\", \"iPod\"], \"606\": [\"n03584829\", \"iron\"], \"607\": [\"n03590841\", \"jack-o'-lantern\"], \"608\": [\"n03594734\", \"jean\"], \"609\": [\"n03594945\", \"jeep\"], \"610\": [\"n03595614\", \"jersey\"], \"611\": [\"n03598930\", \"jigsaw_puzzle\"], \"612\": [\"n03599486\", \"jinrikisha\"], \"613\": [\"n03602883\", \"joystick\"], \"614\": [\"n03617480\", \"kimono\"], \"615\": [\"n03623198\", \"knee_pad\"], \"616\": [\"n03627232\", \"knot\"], \"617\": [\"n03630383\", \"lab_coat\"], \"618\": [\"n03633091\", \"ladle\"], \"619\": [\"n03637318\", \"lampshade\"], \"620\": [\"n03642806\", \"laptop\"], \"621\": [\"n03649909\", \"lawn_mower\"], \"622\": [\"n03657121\", \"lens_cap\"], \"623\": [\"n03658185\", \"letter_opener\"], \"624\": [\"n03661043\", \"library\"], \"625\": [\"n03662601\", \"lifeboat\"], \"626\": [\"n03666591\", \"lighter\"], \"627\": [\"n03670208\", \"limousine\"], \"628\": [\"n03673027\", \"liner\"], \"629\": [\"n03676483\", \"lipstick\"], \"630\": [\"n03680355\", \"Loafer\"], \"631\": [\"n03690938\", \"lotion\"], \"632\": [\"n03691459\", \"loudspeaker\"], \"633\": [\"n03692522\", \"loupe\"], \"634\": [\"n03697007\", \"lumbermill\"], \"635\": [\"n03706229\", \"magnetic_compass\"], \"636\": [\"n03709823\", \"mailbag\"], \"637\": [\"n03710193\", \"mailbox\"], \"638\": [\"n03710637\", \"maillot\"], \"639\": [\"n03710721\", \"maillot\"], \"640\": [\"n03717622\", \"manhole_cover\"], \"641\": [\"n03720891\", \"maraca\"], \"642\": [\"n03721384\", \"marimba\"], \"643\": [\"n03724870\", \"mask\"], \"644\": [\"n03729826\", \"matchstick\"], \"645\": [\"n03733131\", \"maypole\"], \"646\": [\"n03733281\", \"maze\"], \"647\": [\"n03733805\", \"measuring_cup\"], \"648\": [\"n03742115\", \"medicine_chest\"], \"649\": [\"n03743016\", \"megalith\"], \"650\": [\"n03759954\", \"microphone\"], \"651\": [\"n03761084\", \"microwave\"], \"652\": [\"n03763968\", \"military_uniform\"], \"653\": [\"n03764736\", \"milk_can\"], \"654\": [\"n03769881\", \"minibus\"], \"655\": [\"n03770439\", \"miniskirt\"], \"656\": [\"n03770679\", \"minivan\"], \"657\": [\"n03773504\", \"missile\"], \"658\": [\"n03775071\", \"mitten\"], \"659\": [\"n03775546\", \"mixing_bowl\"], \"660\": [\"n03776460\", \"mobile_home\"], \"661\": [\"n03777568\", \"Model_T\"], \"662\": [\"n03777754\", \"modem\"], \"663\": [\"n03781244\", \"monastery\"], \"664\": [\"n03782006\", \"monitor\"], \"665\": [\"n03785016\", \"moped\"], \"666\": [\"n03786901\", \"mortar\"], \"667\": [\"n03787032\", \"mortarboard\"], \"668\": [\"n03788195\", \"mosque\"], \"669\": [\"n03788365\", \"mosquito_net\"], \"670\": [\"n03791053\", \"motor_scooter\"], \"671\": [\"n03792782\", \"mountain_bike\"], \"672\": [\"n03792972\", \"mountain_tent\"], \"673\": [\"n03793489\", \"mouse\"], \"674\": [\"n03794056\", \"mousetrap\"], \"675\": [\"n03796401\", \"moving_van\"], \"676\": [\"n03803284\", \"muzzle\"], \"677\": [\"n03804744\", \"nail\"], \"678\": [\"n03814639\", \"neck_brace\"], \"679\": [\"n03814906\", \"necklace\"], \"680\": [\"n03825788\", \"nipple\"], \"681\": [\"n03832673\", \"notebook\"], \"682\": [\"n03837869\", \"obelisk\"], \"683\": [\"n03838899\", \"oboe\"], \"684\": [\"n03840681\", \"ocarina\"], \"685\": [\"n03841143\", \"odometer\"], \"686\": [\"n03843555\", \"oil_filter\"], \"687\": [\"n03854065\", \"organ\"], \"688\": [\"n03857828\", \"oscilloscope\"], \"689\": [\"n03866082\", \"overskirt\"], \"690\": [\"n03868242\", \"oxcart\"], \"691\": [\"n03868863\", \"oxygen_mask\"], \"692\": [\"n03871628\", \"packet\"], \"693\": [\"n03873416\", \"paddle\"], \"694\": [\"n03874293\", \"paddlewheel\"], \"695\": [\"n03874599\", \"padlock\"], \"696\": [\"n03876231\", \"paintbrush\"], \"697\": [\"n03877472\", \"pajama\"], \"698\": [\"n03877845\", \"palace\"], \"699\": [\"n03884397\", \"panpipe\"], \"700\": [\"n03887697\", \"paper_towel\"], \"701\": [\"n03888257\", \"parachute\"], \"702\": [\"n03888605\", \"parallel_bars\"], \"703\": [\"n03891251\", \"park_bench\"], \"704\": [\"n03891332\", \"parking_meter\"], \"705\": [\"n03895866\", \"passenger_car\"], \"706\": [\"n03899768\", \"patio\"], \"707\": [\"n03902125\", \"pay-phone\"], \"708\": [\"n03903868\", \"pedestal\"], \"709\": [\"n03908618\", \"pencil_box\"], \"710\": [\"n03908714\", \"pencil_sharpener\"], \"711\": [\"n03916031\", \"perfume\"], \"712\": [\"n03920288\", \"Petri_dish\"], \"713\": [\"n03924679\", \"photocopier\"], \"714\": [\"n03929660\", \"pick\"], \"715\": [\"n03929855\", \"pickelhaube\"], \"716\": [\"n03930313\", \"picket_fence\"], \"717\": [\"n03930630\", \"pickup\"], \"718\": [\"n03933933\", \"pier\"], \"719\": [\"n03935335\", \"piggy_bank\"], \"720\": [\"n03937543\", \"pill_bottle\"], \"721\": [\"n03938244\", \"pillow\"], \"722\": [\"n03942813\", \"ping-pong_ball\"], \"723\": [\"n03944341\", \"pinwheel\"], \"724\": [\"n03947888\", \"pirate\"], \"725\": [\"n03950228\", \"pitcher\"], \"726\": [\"n03954731\", \"plane\"], \"727\": [\"n03956157\", \"planetarium\"], \"728\": [\"n03958227\", \"plastic_bag\"], \"729\": [\"n03961711\", \"plate_rack\"], \"730\": [\"n03967562\", \"plow\"], \"731\": [\"n03970156\", \"plunger\"], \"732\": [\"n03976467\", \"Polaroid_camera\"], \"733\": [\"n03976657\", \"pole\"], \"734\": [\"n03977966\", \"police_van\"], \"735\": [\"n03980874\", \"poncho\"], \"736\": [\"n03982430\", \"pool_table\"], \"737\": [\"n03983396\", \"pop_bottle\"], \"738\": [\"n03991062\", \"pot\"], \"739\": [\"n03992509\", \"potter's_wheel\"], \"740\": [\"n03995372\", \"power_drill\"], \"741\": [\"n03998194\", \"prayer_rug\"], \"742\": [\"n04004767\", \"printer\"], \"743\": [\"n04005630\", \"prison\"], \"744\": [\"n04008634\", \"projectile\"], \"745\": [\"n04009552\", \"projector\"], \"746\": [\"n04019541\", \"puck\"], \"747\": [\"n04023962\", \"punching_bag\"], \"748\": [\"n04026417\", \"purse\"], \"749\": [\"n04033901\", \"quill\"], \"750\": [\"n04033995\", \"quilt\"], \"751\": [\"n04037443\", \"racer\"], \"752\": [\"n04039381\", \"racket\"], \"753\": [\"n04040759\", \"radiator\"], \"754\": [\"n04041544\", \"radio\"], \"755\": [\"n04044716\", \"radio_telescope\"], \"756\": [\"n04049303\", \"rain_barrel\"], \"757\": [\"n04065272\", \"recreational_vehicle\"], \"758\": [\"n04067472\", \"reel\"], \"759\": [\"n04069434\", \"reflex_camera\"], \"760\": [\"n04070727\", \"refrigerator\"], \"761\": [\"n04074963\", \"remote_control\"], \"762\": [\"n04081281\", \"restaurant\"], \"763\": [\"n04086273\", \"revolver\"], \"764\": [\"n04090263\", \"rifle\"], \"765\": [\"n04099969\", \"rocking_chair\"], \"766\": [\"n04111531\", \"rotisserie\"], \"767\": [\"n04116512\", \"rubber_eraser\"], \"768\": [\"n04118538\", \"rugby_ball\"], \"769\": [\"n04118776\", \"rule\"], \"770\": [\"n04120489\", \"running_shoe\"], \"771\": [\"n04125021\", \"safe\"], \"772\": [\"n04127249\", \"safety_pin\"], \"773\": [\"n04131690\", \"saltshaker\"], \"774\": [\"n04133789\", \"sandal\"], \"775\": [\"n04136333\", \"sarong\"], \"776\": [\"n04141076\", \"sax\"], \"777\": [\"n04141327\", \"scabbard\"], \"778\": [\"n04141975\", \"scale\"], \"779\": [\"n04146614\", \"school_bus\"], \"780\": [\"n04147183\", \"schooner\"], \"781\": [\"n04149813\", \"scoreboard\"], \"782\": [\"n04152593\", \"screen\"], \"783\": [\"n04153751\", \"screw\"], \"784\": [\"n04154565\", \"screwdriver\"], \"785\": [\"n04162706\", \"seat_belt\"], \"786\": [\"n04179913\", \"sewing_machine\"], \"787\": [\"n04192698\", \"shield\"], \"788\": [\"n04200800\", \"shoe_shop\"], \"789\": [\"n04201297\", \"shoji\"], \"790\": [\"n04204238\", \"shopping_basket\"], \"791\": [\"n04204347\", \"shopping_cart\"], \"792\": [\"n04208210\", \"shovel\"], \"793\": [\"n04209133\", \"shower_cap\"], \"794\": [\"n04209239\", \"shower_curtain\"], \"795\": [\"n04228054\", \"ski\"], \"796\": [\"n04229816\", \"ski_mask\"], \"797\": [\"n04235860\", \"sleeping_bag\"], \"798\": [\"n04238763\", \"slide_rule\"], \"799\": [\"n04239074\", \"sliding_door\"], \"800\": [\"n04243546\", \"slot\"], \"801\": [\"n04251144\", \"snorkel\"], \"802\": [\"n04252077\", \"snowmobile\"], \"803\": [\"n04252225\", \"snowplow\"], \"804\": [\"n04254120\", \"soap_dispenser\"], \"805\": [\"n04254680\", \"soccer_ball\"], \"806\": [\"n04254777\", \"sock\"], \"807\": [\"n04258138\", \"solar_dish\"], \"808\": [\"n04259630\", \"sombrero\"], \"809\": [\"n04263257\", \"soup_bowl\"], \"810\": [\"n04264628\", \"space_bar\"], \"811\": [\"n04265275\", \"space_heater\"], \"812\": [\"n04266014\", \"space_shuttle\"], \"813\": [\"n04270147\", \"spatula\"], \"814\": [\"n04273569\", \"speedboat\"], \"815\": [\"n04275548\", \"spider_web\"], \"816\": [\"n04277352\", \"spindle\"], \"817\": [\"n04285008\", \"sports_car\"], \"818\": [\"n04286575\", \"spotlight\"], \"819\": [\"n04296562\", \"stage\"], \"820\": [\"n04310018\", \"steam_locomotive\"], \"821\": [\"n04311004\", \"steel_arch_bridge\"], \"822\": [\"n04311174\", \"steel_drum\"], \"823\": [\"n04317175\", \"stethoscope\"], \"824\": [\"n04325704\", \"stole\"], \"825\": [\"n04326547\", \"stone_wall\"], \"826\": [\"n04328186\", \"stopwatch\"], \"827\": [\"n04330267\", \"stove\"], \"828\": [\"n04332243\", \"strainer\"], \"829\": [\"n04335435\", \"streetcar\"], \"830\": [\"n04336792\", \"stretcher\"], \"831\": [\"n04344873\", \"studio_couch\"], \"832\": [\"n04346328\", \"stupa\"], \"833\": [\"n04347754\", \"submarine\"], \"834\": [\"n04350905\", \"suit\"], \"835\": [\"n04355338\", \"sundial\"], \"836\": [\"n04355933\", \"sunglass\"], \"837\": [\"n04356056\", \"sunglasses\"], \"838\": [\"n04357314\", \"sunscreen\"], \"839\": [\"n04366367\", \"suspension_bridge\"], \"840\": [\"n04367480\", \"swab\"], \"841\": [\"n04370456\", \"sweatshirt\"], \"842\": [\"n04371430\", \"swimming_trunks\"], \"843\": [\"n04371774\", \"swing\"], \"844\": [\"n04372370\", \"switch\"], \"845\": [\"n04376876\", \"syringe\"], \"846\": [\"n04380533\", \"table_lamp\"], \"847\": [\"n04389033\", \"tank\"], \"848\": [\"n04392985\", \"tape_player\"], \"849\": [\"n04398044\", \"teapot\"], \"850\": [\"n04399382\", \"teddy\"], \"851\": [\"n04404412\", \"television\"], \"852\": [\"n04409515\", \"tennis_ball\"], \"853\": [\"n04417672\", \"thatch\"], \"854\": [\"n04418357\", \"theater_curtain\"], \"855\": [\"n04423845\", \"thimble\"], \"856\": [\"n04428191\", \"thresher\"], \"857\": [\"n04429376\", \"throne\"], \"858\": [\"n04435653\", \"tile_roof\"], \"859\": [\"n04442312\", \"toaster\"], \"860\": [\"n04443257\", \"tobacco_shop\"], \"861\": [\"n04447861\", \"toilet_seat\"], \"862\": [\"n04456115\", \"torch\"], \"863\": [\"n04458633\", \"totem_pole\"], \"864\": [\"n04461696\", \"tow_truck\"], \"865\": [\"n04462240\", \"toyshop\"], \"866\": [\"n04465501\", \"tractor\"], \"867\": [\"n04467665\", \"trailer_truck\"], \"868\": [\"n04476259\", \"tray\"], \"869\": [\"n04479046\", \"trench_coat\"], \"870\": [\"n04482393\", \"tricycle\"], \"871\": [\"n04483307\", \"trimaran\"], \"872\": [\"n04485082\", \"tripod\"], \"873\": [\"n04486054\", \"triumphal_arch\"], \"874\": [\"n04487081\", \"trolleybus\"], \"875\": [\"n04487394\", \"trombone\"], \"876\": [\"n04493381\", \"tub\"], \"877\": [\"n04501370\", \"turnstile\"], \"878\": [\"n04505470\", \"typewriter_keyboard\"], \"879\": [\"n04507155\", \"umbrella\"], \"880\": [\"n04509417\", \"unicycle\"], \"881\": [\"n04515003\", \"upright\"], \"882\": [\"n04517823\", \"vacuum\"], \"883\": [\"n04522168\", \"vase\"], \"884\": [\"n04523525\", \"vault\"], \"885\": [\"n04525038\", \"velvet\"], \"886\": [\"n04525305\", \"vending_machine\"], \"887\": [\"n04532106\", \"vestment\"], \"888\": [\"n04532670\", \"viaduct\"], \"889\": [\"n04536866\", \"violin\"], \"890\": [\"n04540053\", \"volleyball\"], \"891\": [\"n04542943\", \"waffle_iron\"], \"892\": [\"n04548280\", \"wall_clock\"], \"893\": [\"n04548362\", \"wallet\"], \"894\": [\"n04550184\", \"wardrobe\"], \"895\": [\"n04552348\", \"warplane\"], \"896\": [\"n04553703\", \"washbasin\"], \"897\": [\"n04554684\", \"washer\"], \"898\": [\"n04557648\", \"water_bottle\"], \"899\": [\"n04560804\", \"water_jug\"], \"900\": [\"n04562935\", \"water_tower\"], \"901\": [\"n04579145\", \"whiskey_jug\"], \"902\": [\"n04579432\", \"whistle\"], \"903\": [\"n04584207\", \"wig\"], \"904\": [\"n04589890\", \"window_screen\"], \"905\": [\"n04590129\", \"window_shade\"], \"906\": [\"n04591157\", \"Windsor_tie\"], \"907\": [\"n04591713\", \"wine_bottle\"], \"908\": [\"n04592741\", \"wing\"], \"909\": [\"n04596742\", \"wok\"], \"910\": [\"n04597913\", \"wooden_spoon\"], \"911\": [\"n04599235\", \"wool\"], \"912\": [\"n04604644\", \"worm_fence\"], \"913\": [\"n04606251\", \"wreck\"], \"914\": [\"n04612504\", \"yawl\"], \"915\": [\"n04613696\", \"yurt\"], \"916\": [\"n06359193\", \"web_site\"], \"917\": [\"n06596364\", \"comic_book\"], \"918\": [\"n06785654\", \"crossword_puzzle\"], \"919\": [\"n06794110\", \"street_sign\"], \"920\": [\"n06874185\", \"traffic_light\"], \"921\": [\"n07248320\", \"book_jacket\"], \"922\": [\"n07565083\", \"menu\"], \"923\": [\"n07579787\", \"plate\"], \"924\": [\"n07583066\", \"guacamole\"], \"925\": [\"n07584110\", \"consomme\"], \"926\": [\"n07590611\", \"hot_pot\"], \"927\": [\"n07613480\", \"trifle\"], \"928\": [\"n07614500\", \"ice_cream\"], \"929\": [\"n07615774\", \"ice_lolly\"], \"930\": [\"n07684084\", \"French_loaf\"], \"931\": [\"n07693725\", \"bagel\"], \"932\": [\"n07695742\", \"pretzel\"], \"933\": [\"n07697313\", \"cheeseburger\"], \"934\": [\"n07697537\", \"hotdog\"], \"935\": [\"n07711569\", \"mashed_potato\"], \"936\": [\"n07714571\", \"head_cabbage\"], \"937\": [\"n07714990\", \"broccoli\"], \"938\": [\"n07715103\", \"cauliflower\"], \"939\": [\"n07716358\", \"zucchini\"], \"940\": [\"n07716906\", \"spaghetti_squash\"], \"941\": [\"n07717410\", \"acorn_squash\"], \"942\": [\"n07717556\", \"butternut_squash\"], \"943\": [\"n07718472\", \"cucumber\"], \"944\": [\"n07718747\", \"artichoke\"], \"945\": [\"n07720875\", \"bell_pepper\"], \"946\": [\"n07730033\", \"cardoon\"], \"947\": [\"n07734744\", \"mushroom\"], \"948\": [\"n07742313\", \"Granny_Smith\"], \"949\": [\"n07745940\", \"strawberry\"], \"950\": [\"n07747607\", \"orange\"], \"951\": [\"n07749582\", \"lemon\"], \"952\": [\"n07753113\", \"fig\"], \"953\": [\"n07753275\", \"pineapple\"], \"954\": [\"n07753592\", \"banana\"], \"955\": [\"n07754684\", \"jackfruit\"], \"956\": [\"n07760859\", \"custard_apple\"], \"957\": [\"n07768694\", \"pomegranate\"], \"958\": [\"n07802026\", \"hay\"], \"959\": [\"n07831146\", \"carbonara\"], \"960\": [\"n07836838\", \"chocolate_sauce\"], \"961\": [\"n07860988\", \"dough\"], \"962\": [\"n07871810\", \"meat_loaf\"], \"963\": [\"n07873807\", \"pizza\"], \"964\": [\"n07875152\", \"potpie\"], \"965\": [\"n07880968\", \"burrito\"], \"966\": [\"n07892512\", \"red_wine\"], \"967\": [\"n07920052\", \"espresso\"], \"968\": [\"n07930864\", \"cup\"], \"969\": [\"n07932039\", \"eggnog\"], \"970\": [\"n09193705\", \"alp\"], \"971\": [\"n09229709\", \"bubble\"], \"972\": [\"n09246464\", \"cliff\"], \"973\": [\"n09256479\", \"coral_reef\"], \"974\": [\"n09288635\", \"geyser\"], \"975\": [\"n09332890\", \"lakeside\"], \"976\": [\"n09399592\", \"promontory\"], \"977\": [\"n09421951\", \"sandbar\"], \"978\": [\"n09428293\", \"seashore\"], \"979\": [\"n09468604\", \"valley\"], \"980\": [\"n09472597\", \"volcano\"], \"981\": [\"n09835506\", \"ballplayer\"], \"982\": [\"n10148035\", \"groom\"], \"983\": [\"n10565667\", \"scuba_diver\"], \"984\": [\"n11879895\", \"rapeseed\"], \"985\": [\"n11939491\", \"daisy\"], \"986\": [\"n12057211\", \"yellow_lady's_slipper\"], \"987\": [\"n12144580\", \"corn\"], \"988\": [\"n12267677\", \"acorn\"], \"989\": [\"n12620546\", \"hip\"], \"990\": [\"n12768682\", \"buckeye\"], \"991\": [\"n12985857\", \"coral_fungus\"], \"992\": [\"n12998815\", \"agaric\"], \"993\": [\"n13037406\", \"gyromitra\"], \"994\": [\"n13040303\", \"stinkhorn\"], \"995\": [\"n13044778\", \"earthstar\"], \"996\": [\"n13052670\", \"hen-of-the-woods\"], \"997\": [\"n13054560\", \"bolete\"], \"998\": [\"n13133613\", \"ear\"], \"999\": [\"n15075141\", \"toilet_tissue\"]}"
  },
  {
    "path": "utils/classifiers/places.py",
    "content": "import torch\nimport torchvision.models as models\nfrom torchvision import transforms as trn\nfrom torch.nn import functional as F\nimport os\n\n\nclass Classifier():\n    def __init__(self):\n        # the architecture to use\n        arch = 'resnet50'\n\n        # load the pre-trained weights\n        model_file = '%s_places365.pth.tar' % arch\n        if not os.access(model_file, os.W_OK):\n            weight_url = 'http://places2.csail.mit.edu/models_places365/' + model_file\n            os.system('wget ' + weight_url)\n\n        model = models.__dict__[arch](num_classes=365)\n        checkpoint = torch.load(model_file,\n                                map_location=lambda storage, loc: storage)\n        state_dict = {\n            str.replace(k, 'module.', ''): v\n            for k, v in checkpoint['state_dict'].items()\n        }\n        model.load_state_dict(state_dict)\n        model.cuda()\n        model.eval()\n        self.model = model\n        self.mean = [0.485, 0.456, 0.406]\n        self.std =  [0.229, 0.224, 0.225]\n        self.trn = trn.Normalize(self.mean, self.std)\n\n        file_name = 'categories_places365.txt'\n        if not os.access(file_name, os.W_OK):\n            synset_url = 'https://raw.githubusercontent.com/csailvision/places365/master/categories_places365.txt'\n            os.system('wget ' + synset_url)\n        classes = list()\n        with open(file_name) as class_file:\n            for line in class_file:\n                class_name = line.strip().split(' ')[0][3:]\n                classes.append(''.join(class_name.split('/')))\n        self.classes = classes\n\n    def get_name(self, id):\n        return self.classes[id]\n        \n    def transform(self, x):\n        x = F.interpolate(x, size=(224, 224)) / 255.\n        x = torch.stack([self.trn(xi) for xi in x]).cuda()\n        return x\n\n    def get_predictions_and_confidence(self, x):\n        x = self.transform(x)\n        logit = self.model.forward(x)\n        values, ind = logit.max(dim=1)\n        return ind, values\n\n    def get_predictions(self, x):\n        x = self.transform(x)\n        logit = self.model.forward(x)\n        return logit.argmax(dim=1)\n\nif __name__ == '__main__':\n    x = torch.randn((2,3,128,128))\n    c = Classifier()\n    x = c.get_predictions(x)\n    print(x)\n"
  },
  {
    "path": "utils/classifiers/pytorch_playground/.gitignore",
    "content": "__pycache__\n*.jpg\n*.png\nacc1_acc5.txt\nlog\npytorch_playground.egg-info\nscript/val224_compressed.pkl\n\n"
  },
  {
    "path": "utils/classifiers/pytorch_playground/LICENSE",
    "content": "MIT License\n\nCopyright (c) 2017 Aaron Chen\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "utils/classifiers/pytorch_playground/README.md",
    "content": "This is a playground for pytorch beginners, which contains predefined models on popular dataset. Currently we support \n- mnist, svhn\n- cifar10, cifar100\n- stl10\n- alexnet\n- vgg16, vgg16_bn, vgg19, vgg19_bn\n- resnet18, resnet34, resnet50, resnet101, resnet152\n- squeezenet_v0, squeezenet_v1\n- inception_v3\n\nHere is an example for MNIST dataset. This will download the dataset and pre-trained model automatically.\n```\nimport torch\nfrom torch.autograd import Variable\nfrom utee import selector\nmodel_raw, ds_fetcher, is_imagenet = selector.select('mnist')\nds_val = ds_fetcher(batch_size=10, train=False, val=True)\nfor idx, (data, target) in enumerate(ds_val):\n    data =  Variable(torch.FloatTensor(data)).cuda()\n    output = model_raw(data)\n```\n\nAlso, if want to train the MLP model on mnist, simply run `python mnist/train.py`\n\n\n# Install\n```\npython3 setup.py develop --user\n```\n\n# ImageNet dataset\nWe provide precomputed imagenet validation dataset with 224x224x3 size. We first resize the shorter size of image to 256, then we crop 224x224 image in the center. Then we encode the cropped images to jpg string and dump to pickle. \n- `cd script`\n- Download the `val224_compressed.pkl` ([Tsinghua](http://ml.cs.tsinghua.edu.cn/~chenxi/dataset/val224_compressed.pkl) /  [Google Drive](https://drive.google.com/file/d/1U8ir2fOR4Sir3FCj9b7FQRPSVsycTfVc/view?usp=sharing))\n- `python convert.py` (needs 48G memory, thanks [@jnorwood](https://github.com/aaron-xichen/pytorch-playground/issues/18) )\n\n\n# Quantization\nWe also provide a simple demo to quantize these models to specified bit-width with several methods, including linear method, minmax method and non-linear method.\n\n`quantize --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1`\n   \n## Top1 Accuracy\nWe evaluate the performance of popular dataset and models with linear quantized method. The bit-width of running mean and running variance in BN are 10 bits for all results. (except for 32-float)\n\n\n|Model|32-float  |12-bit  |10-bit |8-bit  |6-bit  |\n|:----|:--------:|:------:|:-----:|:-----:|:-----:|\n|[MNIST](http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/mnist-b07bb66b.pth)|98.42|98.43|98.44|98.44|98.32|\n|[SVHN](http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/svhn-f564f3d8.pth)|96.03|96.03|96.04|96.02|95.46|\n|[CIFAR10](http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/cifar10-d875770b.pth)|93.78|93.79|93.80|93.58|90.86|\n|[CIFAR100](http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/cifar100-3a55a987.pth)|74.27|74.21|74.19|73.70|66.32|\n|[STL10](http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/stl10-866321e9.pth)|77.59|77.65|77.70|77.59|73.40|\n|[AlexNet](https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth)|55.70/78.42|55.66/78.41|55.54/78.39|54.17/77.29|18.19/36.25|\n|[VGG16](https://download.pytorch.org/models/vgg16-397923af.pth)|70.44/89.43|70.45/89.43|70.44/89.33|69.99/89.17|53.33/76.32|\n|[VGG19](https://download.pytorch.org/models/vgg19-dcbb9e9d.pth)|71.36/89.94|71.35/89.93|71.34/89.88|70.88/89.62|56.00/78.62|\n|[ResNet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)|68.63/88.31|68.62/88.33|68.49/88.25|66.80/87.20|19.14/36.49|\n|[ResNet34](https://download.pytorch.org/models/resnet34-333f7ec4.pth)|72.50/90.86|72.46/90.82|72.45/90.85|71.47/90.00|32.25/55.71|\n|[ResNet50](https://download.pytorch.org/models/resnet50-19c8e357.pth)|74.98/92.17|74.94/92.12|74.91/92.09|72.54/90.44|2.43/5.36|\n|[ResNet101](https://download.pytorch.org/models/resnet101-5d3b4d8f.pth)|76.69/93.30|76.66/93.25|76.22/92.90|65.69/79.54|1.41/1.18|\n|[ResNet152](https://download.pytorch.org/models/resnet152-b121ed2d.pth)|77.55/93.59|77.51/93.62|77.40/93.54|74.95/92.46|9.29/16.75|\n|[SqueezeNetV0](https://download.pytorch.org/models/squeezenet1_0-a815701f.pth)|56.73/79.39|56.75/79.40|56.70/79.27|53.93/77.04|14.21/29.74|\n|[SqueezeNetV1](https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth)|56.52/79.13|56.52/79.15|56.24/79.03|54.56/77.33|17.10/32.46|\n|[InceptionV3](https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth)|76.41/92.78|76.43/92.71|76.44/92.73|73.67/91.34|1.50/4.82|\n\n**Note: ImageNet 32-float models are directly from torchvision**\n\n\n## Selected Arguments\nHere we give an overview of selected arguments of `quantize.py`\n\n|Flag                          |Default value|Description & Options|\n|:-----------------------------|:-----------------------:|:--------------------------------|\n|type|cifar10|mnist,svhn,cifar10,cifar100,stl10,alexnet,vgg16,vgg16_bn,vgg19,vgg19_bn,resent18,resent34,resnet50,resnet101,resnet152,squeezenet_v0,squeezenet_v1,inception_v3|\n|quant_method|linear|quantization method:linear,minmax,log,tanh|\n|param_bits|8|bit-width of weights and bias|\n|fwd_bits|8|bit-width of activation|\n|bn_bits|32|bit-width of running mean and running vairance|\n|overflow_rate|0.0|overflow rate threshold for linear quantization method|\n|n_samples|20|number of samples to make statistics for activation|\n"
  },
  {
    "path": "utils/classifiers/pytorch_playground/cifar/__init__.py",
    "content": ""
  },
  {
    "path": "utils/classifiers/pytorch_playground/cifar/dataset.py",
    "content": "import torch\nfrom torchvision import datasets, transforms\nfrom torch.utils.data import DataLoader\nimport os\n\ndef get10(batch_size, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):\n    data_root = os.path.expanduser(os.path.join(data_root, 'cifar10-data'))\n    num_workers = kwargs.setdefault('num_workers', 1)\n    kwargs.pop('input_size', None)\n    print(\"Building CIFAR-10 data loader with {} workers\".format(num_workers))\n    ds = []\n    if train:\n        train_loader = torch.utils.data.DataLoader(\n            datasets.CIFAR10(\n                root=data_root, train=True, download=True,\n                transform=transforms.Compose([\n                    transforms.Pad(4),\n                    transforms.RandomCrop(32),\n                    transforms.RandomHorizontalFlip(),\n                    transforms.ToTensor(),\n                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n                ])),\n            batch_size=batch_size, shuffle=True, **kwargs)\n        ds.append(train_loader)\n    if val:\n        test_loader = torch.utils.data.DataLoader(\n            datasets.CIFAR10(\n                root=data_root, train=False, download=True,\n                transform=transforms.Compose([\n                    transforms.ToTensor(),\n                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n                ])),\n            batch_size=batch_size, shuffle=False, **kwargs)\n        ds.append(test_loader)\n    ds = ds[0] if len(ds) == 1 else ds\n    return ds\n\ndef get100(batch_size, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):\n    data_root = os.path.expanduser(os.path.join(data_root, 'cifar100-data'))\n    num_workers = kwargs.setdefault('num_workers', 1)\n    kwargs.pop('input_size', None)\n    print(\"Building CIFAR-100 data loader with {} workers\".format(num_workers))\n    ds = []\n    if train:\n        train_loader = torch.utils.data.DataLoader(\n            datasets.CIFAR100(\n                root=data_root, train=True, download=True,\n                transform=transforms.Compose([\n                    transforms.Pad(4),\n                    transforms.RandomCrop(32),\n                    transforms.RandomHorizontalFlip(),\n                    transforms.ToTensor(),\n                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n                ])),\n            batch_size=batch_size, shuffle=True, **kwargs)\n        ds.append(train_loader)\n\n    if val:\n        test_loader = torch.utils.data.DataLoader(\n            datasets.CIFAR100(\n                root=data_root, train=False, download=True,\n                transform=transforms.Compose([\n                    transforms.ToTensor(),\n                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n                ])),\n            batch_size=batch_size, shuffle=False, **kwargs)\n        ds.append(test_loader)\n    ds = ds[0] if len(ds) == 1 else ds\n    return ds\n\n"
  },
  {
    "path": "utils/classifiers/pytorch_playground/cifar/model.py",
    "content": "import torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\nfrom IPython import embed\nfrom collections import OrderedDict\n\nfrom utee import misc\nprint = misc.logger.info\n\nmodel_urls = {\n    'cifar10': 'http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/cifar10-d875770b.pth',\n}\n\nclass CIFAR(nn.Module):\n    def __init__(self, features, n_channel, num_classes):\n        super(CIFAR, self).__init__()\n        assert isinstance(features, nn.Sequential), type(features)\n        self.features = features\n        self.classifier = nn.Sequential(\n            nn.Linear(n_channel, num_classes)\n        )\n        print(self.features)\n        print(self.classifier)\n\n    def forward(self, x):\n        x = self.features(x)\n        x = x.view(x.size(0), -1)\n        x = self.classifier(x)\n        return x\n\ndef make_layers(cfg, batch_norm=False):\n    layers = []\n    in_channels = 3\n    for i, v in enumerate(cfg):\n        if v == 'M':\n            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]\n        else:\n            padding = v[1] if isinstance(v, tuple) else 1\n            out_channels = v[0] if isinstance(v, tuple) else v\n            conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding)\n            if batch_norm:\n                layers += [conv2d, nn.BatchNorm2d(out_channels, affine=False), nn.ReLU()]\n            else:\n                layers += [conv2d, nn.ReLU()]\n            in_channels = out_channels\n    return nn.Sequential(*layers)\n\ndef cifar10(n_channel=128):\n    cfg = [n_channel, n_channel, 'M', 2*n_channel, 2*n_channel, 'M', 4*n_channel, 4*n_channel, 'M', (8*n_channel, 0), 'M']\n    layers = make_layers(cfg, batch_norm=True)\n    model = CIFAR(layers, n_channel=8*n_channel, num_classes=10)\n    m = model_zoo.load_url(model_urls['cifar10'])\n    state_dict = m.state_dict() if isinstance(m, nn.Module) else m\n    assert isinstance(state_dict, (dict, OrderedDict)), type(state_dict)\n    model.load_state_dict(state_dict)\n    print('loaded')\n    return model\n\n\nif __name__ == '__main__':\n    model = cifar10(128, pretrained='log/cifar10/best-135.pth')\n    embed()\n\n"
  },
  {
    "path": "utils/classifiers/pytorch_playground/cifar/train.py",
    "content": "import argparse\nimport os\nimport time\n\nfrom utee import misc\nimport torch\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.autograd import Variable\n\nimport dataset\nimport model\n\nfrom IPython import embed\n\nparser = argparse.ArgumentParser(description='PyTorch CIFAR-X Example')\nparser.add_argument('--type', default='cifar10', help='cifar10|cifar100')\nparser.add_argument('--channel', type=int, default=128, help='first conv channel (default: 32)')\nparser.add_argument('--wd', type=float, default=0.00, help='weight decay')\nparser.add_argument('--batch_size', type=int, default=200, help='input batch size for training (default: 64)')\nparser.add_argument('--epochs', type=int, default=150, help='number of epochs to train (default: 10)')\nparser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 1e-3)')\nparser.add_argument('--gpu', default=None, help='index of gpus to use')\nparser.add_argument('--ngpu', type=int, default=2, help='number of gpus to use')\nparser.add_argument('--seed', type=int, default=117, help='random seed (default: 1)')\nparser.add_argument('--log_interval', type=int, default=100,  help='how many batches to wait before logging training status')\nparser.add_argument('--test_interval', type=int, default=5,  help='how many epochs to wait before another test')\nparser.add_argument('--logdir', default='log/default', help='folder to save to the log')\nparser.add_argument('--decreasing_lr', default='80,120', help='decreasing strategy')\nargs = parser.parse_args()\nargs.logdir = os.path.join(os.path.dirname(__file__), args.logdir)\nmisc.logger.init(args.logdir, 'train_log')\nprint = misc.logger.info\n\n# select gpu\nargs.gpu = misc.auto_select_gpu(utility_bound=0, num_gpu=args.ngpu, selected_gpus=args.gpu)\nargs.ngpu = len(args.gpu)\n\n# logger\nmisc.ensure_dir(args.logdir)\nprint(\"=================FLAGS==================\")\nfor k, v in args.__dict__.items():\n    print('{}: {}'.format(k, v))\nprint(\"========================================\")\n\n# seed\nargs.cuda = torch.cuda.is_available()\ntorch.manual_seed(args.seed)\nif args.cuda:\n    torch.cuda.manual_seed(args.seed)\n\n# data loader and model\nassert args.type in ['cifar10', 'cifar100'], args.type\nif args.type == 'cifar10':\n    train_loader, test_loader = dataset.get10(batch_size=args.batch_size, num_workers=1)\n    model = model.cifar10(n_channel=args.channel)\nelse:\n    train_loader, test_loader = dataset.get100(batch_size=args.batch_size, num_workers=1)\n    model = model.cifar100(n_channel=args.channel)\nmodel = torch.nn.DataParallel(model, device_ids= range(args.ngpu))\nif args.cuda:\n    model.cuda()\n\n# optimizer\noptimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)\ndecreasing_lr = list(map(int, args.decreasing_lr.split(',')))\nprint('decreasing_lr: ' + str(decreasing_lr))\nbest_acc, old_file = 0, None\nt_begin = time.time()\ntry:\n    # ready to go\n    for epoch in range(args.epochs):\n        model.train()\n        if epoch in decreasing_lr:\n            optimizer.param_groups[0]['lr'] *= 0.1\n        for batch_idx, (data, target) in enumerate(train_loader):\n            indx_target = target.clone()\n            if args.cuda:\n                data, target = data.cuda(), target.cuda()\n            data, target = Variable(data), Variable(target)\n\n            optimizer.zero_grad()\n            output = model(data)\n            loss = F.cross_entropy(output, target)\n            loss.backward()\n            optimizer.step()\n\n            if batch_idx % args.log_interval == 0 and batch_idx > 0:\n                pred = output.data.max(1)[1]  # get the index of the max log-probability\n                correct = pred.cpu().eq(indx_target).sum()\n                acc = correct * 1.0 / len(data)\n                print('Train Epoch: {} [{}/{}] Loss: {:.6f} Acc: {:.4f} lr: {:.2e}'.format(\n                    epoch, batch_idx * len(data), len(train_loader.dataset),\n                    loss.data[0], acc, optimizer.param_groups[0]['lr']))\n\n        elapse_time = time.time() - t_begin\n        speed_epoch = elapse_time / (epoch + 1)\n        speed_batch = speed_epoch / len(train_loader)\n        eta = speed_epoch * args.epochs - elapse_time\n        print(\"Elapsed {:.2f}s, {:.2f} s/epoch, {:.2f} s/batch, ets {:.2f}s\".format(\n            elapse_time, speed_epoch, speed_batch, eta))\n        misc.model_snapshot(model, os.path.join(args.logdir, 'latest.pth'))\n\n        if epoch % args.test_interval == 0:\n            model.eval()\n            test_loss = 0\n            correct = 0\n            for data, target in test_loader:\n                indx_target = target.clone()\n                if args.cuda:\n                    data, target = data.cuda(), target.cuda()\n                data, target = Variable(data, volatile=True), Variable(target)\n                output = model(data)\n                test_loss += F.cross_entropy(output, target).data[0]\n                pred = output.data.max(1)[1]  # get the index of the max log-probability\n                correct += pred.cpu().eq(indx_target).sum()\n\n            test_loss = test_loss / len(test_loader) # average over number of mini-batch\n            acc = 100. * correct / len(test_loader.dataset)\n            print('\\tTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(\n                test_loss, correct, len(test_loader.dataset), acc))\n            if acc > best_acc:\n                new_file = os.path.join(args.logdir, 'best-{}.pth'.format(epoch))\n                misc.model_snapshot(model, new_file, old_file=old_file, verbose=True)\n                best_acc = acc\n                old_file = new_file\nexcept Exception as e:\n    import traceback\n    traceback.print_exc()\nfinally:\n    print(\"Total Elapse: {:.2f}, Best Result: {:.3f}%\".format(time.time()-t_begin, best_acc))\n\n\n"
  },
  {
    "path": "utils/classifiers/pytorch_playground/quantize.py",
    "content": "import argparse\nfrom utee import misc, quant, selector\nimport torch\nimport torch.backends.cudnn as cudnn\ncudnn.benchmark =True\nfrom collections import OrderedDict\n\ndef main():\n    parser = argparse.ArgumentParser(description='PyTorch SVHN Example')\n    parser.add_argument('--type', default='cifar10', help='|'.join(selector.known_models))\n    parser.add_argument('--quant_method', default='linear', help='linear|minmax|log|tanh')\n    parser.add_argument('--batch_size', type=int, default=100, help='input batch size for training (default: 64)')\n    parser.add_argument('--gpu', default=None, help='index of gpus to use')\n    parser.add_argument('--ngpu', type=int, default=8, help='number of gpus to use')\n    parser.add_argument('--seed', type=int, default=117, help='random seed (default: 1)')\n    parser.add_argument('--model_root', default='~/.torch/models/', help='folder to save the model')\n    parser.add_argument('--data_root', default='/data/public_dataset/pytorch/', help='folder to save the model')\n    parser.add_argument('--logdir', default='log/default', help='folder to save to the log')\n\n    parser.add_argument('--input_size', type=int, default=224, help='input size of image')\n    parser.add_argument('--n_sample', type=int, default=20, help='number of samples to infer the scaling factor')\n    parser.add_argument('--param_bits', type=int, default=8, help='bit-width for parameters')\n    parser.add_argument('--bn_bits', type=int, default=32, help='bit-width for running mean and std')\n    parser.add_argument('--fwd_bits', type=int, default=8, help='bit-width for layer output')\n    parser.add_argument('--overflow_rate', type=float, default=0.0, help='overflow rate')\n    args = parser.parse_args()\n\n    args.gpu = misc.auto_select_gpu(utility_bound=0, num_gpu=args.ngpu, selected_gpus=args.gpu)\n    args.ngpu = len(args.gpu)\n    misc.ensure_dir(args.logdir)\n    args.model_root = misc.expand_user(args.model_root)\n    args.data_root = misc.expand_user(args.data_root)\n    args.input_size = 299 if 'inception' in args.type else args.input_size\n    assert args.quant_method in ['linear', 'minmax', 'log', 'tanh']\n    print(\"=================FLAGS==================\")\n    for k, v in args.__dict__.items():\n        print('{}: {}'.format(k, v))\n    print(\"========================================\")\n\n    assert torch.cuda.is_available(), 'no cuda'\n    torch.manual_seed(args.seed)\n    torch.cuda.manual_seed(args.seed)\n\n    # load model and dataset fetcher\n    model_raw, ds_fetcher, is_imagenet = selector.select(args.type, model_root=args.model_root)\n    args.ngpu = args.ngpu if is_imagenet else 1\n\n    # quantize parameters\n    if args.param_bits < 32:\n        state_dict = model_raw.state_dict()\n        state_dict_quant = OrderedDict()\n        sf_dict = OrderedDict()\n        for k, v in state_dict.items():\n            if 'running' in k:\n                if args.bn_bits >=32:\n                    print(\"Ignoring {}\".format(k))\n                    state_dict_quant[k] = v\n                    continue\n                else:\n                    bits = args.bn_bits\n            else:\n                bits = args.param_bits\n\n            if args.quant_method == 'linear':\n                sf = bits - 1. - quant.compute_integral_part(v, overflow_rate=args.overflow_rate)\n                v_quant  = quant.linear_quantize(v, sf, bits=bits)\n            elif args.quant_method == 'log':\n                v_quant = quant.log_minmax_quantize(v, bits=bits)\n            elif args.quant_method == 'minmax':\n                v_quant = quant.min_max_quantize(v, bits=bits)\n            else:\n                v_quant = quant.tanh_quantize(v, bits=bits)\n            state_dict_quant[k] = v_quant\n            print(k, bits)\n        model_raw.load_state_dict(state_dict_quant)\n\n    # quantize forward activation\n    if args.fwd_bits < 32:\n        model_raw = quant.duplicate_model_with_quant(model_raw, bits=args.fwd_bits, overflow_rate=args.overflow_rate,\n                                                     counter=args.n_sample, type=args.quant_method)\n        print(model_raw)\n        val_ds_tmp = ds_fetcher(10, data_root=args.data_root, train=False, input_size=args.input_size)\n        misc.eval_model(model_raw, val_ds_tmp, ngpu=1, n_sample=args.n_sample, is_imagenet=is_imagenet)\n\n    # eval model\n    val_ds = ds_fetcher(args.batch_size, data_root=args.data_root, train=False, input_size=args.input_size)\n    acc1, acc5 = misc.eval_model(model_raw, val_ds, ngpu=args.ngpu, is_imagenet=is_imagenet)\n\n    # print sf\n    print(model_raw)\n    res_str = \"type={}, quant_method={}, param_bits={}, bn_bits={}, fwd_bits={}, overflow_rate={}, acc1={:.4f}, acc5={:.4f}\".format(\n        args.type, args.quant_method, args.param_bits, args.bn_bits, args.fwd_bits, args.overflow_rate, acc1, acc5)\n    print(res_str)\n    with open('acc1_acc5.txt', 'a') as f:\n        f.write(res_str + '\\n')\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "utils/classifiers/pytorch_playground/requirements.txt",
    "content": "Pillow==6.1\ntorchvision==0.4.2\ntqdm==4.41.1\nopencv-python==4.1.2.30\njoblib==0.14.1\n"
  },
  {
    "path": "utils/classifiers/pytorch_playground/roadmap_zh.md",
    "content": "# 定点化Roadmap\n首先定点化的setting分好几种，主要如下所示 (w代表weight，a代表activation，g代表gradient)\n\n最近两年的目前有13篇直接相关的论文，截止到2016年7月\n\n## float转化为定点版本，不允许fine-tune\n- w定点，a浮点\n    - Resiliency of Deep Neural Networks under Quantization [Wongyong Sung, Sungho Shin, 2016.01.07, ICLR2016] {5bit在CIFAR10上恢复正确率}\n    - Fixed Point Quantization of Deep Convolutional Networks [Darryl D.Lin, Sachin S.Talathi, 2016.06.02] {每层定点化策略不同，解析解求出}\n- w+a定点\n    - Hardware-oriented approximation of convolutional neural networks [Philipp Gysel, Mohammad Motamedi, ICLR 2016 Workshop] {ImageNet上8bit-8bit掉0.9%，AlexNet}\n    - Energy-Efficient ConvNets Through Approximate Computing [Bert Moons, KU leuven, 2016.03.22] {结合硬件的trick可以在ImageNet上4-10bit}\n    - Going Deeper with Embedded FPGA Platform for Convolutional Neural Network [Jiantao Qiu, Jie Wang, FPGA2016]{ImageNet上8bit-8bit掉1%，AlexNet}\n\n## float转化为定点版本，允许fine-tune\n- fine-tune整个网络\n    - w定点，a+g浮点\n        - Resiliency of Deep Neural Networks under Quantization [Wongyong Sung, Sungho Shin, 2016.01.07, ICLR2016] {2bit即三值网络在CIFAR10上恢复正确率}\n    - w+a定点，g浮点\n        - Fixed Point Quantization of Deep Convolutional Networks [Darryl D.Lin, Sachin S.Talathi, 2016.06.02] {每层定点化策略不同，解析解求出，CIFAR10上fine-tune后4bit-4bit掉1.32%}\n    - w+a+g定点\n        - Overcoming Challenges in Fixed Point Training of Deep Convolutional Networks [Darryl D.Lin, Sachin S. Talathi, Qualcomm Research，2016.07.08] {无随机rounding，ImageNet上4bit-16bit-16bit掉7.2%，a和g再小就不收敛}\n        - DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients [Shuchang Zhou, Zekun Ni, 2016.06.20] {1bit-2bit-4bit, 第一层和最后一层没有量化，ImageNet上掉5.2%}\n- fine-tune最高几层\n    - w+a+g定点\n        - Overcoming Challenges in Fixed Point Training of Deep Convolutional Networks [Darryl D.Lin, Sachin S. Talathi, Qualcomm Research，2016.07.08] {无随机rounding，ImageNet上4bit-4bit-4bit掉23.3%}\n- 分阶段地从低层到高层fine-tune网络\n    - w+a+g定点\n        - Overcoming Challenges in Fixed Point Training of Deep Convolutional Networks [Darryl D.Lin, Sachin S. Talathi, Qualcomm Research，2016.07.08] {无随机rounding，ImageNet上4bit-4bit-4bit Top5掉11.5%}\n\n## 直接定点从头开始训练\n- w定点，a+g浮点\n    - 二值网络\n        - BinaryConnect: Training Deep Neural Networks with binary weights during propagations [Matthieu Courbariaux, Yoshua Bengio, 2015.11.02, NIPS] {CIFAR10上8.27%, state-of-art}\n        - XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks [Mohammad Rastegari, Washington University, 2016.03.16] {ImageNet上39.2%，掉2.8%, AlexNet}\n    - 三值网络\n        - Ternary Weight Networks [Fengfu Li, Bin Liu, UCAS, China, 2016.05.16] {ImageNet掉2.3%, ResNet-18B}\n        - Trained Ternary Quantization [Chenzhuo Zhu, Song Han, Huizi Mao, William J. Dally, ICLR2017] {ResNet上效果更佳}\n- w+a定点，g浮点\n    - 二值网络\n        - Binarized Neural Networks: Training Neural Networks with Weights and Activations Constrained to +1 or −1 [Matthieu Courbariaux, Yoshua Bengio, 2016.03.17] {CIFAR10上10.15%}\n        - XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks [Mohammad Rastegari, Washington University, 2016.03.16] {ImageNet上55.8%， 掉12.4%}\n- w+a+g定点\n    - Deep Learning with Limited Numerical Precision [ Suyog Gupta, Ankur Agrawal, IBM, 2015.02.09] {随机rounding技巧，CIFAR10上16bit+16bit+16bit复现正确率}\n    - Training deep neural networks with low precision multiplications [Matthieu Courbariaux, Yoshua Bengio, ICLR 2015 Workshop] {CIFAR10上10bit+10bit+12bit复现正确率}\n    - DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients [Shuchang Zhou, Zekun Ni, 2016.06.20] {1bit-2bit-4bit, 第一层和最后一层没有量化，ImageNet上掉8.8%}\n    - Quantized Neural Networks: Training Neural Networks with Low Precision Weights and Activations [Itay Hubara, Matthieu Courbariaux, 2016.09.22]{1bit-2bit-6bit，ImageNet上超过DoReFa 0.33%}"
  },
  {
    "path": "utils/classifiers/pytorch_playground/setup.py",
    "content": "from setuptools import setup, find_packages\n\nwith open(\"requirements.txt\") as requirements_file:\n    REQUIREMENTS = requirements_file.readlines()\n\nsetup(\n    name=\"pytorch-playground\",\n    version=\"1.0.0\",\n    author='Aaron Chen',\n    author_email='aaron.xichen@gmail.com',\n    packages=find_packages(),\n    entry_points = {\n        'console_scripts': [\n            'quantize=quantize:main',\n        ]\n    },\n    install_requires=REQUIREMENTS,\n\n)\n"
  },
  {
    "path": "utils/classifiers/pytorch_playground/utee/__init__.py",
    "content": ""
  },
  {
    "path": "utils/classifiers/pytorch_playground/utee/misc.py",
    "content": "import cv2\nimport os\nimport shutil\nimport pickle as pkl\nimport time\nimport numpy as np\nimport hashlib\n\nfrom IPython import embed\n\nclass Logger(object):\n    def __init__(self):\n        self._logger = None\n\n    def init(self, logdir, name='log'):\n        if self._logger is None:\n            import logging\n            if not os.path.exists(logdir):\n                os.makedirs(logdir)\n            log_file = os.path.join(logdir, name)\n            if os.path.exists(log_file):\n                os.remove(log_file)\n            self._logger = logging.getLogger()\n            self._logger.setLevel('INFO')\n            fh = logging.FileHandler(log_file)\n            ch = logging.StreamHandler()\n            self._logger.addHandler(fh)\n            self._logger.addHandler(ch)\n\n    def info(self, str_info):\n        self.init('/tmp', 'tmp.log')\n        self._logger.info(str_info)\nlogger = Logger()\n\nprint = logger.info\ndef ensure_dir(path, erase=False):\n    if os.path.exists(path) and erase:\n        print(\"Removing old folder {}\".format(path))\n        shutil.rmtree(path)\n    if not os.path.exists(path):\n        print(\"Creating folder {}\".format(path))\n        os.makedirs(path)\n\ndef load_pickle(path):\n    begin_st = time.time()\n    with open(path, 'rb') as f:\n        print(\"Loading pickle object from {}\".format(path))\n        v = pkl.load(f)\n    print(\"=> Done ({:.4f} s)\".format(time.time() - begin_st))\n    return v\n\ndef dump_pickle(obj, path):\n    with open(path, 'wb') as f:\n        print(\"Dumping pickle object to {}\".format(path))\n        pkl.dump(obj, f, protocol=pkl.HIGHEST_PROTOCOL)\n\ndef auto_select_gpu(mem_bound=500, utility_bound=0, gpus=(0, 1, 2, 3, 4, 5, 6, 7), num_gpu=1, selected_gpus=None):\n    import sys\n    import os\n    import subprocess\n    import re\n    import time\n    import numpy as np\n    if 'CUDA_VISIBLE_DEVCIES' in os.environ:\n        sys.exit(0)\n    if selected_gpus is None:\n        mem_trace = []\n        utility_trace = []\n        for i in range(5): # sample 5 times\n            info = subprocess.check_output('nvidia-smi', shell=True).decode('utf-8')\n            mem = [int(s[:-5]) for s in re.compile('\\d+MiB\\s/').findall(info)]\n            utility = [int(re.compile('\\d+').findall(s)[0]) for s in re.compile('\\d+%\\s+Default').findall(info)]\n            mem_trace.append(mem)\n            utility_trace.append(utility)\n            time.sleep(0.1)\n        mem = np.mean(mem_trace, axis=0)\n        utility = np.mean(utility_trace, axis=0)\n        assert(len(mem) == len(utility))\n        nGPU = len(utility)\n        ideal_gpus = [i for i in range(nGPU) if mem[i] <= mem_bound and utility[i] <= utility_bound and i in gpus]\n\n        if len(ideal_gpus) < num_gpu:\n            print(\"No sufficient resource, available: {}, require {} gpu\".format(ideal_gpus, num_gpu))\n            sys.exit(0)\n        else:\n            selected_gpus = list(map(str, ideal_gpus[:num_gpu]))\n    else:\n        selected_gpus = selected_gpus.split(',')\n\n    print(\"Setting GPU: {}\".format(selected_gpus))\n    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(selected_gpus)\n    return selected_gpus\n\ndef expand_user(path):\n    return os.path.abspath(os.path.expanduser(path))\n\ndef model_snapshot(model, new_file, old_file=None, verbose=False):\n    from collections import OrderedDict\n    import torch\n    if isinstance(model, torch.nn.DataParallel):\n        model = model.module\n    if old_file and os.path.exists(expand_user(old_file)):\n        if verbose:\n            print(\"Removing old model {}\".format(expand_user(old_file)))\n        os.remove(expand_user(old_file))\n    if verbose:\n        print(\"Saving model to {}\".format(expand_user(new_file)))\n\n    state_dict = OrderedDict()\n    for k, v in model.state_dict().items():\n        if v.is_cuda:\n            v = v.cpu()\n        state_dict[k] = v\n    torch.save(state_dict, expand_user(new_file))\n\n\ndef load_lmdb(lmdb_file, n_records=None):\n    import lmdb\n    import numpy as np\n    lmdb_file = expand_user(lmdb_file)\n    if os.path.exists(lmdb_file):\n        data = []\n        env = lmdb.open(lmdb_file, readonly=True, max_readers=512)\n        with env.begin() as txn:\n            cursor = txn.cursor()\n            begin_st = time.time()\n            print(\"Loading lmdb file {} into memory\".format(lmdb_file))\n            for key, value in cursor:\n                _, target, _ = key.decode('ascii').split(':')\n                target = int(target)\n                img = cv2.imdecode(np.fromstring(value, np.uint8), cv2.IMREAD_COLOR)\n                data.append((img, target))\n                if n_records is not None and len(data) >= n_records:\n                    break\n        env.close()\n        print(\"=> Done ({:.4f} s)\".format(time.time() - begin_st))\n        return data\n    else:\n        print(\"Not found lmdb file\".format(lmdb_file))\n\ndef str2img(str_b):\n    return cv2.imdecode(np.fromstring(str_b, np.uint8), cv2.IMREAD_COLOR)\n\ndef img2str(img):\n    return cv2.imencode('.jpg', img)[1].tostring()\n\ndef md5(s):\n    m = hashlib.md5()\n    m.update(s)\n    return m.hexdigest()\n\ndef eval_model(model, ds, n_sample=None, ngpu=1, is_imagenet=False):\n    import tqdm\n    import torch\n    from torch import nn\n    from torch.autograd import Variable\n\n    class ModelWrapper(nn.Module):\n        def __init__(self, model):\n            super(ModelWrapper, self).__init__()\n            self.model = model\n            self.mean = [0.485, 0.456, 0.406]\n            self.std = [0.229, 0.224, 0.225]\n\n        def forward(self, input):\n            input.data.div_(255.)\n            input.data[:, 0, :, :].sub_(self.mean[0]).div_(self.std[0])\n            input.data[:, 1, :, :].sub_(self.mean[1]).div_(self.std[1])\n            input.data[:, 2, :, :].sub_(self.mean[2]).div_(self.std[2])\n            return self.model(input)\n\n    correct1, correct5 = 0, 0\n    n_passed = 0\n    if is_imagenet:\n        model = ModelWrapper(model)\n    model = model.eval()\n    model = torch.nn.DataParallel(model, device_ids=range(ngpu)).cuda()\n\n    n_sample = len(ds) if n_sample is None else n_sample\n    for idx, (data, target) in enumerate(tqdm.tqdm(ds, total=n_sample)):\n        n_passed += len(data)\n        data =  Variable(torch.FloatTensor(data)).cuda()\n        indx_target = torch.LongTensor(target)\n        output = model(data)\n        bs = output.size(0)\n        idx_pred = output.data.sort(1, descending=True)[1]\n\n        idx_gt1 = indx_target.expand(1, bs).transpose_(0, 1)\n        idx_gt5 = idx_gt1.expand(bs, 5)\n\n        correct1 += idx_pred[:, :1].cpu().eq(idx_gt1).sum()\n        correct5 += idx_pred[:, :5].cpu().eq(idx_gt5).sum()\n\n        if idx >= n_sample - 1:\n            break\n\n    acc1 = correct1 * 1.0 / n_passed\n    acc5 = correct5 * 1.0 / n_passed\n    return acc1, acc5\n\ndef load_state_dict(model, model_urls, model_root):\n    from torch.utils import model_zoo\n    from torch import nn\n    import re\n    from collections import OrderedDict\n    own_state_old = model.state_dict()\n    own_state = OrderedDict() # remove all 'group' string\n    for k, v in own_state_old.items():\n        k = re.sub('group\\d+\\.', '', k)\n        own_state[k] = v\n\n    state_dict = model_zoo.load_url(model_urls, model_root)\n\n    for name, param in state_dict.items():\n        if name not in own_state:\n            print(own_state.keys())\n            raise KeyError('unexpected key \"{}\" in state_dict'\n                           .format(name))\n        if isinstance(param, nn.Parameter):\n            # backwards compatibility for serialized parameters\n            param = param.data\n        own_state[name].copy_(param)\n\n    missing = set(own_state.keys()) - set(state_dict.keys())\n    no_use = set(state_dict.keys()) - set(own_state.keys())\n    if len(no_use) > 0:\n        raise KeyError('some keys are not used: \"{}\"'.format(no_use))\n\n"
  },
  {
    "path": "utils/classifiers/pytorch_playground/utee/quant.py",
    "content": "from torch.autograd import Variable\nimport torch\nfrom torch import nn\nfrom collections import OrderedDict\nimport math\nfrom IPython import embed\n\ndef compute_integral_part(input, overflow_rate):\n    abs_value = input.abs().view(-1)\n    sorted_value = abs_value.sort(dim=0, descending=True)[0]\n    split_idx = int(overflow_rate * len(sorted_value))\n    v = sorted_value[split_idx]\n    if isinstance(v, Variable):\n        v = float(v.data.cpu())\n    sf = math.ceil(math.log2(v+1e-12))\n    return sf\n\ndef linear_quantize(input, sf, bits):\n    assert bits >= 1, bits\n    if bits == 1:\n        return torch.sign(input) - 1\n    delta = math.pow(2.0, -sf)\n    bound = math.pow(2.0, bits-1)\n    min_val = - bound\n    max_val = bound - 1\n    rounded = torch.floor(input / delta + 0.5)\n\n    clipped_value = torch.clamp(rounded, min_val, max_val) * delta\n    return clipped_value\n\ndef log_minmax_quantize(input, bits):\n    assert bits >= 1, bits\n    if bits == 1:\n        return torch.sign(input), 0.0, 0.0\n\n    s = torch.sign(input)\n    input0 = torch.log(torch.abs(input) + 1e-20)\n    v = min_max_quantize(input0, bits-1)\n    v = torch.exp(v) * s\n    return v\n\ndef log_linear_quantize(input, sf, bits):\n    assert bits >= 1, bits\n    if bits == 1:\n        return torch.sign(input), 0.0, 0.0\n\n    s = torch.sign(input)\n    input0 = torch.log(torch.abs(input) + 1e-20)\n    v = linear_quantize(input0, sf, bits-1)\n    v = torch.exp(v) * s\n    return v\n\ndef min_max_quantize(input, bits):\n    assert bits >= 1, bits\n    if bits == 1:\n        return torch.sign(input) - 1\n    min_val, max_val = input.min(), input.max()\n\n    if isinstance(min_val, Variable):\n        max_val = float(max_val.data.cpu().numpy()[0])\n        min_val = float(min_val.data.cpu().numpy()[0])\n\n    input_rescale = (input - min_val) / (max_val - min_val)\n\n    n = math.pow(2.0, bits) - 1\n    v = torch.floor(input_rescale * n + 0.5) / n\n\n    v =  v * (max_val - min_val) + min_val\n    return v\n\ndef tanh_quantize(input, bits):\n    assert bits >= 1, bits\n    if bits == 1:\n        return torch.sign(input)\n    input = torch.tanh(input) # [-1, 1]\n    input_rescale = (input + 1.0) / 2 #[0, 1]\n    n = math.pow(2.0, bits) - 1\n    v = torch.floor(input_rescale * n + 0.5) / n\n    v = 2 * v - 1 # [-1, 1]\n\n    v = 0.5 * torch.log((1 + v) / (1 - v)) # arctanh\n    return v\n\n\nclass LinearQuant(nn.Module):\n    def __init__(self, name, bits, sf=None, overflow_rate=0.0, counter=10):\n        super(LinearQuant, self).__init__()\n        self.name = name\n        self._counter = counter\n\n        self.bits = bits\n        self.sf = sf\n        self.overflow_rate = overflow_rate\n\n    @property\n    def counter(self):\n        return self._counter\n\n    def forward(self, input):\n        if self._counter > 0:\n            self._counter -= 1\n            sf_new = self.bits - 1 - compute_integral_part(input, self.overflow_rate)\n            self.sf = min(self.sf, sf_new) if self.sf is not None else sf_new\n            return input\n        else:\n            output = linear_quantize(input, self.sf, self.bits)\n            return output\n\n    def __repr__(self):\n        return '{}(sf={}, bits={}, overflow_rate={:.3f}, counter={})'.format(\n            self.__class__.__name__, self.sf, self.bits, self.overflow_rate, self.counter)\n\nclass LogQuant(nn.Module):\n    def __init__(self, name, bits, sf=None, overflow_rate=0.0, counter=10):\n        super(LogQuant, self).__init__()\n        self.name = name\n        self._counter = counter\n\n        self.bits = bits\n        self.sf = sf\n        self.overflow_rate = overflow_rate\n\n    @property\n    def counter(self):\n        return self._counter\n\n    def forward(self, input):\n        if self._counter > 0:\n            self._counter -= 1\n            log_abs_input = torch.log(torch.abs(input))\n            sf_new = self.bits - 1 - compute_integral_part(log_abs_input, self.overflow_rate)\n            self.sf = min(self.sf, sf_new) if self.sf is not None else sf_new\n            return input\n        else:\n            output = log_linear_quantize(input, self.sf, self.bits)\n            return output\n\n    def __repr__(self):\n        return '{}(sf={}, bits={}, overflow_rate={:.3f}, counter={})'.format(\n            self.__class__.__name__, self.sf, self.bits, self.overflow_rate, self.counter)\n\nclass NormalQuant(nn.Module):\n    def __init__(self, name, bits, quant_func):\n        super(NormalQuant, self).__init__()\n        self.name = name\n        self.bits = bits\n        self.quant_func = quant_func\n\n    @property\n    def counter(self):\n        return self._counter\n\n    def forward(self, input):\n        output = self.quant_func(input, self.bits)\n        return output\n\n    def __repr__(self):\n        return '{}(bits={})'.format(self.__class__.__name__, self.bits)\n\ndef duplicate_model_with_quant(model, bits, overflow_rate=0.0, counter=10, type='linear'):\n    \"\"\"assume that original model has at least a nn.Sequential\"\"\"\n    assert type in ['linear', 'minmax', 'log', 'tanh']\n    if isinstance(model, nn.Sequential):\n        l = OrderedDict()\n        for k, v in model._modules.items():\n            if isinstance(v, (nn.Conv2d, nn.Linear, nn.BatchNorm1d, nn.BatchNorm2d, nn.AvgPool2d)):\n                l[k] = v\n                if type == 'linear':\n                    quant_layer = LinearQuant('{}_quant'.format(k), bits=bits, overflow_rate=overflow_rate, counter=counter)\n                elif type == 'log':\n                    # quant_layer = LogQuant('{}_quant'.format(k), bits=bits, overflow_rate=overflow_rate, counter=counter)\n                    quant_layer = NormalQuant('{}_quant'.format(k), bits=bits, quant_func=log_minmax_quantize)\n                elif type == 'minmax':\n                    quant_layer = NormalQuant('{}_quant'.format(k), bits=bits, quant_func=min_max_quantize)\n                else:\n                    quant_layer = NormalQuant('{}_quant'.format(k), bits=bits, quant_func=tanh_quantize)\n                l['{}_{}_quant'.format(k, type)] = quant_layer\n            else:\n                l[k] = duplicate_model_with_quant(v, bits, overflow_rate, counter, type)\n        m = nn.Sequential(l)\n        return m\n    else:\n        for k, v in model._modules.items():\n            model._modules[k] = duplicate_model_with_quant(v, bits, overflow_rate, counter, type)\n        return model\n\n"
  },
  {
    "path": "utils/classifiers/pytorch_playground/utee/selector.py",
    "content": "from utee import misc\nimport os\nfrom imagenet import dataset\nprint = misc.logger.info\nfrom IPython import embed\n\nknown_models = [\n    'mnist', 'svhn', # 28x28\n    'cifar10', 'cifar100', # 32x32\n    'stl10', # 96x96\n    'alexnet', # 224x224\n    'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', # 224x224\n    'resnet18', 'resnet34', 'resnet50', 'resnet101','resnet152', # 224x224\n    'squeezenet_v0', 'squeezenet_v1', #224x224\n    'inception_v3', # 299x299\n]\n\ndef mnist(cuda=True, model_root=None):\n    print(\"Building and initializing mnist parameters\")\n    from mnist import model, dataset\n    m = model.mnist(pretrained=os.path.join(model_root, 'mnist.pth'))\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, False\n\ndef svhn(cuda=True, model_root=None):\n    print(\"Building and initializing svhn parameters\")\n    from svhn import model, dataset\n    m = model.svhn(32, pretrained=os.path.join(model_root, 'svhn.pth'))\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, False\n\ndef cifar10(cuda=True, model_root=None):\n    print(\"Building and initializing cifar10 parameters\")\n    from cifar import model, dataset\n    m = model.cifar10(128, pretrained=os.path.join(model_root, 'cifar10.pth'))\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get10, False\n\ndef cifar100(cuda=True, model_root=None):\n    print(\"Building and initializing cifar100 parameters\")\n    from cifar import model, dataset\n    m = model.cifar100(128, pretrained=os.path.join(model_root, 'cifar100.pth'))\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get100, False\n\ndef stl10(cuda=True, model_root=None):\n    print(\"Building and initializing stl10 parameters\")\n    from stl10 import model, dataset\n    m = model.stl10(32, pretrained=os.path.join(model_root, 'stl10.pth'))\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, False\n\ndef alexnet(cuda=True, model_root=None):\n    print(\"Building and initializing alexnet parameters\")\n    from imagenet import alexnet as alx\n    m = alx.alexnet(True, model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef vgg16(cuda=True, model_root=None):\n    print(\"Building and initializing vgg16 parameters\")\n    from imagenet import vgg\n    m = vgg.vgg16(True, model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef vgg16_bn(cuda=True, model_root=None):\n    print(\"Building vgg16_bn parameters\")\n    from imagenet import vgg\n    m = vgg.vgg19_bn(model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef vgg19(cuda=True, model_root=None):\n    print(\"Building and initializing vgg19 parameters\")\n    from imagenet import vgg\n    m = vgg.vgg19(True, model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef vgg19_bn(cuda=True, model_root=None):\n    print(\"Building vgg19_bn parameters\")\n    from imagenet import vgg\n    m = vgg.vgg19_bn(model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef inception_v3(cuda=True, model_root=None):\n    print(\"Building and initializing inception_v3 parameters\")\n    from imagenet import inception\n    m = inception.inception_v3(True, model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef resnet18(cuda=True, model_root=None):\n    print(\"Building and initializing resnet-18 parameters\")\n    from imagenet import resnet\n    m = resnet.resnet18(True, model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef resnet34(cuda=True, model_root=None):\n    print(\"Building and initializing resnet-34 parameters\")\n    from imagenet import resnet\n    m = resnet.resnet34(True, model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef resnet50(cuda=True, model_root=None):\n    print(\"Building and initializing resnet-50 parameters\")\n    from imagenet import resnet\n    m = resnet.resnet50(True, model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef resnet101(cuda=True, model_root=None):\n    print(\"Building and initializing resnet-101 parameters\")\n    from imagenet import resnet\n    m = resnet.resnet101(True, model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef resnet152(cuda=True, model_root=None):\n    print(\"Building and initializing resnet-152 parameters\")\n    from imagenet import resnet\n    m = resnet.resnet152(True, model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef squeezenet_v0(cuda=True, model_root=None):\n    print(\"Building and initializing squeezenet_v0 parameters\")\n    from imagenet import squeezenet\n    m = squeezenet.squeezenet1_0(True, model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef squeezenet_v1(cuda=True, model_root=None):\n    print(\"Building and initializing squeezenet_v1 parameters\")\n    from imagenet import squeezenet\n    m = squeezenet.squeezenet1_1(True, model_root)\n    if cuda:\n        m = m.cuda()\n    return m, dataset.get, True\n\ndef select(model_name, **kwargs):\n    assert model_name in known_models, model_name\n    kwargs.setdefault('model_root', os.path.expanduser('~/.torch/models'))\n    return eval('{}'.format(model_name))(**kwargs)\n\nif __name__ == '__main__':\n    m1 = alexnet()\n    embed()\n\n\n"
  },
  {
    "path": "utils/classifiers/stacked_mnist.py",
    "content": "import torch\nfrom torch import nn\nimport torch.utils.model_zoo as model_zoo\nfrom collections import OrderedDict\nfrom torchvision import datasets\nfrom torch.nn import functional as F\nfrom torchvision import transforms\n\nCLASSIFIER_PATH = 'mnist_model.pt'\n\nclass Classifier():\n    def __init__(self):\n        self.mnist = MNISTClassifier().cuda()\n\n        try:\n            self.mnist.load(CLASSIFIER_PATH)\n        except Exception as e:\n            print(e)\n            self.mnist.train()\n        \n\n    def get_predictions(self, x):\n        assert(x.size(1) == 3)\n        result = self.mnist.get_predictions(x[:, 0, :, :])\n        for channel_number in range(1, 3):\n            result = result + self.mnist.get_predictions(x[:, channel_number, :, :]) * 10**channel_number\n        return result\n\ndef get_mnist_dataloader(batch_size=100):\n    dataset = datasets.MNIST('data/MNIST', train=True, transform=transforms.Compose([\n                                    transforms.Resize(32),\n                                    transforms.CenterCrop(32),\n                                    transforms.ToTensor(),\n                                    transforms.Normalize((0.5, ), (0.5, ))\n                                ]))\n\n    return torch.utils.data.DataLoader(\n                dataset,\n                batch_size=batch_size,\n                num_workers=12,\n                shuffle=True,\n                pin_memory=True,\n                sampler=None,\n                drop_last=True)\n\nclass MNISTClassifier(nn.Module):\n    def __init__(self, input_dims=1024, n_hiddens=[256, 256], n_class=10):\n        super(MNISTClassifier, self).__init__()\n        self.input_dims = input_dims\n        \n        current_dims = input_dims\n        layers = OrderedDict()\n        for i, n_hidden in enumerate(n_hiddens):\n            layers['fc{}'.format(i+1)] = nn.Linear(current_dims, n_hidden)\n            layers['relu{}'.format(i+1)] = nn.ReLU()\n            layers['drop{}'.format(i+1)] = nn.Dropout(0.2)\n            current_dims = n_hidden\n        layers['out'] = nn.Linear(current_dims, n_class)\n\n        self.model= nn.Sequential(layers)\n        print(self.model)\n\n    def forward(self, input):\n        input = input.view(input.size(0), -1)\n        assert input.size(1) == self.input_dims\n        return self.model.forward(input)\n\n    def get_predictions(self, input):\n        logits = self.forward(input)\n        return logits.argmax(dim=1)\n\n    def load(self, path):\n        self.load_state_dict(torch.load(path))\n        print('Loaded pretrained MNIST classifier')\n\n    def train(self):\n        print('Training MNIST classifier')\n        dataloader = get_mnist_dataloader()        \n        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n\n        for epoch in range(10):\n            for it, (x, y) in enumerate(dataloader):\n                optimizer.zero_grad()\n                x, y = x.cuda(), y.cuda()\n                logits = self.forward(x)\n                loss = F.cross_entropy(logits, y)\n                loss.backward()\n                optimizer.step()\n                if it % 100 == 0:\n                    acc = (self.get_predictions(x) == y).float().mean().item()\n                    print(f'[{epoch}, {it}], closs={loss}, acc={acc}')\n        \n\n        torch.save(self.state_dict(), CLASSIFIER_PATH)\n        \n\nif __name__ == '__main__':\n    classifier = Classifier()\n    train_loader = get_mnist_dataloader(10)  \n    xs, ys = [], []\n    for i, (x, y) in enumerate(train_loader):\n        if i == 3:\n            break  \n        xs.append(x.cuda())\n        ys.append(y)\n    print(ys)\n    print(classifier.get_predictions(torch.cat(xs, dim=1)))"
  },
  {
    "path": "utils/get_empirical_distribution.py",
    "content": "import argparse\nimport os\nfrom tqdm import tqdm\n\nimport json\nimport numpy as np\n\nfrom classifiers import classifier_dict\nfrom np_to_pt_img import np_to_pt\n\n\ndef get_empirical_distribution(path_to_samples):\n    ''' gets the fake and real distributions induced by the classifier '''\n    results = {}\n\n    with np.load(path_to_samples, allow_pickle=True) as data:\n        for datatype in ['fake']:  # , 'real'\n            counts = {}\n            results[datatype] = counts\n            imgs = data[datatype]\n            print(f'Found {len(imgs)} samples in {path_to_samples}')\n            for it in tqdm(range(len(imgs) // batch_size)):\n                x_batch = np_to_pt(imgs[it * batch_size:(it + 1) * batch_size]).cuda()\n                y_pred = classifier.get_predictions(x_batch)\n                for yi in y_pred:\n                    yi = yi.item()\n                    if yi not in counts:\n                        counts[yi] = 0\n                    counts[yi] += 1\n            counts = {str(k): v / len(imgs) for k, v in counts.items()}\n    return results\n\n\ndef get_kl(fake, nclasses):\n    '''computes the log10 kl between empirical distributions.'''\n    result = 0\n    total = sum([v for k, v in fake.items()])\n    for c, count in fake.items():\n        pi = count / total\n        # log10 seems to reproduce pacgan results\n        result += pi * np.log10(pi * nclasses)\n    return result\n\n\nnmodes_gt = {'places': 365, 'cifar': 10, 'imagenet': 1000, 'stacked_mnist': 1000}\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser('compute empirical distributions and reverse-kl metrics')\n    parser.add_argument('--samples', help='path to samples')\n    parser.add_argument('--it', type=str, help='iteration number (can be \\'pretrained\\') of samples')\n    parser.add_argument('--results_dir', help='path to results_dir')\n    parser.add_argument('--dataset', type=str, required=True)\n    parser.add_argument('--batch_size', type=int, default=100)\n    args = parser.parse_args()\n\n    batch_size = args.batch_size\n    classifier = classifier_dict[args.dataset]()\n    it = args.it\n    results_dir = args.results_dir\n    result = get_empirical_distribution(args.samples)\n    nmodes = len(result['fake'])\n    nclasses = nmodes_gt[args.dataset]\n\n    kl = get_kl(result['fake'], nclasses)\n\n    with open(os.path.join(args.results_dir, 'kl_results.json')) as f:\n        kl_results = json.load(f)\n    with open(os.path.join(args.results_dir, 'nmodes_results.json')) as f:\n        nmodes_results = json.load(f)\n\n    kl_results[it] = kl\n    nmodes_results[it] = nmodes\n\n    print(f'{results_dir} iteration {it} KL: {kl} Covered {nmodes} out of {nclasses} total modes')\n\n    with open(os.path.join(args.results_dir, 'kl_results.json'), 'w') as f:\n        f.write(json.dumps(kl_results))\n    with open(os.path.join(args.results_dir, 'nmodes_results.json'), 'w') as f:\n        f.write(json.dumps(nmodes_results))\n"
  },
  {
    "path": "utils/get_gt_imgs.py",
    "content": "import os\nimport argparse\nfrom tqdm import tqdm\nfrom PIL import Image\nimport torch\nfrom torchvision import transforms, datasets\nimport numpy as np\nimport random\n\n\ndef get_images(root, N):\n    if False and os.path.exists(root + '.txt'):\n        with open(os.path.exists(root + '.txt')) as f:\n            files = f.readlines()\n            random.shuffle(files)\n            return files\n    else:\n        all_files = []\n        for i, (dp, dn, fn) in enumerate(os.walk(os.path.expanduser(root))):\n            for j, f in enumerate(fn):\n                if j >= 1000:\n                    break     # don't get whole dataset, just get enough images per class\n                if f.endswith(('.png', '.webp', 'jpg', '.JPEG')):\n                    all_files.append(os.path.join(dp, f))\n        random.shuffle(all_files)\n        return all_files\n\n\ndef pt_to_np(imgs):\n    '''normalizes pytorch image in [-1, 1] to [0, 255]'''\n    return (imgs.permute(0, 2, 3, 1).mul_(0.5).add_(0.5).mul_(255)).clamp_(0, 255).numpy()\n\n\ndef get_transform(size):\n    return transforms.Compose([\n        transforms.Resize(size),\n        transforms.CenterCrop(size),\n        transforms.ToTensor(),\n        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n    ])\n\n\ndef get_gt_samples(dataset, nimgs=50000):\n    if dataset != 'cifar':\n        transform = get_transform(sizes[dataset])\n        all_images = get_images(paths[dataset], nimgs)\n        images = []\n        for file_path in tqdm(all_images[:nimgs]):\n            images.append(transform(Image.open(file_path).convert('RGB')))\n        return pt_to_np(torch.stack(images))\n    else:\n        data = datasets.CIFAR10(paths[dataset], transform=get_transform(sizes[dataset]))\n        images = []\n        for x, y in tqdm(data):\n            images.append(x)\n        return pt_to_np(torch.stack(images))\n\n\npaths = {\n    'imagenet': 'data/ImageNet',\n    'places': 'data/Places365',\n    'cifar': 'data/CIFAR'\n}\n\nsizes = {'imagenet': 128, 'places': 128, 'cifar': 32}\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser('Save a batch of ground truth train set images for evaluation')\n    parser.add_argument('--cifar', action='store_true')\n    parser.add_argument('--imagenet', action='store_true')\n    parser.add_argument('--places', action='store_true')\n    args = parser.parse_args()\n\n    os.makedirs('output', exist_ok=True)\n\n    if args.cifar:\n        cifar_samples = get_gt_samples('cifar', nimgs=50000)\n        np.savez('output/cifar_gt_imgs.npz', fake=cifar_samples, real=cifar_samples)\n    if args.imagenet:\n        imagenet_samples = get_gt_samples('imagenet', nimgs=50000)\n        np.savez('output/imagenet_gt_imgs.npz', fake=imagenet_samples, real=imagenet_samples)\n    if args.places:\n        places_samples = get_gt_samples('places', nimgs=50000)\n        np.savez('output/places_gt_imgs.npz', fake=places_samples, real=places_samples)\n"
  },
  {
    "path": "utils/np_to_pt_img.py",
    "content": "import torch\n\n\ndef np_to_pt(x):\n    ''' permutes the appropriate channels to turn numpy formatted images to pt formatted images. does NOT renormalize '''\n    x = torch.from_numpy(x)\n    if len(x.shape) == 4:\n        return x.permute(0, 3, 1, 2)\n    elif len(x.shape) == 3:\n        return x.permute(2, 0, 1)\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "visualize_clusters.py",
    "content": "import argparse\nimport os\nimport shutil\nimport torch\nimport torchvision\n\nfrom torch import nn\nfrom gan_training import utils\nfrom gan_training.inputs import get_dataset\nfrom gan_training.checkpoints import CheckpointIO\nfrom gan_training.config import load_config\nfrom seeded_sampler import SeededSampler\n\ntorch.backends.cudnn.benchmark = True\n\n# Arguments\nparser = argparse.ArgumentParser(description='Visualize the samples/clusters of a class-conditional GAN')\nparser.add_argument('config', type=str, help='Path to config file.')\nparser.add_argument('--model_it', type=int, help='If you want to load from a specific model iteration')\nparser.add_argument('--show_clusters', action='store_true', help='show the real images. Requires a path to the real image train directory')\nargs = parser.parse_args()\n\nconfig = load_config(args.config, 'configs/default.yaml')\nout_dir = config['training']['out_dir']\n\n\ndef main():\n    checkpoint_dir = os.path.join(out_dir, 'chkpts')\n\n    most_recent = utils.get_most_recent(checkpoint_dir, 'model') if args.model_it is None else args.model_it\n\n    cluster_path = os.path.join(out_dir, 'clusters')\n    print('Saving clusters/samples to', cluster_path)\n\n    os.makedirs(cluster_path, exist_ok=True)\n\n    shutil.copyfile('seeing/lightbox.html', os.path.join(cluster_path, '+lightbox.html'))\n\n    checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)\n\n    most_recent = utils.get_most_recent(checkpoint_dir, 'model') if args.model_it is None else args.model_it\n    clusterer = checkpoint_io.load_clusterer(most_recent, pretrained=config['pretrained'], load_samples=False)\n\n    if isinstance(clusterer.discriminator, nn.DataParallel):\n        clusterer.discriminator = clusterer.discriminator.module\n\n    model_path = os.path.join(checkpoint_dir, 'model_%08d.pt' % most_recent)\n    sampler = SeededSampler(args.config,\n                            model_path=model_path,\n                            clusterer_path=os.path.join(checkpoint_dir, f'clusterer{most_recent}.pkl'),\n                            pretrained=config['pretrained'])\n\n    if args.show_clusters:\n        clusters = [[] for _ in range(config['generator']['nlabels'])]\n        train_dataset, _ = get_dataset(\n            name='webp'\n            if 'cifar' not in config['data']['train_dir'].lower() else 'cifar10',\n            data_dir=config['data']['train_dir'],\n            size=config['data']['img_size'])\n\n        train_loader = torch.utils.data.DataLoader(\n            train_dataset,\n            batch_size=config['training']['batch_size'],\n            num_workers=config['training']['nworkers'],\n            shuffle=True,\n            pin_memory=True,\n            sampler=None,\n            drop_last=True)\n\n        print('Generating clusters')\n        for batch_num, (x_real, y_gt) in enumerate(train_loader):\n            x_real = x_real.cuda()\n            y_pred = clusterer.get_labels(x_real, y_gt)\n\n            for i, yi in enumerate(y_pred):\n                clusters[yi].append(x_real[i].cpu())\n\n            # don't generate too many, we're only visualizing 20 per cluster\n            if batch_num * config['training']['batch_size'] >= 10000:\n                break\n    else:\n        clusters = [None] * config['generator']['nlabels']\n\n    nimgs = 20\n    nrows = 4\n\n    for i in range(len(clusters)):\n        if clusters[i] is None:\n            pass\n        elif len(clusters[i]) >= nimgs:\n            cluster = torch.stack(clusters[i])[:nimgs]\n\n            torchvision.utils.save_image(cluster * 0.5 + 0.5,\n                                         os.path.join(cluster_path, f'{i}_real.png'),\n                                         nrow=nrows)\n        generated = []\n        for seed in range(nimgs):\n            img = sampler.conditional_sample(i, seed=seed)\n            generated.append(img.detach().cpu())\n        generated = torch.cat(generated)\n\n        torchvision.utils.save_image(generated * 0.5 + 0.5,\n                                     os.path.join(cluster_path, f'{i}_gen.png'),\n                                     nrow=nrows)\n\n    print('Clusters/samples can be visualized under', cluster_path)\n\n\nif __name__ == '__main__':\n    main()\n"
  }
]